Skip to content

Commit

Permalink
Remove setup_transforms from Engine
Browse files Browse the repository at this point in the history
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
samet-akcay committed Oct 9, 2024
1 parent a133048 commit c748a0d
Showing 1 changed file with 1 addition and 62 deletions.
63 changes: 1 addition & 62 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from pathlib import Path
from typing import Any

import torch
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import Logger
from lightning.pytorch.trainer import Trainer
Expand Down Expand Up @@ -302,60 +301,6 @@ def _setup_dataset_task(
)
data.task = self.task

@staticmethod
def _setup_transform(
model: AnomalyModule,
datamodule: AnomalibDataModule | None = None,
dataloaders: EVAL_DATALOADERS | TRAIN_DATALOADERS | None = None,
ckpt_path: Path | str | None = None,
) -> None:
"""Implements the logic for setting the transform at the start of each run.
Any transform passed explicitly to the datamodule takes precedence. Otherwise, if a checkpoint path is provided,
we can load the transform from the checkpoint. If no transform is provided, we use the default transform from
the model.
Args:
model (AnomalyModule): The model to assign the transform to.
datamodule (AnomalibDataModule | None): The datamodule to assign the transform from.
defaults to ``None``.
dataloaders (EVAL_DATALOADERS | TRAIN_DATALOADERS | None): Dataloaders to assign the transform to.
defaults to ``None``.
ckpt_path (str): The path to the checkpoint.
defaults to ``None``.
Returns:
Transform: The transform loaded from the checkpoint.
"""
if isinstance(dataloaders, DataLoader):
dataloaders = [dataloaders]

# get transform
if datamodule and datamodule.transform:
# a transform passed explicitly to the datamodule takes precedence
transform = datamodule.transform
elif dataloaders and any(getattr(dl.dataset, "transform", None) for dl in dataloaders):
# if dataloaders are provided, we use the transform from the first dataloader that has a transform
transform = next(dl.dataset.transform for dl in dataloaders if getattr(dl.dataset, "transform", None))
elif ckpt_path is not None:
# if a checkpoint path is provided, we can load the transform from the checkpoint
checkpoint = torch.load(ckpt_path, map_location=model.device)
transform = checkpoint["transform"]
elif model.transform is None:
# if no transform is provided, we use the default transform from the model
image_size = datamodule.image_size if datamodule else None
transform = model.configure_transforms(image_size)
else:
transform = model.transform

# update transform in model
model.set_transform(transform)
# The dataloaders don't have access to the trainer and/or model, so we need to set the transforms manually
if dataloaders:
for dataloader in dataloaders:
if not getattr(dataloader.dataset, "transform", None):
dataloader.dataset.transform = transform

def _setup_anomalib_callbacks(self, model: AnomalyModule) -> None:
"""Set up callbacks for the trainer."""
_callbacks: list[Callback] = []
Expand Down Expand Up @@ -471,7 +416,6 @@ def fit(
)
self._setup_trainer(model)
self._setup_dataset_task(train_dataloaders, val_dataloaders, datamodule)
self._setup_transform(model, datamodule=datamodule, ckpt_path=ckpt_path)
if model.learning_type in {LearningType.ZERO_SHOT, LearningType.FEW_SHOT}:
# if the model is zero-shot or few-shot, we only need to run validate for normalization and thresholding
self.trainer.validate(model, val_dataloaders, datamodule=datamodule, ckpt_path=ckpt_path)
Expand Down Expand Up @@ -525,7 +469,6 @@ def validate(
if model:
self._setup_trainer(model)
self._setup_dataset_task(dataloaders)
self._setup_transform(model or self.model, datamodule=datamodule, ckpt_path=ckpt_path)
return self.trainer.validate(model, dataloaders, ckpt_path, verbose, datamodule)

def test(
Expand Down Expand Up @@ -619,7 +562,6 @@ def test(
raise RuntimeError(msg)

self._setup_dataset_task(dataloaders)
self._setup_transform(model or self.model, datamodule=datamodule, ckpt_path=ckpt_path)
if self._should_run_validation(model or self.model, ckpt_path):
logger.info("Running validation before testing to collect normalization metrics and/or thresholds.")
self.trainer.validate(model, dataloaders, None, verbose=False, datamodule=datamodule)
Expand Down Expand Up @@ -724,7 +666,6 @@ def predict(
dataloaders = dataloaders or None

self._setup_dataset_task(dataloaders, datamodule)
self._setup_transform(model or self.model, datamodule=datamodule, dataloaders=dataloaders, ckpt_path=ckpt_path)

if self._should_run_validation(model or self.model, ckpt_path):
logger.info("Running validation before predicting to collect normalization metrics and/or thresholds.")
Expand Down Expand Up @@ -794,7 +735,6 @@ def train(
test_dataloaders,
datamodule,
)
self._setup_transform(model, datamodule=datamodule, ckpt_path=ckpt_path)
if model.learning_type in {LearningType.ZERO_SHOT, LearningType.FEW_SHOT}:
# if the model is zero-shot or few-shot, we only need to run validate for normalization and thresholding
self.trainer.validate(model, val_dataloaders, None, verbose=False, datamodule=datamodule)
Expand Down Expand Up @@ -841,8 +781,7 @@ def export(
Path: Path to the exported model.
Raises:
ValueError: If Dataset, Datamodule, and transform are not provided.
TypeError: If path to the transform file is not a string or Path.
ValueError: If Dataset, Datamodule are not provided.
CLI Usage:
1. To export as a torch ``.pt`` file you can run the following command.
Expand Down

0 comments on commit c748a0d

Please sign in to comment.