diff --git a/src/anomalib/pre_processing/pre_processing.py b/src/anomalib/pre_processing/pre_processing.py index 5b8c8f46a1..4f9ff5bbc1 100644 --- a/src/anomalib/pre_processing/pre_processing.py +++ b/src/anomalib/pre_processing/pre_processing.py @@ -106,18 +106,25 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non stage: The stage (e.g., 'fit', 'train', 'val', 'test', 'predict'). """ super().setup(trainer, pl_module, stage) - stage = TrainerFn(stage).value # This is to convert the stage to a string - stages = ["train", "val"] if stage == "fit" else [stage] - for current_stage in stages: - transform = getattr(self, f"{current_stage}_transform") - if transform: - if hasattr(trainer, "datamodule"): - set_datamodule_transform(trainer.datamodule, transform, current_stage) - elif hasattr(trainer, f"{current_stage}_dataloaders"): - set_dataloader_transform(getattr(trainer, f"{current_stage}_dataloaders"), transform) - else: - msg = f"Trainer does not have a datamodule or {current_stage}_dataloaders attribute" - raise ValueError(msg) + # Get stage transform + stage = TrainerFn(stage).value # Make sure ``stage`` is a str + stage_transforms = { + "fit": self.train_transform, + "validate": self.val_transform, + "test": self.test_transform, + "predict": self.predict_transform, + } + transform = stage_transforms.get(stage) + + # Assign the transform to the datamodule or dataloaders + if transform: + if hasattr(trainer, "datamodule"): + set_datamodule_transform(trainer.datamodule, transform, stage) + elif hasattr(trainer, f"{stage}_dataloaders"): + set_dataloader_transform(getattr(trainer, f"{stage}_dataloaders"), transform) + else: + msg = f"Trainer does not have a datamodule or {stage}_dataloaders attribute" + raise ValueError(msg) def forward(self, batch: torch.Tensor) -> torch.Tensor: """Apply transforms to the batch of tensors for inference.