Skip to content

Commit

Permalink
Remove exportable transform from anomaly module and move to pre-proce…
Browse files Browse the repository at this point in the history
…ssor

Signed-off-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
samet-akcay committed Oct 15, 2024
1 parent 785d64f commit b798243
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 14 deletions.
4 changes: 3 additions & 1 deletion src/anomalib/deploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from anomalib.data.transforms import ExportableCenterCrop


def make_transform_exportable(transform: Transform) -> Transform:
def get_exportable_transform(transform: Transform | None) -> Transform | None:
"""Get exportable transform.
Some transforms are not supported by ONNX/OpenVINO, so we need to replace them with exportable versions.
"""
if transform is None:
return None
transform = disable_antialiasing(transform)
return convert_centercrop(transform)

Expand Down
2 changes: 0 additions & 2 deletions src/anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ def forward(self, batch: torch.Tensor, *args, **kwargs) -> InferenceBatch:
Tensor: Output tensor from the model.
"""
del args, kwargs # These variables are not used.
if self.exportable_transform:
batch = self.exportable_transform(batch)
batch = self.model(batch)
return self.post_processor(batch) if self.post_processor else batch

Expand Down
8 changes: 0 additions & 8 deletions src/anomalib/models/components/base/export_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
from lightning.pytorch import LightningModule
from torch import nn
from torchmetrics import Metric
from torchvision.transforms.v2 import Transform

from anomalib import TaskType
from anomalib.data import AnomalibDataModule
from anomalib.deploy.export import CompressionType, ExportType
from anomalib.deploy.utils import make_transform_exportable
from anomalib.metrics import create_metric_collection
from anomalib.pre_processing import PreProcessor
from anomalib.utils.exceptions import try_import
Expand Down Expand Up @@ -440,12 +438,6 @@ def _get_metadata(

return metadata

@property
def exportable_transform(self) -> Transform | None:
"""Return the exportable transform."""
transform = self.pre_processor.test_transform
return make_transform_exportable(transform) if transform else None


def _write_metadata_to_json(metadata: dict[str, Any], export_root: Path) -> None:
"""Write metadata to json file.
Expand Down
7 changes: 4 additions & 3 deletions src/anomalib/pre_processing/pre_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torchvision.transforms.v2 import Transform

from anomalib.data.dataclasses.torch.base import Batch
from anomalib.deploy.utils import get_exportable_transform


class PreProcessor(nn.Module, Callback):
Expand All @@ -31,9 +32,9 @@ def __init__(
)
raise ValueError(msg)

self.train_transform = train_transform or transform
self.val_transform = val_transform or transform
self.test_transform = test_transform or transform
self.train_transform = get_exportable_transform(train_transform or transform)
self.val_transform = get_exportable_transform(val_transform or transform)
self.test_transform = get_exportable_transform(test_transform or transform)

def on_train_batch_start(
self,
Expand Down

0 comments on commit b798243

Please sign in to comment.