Skip to content

Commit

Permalink
Add preprocessor to AnomalyModule and models
Browse files Browse the repository at this point in the history
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
samet-akcay committed Oct 9, 2024
1 parent c748a0d commit 03a2a2e
Show file tree
Hide file tree
Showing 21 changed files with 201 additions and 113 deletions.
12 changes: 1 addition & 11 deletions src/anomalib/data/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pathlib import Path

from torch.utils.data.dataset import Dataset
from torchvision.transforms.v2 import Transform

from anomalib.data import ImageBatch, ImageItem
from anomalib.data.utils import get_image_filenames, read_image
Expand All @@ -18,22 +17,18 @@ class PredictDataset(Dataset):
Args:
path (str | Path): Path to an image or image-folder.
transform (A.Compose | None, optional): Transform object describing the transforms that are
applied to the inputs.
image_size (int | tuple[int, int] | None, optional): Target image size
to resize the original image. Defaults to None.
"""

def __init__(
self,
path: str | Path,
transform: Transform | None = None,
image_size: int | tuple[int, int] = (256, 256),
) -> None:
super().__init__()

self.image_filenames = get_image_filenames(path)
self.transform = transform
self.image_size = image_size

def __len__(self) -> int:
Expand All @@ -44,13 +39,8 @@ def __getitem__(self, index: int) -> ImageItem:
"""Get the image based on the `index`."""
image_filename = self.image_filenames[index]
image = read_image(image_filename, as_tensor=True)
if self.transform:
image = self.transform(image)

return ImageItem(
image=image,
image_path=str(image_filename),
)
return ImageItem(image=image, image_path=str(image_filename))

@property
def collate_fn(self) -> Callable:
Expand Down
35 changes: 22 additions & 13 deletions src/anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
from abc import ABC, abstractmethod
from collections import OrderedDict
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any

Expand All @@ -22,12 +23,11 @@
from anomalib.data import Batch, InferenceBatch
from anomalib.metrics.threshold import Threshold
from anomalib.post_processing import OneClassPostProcessor, PostProcessor
from anomalib.pre_processing import PreProcessor

from .export_mixin import ExportMixin

if TYPE_CHECKING:
from lightning.pytorch.callbacks import Callback

from anomalib.metrics import AnomalibMetricCollection

logger = logging.getLogger(__name__)
Expand All @@ -39,7 +39,11 @@ class AnomalyModule(ExportMixin, pl.LightningModule, ABC):
Acts as a base class for all the Anomaly Modules in the library.
"""

def __init__(self, post_processor: PostProcessor | None = None) -> None:
def __init__(
self,
pre_processor: PreProcessor | None = None,
post_processor: PostProcessor | None = None,
) -> None:
super().__init__()
logger.info("Initializing %s model.", self.__class__.__name__)

Expand All @@ -51,6 +55,7 @@ def __init__(self, post_processor: PostProcessor | None = None) -> None:
self.image_metrics: AnomalibMetricCollection
self.pixel_metrics: AnomalibMetricCollection

self.pre_processor = pre_processor or self.configure_pre_processor()
self.post_processor = post_processor or self.default_post_processor()

self._transform: Transform | None = None
Expand Down Expand Up @@ -79,6 +84,10 @@ def _setup(self) -> None:
initialization.
"""

def configure_callbacks(self) -> Sequence[Callback] | Callback:
"""Configure default callbacks for AnomalyModule."""
return [self.pre_processor]

def forward(self, batch: torch.Tensor, *args, **kwargs) -> InferenceBatch:
"""Perform the forward-pass by passing input tensor to the module.
Expand Down Expand Up @@ -183,23 +192,23 @@ def set_transform(self, transform: Transform) -> None:
"""Update the transform linked to the model instance."""
self._transform = transform

def configure_transforms(self, image_size: tuple[int, int] | None = None) -> Transform: # noqa: PLR6301
"""Default transforms.
def configure_pre_processor(self, image_size: tuple[int, int] | None = None) -> PreProcessor: # noqa: PLR6301
"""Configure the pre-processor.
The default transform is resize to 256x256 and normalize to ImageNet stats. Individual models can override
this method to provide custom transforms.
The default pre-processor is resize to 256x256 and normalize to ImageNet stats. Individual models can override
this method to provide custom transforms and pre-processing pipelines.
"""
logger.warning(
"No implementation of `configure_transforms` was provided in the Lightning model. Using default "
"No implementation of `configure_pre_processor` was provided in the Lightning model. Using default "
"transforms from the base class. This may not be suitable for your use case. Please override "
"`configure_transforms` in your model.",
"`configure_pre_processor` in your model.",
)
image_size = image_size or (256, 256)
return Compose(
[
return PreProcessor(
transform=Compose([
Resize(image_size, antialias=True),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
],
]),
)

def default_post_processor(self) -> PostProcessor:
Expand All @@ -220,7 +229,7 @@ def input_size(self) -> tuple[int, int] | None:
The effective input size is the size of the input tensor after the transform has been applied. If the transform
is not set, or if the transform does not change the shape of the input tensor, this method will return None.
"""
transform = self.transform or self.configure_transforms()
transform = self.transform or self.configure_pre_processor()
if transform is None:
return None
dummy_input = torch.zeros(1, 3, 1, 1)
Expand Down
7 changes: 6 additions & 1 deletion src/anomalib/models/image/cfa/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from anomalib import LearningType
from anomalib.data import Batch
from anomalib.models.components import AnomalyModule
from anomalib.pre_processing import PreProcessor

from .loss import CfaLoss
from .torch_model import CfaModel
Expand All @@ -42,6 +43,9 @@ class Cfa(AnomalyModule):
Defaults to ``3``.
radius (float): Radius of the hypersphere to search the soft boundary.
Defaults to ``1e-5``.
pre_processor (PreProcessor, optional): Pre-processor for the model.
This is used to pre-process the input data before it is passed to the model.
Defaults to ``None``.
"""

def __init__(
Expand All @@ -52,8 +56,9 @@ def __init__(
num_nearest_neighbors: int = 3,
num_hard_negative_features: int = 3,
radius: float = 1e-5,
pre_processor: PreProcessor | None = None,
) -> None:
super().__init__()
super().__init__(pre_processor=pre_processor)
self.model: CfaModel = CfaModel(
backbone=backbone,
gamma_c=gamma_c,
Expand Down
4 changes: 3 additions & 1 deletion src/anomalib/models/image/cflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from anomalib import LearningType
from anomalib.data import Batch
from anomalib.models.components import AnomalyModule
from anomalib.pre_processing import PreProcessor

from .torch_model import CflowModel
from .utils import get_logp, positional_encoding_2d
Expand Down Expand Up @@ -57,6 +58,7 @@ class Cflow(AnomalyModule):

def __init__(
self,
pre_processor: PreProcessor | None = None,
backbone: str = "wide_resnet50_2",
layers: Sequence[str] = ("layer2", "layer3", "layer4"),
pre_trained: bool = True,
Expand All @@ -68,7 +70,7 @@ def __init__(
permute_soft: bool = False,
lr: float = 0.0001,
) -> None:
super().__init__()
super().__init__(pre_processor=pre_processor)

self.model: CflowModel = CflowModel(
backbone=backbone,
Expand Down
4 changes: 3 additions & 1 deletion src/anomalib/models/image/csflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from anomalib import LearningType
from anomalib.data import Batch
from anomalib.models.components import AnomalyModule
from anomalib.pre_processing import PreProcessor

from .loss import CsFlowLoss
from .torch_model import CsFlowModel
Expand Down Expand Up @@ -44,8 +45,9 @@ def __init__(
n_coupling_blocks: int = 4,
clamp: int = 3,
num_channels: int = 3,
pre_processor: PreProcessor | None = None,
) -> None:
super().__init__()
super().__init__(pre_processor=pre_processor)

self.cross_conv_hidden_channels = cross_conv_hidden_channels
self.n_coupling_blocks = n_coupling_blocks
Expand Down
4 changes: 3 additions & 1 deletion src/anomalib/models/image/dfkde/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from anomalib.data import Batch
from anomalib.models.components import AnomalyModule, MemoryBankMixin
from anomalib.models.components.classification import FeatureScalingMethod
from anomalib.pre_processing import PreProcessor

from .torch_model import DfkdeModel

Expand Down Expand Up @@ -46,8 +47,9 @@ def __init__(
n_pca_components: int = 16,
feature_scaling_method: FeatureScalingMethod = FeatureScalingMethod.SCALE,
max_training_points: int = 40000,
pre_processor: PreProcessor | None = None,
) -> None:
super().__init__()
super().__init__(pre_processor=pre_processor)

self.model = DfkdeModel(
layers=layers,
Expand Down
7 changes: 6 additions & 1 deletion src/anomalib/models/image/dfm/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from anomalib import LearningType
from anomalib.data import Batch
from anomalib.models.components import AnomalyModule, MemoryBankMixin
from anomalib.pre_processing import PreProcessor

from .torch_model import DFMModel

Expand All @@ -37,6 +38,9 @@ class Dfm(MemoryBankMixin, AnomalyModule):
Defaults to ``0.97``.
score_type (str, optional): Scoring type. Options are `fre` and `nll`.
Defaults to ``fre``.
pre_processor (PreProcessor, optional): Pre-processor for the model.
This is used to pre-process the input data before it is passed to the model.
Defaults to ``None``.
"""

def __init__(
Expand All @@ -47,8 +51,9 @@ def __init__(
pooling_kernel_size: int = 4,
pca_level: float = 0.97,
score_type: str = "fre",
pre_processor: PreProcessor | None = None,
) -> None:
super().__init__()
super().__init__(pre_processor=pre_processor)

self.model: DFMModel = DFMModel(
backbone=backbone,
Expand Down
7 changes: 6 additions & 1 deletion src/anomalib/models/image/draem/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from anomalib.data import Batch
from anomalib.data.utils import Augmenter
from anomalib.models.components import AnomalyModule
from anomalib.pre_processing import PreProcessor

from .loss import DraemLoss
from .torch_model import DraemModel
Expand All @@ -35,6 +36,9 @@ class Draem(AnomalyModule):
anomaly_source_path (str | None): Path to folder that contains the anomaly source images. Random noise will
be used if left empty.
Defaults to ``None``.
pre_processor (PreProcessor, optional): Pre-processor for the model.
This is used to pre-process the input data before it is passed to the model.
Defaults to ``None``.
"""

def __init__(
Expand All @@ -43,8 +47,9 @@ def __init__(
sspcab_lambda: float = 0.1,
anomaly_source_path: str | None = None,
beta: float | tuple[float, float] = (0.1, 1.0),
pre_processor: PreProcessor | None = None,
) -> None:
super().__init__()
super().__init__(pre_processor=pre_processor)

self.augmenter = Augmenter(anomaly_source_path, beta=beta)
self.model = DraemModel(sspcab=enable_sspcab)
Expand Down
13 changes: 11 additions & 2 deletions src/anomalib/models/image/dsr/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from anomalib.models.image.dsr.anomaly_generator import DsrAnomalyGenerator
from anomalib.models.image.dsr.loss import DsrSecondStageLoss, DsrThirdStageLoss
from anomalib.models.image.dsr.torch_model import DsrModel
from anomalib.pre_processing import PreProcessor

__all__ = ["Dsr"]

Expand All @@ -39,10 +40,18 @@ class Dsr(AnomalyModule):
Args:
latent_anomaly_strength (float): Strength of the generated anomalies in the latent space. Defaults to 0.2
upsampling_train_ratio (float): Ratio of training steps for the upsampling module. Defaults to 0.7
pre_processor (PreProcessor, optional): Pre-processor for the model.
This is used to pre-process the input data before it is passed to the model.
Defaults to ``None``.
"""

def __init__(self, latent_anomaly_strength: float = 0.2, upsampling_train_ratio: float = 0.7) -> None:
super().__init__()
def __init__(
self,
latent_anomaly_strength: float = 0.2,
upsampling_train_ratio: float = 0.7,
pre_processor: PreProcessor | None = None,
) -> None:
super().__init__(pre_processor=pre_processor)

self.automatic_optimization = False
self.upsampling_train_ratio = upsampling_train_ratio
Expand Down
26 changes: 14 additions & 12 deletions src/anomalib/models/image/efficient_ad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms.v2 import CenterCrop, Compose, Normalize, RandomGrayscale, Resize, ToTensor, Transform
from torchvision.transforms.v2 import CenterCrop, Compose, Normalize, RandomGrayscale, Resize, ToTensor

from anomalib import LearningType
from anomalib.data import Batch
from anomalib.data.utils import DownloadInfo, download_and_extract
from anomalib.models.components import AnomalyModule
from anomalib.pre_processing import PreProcessor

from .torch_model import EfficientAdModel, EfficientAdModelSize, reduce_tensor_elems

Expand Down Expand Up @@ -58,6 +59,9 @@ class EfficientAd(AnomalyModule):
pad_maps (bool): relevant if padding is set to False. In this case, pad_maps = True pads the
output anomaly maps so that their size matches the size in the padding = True case.
Defaults to ``True``.
pre_processor (PreProcessor, optional): Pre-processor for the model.
This is used to pre-process the input data before it is passed to the model.
Defaults to ``None``.
"""

def __init__(
Expand All @@ -69,8 +73,9 @@ def __init__(
weight_decay: float = 0.00001,
padding: bool = False,
pad_maps: bool = True,
pre_processor: PreProcessor | None = None,
) -> None:
super().__init__()
super().__init__(pre_processor=pre_processor)

self.imagenet_dir = Path(imagenet_dir)
if not isinstance(model_size, EfficientAdModelSize):
Expand Down Expand Up @@ -203,6 +208,13 @@ def _get_quantiles_of_maps(self, maps: list[torch.Tensor]) -> tuple[torch.Tensor
qb = torch.quantile(maps_flat, q=0.995).to(self.device)
return qa, qb

@staticmethod
def configure_pre_processor(image_size: tuple[int, int] | None = None) -> PreProcessor:
"""Default transform for EfficientAd. Imagenet normalization applied in forward."""
image_size = image_size or (256, 256)
transform = Compose([Resize(image_size, antialias=True)])
return PreProcessor(transform=transform)

def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configure optimizers."""
optimizer = torch.optim.Adam(
Expand Down Expand Up @@ -318,13 +330,3 @@ def learning_type(self) -> LearningType:
LearningType: Learning type of the model.
"""
return LearningType.ONE_CLASS

@staticmethod
def configure_transforms(image_size: tuple[int, int] | None = None) -> Transform:
"""Default transform for EfficientAd. Imagenet normalization applied in forward."""
image_size = image_size or (256, 256)
return Compose(
[
Resize(image_size, antialias=True),
],
)
Loading

0 comments on commit 03a2a2e

Please sign in to comment.