diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index e7612e6e57..c0de7dbea1 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -8,7 +8,6 @@ from pathlib import Path from typing import Any -import torch from lightning.pytorch.callbacks import Callback from lightning.pytorch.loggers import Logger from lightning.pytorch.trainer import Trainer @@ -302,60 +301,6 @@ def _setup_dataset_task( ) data.task = self.task - @staticmethod - def _setup_transform( - model: AnomalyModule, - datamodule: AnomalibDataModule | None = None, - dataloaders: EVAL_DATALOADERS | TRAIN_DATALOADERS | None = None, - ckpt_path: Path | str | None = None, - ) -> None: - """Implements the logic for setting the transform at the start of each run. - - Any transform passed explicitly to the datamodule takes precedence. Otherwise, if a checkpoint path is provided, - we can load the transform from the checkpoint. If no transform is provided, we use the default transform from - the model. - - Args: - model (AnomalyModule): The model to assign the transform to. - datamodule (AnomalibDataModule | None): The datamodule to assign the transform from. - defaults to ``None``. - dataloaders (EVAL_DATALOADERS | TRAIN_DATALOADERS | None): Dataloaders to assign the transform to. - defaults to ``None``. - ckpt_path (str): The path to the checkpoint. - defaults to ``None``. - - Returns: - Transform: The transform loaded from the checkpoint. - """ - if isinstance(dataloaders, DataLoader): - dataloaders = [dataloaders] - - # get transform - if datamodule and datamodule.transform: - # a transform passed explicitly to the datamodule takes precedence - transform = datamodule.transform - elif dataloaders and any(getattr(dl.dataset, "transform", None) for dl in dataloaders): - # if dataloaders are provided, we use the transform from the first dataloader that has a transform - transform = next(dl.dataset.transform for dl in dataloaders if getattr(dl.dataset, "transform", None)) - elif ckpt_path is not None: - # if a checkpoint path is provided, we can load the transform from the checkpoint - checkpoint = torch.load(ckpt_path, map_location=model.device) - transform = checkpoint["transform"] - elif model.transform is None: - # if no transform is provided, we use the default transform from the model - image_size = datamodule.image_size if datamodule else None - transform = model.configure_transforms(image_size) - else: - transform = model.transform - - # update transform in model - model.set_transform(transform) - # The dataloaders don't have access to the trainer and/or model, so we need to set the transforms manually - if dataloaders: - for dataloader in dataloaders: - if not getattr(dataloader.dataset, "transform", None): - dataloader.dataset.transform = transform - def _setup_anomalib_callbacks(self, model: AnomalyModule) -> None: """Set up callbacks for the trainer.""" _callbacks: list[Callback] = [] @@ -471,7 +416,6 @@ def fit( ) self._setup_trainer(model) self._setup_dataset_task(train_dataloaders, val_dataloaders, datamodule) - self._setup_transform(model, datamodule=datamodule, ckpt_path=ckpt_path) if model.learning_type in {LearningType.ZERO_SHOT, LearningType.FEW_SHOT}: # if the model is zero-shot or few-shot, we only need to run validate for normalization and thresholding self.trainer.validate(model, val_dataloaders, datamodule=datamodule, ckpt_path=ckpt_path) @@ -525,7 +469,6 @@ def validate( if model: self._setup_trainer(model) self._setup_dataset_task(dataloaders) - self._setup_transform(model or self.model, datamodule=datamodule, ckpt_path=ckpt_path) return self.trainer.validate(model, dataloaders, ckpt_path, verbose, datamodule) def test( @@ -619,7 +562,6 @@ def test( raise RuntimeError(msg) self._setup_dataset_task(dataloaders) - self._setup_transform(model or self.model, datamodule=datamodule, ckpt_path=ckpt_path) if self._should_run_validation(model or self.model, ckpt_path): logger.info("Running validation before testing to collect normalization metrics and/or thresholds.") self.trainer.validate(model, dataloaders, None, verbose=False, datamodule=datamodule) @@ -724,7 +666,6 @@ def predict( dataloaders = dataloaders or None self._setup_dataset_task(dataloaders, datamodule) - self._setup_transform(model or self.model, datamodule=datamodule, dataloaders=dataloaders, ckpt_path=ckpt_path) if self._should_run_validation(model or self.model, ckpt_path): logger.info("Running validation before predicting to collect normalization metrics and/or thresholds.") @@ -794,7 +735,6 @@ def train( test_dataloaders, datamodule, ) - self._setup_transform(model, datamodule=datamodule, ckpt_path=ckpt_path) if model.learning_type in {LearningType.ZERO_SHOT, LearningType.FEW_SHOT}: # if the model is zero-shot or few-shot, we only need to run validate for normalization and thresholding self.trainer.validate(model, val_dataloaders, None, verbose=False, datamodule=datamodule) @@ -841,8 +781,7 @@ def export( Path: Path to the exported model. Raises: - ValueError: If Dataset, Datamodule, and transform are not provided. - TypeError: If path to the transform file is not a string or Path. + ValueError: If Dataset, Datamodule are not provided. CLI Usage: 1. To export as a torch ``.pt`` file you can run the following command.