Skip to content

Commit

Permalink
Get stage transforms in setup of pre-processor
Browse files Browse the repository at this point in the history
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
samet-akcay committed Oct 22, 2024
1 parent 1b0483a commit a503be1
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions src/anomalib/pre_processing/pre_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,25 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non
stage: The stage (e.g., 'fit', 'train', 'val', 'test', 'predict').
"""
super().setup(trainer, pl_module, stage)
stage = TrainerFn(stage).value # This is to convert the stage to a string
stages = ["train", "val"] if stage == "fit" else [stage]
for current_stage in stages:
transform = getattr(self, f"{current_stage}_transform")
if transform:
if hasattr(trainer, "datamodule"):
set_datamodule_transform(trainer.datamodule, transform, current_stage)
elif hasattr(trainer, f"{current_stage}_dataloaders"):
set_dataloader_transform(getattr(trainer, f"{current_stage}_dataloaders"), transform)
else:
msg = f"Trainer does not have a datamodule or {current_stage}_dataloaders attribute"
raise ValueError(msg)
# Get stage transform
stage = TrainerFn(stage).value # Make sure ``stage`` is a str
stage_transforms = {
"fit": self.train_transform,
"validate": self.val_transform,
"test": self.test_transform,
"predict": self.predict_transform,
}
transform = stage_transforms.get(stage)

# Assign the transform to the datamodule or dataloaders
if transform:
if hasattr(trainer, "datamodule"):
set_datamodule_transform(trainer.datamodule, transform, stage)
elif hasattr(trainer, f"{stage}_dataloaders"):
set_dataloader_transform(getattr(trainer, f"{stage}_dataloaders"), transform)
else:
msg = f"Trainer does not have a datamodule or {stage}_dataloaders attribute"
raise ValueError(msg)

def forward(self, batch: torch.Tensor) -> torch.Tensor:
"""Apply transforms to the batch of tensors for inference.
Expand Down

0 comments on commit a503be1

Please sign in to comment.