Skip to content

Commit

Permalink
adding classification loss
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanKuchin committed Nov 11, 2024
1 parent 8096ab7 commit 59bcaf8
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
9 changes: 4 additions & 5 deletions src/pancreas_ai/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from pancreas_ai.dataset.craft_datasets import craft_datasets
from pancreas_ai.tools.categorical_metrics import CategoricalMetric, CategoricalF1, CustomCounter, CustomReduceMetric
from pancreas_ai.tools.craft_network import craft_network
from pancreas_ai.tools.craft_network.loss import loss_func_generator
from pancreas_ai.tools.craft_network import factory
from pancreas_ai.tools.craft_network.loss import loss_func_factory
from pancreas_ai import config


Expand All @@ -30,8 +30,7 @@ def main():

ds_train = ds_train.prefetch(1).repeat(config.TRAIN_PASSES_PER_VALIDATION)

model = craft_network(config)
# predict_on_random_data(model)
model = factory.model_factory(config)

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
config.MODEL_CHECKPOINT,
Expand All @@ -55,7 +54,7 @@ def main():
learning_rate = config.INITIAL_LEARNING_RATE,
# gradient_accumulation_steps = config.GRADIENT_ACCUMULATION_STEPS,
),
loss = loss_func_generator(config.LOSS_FUNCTION),
loss = loss_func_factory(config),
metrics = [
'accuracy',
CategoricalMetric(tf.keras.metrics.TruePositives(), name = 'custom_tp'),
Expand Down
2 changes: 1 addition & 1 deletion src/pancreas_ai/dataset/craft_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


DEBUG_DATALOADER = False
DEBUG_DATA_LOADING_PERFORMANCE = True
DEBUG_DATA_LOADING_PERFORMANCE = False

def fname_from_full_path(fname_src:str):
if DEBUG_DATALOADER:
Expand Down
19 changes: 16 additions & 3 deletions src/pancreas_ai/tools/craft_network/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,24 @@ def __weighted_loss(y_true, y_pred):

return loss

def loss_func_generator(loss_name):
if loss_name == "dice":
def __segmentation_loss(config: dict):
if config.LOSS_FUNCTION == "dice":
return __dice_loss
elif loss_name == "scce":
elif config.LOSS_FUNCTION == "scce":
return __weighted_loss
else:
raise ValueError("Unknown loss function")


def __classification_loss(config: dict):
return tf.keras.losses.BinaryCrossentropy(from_logits = False)


def loss_func_factory(config: dict):
if config.TASK_TYPE == "segmentation":
return __segmentation_loss(config)
elif config.TASK_TYPE == "classification":
return __classification_loss(config)
else:
raise ValueError("Unknown loss function")

0 comments on commit 59bcaf8

Please sign in to comment.