Skip to content

Commit

Permalink
Remove transforms from datasets
Browse files Browse the repository at this point in the history
Signed-off-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
samet-akcay committed Oct 9, 2024
1 parent 7738e38 commit a133048
Show file tree
Hide file tree
Showing 14 changed files with 19 additions and 113 deletions.
18 changes: 4 additions & 14 deletions src/anomalib/data/datasets/base/depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch
from PIL import Image
from torchvision.transforms.functional import to_tensor
from torchvision.transforms.v2 import Transform
from torchvision.tv_tensors import Mask

from anomalib import TaskType
Expand All @@ -24,14 +23,10 @@ class AnomalibDepthDataset(AnomalibDataset, ABC):
Args:
task (str): Task type, either 'classification' or 'segmentation'
transform (Transform, optional): Transforms that should be applied to the input images.
Defaults to ``None``.
"""

def __init__(self, task: TaskType, transform: Transform | None = None) -> None:
super().__init__(task, transform)

self.transform = transform
def __init__(self, task: TaskType) -> None:
super().__init__(task)

def __getitem__(self, index: int) -> DepthItem:
"""Return rgb image, depth image and mask.
Expand All @@ -52,9 +47,7 @@ def __getitem__(self, index: int) -> DepthItem:
item = {"image_path": image_path, "depth_path": depth_path, "label": label_index}

if self.task == TaskType.CLASSIFICATION:
item["image"], item["depth_image"] = (
self.transform(image, depth_image) if self.transform else (image, depth_image)
)
item["image"], item["depth_image"] = image, depth_image
elif self.task == TaskType.SEGMENTATION:
# Only Anomalous (1) images have masks in anomaly datasets
# Therefore, create empty mask for Normal (0) images.
Expand All @@ -63,11 +56,8 @@ def __getitem__(self, index: int) -> DepthItem:
if label_index == LabelName.NORMAL
else Mask(to_tensor(Image.open(mask_path)).squeeze())
)
item["image"], item["depth_image"], item["mask"] = (
self.transform(image, depth_image, mask) if self.transform else (image, depth_image, mask)
)
item["image"], item["depth_image"], item["mask"] = image, depth_image, mask
item["mask_path"] = mask_path

else:
msg = f"Unknown task type: {self.task}"
raise ValueError(msg)
Expand Down
8 changes: 2 additions & 6 deletions src/anomalib/data/datasets/base/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch
from pandas import DataFrame
from torch.utils.data import Dataset
from torchvision.transforms.v2 import Transform
from torchvision.tv_tensors import Mask

from anomalib import TaskType
Expand Down Expand Up @@ -58,14 +57,11 @@ class AnomalibDataset(Dataset, ABC):
Args:
task (str): Task type, either 'classification' or 'segmentation'
transform (Transform, optional): Transforms that should be applied to the input images.
Defaults to ``None``.
"""

def __init__(self, task: TaskType | str, transform: Transform | None = None) -> None:
def __init__(self, task: TaskType | str) -> None:
super().__init__()
self.task = TaskType(task)
self.transform = transform
self._samples: DataFrame | None = None
self._category: str | None = None

Expand Down Expand Up @@ -170,7 +166,7 @@ def __getitem__(self, index: int) -> DatasetItem:
item = {"image_path": image_path, "gt_label": label_index}

if self.task == TaskType.CLASSIFICATION:
item["image"] = self.transform(image) if self.transform else image
item["image"] = image
elif self.task == TaskType.SEGMENTATION:
# Only Anomalous (1) images have masks in anomaly datasets
# Therefore, create empty mask for Normal (0) images.
Expand Down
13 changes: 1 addition & 12 deletions src/anomalib/data/datasets/base/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@

import torch
from pandas import DataFrame
from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2.functional import to_dtype, to_dtype_video
from torchvision.tv_tensors import Mask

from anomalib import TaskType
from anomalib.data.dataclasses import VideoBatch, VideoItem
Expand Down Expand Up @@ -39,8 +37,6 @@ class AnomalibVideoDataset(AnomalibDataset, ABC):
task (str): Task type, either 'classification' or 'segmentation'
clip_length_in_frames (int): Number of video frames in each clip.
frames_between_clips (int): Number of frames between each consecutive video clip.
transform (Transform, optional): Transforms that should be applied to the input clips.
Defaults to ``None``.
target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval.
Defaults to ``VideoTargetFrame.LAST``.
"""
Expand All @@ -50,14 +46,12 @@ def __init__(
task: TaskType,
clip_length_in_frames: int,
frames_between_clips: int,
transform: Transform | None = None,
target_frame: VideoTargetFrame = VideoTargetFrame.LAST,
) -> None:
super().__init__(task, transform)
super().__init__(task)

self.clip_length_in_frames = clip_length_in_frames
self.frames_between_clips = frames_between_clips
self.transform = transform

self.indexer: ClipsIndexer | None = None
self.indexer_cls: Callable | None = None
Expand Down Expand Up @@ -153,13 +147,8 @@ def __getitem__(self, index: int) -> VideoItem:
# include the untransformed image for visualization
item.original_image = to_dtype(item.image, torch.uint8, scale=True)

# apply transforms
if item.gt_mask is not None:
if self.transform:
item.image, item.gt_mask = self.transform(item.image, Mask(item.gt_mask))
item.gt_label = torch.Tensor([1 in frame for frame in item.gt_mask]).int().squeeze(0)
elif self.transform:
item.image = self.transform(item.image)

# squeeze temporal dimensions in case clip length is 1
item.image = item.image.squeeze(0)
Expand Down
7 changes: 1 addition & 6 deletions src/anomalib/data/datasets/depth/folder_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pathlib import Path

from pandas import DataFrame, isna
from torchvision.transforms.v2 import Transform

from anomalib import TaskType
from anomalib.data.datasets.base.depth import AnomalibDepthDataset
Expand All @@ -24,7 +23,6 @@ class Folder3DDataset(AnomalibDepthDataset):
Args:
name (str): Name of the dataset.
task (TaskType): Task type. (``classification``, ``detection`` or ``segmentation``).
transform (Transform): Transforms that should be applied to the input images.
normal_dir (str | Path): Path to the directory containing normal images.
root (str | Path | None): Root folder of the dataset.
Defaults to ``None``.
Expand All @@ -45,8 +43,6 @@ class Folder3DDataset(AnomalibDepthDataset):
normal_test_depth_dir (str | Path | None, optional): Path to the directory containing
normal depth images for the test dataset. Normal test images will be a split of `normal_dir` if `None`.
Defaults to ``None``.
transform (Transform, optional): Transforms that should be applied to the input images.
Defaults to ``None``.
split (str | Split | None): Fixed subset split that follows from folder structure on file system.
Choose from [Split.FULL, Split.TRAIN, Split.TEST]
Defaults to ``None``.
Expand All @@ -70,11 +66,10 @@ def __init__(
normal_depth_dir: str | Path | None = None,
abnormal_depth_dir: str | Path | None = None,
normal_test_depth_dir: str | Path | None = None,
transform: Transform | None = None,
split: str | Split | None = None,
extensions: tuple[str, ...] | None = None,
) -> None:
super().__init__(task, transform)
super().__init__(task)

self._name = name
self.split = split
Expand Down
6 changes: 1 addition & 5 deletions src/anomalib/data/datasets/depth/mvtec_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from pathlib import Path

from pandas import DataFrame
from torchvision.transforms.v2 import Transform

from anomalib import TaskType
from anomalib.data.datasets.base.depth import AnomalibDepthDataset
Expand All @@ -43,8 +42,6 @@ class MVTec3DDataset(AnomalibDepthDataset):
Defaults to ``"./datasets/MVTec3D"``.
category (str): Sub-category of the dataset, e.g. 'bagel'
Defaults to ``"bagel"``.
transform (Transform, optional): Transforms that should be applied to the input images.
Defaults to ``None``.
split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST
Defaults to ``None``.
"""
Expand All @@ -54,10 +51,9 @@ def __init__(
task: TaskType,
root: Path | str = "./datasets/MVTec3D",
category: str = "bagel",
transform: Transform | None = None,
split: str | Split | None = None,
) -> None:
super().__init__(task=task, transform=transform)
super().__init__(task)

self.root_category = Path(root) / Path(category)
self.split = split
Expand Down
9 changes: 1 addition & 8 deletions src/anomalib/data/datasets/image/btech.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import pandas as pd
from pandas.core.frame import DataFrame
from torchvision.transforms.v2 import Transform

from anomalib import TaskType
from anomalib.data.datasets.base.image import AnomalibDataset
Expand All @@ -28,19 +27,14 @@ class BTechDataset(AnomalibDataset):
Args:
root: Path to the BTech dataset
category: Name of the BTech category.
transform (Transform, optional): Transforms that should be applied to the input images.
Defaults to ``None``.
split: 'train', 'val' or 'test'
task: ``classification``, ``detection`` or ``segmentation``
create_validation_set: Create a validation subset in addition to the train and test subsets
Examples:
>>> from anomalib.data.image.btech import BTechDataset
>>> from anomalib.data.utils.transforms import get_transforms
>>> transform = get_transforms(image_size=256)
>>> dataset = BTechDataset(
... task="classification",
... transform=transform,
... root='./datasets/BTech',
... category='01',
... )
Expand Down Expand Up @@ -69,11 +63,10 @@ def __init__(
self,
root: str | Path,
category: str,
transform: Transform | None = None,
split: str | Split | None = None,
task: TaskType | str = TaskType.SEGMENTATION,
) -> None:
super().__init__(task, transform)
super().__init__(task)

self.root_category = Path(root) / category
self.split = split
Expand Down
14 changes: 2 additions & 12 deletions src/anomalib/data/datasets/image/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pathlib import Path

from pandas import DataFrame
from torchvision.transforms.v2 import Transform

from anomalib import TaskType
from anomalib.data.datasets.base.image import AnomalibDataset
Expand All @@ -27,8 +26,6 @@ class FolderDataset(AnomalibDataset):
Args:
name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving.
task (TaskType): Task type. (``classification``, ``detection`` or ``segmentation``).
transform (Transform, optional): Transforms that should be applied to the input images.
Defaults to ``None``.
normal_dir (str | Path | Sequence): Path to the directory containing normal images.
root (str | Path | None): Root folder of the dataset.
Defaults to ``None``.
Expand All @@ -52,20 +49,14 @@ class FolderDataset(AnomalibDataset):
Examples:
Assume that we would like to use this ``FolderDataset`` to create a dataset from a folder for a classification
task. We could first create the transforms,
>>> from anomalib.data.utils import InputNormalizationMethod, get_transforms
>>> transform = get_transforms(image_size=256, normalization=InputNormalizationMethod.NONE)
We could then create the dataset as follows,
task.
.. code-block:: python
folder_dataset_classification_train = FolderDataset(
normal_dir=dataset_root / "good",
abnormal_dir=dataset_root / "crack",
split="train",
transform=transform,
task=TaskType.CLASSIFICATION,
)
Expand All @@ -76,15 +67,14 @@ def __init__(
name: str,
task: TaskType,
normal_dir: str | Path | Sequence[str | Path],
transform: Transform | None = None,
root: str | Path | None = None,
abnormal_dir: str | Path | Sequence[str | Path] | None = None,
normal_test_dir: str | Path | Sequence[str | Path] | None = None,
mask_dir: str | Path | Sequence[str | Path] | None = None,
split: str | Split | None = None,
extensions: tuple[str, ...] | None = None,
) -> None:
super().__init__(task, transform)
super().__init__(task)

self._name = name
self.split = split
Expand Down
6 changes: 1 addition & 5 deletions src/anomalib/data/datasets/image/kolektor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from cv2 import imread
from pandas import DataFrame
from sklearn.model_selection import train_test_split
from torchvision.transforms.v2 import Transform

from anomalib import TaskType
from anomalib.data.datasets import AnomalibDataset
Expand All @@ -38,8 +37,6 @@ class KolektorDataset(AnomalibDataset):
task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``
root (Path | str): Path to the root of the dataset
Defaults to ``./datasets/kolektor``.
transform (Transform, optional): Transforms that should be applied to the input images.
Defaults to ``None``.
split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST
Defaults to ``None``.
"""
Expand All @@ -48,10 +45,9 @@ def __init__(
self,
task: TaskType,
root: Path | str = "./datasets/kolektor",
transform: Transform | None = None,
split: str | Split | None = None,
) -> None:
super().__init__(task=task, transform=transform)
super().__init__(task)

self.root = root
self.split = split
Expand Down
9 changes: 1 addition & 8 deletions src/anomalib/data/datasets/image/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from pathlib import Path

from pandas import DataFrame
from torchvision.transforms.v2 import Transform

from anomalib import TaskType
from anomalib.data.datasets.base import AnomalibDataset
Expand Down Expand Up @@ -65,21 +64,16 @@ class MVTecDataset(AnomalibDataset):
Defaults to ``./datasets/MVTec``.
category (str): Sub-category of the dataset, e.g. 'bottle'
Defaults to ``bottle``.
transform (Transform, optional): Transforms that should be applied to the input images.
Defaults to ``None``.
split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST
Defaults to ``None``.
Examples:
.. code-block:: python
from anomalib.data.image.mvtec import MVTecDataset
from anomalib.data.utils.transforms import get_transforms
transform = get_transforms(image_size=256)
dataset = MVTecDataset(
task="classification",
transform=transform,
root='./datasets/MVTec',
category='zipper',
)
Expand Down Expand Up @@ -110,10 +104,9 @@ def __init__(
task: TaskType,
root: Path | str = "./datasets/MVTec",
category: str = "bottle",
transform: Transform | None = None,
split: str | Split | None = None,
) -> None:
super().__init__(task=task, transform=transform)
super().__init__(task)

self.root_category = Path(root) / Path(category)
self.category = category
Expand Down
Loading

0 comments on commit a133048

Please sign in to comment.