diff --git a/src/pancreas_ai/bin/train.py b/src/pancreas_ai/bin/train.py index c067f13..d7d32c6 100644 --- a/src/pancreas_ai/bin/train.py +++ b/src/pancreas_ai/bin/train.py @@ -7,8 +7,8 @@ from pancreas_ai.dataset.craft_datasets import craft_datasets from pancreas_ai.tools.categorical_metrics import CategoricalMetric, CategoricalF1, CustomCounter, CustomReduceMetric -from pancreas_ai.tools.craft_network import craft_network -from pancreas_ai.tools.craft_network.loss import loss_func_generator +from pancreas_ai.tools.craft_network import factory +from pancreas_ai.tools.craft_network.loss import loss_func_factory from pancreas_ai import config @@ -30,8 +30,7 @@ def main(): ds_train = ds_train.prefetch(1).repeat(config.TRAIN_PASSES_PER_VALIDATION) - model = craft_network(config) - # predict_on_random_data(model) + model = factory.model_factory(config) checkpoint_cb = tf.keras.callbacks.ModelCheckpoint( config.MODEL_CHECKPOINT, @@ -55,7 +54,7 @@ def main(): learning_rate = config.INITIAL_LEARNING_RATE, # gradient_accumulation_steps = config.GRADIENT_ACCUMULATION_STEPS, ), - loss = loss_func_generator(config.LOSS_FUNCTION), + loss = loss_func_factory(config), metrics = [ 'accuracy', CategoricalMetric(tf.keras.metrics.TruePositives(), name = 'custom_tp'), diff --git a/src/pancreas_ai/dataset/craft_datasets.py b/src/pancreas_ai/dataset/craft_datasets.py index b3e4b34..81893a0 100644 --- a/src/pancreas_ai/dataset/craft_datasets.py +++ b/src/pancreas_ai/dataset/craft_datasets.py @@ -18,7 +18,7 @@ DEBUG_DATALOADER = False -DEBUG_DATA_LOADING_PERFORMANCE = True +DEBUG_DATA_LOADING_PERFORMANCE = False def fname_from_full_path(fname_src:str): if DEBUG_DATALOADER: diff --git a/src/pancreas_ai/tools/craft_network/loss.py b/src/pancreas_ai/tools/craft_network/loss.py index 36bad8a..c087c0f 100644 --- a/src/pancreas_ai/tools/craft_network/loss.py +++ b/src/pancreas_ai/tools/craft_network/loss.py @@ -51,11 +51,24 @@ def __weighted_loss(y_true, y_pred): return loss -def loss_func_generator(loss_name): - if loss_name == "dice": +def __segmentation_loss(config: dict): + if config.LOSS_FUNCTION == "dice": return __dice_loss - elif loss_name == "scce": + elif config.LOSS_FUNCTION == "scce": return __weighted_loss else: raise ValueError("Unknown loss function") + +def __classification_loss(config: dict): + return tf.keras.losses.BinaryCrossentropy(from_logits = False) + + +def loss_func_factory(config: dict): + if config.TASK_TYPE == "segmentation": + return __segmentation_loss(config) + elif config.TASK_TYPE == "classification": + return __classification_loss(config) + else: + raise ValueError("Unknown loss function") +