diff --git a/docs/source/images/tiled_ensemble/ensemble_flow.png b/docs/source/images/tiled_ensemble/ensemble_flow.png new file mode 100644 index 0000000000..7a5a81fa79 Binary files /dev/null and b/docs/source/images/tiled_ensemble/ensemble_flow.png differ diff --git a/docs/source/markdown/guides/how_to/pipelines/custom_pipeline.md b/docs/source/markdown/guides/how_to/pipelines/custom_pipeline.md new file mode 100644 index 0000000000..ed3d66f81d --- /dev/null +++ b/docs/source/markdown/guides/how_to/pipelines/custom_pipeline.md @@ -0,0 +1,254 @@ +# Pipelines + +This guide demonstrates how to create a [Pipeline](../../reference/pipelines/index.md) for your custom task. + +A pipeline is made up of runners. These runners are responsible for running a single type of job. A job is the smallest unit of work that is independent, such as, training a model or statistical comparison of the outputs of two models. Each job should be designed to be independent of other jobs so that they are agnostic to the runner that is running them. This ensures that the job can be run in parallel or serially without any changes to the job itself. The runner does not directly instantiate a job but rather has a job generator that generates the job based on the configuration. This generator is responsible for parsing the config and generating the job. + +## Birds Eye View + +In this guide we are going to create a dummy significant parameter search pipeline. The pipeline will have two jobs. The first job trains a model and computes the metric. The second job computes the significance of the parameters to the final score using shapely values. The final output of the pipeline is a plot that shows the contribution of each parameter to the final score. This will help teach you how to create a pipeline, a job, a job generator, and how to expose it to the `anomalib` CLI. The pipeline is going to be named `experiment`. So by the end of this you will be able to generate significance plot using + +```{literalinclude} ../../../../snippets/pipelines/dummy/anomalib_cli.txt +:language: bash +``` + +The final directory structure will look as follows: + +```{literalinclude} ../../../../snippets/pipelines/dummy/src_dir_structure.txt + +``` + +```{literalinclude} ../../../../snippets/pipelines/dummy/tools_dir_structure.txt +:language: bash +``` + +## Creating the Jobs + +Let's first look at the base class for the [jobs](../../reference/pipelines/base/job.md). It has a few methods defined. + +- The `run` method is the main method that is called by the runner. This is where we will train the model and return the model metrics. +- The `collect` method is used to gather the results from all the runs and collate them. This is handy as we want to pass a single object to the next job that contains details of all the runs including the final score. +- The `save` method is used to write any artifacts to the disk. It accepts the gathered results as a parameter. This is useful in a variety of situations. Say, when we want to write the results in a csv file or write the raw anomaly maps for further processing. + +Let's create the first job that trains the model and computes the metric. Since it is a dummy example, we will just return a random number as the metric. + +```python +class TrainJob(Job): + name = "train" + + def __init__(self, lr: float, backbone: str, stride: int): + self.lr = lr + self.backbone = backbone + self.stride = stride + + def run(self, task_id: int | None = None) -> dict: + print(f"Training with lr: {self.lr}, backbone: {self.backbone}, stride: {self.stride}") + time.sleep(2) + score = np.random.uniform(0.7, 0.1) + return {"lr": self.lr, "backbone": self.backbone, "stride": self.stride, "score": score} +``` + +Ignore the `task_id` for now. It is used for parallel jobs. We will come back to it later. + +````{note} +The `name` attribute is important and is used to identify the arguments in the job config file. +So, in our case the config `yaml` file will contain an entry like this: + +```yaml +... +train: + lr: + backbone: + stride: +... +```` + +Of course, it is up to us to choose what parameters should be shown under the `train` key. + +Let's also add the `collect` method so that we return a nice dict object that can be used by the next job. + +```python +def collect(results: list[dict]) -> dict: + output: dict = {} + for key in results[0]: + output[key] = [] + for result in results: + for key, value in result.items(): + output[key].append(value) + return output +``` + +We can also define a `save` method that writes the dictionary as a csv file. + +```python +@staticmethod +def save(results: dict) -> None: + """Save results in a csv file.""" + results_df = pd.DataFrame(results) + file_path = Path("runs") / TrainJob.name + file_path.mkdir(parents=True, exist_ok=True) + results_df.to_csv(file_path / "results.csv", index=False) +``` + +The entire job class is shown below. + +```{literalinclude} ../../../../snippets/pipelines/dummy/train_job.txt +:language: python +``` + +Now we need a way to generate this job when the pipeline is run. To do this we need to subclass the [JobGenerator](../../reference/pipelines/base/generator.md) class. + +The job generator is the actual object that is attached to a runner and is responsible for parsing the configuration and generating jobs. It has two methods that need to be implemented. + +- `generate_job`: This method accepts the configuration as a dictionary and, optionally, the results of the previous job. For the train job, we don't need results for previous jobs, so we will ignore it. +- `job_class`: This holds the reference to the class of the job that the generator will yield. It is used to inform the runner about the job that is being run, and is used to access the static attributes of the job such as its name, collect method, etc. + +Let's first start by defining the configuration that the generator will accept. The train job requires three parameters: `lr`, `backbone`, and `stride`. We will also add another parameter that defines the number of experiments we want to run. One way to define it would be as follows: + +```yaml +train: + experiments: 10 + lr: [0.1, 0.99] + backbone: + - resnet18 + - wide_resnet50 + stride: + - 3 + - 5 +``` + +For this example the specification is defined as follows. + +1. The number of experiments is set to 10. +2. Learning rate is sampled from a uniform distribution in the range `[0.1, 0.99]`. +3. The backbone is chosen from the list `["resnet18", "wide_resnet50"]`. +4. The stride is chosen from the list `[3, 5]`. + +```{note} +While the `[ ]` and `-` syntax in `yaml` both signify a list, for visual disambiguation this example uses `[ ]` to denote closed interval and `-` for a list of options. +``` + +With this defined, we can define the generator class as follows. + +```{literalinclude} ../../../../snippets/pipelines/dummy/train_generator.txt +:language: python +``` + +Since this is a dummy example, we generate the next experiment randomly. In practice, you would use a more sophisticated method that relies on your validation metrics to generate the next experiment. + +```{admonition} Challenge +:class: tip +For a challenge define your own configuration and a generator to parse that configuration. +``` + +Okay, so now we can train the model. We still need a way to find out which parameters contribute the most to the final score. We will do this by computing the shapely values to find out the contribution of each parameter to the final score. + +Let's first start by adding the library to our environment + +```bash +pip install shap +``` + +The following listing shows the job that computes the shapely values and saves a plot that shows the contribution of each parameter to the final score. A quick rundown without going into the details of the job (as it is irrelevant to the pipeline) is as follows. We create a `RandomForestRegressor` that is trained on the parameters to predict the final score. We then compute the shapely values to identify the parameters that have the most significant impact on the model performance. Finally, the `save` method saves the plot so we can visually inspect the results. + +```{literalinclude} ../../../../snippets/pipelines/dummy/significance_job.txt + +``` + +Great! Now we have the job, as before, we need the generator. Since we only need the results from the previous stage, we don't need to define the config. Let's quickly write that as well. + +```{literalinclude} ../../../../snippets/pipelines/dummy/significance_job_generator.txt + +``` + +## Experiment Pipeline + +So now we have the jobs, and a way to generate them. Let's look at how we can chain them together to achieve what we want. We will use the [Pipeline](../../reference/pipelines/base/pipeline.md) class to define the pipeline. + +When creating a custom pipeline, there is only one important method that we need to implement. That is the `_setup_runners` method. This is where we chain the runners together. + +```{literalinclude} ../../../../snippets/pipelines/dummy/pipeline_serial.txt +:language: python +``` + +In this example we use `SerialRunner` for running each job. It is a simple runner that runs the jobs in a serial manner. For more information on `SerialRunner` look [here](../../reference/pipelines/runners/serial.md). + +Okay, so we have the pipeline. How do we run it? To do this let's create a simple entrypoint in `tools` folder of Anomalib. + +Here is how the directory looks. + +```{literalinclude} ../../../../snippets/pipelines/dummy/tools_dir_structure.txt +:language: bash +``` + +As you can see, we have the `config.yaml` file in the same directory. Let's quickly populate `experiment.py`. + +```python +from anomalib.pipelines.experiment_pipeline import ExperimentPipeline + +if __name__ == "__main__": + ExperimentPipeline().run() +``` + +Alright! Time to take it on the road. + +```bash +python tools/experimental/experiment/experiment.py --config tools/experimental/experiment/config.yaml +``` + +If all goes well you should see the summary plot in `runs/significant_feature/summary_plot.png`. + +## Exposing to the CLI + +Now that you have your shiny new pipeline, you can expose it as a subcommand to `anomalib` by adding an entry to the pipeline registry in `anomalib/cli/pipelines.py`. + +```python +if try_import("anomalib.pipelines"): + ... + from anomalib.pipelines import ExperimentPipeline + +PIPELINE_REGISTRY: dict[str, type[Pipeline]] | None = { + "experiment": ExperimentPipeline, + ... +} +``` + +With this you can now call + +```{literalinclude} ../../../../snippets/pipelines/dummy/anomalib_cli.txt +:language: bash +``` + +Congratulations! You have successfully created a pipeline that trains a model and computes the significance of the parameters to the final score 🎉 + +```{admonition} Challenge +:class: tip +This example used a random model hence the scores were meaningless. Try to implement a real model and compute the scores. Look into which parameters lead to the most significant contribution to your score. +``` + +## Final Tweaks + +Before we end, let's look at a few final tweaks that you can make to the pipeline. + +First, let's run the initial model training in parallel. Since all jobs are independent, we can use the [ParallelRunner](../../reference/pipelines/runners/parallel.md). Since the `TrainJob` is a dummy job in this example, the pool of parallel jobs is set to the number of experiments. + +```{literalinclude} ../../../../snippets/pipelines/dummy/pipeline_parallel.txt + +``` + +You now notice that the entire pipeline takes lesser time to run. This is handy when you have large number of experiments, and when each job takes substantial time to run. + +Now on to the second one. When running the pipeline we don't want our terminal cluttered with the outputs from each run. Anomalib provides a handy decorator that temporarily hides the output of a function. It suppresses all outputs to the standard out and the standard error unless an exception is raised. Let's add this to the `TrainJob` + +```python +from anomalib.utils.logging import hide_output + +class TrainJob(Job): + ... + + @hide_output + def run(self, task_id: int | None = None) -> dict: + ... +``` + +You will no longer see the output of the `print` statement in the `TrainJob` method in the terminal. diff --git a/docs/source/markdown/guides/how_to/pipelines/index.md b/docs/source/markdown/guides/how_to/pipelines/index.md index ed3d66f81d..c7f2c44706 100644 --- a/docs/source/markdown/guides/how_to/pipelines/index.md +++ b/docs/source/markdown/guides/how_to/pipelines/index.md @@ -1,254 +1,30 @@ -# Pipelines +# Pipeline Tutorials -This guide demonstrates how to create a [Pipeline](../../reference/pipelines/index.md) for your custom task. +This section contains tutorials on how to use different pipelines of Anomalib and how to creat your own. -A pipeline is made up of runners. These runners are responsible for running a single type of job. A job is the smallest unit of work that is independent, such as, training a model or statistical comparison of the outputs of two models. Each job should be designed to be independent of other jobs so that they are agnostic to the runner that is running them. This ensures that the job can be run in parallel or serially without any changes to the job itself. The runner does not directly instantiate a job but rather has a job generator that generates the job based on the configuration. This generator is responsible for parsing the config and generating the job. +::::{grid} +:margin: 1 1 0 0 +:gutter: 1 -## Birds Eye View +:::{grid-item-card} {octicon}`stack` Tiled Ensemble +:link: ./tiled_ensemble +:link-type: doc -In this guide we are going to create a dummy significant parameter search pipeline. The pipeline will have two jobs. The first job trains a model and computes the metric. The second job computes the significance of the parameters to the final score using shapely values. The final output of the pipeline is a plot that shows the contribution of each parameter to the final score. This will help teach you how to create a pipeline, a job, a job generator, and how to expose it to the `anomalib` CLI. The pipeline is going to be named `experiment`. So by the end of this you will be able to generate significance plot using +Learn more about how to use the tiled ensemble pipelines. +::: -```{literalinclude} ../../../../snippets/pipelines/dummy/anomalib_cli.txt -:language: bash -``` - -The final directory structure will look as follows: - -```{literalinclude} ../../../../snippets/pipelines/dummy/src_dir_structure.txt - -``` - -```{literalinclude} ../../../../snippets/pipelines/dummy/tools_dir_structure.txt -:language: bash -``` - -## Creating the Jobs - -Let's first look at the base class for the [jobs](../../reference/pipelines/base/job.md). It has a few methods defined. - -- The `run` method is the main method that is called by the runner. This is where we will train the model and return the model metrics. -- The `collect` method is used to gather the results from all the runs and collate them. This is handy as we want to pass a single object to the next job that contains details of all the runs including the final score. -- The `save` method is used to write any artifacts to the disk. It accepts the gathered results as a parameter. This is useful in a variety of situations. Say, when we want to write the results in a csv file or write the raw anomaly maps for further processing. - -Let's create the first job that trains the model and computes the metric. Since it is a dummy example, we will just return a random number as the metric. - -```python -class TrainJob(Job): - name = "train" +:::{grid-item-card} {octicon}`gear` Custom Pipeline +:link: ./custom_pipeline +:link-type: doc - def __init__(self, lr: float, backbone: str, stride: int): - self.lr = lr - self.backbone = backbone - self.stride = stride - - def run(self, task_id: int | None = None) -> dict: - print(f"Training with lr: {self.lr}, backbone: {self.backbone}, stride: {self.stride}") - time.sleep(2) - score = np.random.uniform(0.7, 0.1) - return {"lr": self.lr, "backbone": self.backbone, "stride": self.stride, "score": score} -``` - -Ignore the `task_id` for now. It is used for parallel jobs. We will come back to it later. - -````{note} -The `name` attribute is important and is used to identify the arguments in the job config file. -So, in our case the config `yaml` file will contain an entry like this: - -```yaml -... -train: - lr: - backbone: - stride: -... -```` - -Of course, it is up to us to choose what parameters should be shown under the `train` key. - -Let's also add the `collect` method so that we return a nice dict object that can be used by the next job. - -```python -def collect(results: list[dict]) -> dict: - output: dict = {} - for key in results[0]: - output[key] = [] - for result in results: - for key, value in result.items(): - output[key].append(value) - return output -``` - -We can also define a `save` method that writes the dictionary as a csv file. - -```python -@staticmethod -def save(results: dict) -> None: - """Save results in a csv file.""" - results_df = pd.DataFrame(results) - file_path = Path("runs") / TrainJob.name - file_path.mkdir(parents=True, exist_ok=True) - results_df.to_csv(file_path / "results.csv", index=False) -``` - -The entire job class is shown below. - -```{literalinclude} ../../../../snippets/pipelines/dummy/train_job.txt -:language: python -``` - -Now we need a way to generate this job when the pipeline is run. To do this we need to subclass the [JobGenerator](../../reference/pipelines/base/generator.md) class. - -The job generator is the actual object that is attached to a runner and is responsible for parsing the configuration and generating jobs. It has two methods that need to be implemented. - -- `generate_job`: This method accepts the configuration as a dictionary and, optionally, the results of the previous job. For the train job, we don't need results for previous jobs, so we will ignore it. -- `job_class`: This holds the reference to the class of the job that the generator will yield. It is used to inform the runner about the job that is being run, and is used to access the static attributes of the job such as its name, collect method, etc. - -Let's first start by defining the configuration that the generator will accept. The train job requires three parameters: `lr`, `backbone`, and `stride`. We will also add another parameter that defines the number of experiments we want to run. One way to define it would be as follows: - -```yaml -train: - experiments: 10 - lr: [0.1, 0.99] - backbone: - - resnet18 - - wide_resnet50 - stride: - - 3 - - 5 -``` - -For this example the specification is defined as follows. - -1. The number of experiments is set to 10. -2. Learning rate is sampled from a uniform distribution in the range `[0.1, 0.99]`. -3. The backbone is chosen from the list `["resnet18", "wide_resnet50"]`. -4. The stride is chosen from the list `[3, 5]`. - -```{note} -While the `[ ]` and `-` syntax in `yaml` both signify a list, for visual disambiguation this example uses `[ ]` to denote closed interval and `-` for a list of options. -``` - -With this defined, we can define the generator class as follows. - -```{literalinclude} ../../../../snippets/pipelines/dummy/train_generator.txt -:language: python -``` - -Since this is a dummy example, we generate the next experiment randomly. In practice, you would use a more sophisticated method that relies on your validation metrics to generate the next experiment. - -```{admonition} Challenge -:class: tip -For a challenge define your own configuration and a generator to parse that configuration. -``` - -Okay, so now we can train the model. We still need a way to find out which parameters contribute the most to the final score. We will do this by computing the shapely values to find out the contribution of each parameter to the final score. - -Let's first start by adding the library to our environment - -```bash -pip install shap -``` +Learn more about how to create a new custom pipeline. +::: -The following listing shows the job that computes the shapely values and saves a plot that shows the contribution of each parameter to the final score. A quick rundown without going into the details of the job (as it is irrelevant to the pipeline) is as follows. We create a `RandomForestRegressor` that is trained on the parameters to predict the final score. We then compute the shapely values to identify the parameters that have the most significant impact on the model performance. Finally, the `save` method saves the plot so we can visually inspect the results. +:::: -```{literalinclude} ../../../../snippets/pipelines/dummy/significance_job.txt +```{toctree} +:caption: Model Tutorials +:hidden: +./feature_extractors ``` - -Great! Now we have the job, as before, we need the generator. Since we only need the results from the previous stage, we don't need to define the config. Let's quickly write that as well. - -```{literalinclude} ../../../../snippets/pipelines/dummy/significance_job_generator.txt - -``` - -## Experiment Pipeline - -So now we have the jobs, and a way to generate them. Let's look at how we can chain them together to achieve what we want. We will use the [Pipeline](../../reference/pipelines/base/pipeline.md) class to define the pipeline. - -When creating a custom pipeline, there is only one important method that we need to implement. That is the `_setup_runners` method. This is where we chain the runners together. - -```{literalinclude} ../../../../snippets/pipelines/dummy/pipeline_serial.txt -:language: python -``` - -In this example we use `SerialRunner` for running each job. It is a simple runner that runs the jobs in a serial manner. For more information on `SerialRunner` look [here](../../reference/pipelines/runners/serial.md). - -Okay, so we have the pipeline. How do we run it? To do this let's create a simple entrypoint in `tools` folder of Anomalib. - -Here is how the directory looks. - -```{literalinclude} ../../../../snippets/pipelines/dummy/tools_dir_structure.txt -:language: bash -``` - -As you can see, we have the `config.yaml` file in the same directory. Let's quickly populate `experiment.py`. - -```python -from anomalib.pipelines.experiment_pipeline import ExperimentPipeline - -if __name__ == "__main__": - ExperimentPipeline().run() -``` - -Alright! Time to take it on the road. - -```bash -python tools/experimental/experiment/experiment.py --config tools/experimental/experiment/config.yaml -``` - -If all goes well you should see the summary plot in `runs/significant_feature/summary_plot.png`. - -## Exposing to the CLI - -Now that you have your shiny new pipeline, you can expose it as a subcommand to `anomalib` by adding an entry to the pipeline registry in `anomalib/cli/pipelines.py`. - -```python -if try_import("anomalib.pipelines"): - ... - from anomalib.pipelines import ExperimentPipeline - -PIPELINE_REGISTRY: dict[str, type[Pipeline]] | None = { - "experiment": ExperimentPipeline, - ... -} -``` - -With this you can now call - -```{literalinclude} ../../../../snippets/pipelines/dummy/anomalib_cli.txt -:language: bash -``` - -Congratulations! You have successfully created a pipeline that trains a model and computes the significance of the parameters to the final score 🎉 - -```{admonition} Challenge -:class: tip -This example used a random model hence the scores were meaningless. Try to implement a real model and compute the scores. Look into which parameters lead to the most significant contribution to your score. -``` - -## Final Tweaks - -Before we end, let's look at a few final tweaks that you can make to the pipeline. - -First, let's run the initial model training in parallel. Since all jobs are independent, we can use the [ParallelRunner](../../reference/pipelines/runners/parallel.md). Since the `TrainJob` is a dummy job in this example, the pool of parallel jobs is set to the number of experiments. - -```{literalinclude} ../../../../snippets/pipelines/dummy/pipeline_parallel.txt - -``` - -You now notice that the entire pipeline takes lesser time to run. This is handy when you have large number of experiments, and when each job takes substantial time to run. - -Now on to the second one. When running the pipeline we don't want our terminal cluttered with the outputs from each run. Anomalib provides a handy decorator that temporarily hides the output of a function. It suppresses all outputs to the standard out and the standard error unless an exception is raised. Let's add this to the `TrainJob` - -```python -from anomalib.utils.logging import hide_output - -class TrainJob(Job): - ... - - @hide_output - def run(self, task_id: int | None = None) -> dict: - ... -``` - -You will no longer see the output of the `print` statement in the `TrainJob` method in the terminal. diff --git a/docs/source/markdown/guides/how_to/pipelines/tiled_ensemble.md b/docs/source/markdown/guides/how_to/pipelines/tiled_ensemble.md new file mode 100644 index 0000000000..3550efb5fd --- /dev/null +++ b/docs/source/markdown/guides/how_to/pipelines/tiled_ensemble.md @@ -0,0 +1,157 @@ +# Tiled ensemble + +This guide will show you how to use **The Tiled Ensemble** method for anomaly detection. For more details, refer to the official [Paper](https://openaccess.thecvf.com/content/CVPR2024W/VAND/html/Rolih_Divide_and_Conquer_High-Resolution_Industrial_Anomaly_Detection_via_Memory_Efficient_CVPRW_2024_paper.html). + +The tiled ensemble approach reduces memory consumption by dividing input images into a grid of tiles and training a dedicated model for each tile location. +It is compatible with any existing image anomaly detection model without the need for any modification of the underlying architecture. + +![Tiled ensemble flow](../../../../images/tiled_ensemble/ensemble_flow.png) + +```{note} +This feature is experimental and may not work as expected. +For any problems refer to [Issues](https://github.com/openvinotoolkit/anomalib/issues) and feel free to ask any question in [Discussions](https://github.com/openvinotoolkit/anomalib/discussions). +``` + +## Training + +You can train a tiled ensemble using the training script located inside `tools/tiled_ensemble` directory: + +```{code-block} bash + +python tools/tiled_ensemble/train_ensemble.py \ + --config tools/tiled_ensemble/ens_config.yaml +``` + +By default, the Padim model is trained on **MVTec AD bottle** category using image size of 256x256, divided into non-overlapping 128x128 tiles. +You can modify these parameters in the [config file](#ensemble-configuration). + +## Evaluation + +After training, you can evaluate the tiled ensemble on test data using: + +```{code-block} bash + +python tools/tiled_ensemble/eval.py \ + --config tools/tiled_ensemble/ens_config.yaml \ + --root path_to_results_dir + +``` + +Ensure that `root` points to the directory containing the training results, typically `results/padim/mvtec/bottle/runX`. + +## Ensemble configuration + +Tiled ensemble is configured using `ens_config.yaml` file in the `tools/tiled_ensemble` directory. +It contains general settings and tiled ensemble specific settings. + +### General + +General settings at the top of the config file are used to set up the random `seed`, `accelerator` (device) and the path to where results will be saved `default_root_dir`. + +```{code-block} yaml +seed: 42 +accelerator: "gpu" +default_root_dir: "results" +``` + +### Tiling + +This section contains the following settings, used for image tiling: + +```{code-block} yaml + +tiling: + tile_size: 256 + stride: 256 +``` + +These settings determine the tile size and stride. Another important parameter is image_size from `data` section later in the config. It determines the original size of the image. + +Input image is split into tiles, where each tile is of shape set by `tile_size` and tiles are taken with step set by `stride`. +For example: having image_size: 512, tile_size: 256, and stride: 256, results in 4 non-overlapping tile locations. + +### Normalization and thresholding + +Next up are the normalization and thresholding settings: + +```{code-block} yaml +normalization_stage: image +thresholding: + method: F1AdaptiveThreshold + stage: image +``` + +- **Normalization**: Can be applied per each tile location separately (`tile` option), after combining prediction (`image` option), or skipped (`none` option). + +- **Thresholding**: Can also be applied at different stages, but it is limited to `tile` and `image`. Another setting for thresholding is the method used. It can be specified as a string or by the class path. + +### Data + +The `data` section is used to configure the input `image_size` and other parameters for the dataset used. + +```{code-block} yaml +data: + class_path: anomalib.data.MVTec + init_args: + root: ./datasets/MVTec + category: bottle + train_batch_size: 32 + eval_batch_size: 32 + num_workers: 8 + task: segmentation + transform: null + train_transform: null + eval_transform: null + test_split_mode: from_dir + test_split_ratio: 0.2 + val_split_mode: same_as_test + val_split_ratio: 0.5 + image_size: [256, 256] +``` + +Refer to [Data](../../reference/data/image/index.md) for more details on parameters. + +### SeamSmoothing + +This section contains settings for `SeamSmoothing` block of pipeline: + +```{code-block} yaml +SeamSmoothing: + apply: True + sigma: 2 + width: 0.1 + +``` + +SeamSmoothing job is responsible for smoothing of regions where tiles meet - called tile seams. + +- **apply**: If True, smoothing will be applied. +- **sigma**: Controls the sigma of Gaussian filter used for smoothing. +- **width**: Sets the percentage of the region around the seam to be smoothed. + +### TrainModels + +The last section `TrainModels` contains the setup for model training: + +```{code-block} yaml +TrainModels: + model: + class_path: Fastflow + + metrics: + pixel: AUROC + image: AUROC + + trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + patience: 42 + monitor: pixel_AUROC + mode: max +``` + +- **Model**: Specifies the model used. Refer to [Models](../../reference/models/image/index.md) for more details on the model parameters. +- **Metrics**: Defines evaluation metrics for pixel and image level. +- **Trainer**: _optional_ parameters, used to control the training process. Refer to [Engine](../../reference/engine/index.md) for more details. diff --git a/src/anomalib/data/base/datamodule.py b/src/anomalib/data/base/datamodule.py index cb95ca8171..a9197f6670 100644 --- a/src/anomalib/data/base/datamodule.py +++ b/src/anomalib/data/base/datamodule.py @@ -119,6 +119,8 @@ def __init__( self._is_setup = False # flag to track if setup has been called from the trainer + self.collate_fn = collate_fn + @property def name(self) -> str: """Name of the datamodule.""" @@ -224,6 +226,7 @@ def train_dataloader(self) -> TRAIN_DATALOADERS: shuffle=True, batch_size=self.train_batch_size, num_workers=self.num_workers, + collate_fn=self.collate_fn, ) def val_dataloader(self) -> EVAL_DATALOADERS: @@ -233,7 +236,7 @@ def val_dataloader(self) -> EVAL_DATALOADERS: shuffle=False, batch_size=self.eval_batch_size, num_workers=self.num_workers, - collate_fn=collate_fn, + collate_fn=self.collate_fn, ) def test_dataloader(self) -> EVAL_DATALOADERS: @@ -243,7 +246,7 @@ def test_dataloader(self) -> EVAL_DATALOADERS: shuffle=False, batch_size=self.eval_batch_size, num_workers=self.num_workers, - collate_fn=collate_fn, + collate_fn=self.collate_fn, ) def predict_dataloader(self) -> EVAL_DATALOADERS: diff --git a/src/anomalib/data/utils/tiler.py b/src/anomalib/data/utils/tiler.py index 089aeaae91..2c1e949e45 100644 --- a/src/anomalib/data/utils/tiler.py +++ b/src/anomalib/data/utils/tiler.py @@ -162,11 +162,11 @@ def __init__( remove_border_count: int = 0, mode: ImageUpscaleMode = ImageUpscaleMode.PADDING, ) -> None: - self.tile_size_h, self.tile_size_w = self.__validate_size_type(tile_size) + self.tile_size_h, self.tile_size_w = self.validate_size_type(tile_size) self.random_tile_count = 4 if stride is not None: - self.stride_h, self.stride_w = self.__validate_size_type(stride) + self.stride_h, self.stride_w = self.validate_size_type(stride) self.remove_border_count = remove_border_count self.overlapping = not (self.stride_h == self.tile_size_h and self.stride_w == self.tile_size_w) @@ -201,7 +201,15 @@ def __init__( self.num_patches_w: int @staticmethod - def __validate_size_type(parameter: int | Sequence) -> tuple[int, ...]: + def validate_size_type(parameter: int | Sequence) -> tuple[int, ...]: + """Validate size type and return tuple of form [tile_h, tile_w]. + + Args: + parameter (int | Sequence): input tile size parameter. + + Returns: + tuple[int, ...]: Validated tile size in tuple form. + """ if isinstance(parameter, int): output = (parameter, parameter) elif isinstance(parameter, Sequence): diff --git a/src/anomalib/models/components/base/anomaly_module.py b/src/anomalib/models/components/base/anomaly_module.py index 963ce485a3..ecd4c62d13 100644 --- a/src/anomalib/models/components/base/anomaly_module.py +++ b/src/anomalib/models/components/base/anomaly_module.py @@ -266,6 +266,8 @@ def input_size(self) -> tuple[int, int] | None: The effective input size is the size of the input tensor after the transform has been applied. If the transform is not set, or if the transform does not change the shape of the input tensor, this method will return None. """ + if self._input_size: + return self._input_size transform = self.transform or self.configure_transforms() if transform is None: return None @@ -275,6 +277,10 @@ def input_size(self) -> tuple[int, int] | None: return None return output_shape[-2:] + def set_input_size(self, input_size: tuple[int, int]) -> None: + """Update the effective input size of the model.""" + self._input_size = input_size + def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: """Called when saving the model to a checkpoint. diff --git a/src/anomalib/pipelines/tiled_ensemble/__init__.py b/src/anomalib/pipelines/tiled_ensemble/__init__.py new file mode 100644 index 0000000000..1a068562b7 --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/__init__.py @@ -0,0 +1,12 @@ +"""Tiled ensemble pipelines.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .test_pipeline import EvalTiledEnsemble +from .train_pipeline import TrainTiledEnsemble + +__all__ = [ + "TrainTiledEnsemble", + "EvalTiledEnsemble", +] diff --git a/src/anomalib/pipelines/tiled_ensemble/components/__init__.py b/src/anomalib/pipelines/tiled_ensemble/components/__init__.py new file mode 100644 index 0000000000..619dc2e673 --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/components/__init__.py @@ -0,0 +1,30 @@ +"""Tiled ensemble pipeline components.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .merging import MergeJobGenerator +from .metrics_calculation import MetricsCalculationJobGenerator +from .model_training import TrainModelJobGenerator +from .normalization import NormalizationJobGenerator +from .prediction import PredictJobGenerator +from .smoothing import SmoothingJobGenerator +from .stats_calculation import StatisticsJobGenerator +from .thresholding import ThresholdingJobGenerator +from .utils import NormalizationStage, PredictData, ThresholdStage +from .visualization import VisualizationJobGenerator + +__all__ = [ + "NormalizationStage", + "ThresholdStage", + "PredictData", + "TrainModelJobGenerator", + "PredictJobGenerator", + "MergeJobGenerator", + "SmoothingJobGenerator", + "StatisticsJobGenerator", + "NormalizationJobGenerator", + "ThresholdingJobGenerator", + "VisualizationJobGenerator", + "MetricsCalculationJobGenerator", +] diff --git a/src/anomalib/pipelines/tiled_ensemble/components/merging.py b/src/anomalib/pipelines/tiled_ensemble/components/merging.py new file mode 100644 index 0000000000..6e8d5fc84c --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/components/merging.py @@ -0,0 +1,110 @@ +"""Tiled ensemble - prediction merging job.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections.abc import Generator +from typing import Any + +from tqdm import tqdm + +from anomalib.pipelines.components import Job, JobGenerator +from anomalib.pipelines.types import GATHERED_RESULTS, RUN_RESULTS + +from .utils.ensemble_tiling import EnsembleTiler +from .utils.helper_functions import get_ensemble_tiler +from .utils.prediction_data import EnsemblePredictions +from .utils.prediction_merging import PredictionMergingMechanism + +logger = logging.getLogger(__name__) + + +class MergeJob(Job): + """Job for merging tile-level predictions into image-level predictions. + + Args: + predictions (EnsemblePredictions): Object containing ensemble predictions. + tiler (EnsembleTiler): Ensemble tiler used for untiling. + """ + + name = "Merge" + + def __init__(self, predictions: EnsemblePredictions, tiler: EnsembleTiler) -> None: + super().__init__() + self.predictions = predictions + self.tiler = tiler + + def run(self, task_id: int | None = None) -> list[Any]: + """Run merging job that merges all batches of tile-level predictions into image-level predictions. + + Args: + task_id: Not used in this case. + + Returns: + list[Any]: List of merged predictions. + """ + del task_id # not needed here + + merger = PredictionMergingMechanism(self.predictions, self.tiler) + + logger.info("Merging predictions.") + + # merge all batches + merged_predictions = [ + merger.merge_tile_predictions(batch_idx) + for batch_idx in tqdm(range(merger.num_batches), desc="Prediction merging") + ] + + return merged_predictions # noqa: RET504 + + @staticmethod + def collect(results: list[RUN_RESULTS]) -> GATHERED_RESULTS: + """Nothing to collect in this job. + + Returns: + list[Any]: List of predictions. + """ + # take the first element as result is list of lists here + return results[0] + + @staticmethod + def save(results: GATHERED_RESULTS) -> None: + """Nothing to save in this job.""" + + +class MergeJobGenerator(JobGenerator): + """Generate MergeJob.""" + + def __init__(self, tiling_args: dict, data_args: dict) -> None: + super().__init__() + self.tiling_args = tiling_args + self.data_args = data_args + + @property + def job_class(self) -> type: + """Return the job class.""" + return MergeJob + + def generate_jobs( + self, + args: dict | None = None, + prev_stage_result: EnsemblePredictions | None = None, + ) -> Generator[MergeJob, None, None]: + """Return a generator producing a single merging job. + + Args: + args (dict): Tiled ensemble pipeline args. + prev_stage_result (EnsemblePredictions): Ensemble predictions from predict step. + + Returns: + Generator[MergeJob, None, None]: MergeJob generator + """ + del args # args not used here + + tiler = get_ensemble_tiler(self.tiling_args, self.data_args) + if prev_stage_result is not None: + yield MergeJob(prev_stage_result, tiler) + else: + msg = "Merging job requires tile level predictions from previous step." + raise ValueError(msg) diff --git a/src/anomalib/pipelines/tiled_ensemble/components/metrics_calculation.py b/src/anomalib/pipelines/tiled_ensemble/components/metrics_calculation.py new file mode 100644 index 0000000000..530662b1d3 --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/components/metrics_calculation.py @@ -0,0 +1,217 @@ +"""Tiled ensemble - metrics calculation job.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections.abc import Generator +from pathlib import Path +from typing import Any + +import pandas as pd +from tqdm import tqdm + +from anomalib import TaskType +from anomalib.metrics import AnomalibMetricCollection, create_metric_collection +from anomalib.pipelines.components import Job, JobGenerator +from anomalib.pipelines.types import GATHERED_RESULTS, PREV_STAGE_RESULT, RUN_RESULTS + +from .utils import NormalizationStage +from .utils.helper_functions import get_threshold_values + +logger = logging.getLogger(__name__) + + +class MetricsCalculationJob(Job): + """Job for image and pixel metrics calculation. + + Args: + accelerator (str): Accelerator (device) to use. + predictions (list[Any]): List of batch predictions. + root_dir (Path): Root directory to save checkpoints, stats and images. + image_metrics (AnomalibMetricCollection): Collection of all image-level metrics. + pixel_metrics (AnomalibMetricCollection): Collection of all pixel-level metrics. + """ + + name = "Metrics" + + def __init__( + self, + accelerator: str, + predictions: list[Any] | None, + root_dir: Path, + image_metrics: AnomalibMetricCollection, + pixel_metrics: AnomalibMetricCollection, + ) -> None: + super().__init__() + self.accelerator = accelerator + self.predictions = predictions + self.root_dir = root_dir + self.image_metrics = image_metrics + self.pixel_metrics = pixel_metrics + + def run(self, task_id: int | None = None) -> dict: + """Run a job that calculates image and pixel level metrics. + + Args: + task_id: Not used in this case. + + Returns: + dict[str, float]: Dictionary containing calculated metric values. + """ + del task_id # not needed here + + logger.info("Starting metrics calculation.") + + # add predicted data to metrics + for data in tqdm(self.predictions, desc="Calculating metrics"): + self.image_metrics.update(data["pred_scores"], data["label"].int()) + if "mask" in data and "anomaly_maps" in data: + self.pixel_metrics.update(data["anomaly_maps"], data["mask"].int()) + + # compute all metrics on specified accelerator + metrics_dict = {} + for name, metric in self.image_metrics.items(): + metric.to(self.accelerator) + metrics_dict[name] = metric.compute().item() + metric.cpu() + + if self.pixel_metrics.update_called: + for name, metric in self.pixel_metrics.items(): + metric.to(self.accelerator) + metrics_dict[name] = metric.compute().item() + metric.cpu() + + for name, value in metrics_dict.items(): + print(f"{name}: {value:.4f}") + + # save path used in `save` method + metrics_dict["save_path"] = self.root_dir / "metric_results.csv" + + return metrics_dict + + @staticmethod + def collect(results: list[RUN_RESULTS]) -> GATHERED_RESULTS: + """Nothing to collect in this job. + + Returns: + list[Any]: list of predictions. + """ + # take the first element as result is list of dict here + return results[0] + + @staticmethod + def save(results: GATHERED_RESULTS) -> None: + """Save metrics values to csv.""" + logger.info("Saving metrics to csv.") + + # get and remove path from stats dict + results_path: Path = results.pop("save_path") + results_path.parent.mkdir(parents=True, exist_ok=True) + + df_dict = {k: [v] for k, v in results.items()} + metrics_df = pd.DataFrame(df_dict) + metrics_df.to_csv(results_path, index=False) + + +class MetricsCalculationJobGenerator(JobGenerator): + """Generate MetricsCalculationJob. + + Args: + root_dir (Path): Root directory to save checkpoints, stats and images. + """ + + def __init__( + self, + accelerator: str, + root_dir: Path, + task: TaskType, + metrics: dict, + normalization_stage: NormalizationStage, + ) -> None: + self.accelerator = accelerator + self.root_dir = root_dir + self.task = task + self.metrics = metrics + self.normalization_stage = normalization_stage + + @property + def job_class(self) -> type: + """Return the job class.""" + return MetricsCalculationJob + + def configure_ensemble_metrics( + self, + image_metrics: list[str] | dict[str, dict[str, Any]] | None = None, + pixel_metrics: list[str] | dict[str, dict[str, Any]] | None = None, + ) -> tuple[AnomalibMetricCollection, AnomalibMetricCollection]: + """Configure image and pixel metrics and put them into a collection. + + Args: + image_metrics (list[str] | None): List of image-level metric names. + pixel_metrics (list[str] | None): List of pixel-level metric names. + + Returns: + tuple[AnomalibMetricCollection, AnomalibMetricCollection]: + Image-metrics collection and pixel-metrics collection + """ + image_metrics = [] if image_metrics is None else image_metrics + + if pixel_metrics is None: + pixel_metrics = [] + elif self.task == TaskType.CLASSIFICATION: + pixel_metrics = [] + logger.warning( + "Cannot perform pixel-level evaluation when task type is classification. " + "Ignoring the following pixel-level metrics: %s", + pixel_metrics, + ) + + # if a single metric is passed, transform to list to fit the creation function + if isinstance(image_metrics, str): + image_metrics = [image_metrics] + if isinstance(pixel_metrics, str): + pixel_metrics = [pixel_metrics] + + image_metrics_collection = create_metric_collection(image_metrics, "image_") + pixel_metrics_collection = create_metric_collection(pixel_metrics, "pixel_") + + return image_metrics_collection, pixel_metrics_collection + + def generate_jobs( + self, + args: dict | None = None, + prev_stage_result: PREV_STAGE_RESULT = None, + ) -> Generator[MetricsCalculationJob, None, None]: + """Make a generator that yields a single metrics calculation job. + + Args: + args: ensemble run config. + prev_stage_result: ensemble predictions from previous step. + + Returns: + Generator[MetricsCalculationJob, None, None]: MetricsCalculationJob generator + """ + del args # args not used here + + image_metrics_config = self.metrics.get("image", None) + pixel_metrics_config = self.metrics.get("pixel", None) + + image_threshold, pixel_threshold = get_threshold_values(self.normalization_stage, self.root_dir) + + image_metrics, pixel_metrics = self.configure_ensemble_metrics( + image_metrics=image_metrics_config, + pixel_metrics=pixel_metrics_config, + ) + + # set thresholds for metrics that need it + image_metrics.set_threshold(image_threshold) + pixel_metrics.set_threshold(pixel_threshold) + + yield MetricsCalculationJob( + accelerator=self.accelerator, + predictions=prev_stage_result, + root_dir=self.root_dir, + image_metrics=image_metrics, + pixel_metrics=pixel_metrics, + ) diff --git a/src/anomalib/pipelines/tiled_ensemble/components/model_training.py b/src/anomalib/pipelines/tiled_ensemble/components/model_training.py new file mode 100644 index 0000000000..6bc81c793b --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/components/model_training.py @@ -0,0 +1,192 @@ +"""Tiled ensemble - ensemble training job.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections.abc import Generator +from itertools import product +from pathlib import Path + +from lightning import seed_everything + +from anomalib.data import AnomalibDataModule +from anomalib.models import AnomalyModule +from anomalib.pipelines.components import Job, JobGenerator +from anomalib.pipelines.types import GATHERED_RESULTS, PREV_STAGE_RESULT + +from .utils import NormalizationStage +from .utils.ensemble_engine import TiledEnsembleEngine +from .utils.helper_functions import ( + get_ensemble_datamodule, + get_ensemble_engine, + get_ensemble_model, + get_ensemble_tiler, +) + +logger = logging.getLogger(__name__) + + +class TrainModelJob(Job): + """Job for training of individual models in the tiled ensemble. + + Args: + accelerator (str): Accelerator (device) to use. + seed (int): Random seed for reproducibility. + root_dir (Path): Root directory to save checkpoints, stats and images. + tile_index (tuple[int, int]): Index of tile that this model processes. + normalization_stage (str): Normalization stage flag. + metrics (dict): metrics dict with pixel and image metric names. + trainer_args (dict| None): Additional arguments to pass to the trainer class. + model (AnomalyModule): Model to train. + datamodule (AnomalibDataModule): Datamodule with all dataloaders. + + """ + + name = "TrainModels" + + def __init__( + self, + accelerator: str, + seed: int, + root_dir: Path, + tile_index: tuple[int, int], + normalization_stage: str, + metrics: dict, + trainer_args: dict | None, + model: AnomalyModule, + datamodule: AnomalibDataModule, + ) -> None: + super().__init__() + self.accelerator = accelerator + self.seed = seed + self.root_dir = root_dir + self.tile_index = tile_index + self.normalization_stage = normalization_stage + self.metrics = metrics + self.trainer_args = trainer_args + self.model = model + self.datamodule = datamodule + + def run( + self, + task_id: int | None = None, + ) -> TiledEnsembleEngine: + """Run train job that fits the model for given tile location. + + Args: + task_id: Passed when job is ran in parallel. + + Returns: + TiledEnsembleEngine: Engine containing trained model. + """ + devices: str | list[int] = "auto" + if task_id is not None: + devices = [task_id] + logger.info(f"Running job {self.model.__class__.__name__} with device {task_id}") + + logger.info("Start of training for tile at position %s,", self.tile_index) + seed_everything(self.seed) + + # create engine for specific tile location and fit the model + engine = get_ensemble_engine( + tile_index=self.tile_index, + accelerator=self.accelerator, + devices=devices, + root_dir=self.root_dir, + normalization_stage=self.normalization_stage, + metrics=self.metrics, + trainer_args=self.trainer_args, + ) + engine.fit(model=self.model, datamodule=self.datamodule) + # move model to cpu to avoid memory issues as the engine is returned to be used in validation phase + engine.model.cpu() + + return engine + + @staticmethod + def collect(results: list[TiledEnsembleEngine]) -> dict[tuple[int, int], TiledEnsembleEngine]: + """Collect engines from each tile location into a dict. + + Returns: + dict[tuple[int, int], TiledEnsembleEngine]: Dict has form {tile_index: TiledEnsembleEngine} + """ + return {r.tile_index: r for r in results} + + @staticmethod + def save(results: GATHERED_RESULTS) -> None: + """Skip as checkpoints are already saved by callback.""" + + +class TrainModelJobGenerator(JobGenerator): + """Generator for training job that train model for each tile location. + + Args: + root_dir (Path): Root directory to save checkpoints, stats and images. + """ + + def __init__( + self, + seed: int, + accelerator: str, + root_dir: Path, + tiling_args: dict, + data_args: dict, + normalization_stage: NormalizationStage, + ) -> None: + self.seed = seed + self.accelerator = accelerator + self.root_dir = root_dir + self.tiling_args = tiling_args + self.data_args = data_args + self.normalization_stage = normalization_stage + + @property + def job_class(self) -> type: + """Return the job class.""" + return TrainModelJob + + def generate_jobs( + self, + args: dict | None = None, + prev_stage_result: PREV_STAGE_RESULT = None, + ) -> Generator[TrainModelJob, None, None]: + """Generate training jobs for each tile location. + + Args: + args (dict): Dict with config passed to training. + prev_stage_result (None): Not used here. + + Returns: + Generator[TrainModelJob, None, None]: TrainModelJob generator + """ + del prev_stage_result # Not needed for this job + if args is None: + msg = "TrainModels job requires config args" + raise ValueError(msg) + + # tiler used for splitting the image and getting the tile count + tiler = get_ensemble_tiler(self.tiling_args, self.data_args) + + logger.info( + "Tiled ensemble training started. Separate models will be trained for %d tile locations.", + tiler.num_tiles, + ) + # go over all tile positions + for tile_index in product(range(tiler.num_patches_h), range(tiler.num_patches_w)): + # prepare datamodule with custom collate function that only provides specific tile of image + datamodule = get_ensemble_datamodule(self.data_args, tiler, tile_index) + model = get_ensemble_model(args["model"], tiler) + + # pass root_dir to engine so all models in ensemble have the same root dir + yield TrainModelJob( + accelerator=self.accelerator, + seed=self.seed, + root_dir=self.root_dir, + tile_index=tile_index, + normalization_stage=self.normalization_stage, + metrics=args["metrics"], + trainer_args=args.get("trainer", {}), + model=model, + datamodule=datamodule, + ) diff --git a/src/anomalib/pipelines/tiled_ensemble/components/normalization.py b/src/anomalib/pipelines/tiled_ensemble/components/normalization.py new file mode 100644 index 0000000000..8c7a563506 --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/components/normalization.py @@ -0,0 +1,120 @@ +"""Tiled ensemble - normalization job.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +from collections.abc import Generator +from pathlib import Path +from typing import Any + +from tqdm import tqdm + +from anomalib.pipelines.components import Job, JobGenerator +from anomalib.pipelines.types import GATHERED_RESULTS, RUN_RESULTS +from anomalib.utils.normalization.min_max import normalize + +logger = logging.getLogger(__name__) + + +class NormalizationJob(Job): + """Job for normalization of predictions. + + Args: + predictions (list[Any]): List of predictions. + root_dir (Path): Root directory containing statistics needed for normalization. + """ + + name = "Normalize" + + def __init__(self, predictions: list[Any] | None, root_dir: Path) -> None: + super().__init__() + self.predictions = predictions + self.root_dir = root_dir + + def run(self, task_id: int | None = None) -> list[Any] | None: + """Run normalization job which normalizes image, pixel and box scores. + + Args: + task_id: Not used in this case. + + Returns: + list[Any]: List of normalized predictions. + """ + del task_id # not needed here + + # load all statistics needed for normalization + stats_path = self.root_dir / "weights" / "lightning" / "stats.json" + with stats_path.open("r") as f: + stats = json.load(f) + minmax = stats["minmax"] + image_threshold = stats["image_threshold"] + pixel_threshold = stats["pixel_threshold"] + + logger.info("Starting normalization.") + + for data in tqdm(self.predictions, desc="Normalizing"): + data["pred_scores"] = normalize( + data["pred_scores"], + image_threshold, + minmax["pred_scores"]["min"], + minmax["pred_scores"]["max"], + ) + if "anomaly_maps" in data: + data["anomaly_maps"] = normalize( + data["anomaly_maps"], + pixel_threshold, + minmax["anomaly_maps"]["min"], + minmax["anomaly_maps"]["max"], + ) + + return self.predictions + + @staticmethod + def collect(results: list[RUN_RESULTS]) -> GATHERED_RESULTS: + """Nothing to collect in this job. + + Returns: + list[Any]: List of predictions. + """ + # take the first element as result is list of lists here + return results[0] + + @staticmethod + def save(results: GATHERED_RESULTS) -> None: + """Nothing is saved in this job.""" + + +class NormalizationJobGenerator(JobGenerator): + """Generate NormalizationJob. + + Args: + root_dir (Path): Root directory where statistics are saved. + """ + + def __init__(self, root_dir: Path) -> None: + self.root_dir = root_dir + + @property + def job_class(self) -> type: + """Return the job class.""" + return NormalizationJob + + def generate_jobs( + self, + args: dict | None = None, + prev_stage_result: list[Any] | None = None, + ) -> Generator[NormalizationJob, None, None]: + """Return a generator producing a single normalization job. + + Args: + args: not used here. + prev_stage_result (list[Any]): Ensemble predictions from previous step. + + Returns: + Generator[NormalizationJob, None, None]: NormalizationJob generator. + """ + del args # not needed here + + yield NormalizationJob(prev_stage_result, self.root_dir) diff --git a/src/anomalib/pipelines/tiled_ensemble/components/prediction.py b/src/anomalib/pipelines/tiled_ensemble/components/prediction.py new file mode 100644 index 0000000000..792d86a497 --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/components/prediction.py @@ -0,0 +1,228 @@ +"""Tiled ensemble - ensemble prediction job.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections.abc import Generator +from itertools import product +from pathlib import Path +from typing import Any + +from lightning import seed_everything +from torch.utils.data import DataLoader + +from anomalib.models import AnomalyModule +from anomalib.pipelines.components import Job, JobGenerator +from anomalib.pipelines.types import GATHERED_RESULTS, PREV_STAGE_RESULT + +from .utils import NormalizationStage, PredictData +from .utils.ensemble_engine import TiledEnsembleEngine +from .utils.helper_functions import ( + get_ensemble_datamodule, + get_ensemble_engine, + get_ensemble_model, + get_ensemble_tiler, +) +from .utils.prediction_data import EnsemblePredictions + +logger = logging.getLogger(__name__) + + +class PredictJob(Job): + """Job for generating predictions with individual models in the tiled ensemble. + + Args: + accelerator (str): Accelerator (device) to use. + seed (int): Random seed for reproducibility. + root_dir (Path): Root directory to save checkpoints, stats and images. + tile_index (tuple[int, int]): Index of tile that this model processes. + normalization_stage (str): Normalization stage flag. + dataloader (DataLoader): Dataloader to use for training (either val or test). + model (AnomalyModule): Model to train. + engine (TiledEnsembleEngine | None): + engine from train job. If job is used standalone, instantiate engine and model from checkpoint. + ckpt_path (Path | None): Path to checkpoint to be loaded if engine doesn't contain correct weights. + + """ + + name = "Predict" + + def __init__( + self, + accelerator: str, + seed: int, + root_dir: Path, + tile_index: tuple[int, int], + normalization_stage: str, + dataloader: DataLoader, + model: AnomalyModule | None, + engine: TiledEnsembleEngine | None, + ckpt_path: Path | None, + ) -> None: + super().__init__() + if engine is None and ckpt_path is None: + msg = "Either engine or checkpoint must be provided to predict job." + raise ValueError(msg) + + self.accelerator = accelerator + self.seed = seed + self.root_dir = root_dir + self.tile_index = tile_index + self.normalization_stage = normalization_stage + self.dataloader = dataloader + self.model = model + self.engine = engine + self.ckpt_path = ckpt_path + + def run( + self, + task_id: int | None = None, + ) -> tuple[tuple[int, int], Any | None]: + """Predict job that predicts the data with specific model for given tile location. + + Args: + task_id: Passed when job is ran in parallel. + + Returns: + tuple[tuple[int, int], list[Any]]: Tile index, List of predictions. + """ + devices: str | list[int] = "auto" + if task_id is not None: + devices = [task_id] + logger.info(f"Running job {self.model.__class__.__name__} with device {task_id}") + + logger.info("Start of predicting for tile at position %s,", self.tile_index) + seed_everything(self.seed) + + if self.engine is None: + # in case predict is invoked separately from train job, make new engine instance + self.engine = get_ensemble_engine( + tile_index=self.tile_index, + accelerator=self.accelerator, + devices=devices, + root_dir=self.root_dir, + normalization_stage=self.normalization_stage, + ) + + predictions = self.engine.predict(model=self.model, dataloaders=self.dataloader, ckpt_path=self.ckpt_path) + + # also return tile index as it's needed in collect method + return self.tile_index, predictions + + @staticmethod + def collect(results: list[tuple[tuple[int, int], list[Any]]]) -> EnsemblePredictions: + """Collect predictions from each tile location into the predictions class. + + Returns: + EnsemblePredictions: Object containing all predictions in form ready for merging. + """ + storage = EnsemblePredictions() + + for tile_index, predictions in results: + storage.add_tile_prediction(tile_index, predictions) + + return storage + + @staticmethod + def save(results: GATHERED_RESULTS) -> None: + """This stage doesn't save anything.""" + + +class PredictJobGenerator(JobGenerator): + """Generator for predict job that uses individual models to predict for each tile location. + + Args: + root_dir (Path): Root directory to save checkpoints, stats and images. + data_source (PredictData): Whether to predict on validation set. If false use test set. + """ + + def __init__( + self, + data_source: PredictData, + seed: int, + accelerator: str, + root_dir: Path, + tiling_args: dict, + data_args: dict, + model_args: dict, + normalization_stage: NormalizationStage, + ) -> None: + self.data_source = data_source + self.seed = seed + self.accelerator = accelerator + self.root_dir = root_dir + self.tiling_args = tiling_args + self.data_args = data_args + self.model_args = model_args + self.normalization_stage = normalization_stage + + @property + def job_class(self) -> type: + """Return the job class.""" + return PredictJob + + def generate_jobs( + self, + args: dict | None = None, + prev_stage_result: PREV_STAGE_RESULT = None, + ) -> Generator[PredictJob, None, None]: + """Generate predict jobs for each tile location. + + Args: + args (dict): Dict with config passed to training. + prev_stage_result (dict[tuple[int, int], TiledEnsembleEngine] | None): + if called after train job this contains engines with individual models, otherwise load from checkpoints. + + Returns: + Generator[PredictJob, None, None]: PredictJob generator. + """ + del args # args not used here + + # tiler used for splitting the image and getting the tile count + tiler = get_ensemble_tiler(self.tiling_args, self.data_args) + + logger.info( + "Tiled ensemble predicting started using %s data.", + self.data_source.value, + ) + # go over all tile positions + for tile_index in product(range(tiler.num_patches_h), range(tiler.num_patches_w)): + # prepare datamodule with custom collate function that only provides specific tile of image + datamodule = get_ensemble_datamodule(self.data_args, tiler, tile_index) + + # check if predict step is positioned after training + if prev_stage_result and tile_index in prev_stage_result: + engine = prev_stage_result[tile_index] + # model is inside engine in this case + model = engine.model + ckpt_path = None + else: + # any other case - predict is called standalone + engine = None + # we need to make new model instance as it's not inside engine + model = get_ensemble_model(self.model_args, tiler) + tile_i, tile_j = tile_index + # prepare checkpoint path for model on current tile location + ckpt_path = self.root_dir / "weights" / "lightning" / f"model{tile_i}_{tile_j}.ckpt" + + # pick the dataloader based on predict data + dataloader = datamodule.test_dataloader() + if self.data_source == PredictData.VAL: + dataloader = datamodule.val_dataloader() + # TODO(blaz-r): - this is tweak to avoid problem in engine:388 + # 2254 + dataloader.dataset.transform = None + + # pass root_dir to engine so all models in ensemble have the same root dir + yield PredictJob( + accelerator=self.accelerator, + seed=self.seed, + root_dir=self.root_dir, + tile_index=tile_index, + normalization_stage=self.normalization_stage, + model=model, + dataloader=dataloader, + engine=engine, + ckpt_path=ckpt_path, + ) diff --git a/src/anomalib/pipelines/tiled_ensemble/components/smoothing.py b/src/anomalib/pipelines/tiled_ensemble/components/smoothing.py new file mode 100644 index 0000000000..b3d5a51000 --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/components/smoothing.py @@ -0,0 +1,167 @@ +"""Tiled ensemble - seam smoothing job.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections.abc import Generator +from typing import Any + +import torch +from tqdm import tqdm + +from anomalib.models.components import GaussianBlur2d +from anomalib.pipelines.components import Job, JobGenerator +from anomalib.pipelines.types import GATHERED_RESULTS, RUN_RESULTS + +from .utils.ensemble_tiling import EnsembleTiler +from .utils.helper_functions import get_ensemble_tiler + +logger = logging.getLogger(__name__) + + +class SmoothingJob(Job): + """Job for smoothing the area around the tile seam. + + Args: + accelerator (str): Accelerator used for processing. + predictions (list[Any]): List of image-level predictions. + width_factor (float): Factor multiplied by tile dimension to get the region around seam which will be smoothed. + filter_sigma (float): Sigma of filter used for smoothing the seams. + tiler (EnsembleTiler): Tiler object used to get tile dimension data. + """ + + name = "SeamSmoothing" + + def __init__( + self, + accelerator: str, + predictions: list[Any], + width_factor: float, + filter_sigma: float, + tiler: EnsembleTiler, + ) -> None: + super().__init__() + self.accelerator = accelerator + self.predictions = predictions + + # offset in pixels of region around tile seam that will be smoothed + self.height_offset = int(tiler.tile_size_h * width_factor) + self.width_offset = int(tiler.tile_size_w * width_factor) + self.tiler = tiler + + self.seam_mask = self.prepare_seam_mask() + + self.blur = GaussianBlur2d(sigma=filter_sigma) + + def prepare_seam_mask(self) -> torch.Tensor: + """Prepare boolean mask of regions around the part where tiles seam in ensemble. + + Returns: + torch.Tensor: Representation of boolean mask where filtered seams should be used. + """ + img_h, img_w = self.tiler.image_size + stride_h, stride_w = self.tiler.stride_h, self.tiler.stride_w + + mask = torch.zeros(img_h, img_w, dtype=torch.bool) + + # prepare mask strip on vertical seams + curr_w = stride_w + while curr_w < img_w: + start_i = curr_w - self.width_offset + end_i = curr_w + self.width_offset + mask[:, start_i:end_i] = 1 + curr_w += stride_w + + # prepare mask strip on horizontal seams + curr_h = stride_h + while curr_h < img_h: + start_i = curr_h - self.height_offset + end_i = curr_h + self.height_offset + mask[start_i:end_i, :] = True + curr_h += stride_h + + return mask + + def run(self, task_id: int | None = None) -> list[Any]: + """Run smoothing job. + + Args: + task_id: Not used in this case. + + Returns: + list[Any]: List of predictions. + """ + del task_id # not needed here + + logger.info("Starting seam smoothing.") + + for data in tqdm(self.predictions, desc="Seam smoothing"): + # move to specified accelerator for faster execution + data["anomaly_maps"] = data["anomaly_maps"].to(self.accelerator) + # smooth the anomaly map and take only region around seams delimited by seam_mask + smoothed = self.blur(data["anomaly_maps"]) + data["anomaly_maps"][:, :, self.seam_mask] = smoothed[:, :, self.seam_mask] + data["anomaly_maps"] = data["anomaly_maps"].cpu() + + return self.predictions + + @staticmethod + def collect(results: list[RUN_RESULTS]) -> GATHERED_RESULTS: + """Nothing to collect in this job. + + Returns: + list[Any]: List of predictions. + """ + # take the first element as result is list of lists here + return results[0] + + @staticmethod + def save(results: GATHERED_RESULTS) -> None: + """Nothing to save in this job.""" + + +class SmoothingJobGenerator(JobGenerator): + """Generate SmoothingJob.""" + + def __init__(self, accelerator: str, tiling_args: dict, data_args: dict) -> None: + super().__init__() + self.accelerator = accelerator + self.tiling_args = tiling_args + self.data_args = data_args + + @property + def job_class(self) -> type: + """Return the job class.""" + return SmoothingJob + + def generate_jobs( + self, + args: dict | None = None, + prev_stage_result: list[Any] | None = None, + ) -> Generator[SmoothingJob, None, None]: + """Return a generator producing a single seam smoothing job. + + Args: + args: Tiled ensemble pipeline args. + prev_stage_result (list[Any]): Ensemble predictions from previous step. + + Returns: + Generator[SmoothingJob, None, None]: SmoothingJob generator + """ + if args is None: + msg = "SeamSmoothing job requires config args" + raise ValueError(msg) + # tiler is used to determine where seams appear + tiler = get_ensemble_tiler(self.tiling_args, self.data_args) + if prev_stage_result is not None: + yield SmoothingJob( + accelerator=self.accelerator, + predictions=prev_stage_result, + width_factor=args["width"], + filter_sigma=args["sigma"], + tiler=tiler, + ) + else: + msg = "Join smoothing job requires tile level predictions from previous step." + raise ValueError(msg) diff --git a/src/anomalib/pipelines/tiled_ensemble/components/stats_calculation.py b/src/anomalib/pipelines/tiled_ensemble/components/stats_calculation.py new file mode 100644 index 0000000000..6c48b639f7 --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/components/stats_calculation.py @@ -0,0 +1,180 @@ +"""Tiled ensemble - post-processing statistics calculation job.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +from collections.abc import Generator +from pathlib import Path +from typing import Any + +import torch +from omegaconf import DictConfig, ListConfig +from torchmetrics import MetricCollection +from tqdm import tqdm + +from anomalib.callbacks.thresholding import _ThresholdCallback +from anomalib.metrics import MinMax +from anomalib.metrics.threshold import Threshold +from anomalib.pipelines.components import Job, JobGenerator +from anomalib.pipelines.types import GATHERED_RESULTS, RUN_RESULTS + +logger = logging.getLogger(__name__) + + +class StatisticsJob(Job): + """Job for calculating min, max and threshold statistics for post-processing. + + Args: + predictions (list[Any]): List of image-level predictions. + root_dir (Path): Root directory to save checkpoints, stats and images. + """ + + name = "Stats" + + def __init__( + self, + predictions: list[Any] | None, + root_dir: Path, + image_threshold: Threshold, + pixel_threshold: Threshold, + ) -> None: + super().__init__() + self.predictions = predictions + self.root_dir = root_dir + self.image_threshold = image_threshold + self.pixel_threshold = pixel_threshold + + def run(self, task_id: int | None = None) -> dict: + """Run job that calculates statistics needed in post-processing steps. + + Args: + task_id: Not used in this case + + Returns: + dict: Statistics dict with min, max and threshold values. + """ + del task_id # not needed here + + minmax = MetricCollection( + { + "anomaly_maps": MinMax().cpu(), + "pred_scores": MinMax().cpu(), + }, + ) + pixel_update_called = False + + logger.info("Starting post-processing statistics calculation.") + + for data in tqdm(self.predictions, desc="Stats calculation"): + # update minmax + if "anomaly_maps" in data: + minmax["anomaly_maps"](data["anomaly_maps"]) + if "pred_scores" in data: + minmax["pred_scores"](data["pred_scores"]) + + # update thresholds + self.image_threshold.update(data["pred_scores"], data["label"].int()) + if "mask" in data and "anomaly_maps" in data: + self.pixel_threshold.update(torch.squeeze(data["anomaly_maps"]), torch.squeeze(data["mask"].int())) + pixel_update_called = True + + self.image_threshold.compute() + if pixel_update_called: + self.pixel_threshold.compute() + else: + self.pixel_threshold.value = self.image_threshold.value + + min_max_vals = {} + for pred_name, pred_metric in minmax.items(): + min_max_vals[pred_name] = { + "min": pred_metric.min.item(), + "max": pred_metric.max.item(), + } + + # return stats with save path that is later used to save statistics. + return { + "minmax": min_max_vals, + "image_threshold": self.image_threshold.value.item(), + "pixel_threshold": self.pixel_threshold.value.item(), + "save_path": (self.root_dir / "weights" / "lightning" / "stats.json"), + } + + @staticmethod + def collect(results: list[RUN_RESULTS]) -> GATHERED_RESULTS: + """Nothing to collect in this job. + + Returns: + dict: statistics dictionary. + """ + # take the first element as result is list of lists here + return results[0] + + @staticmethod + def save(results: GATHERED_RESULTS) -> None: + """Save statistics to file system.""" + # get and remove path from stats dict + stats_path: Path = results.pop("save_path") + stats_path.parent.mkdir(parents=True, exist_ok=True) + + # save statistics next to weights + with stats_path.open("w", encoding="utf-8") as stats_file: + json.dump(results, stats_file, ensure_ascii=False, indent=4) + + +class StatisticsJobGenerator(JobGenerator): + """Generate StatisticsJob. + + Args: + root_dir (Path): Root directory where statistics file will be saved (in weights folder). + """ + + def __init__( + self, + root_dir: Path, + thresholding_method: DictConfig | str | ListConfig | list[dict[str, str | float]], + ) -> None: + self.root_dir = root_dir + self.threshold = thresholding_method + + @property + def job_class(self) -> type: + """Return the job class.""" + return StatisticsJob + + def generate_jobs( + self, + args: dict | None = None, + prev_stage_result: list[Any] | None = None, + ) -> Generator[StatisticsJob, None, None]: + """Return a generator producing a single stats calculating job. + + Args: + args: Not used here. + prev_stage_result (list[Any]): Ensemble predictions from previous step. + + Returns: + Generator[StatisticsJob, None, None]: StatisticsJob generator. + """ + del args # not needed here + + # get threshold class based config + if isinstance(self.threshold, str | DictConfig): + # single method provided + image_threshold = _ThresholdCallback._get_threshold_from_config(self.threshold) # noqa: SLF001 + pixel_threshold = image_threshold.clone() + elif isinstance(self.threshold, ListConfig | list): + # image and pixel method specified separately + image_threshold = _ThresholdCallback._get_threshold_from_config(self.threshold[0]) # noqa: SLF001 + pixel_threshold = _ThresholdCallback._get_threshold_from_config(self.threshold[1]) # noqa: SLF001 + else: + msg = f"Invalid threshold config {self.threshold}" + raise TypeError(msg) + + yield StatisticsJob( + predictions=prev_stage_result, + root_dir=self.root_dir, + image_threshold=image_threshold, + pixel_threshold=pixel_threshold, + ) diff --git a/src/anomalib/pipelines/tiled_ensemble/components/thresholding.py b/src/anomalib/pipelines/tiled_ensemble/components/thresholding.py new file mode 100644 index 0000000000..733c3d99db --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/components/thresholding.py @@ -0,0 +1,114 @@ +"""Tiled ensemble - thresholding job.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections.abc import Generator +from pathlib import Path +from typing import Any + +from tqdm import tqdm + +from anomalib.pipelines.components import Job, JobGenerator +from anomalib.pipelines.types import GATHERED_RESULTS, RUN_RESULTS + +from .utils import NormalizationStage +from .utils.helper_functions import get_threshold_values + +logger = logging.getLogger(__name__) + + +class ThresholdingJob(Job): + """Job used to threshold predictions, producing labels from scores. + + Args: + predictions (list[Any]): List of predictions. + image_threshold (float): Threshold used for image-level thresholding. + pixel_threshold (float): Threshold used for pixel-level thresholding. + """ + + name = "Threshold" + + def __init__(self, predictions: list[Any] | None, image_threshold: float, pixel_threshold: float) -> None: + super().__init__() + self.predictions = predictions + self.image_threshold = image_threshold + self.pixel_threshold = pixel_threshold + + def run(self, task_id: int | None = None) -> list[Any] | None: + """Run job that produces prediction labels from scores. + + Args: + task_id: Not used in this case. + + Returns: + list[Any]: List of thresholded predictions. + """ + del task_id # not needed here + + logger.info("Starting thresholding.") + + for data in tqdm(self.predictions, desc="Thresholding"): + if "pred_scores" in data: + data["pred_labels"] = data["pred_scores"] >= self.image_threshold + if "anomaly_maps" in data: + data["pred_masks"] = data["anomaly_maps"] >= self.pixel_threshold + + return self.predictions + + @staticmethod + def collect(results: list[RUN_RESULTS]) -> GATHERED_RESULTS: + """Nothing to collect in this job. + + Returns: + list[Any]: List of predictions. + """ + # take the first element as result is list of lists here + return results[0] + + @staticmethod + def save(results: GATHERED_RESULTS) -> None: + """Nothing is saved in this job.""" + + +class ThresholdingJobGenerator(JobGenerator): + """Generate ThresholdingJob. + + Args: + root_dir (Path): Root directory containing post-processing stats. + """ + + def __init__(self, root_dir: Path, normalization_stage: NormalizationStage) -> None: + self.root_dir = root_dir + self.normalization_stage = normalization_stage + + @property + def job_class(self) -> type: + """Return the job class.""" + return ThresholdingJob + + def generate_jobs( + self, + args: dict | None = None, + prev_stage_result: list[Any] | None = None, + ) -> Generator[ThresholdingJob, None, None]: + """Return a generator producing a single thresholding job. + + Args: + args: ensemble run args. + prev_stage_result (list[Any]): Ensemble predictions from previous step. + + Returns: + Generator[ThresholdingJob, None, None]: ThresholdingJob generator. + """ + del args # args not used here + + # get threshold values base on normalization + image_threshold, pixel_threshold = get_threshold_values(self.normalization_stage, self.root_dir) + + yield ThresholdingJob( + predictions=prev_stage_result, + image_threshold=image_threshold, + pixel_threshold=pixel_threshold, + ) diff --git a/src/anomalib/pipelines/tiled_ensemble/components/utils/__init__.py b/src/anomalib/pipelines/tiled_ensemble/components/utils/__init__.py new file mode 100644 index 0000000000..a010208908 --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/components/utils/__init__.py @@ -0,0 +1,44 @@ +"""Tiled ensemble utils and helper functions.""" + +from enum import Enum + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +class NormalizationStage(str, Enum): + """Enum signaling at which stage the normalization is done. + + In case of tile, tiles are normalized for each tile position separately. + In case of image, normalization is done at the end when images are joined back together. + In case of none, output is not normalized. + """ + + TILE = "tile" + IMAGE = "image" + NONE = "none" + + +class ThresholdStage(str, Enum): + """Enum signaling at which stage the thresholding is applied. + + In case of tile, thresholding is applied for each tile location separately. + In case of image, thresholding is applied at the end when images are joined back together. + """ + + TILE = "tile" + IMAGE = "image" + + +class PredictData(Enum): + """Enum indicating which data to use in prediction job.""" + + VAL = "val" + TEST = "test" + + +__all__ = [ + "NormalizationStage", + "ThresholdStage", + "PredictData", +] diff --git a/src/anomalib/pipelines/tiled_ensemble/components/utils/ensemble_engine.py b/src/anomalib/pipelines/tiled_ensemble/components/utils/ensemble_engine.py new file mode 100644 index 0000000000..449109ed3f --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/components/utils/ensemble_engine.py @@ -0,0 +1,92 @@ +"""Implements custom Anomalib engine for tiled ensemble training.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from pathlib import Path + +from lightning.pytorch.callbacks import Callback, RichModelSummary + +from anomalib.callbacks import ModelCheckpoint, TimerCallback +from anomalib.callbacks.metrics import _MetricsCallback +from anomalib.callbacks.normalization import get_normalization_callback +from anomalib.callbacks.post_processor import _PostProcessorCallback +from anomalib.callbacks.thresholding import _ThresholdCallback +from anomalib.engine import Engine +from anomalib.models import AnomalyModule +from anomalib.utils.path import create_versioned_dir + +logger = logging.getLogger(__name__) + + +class TiledEnsembleEngine(Engine): + """Engine used for training and evaluating tiled ensemble. + + Most of the logic stays the same, but workspace creation and callbacks are adjusted for ensemble. + + Args: + tile_index (tuple[int, int]): index of tile that this engine instance processes. + **kwargs: Engine arguments. + """ + + def __init__(self, tile_index: tuple[int, int], **kwargs) -> None: + self.tile_index = tile_index + super().__init__(**kwargs) + + def _setup_workspace(self, *args, **kwargs) -> None: + """Skip since in case of tiled ensemble, workspace is only setup once at the beginning of training.""" + + @staticmethod + def setup_ensemble_workspace(args: dict, versioned_dir: bool = True) -> Path: + """Set up the workspace at the beginning of tiled ensemble training. + + Args: + args (dict): Tiled ensemble config dict. + versioned_dir (bool, optional): Whether to create a versioned directory. + Defaults to ``True``. + + Returns: + Path: path to new workspace root dir + """ + model_name = args["TrainModels"]["model"]["class_path"].split(".")[-1] + dataset_name = args["data"]["class_path"].split(".")[-1] + category = args["data"]["init_args"]["category"] + root_dir = Path(args["default_root_dir"]) / model_name / dataset_name / category + return create_versioned_dir(root_dir) if versioned_dir else root_dir / "latest" + + def _setup_anomalib_callbacks(self, model: AnomalyModule) -> None: + """Modified method to enable individual model training. It's called when Trainer is being set up.""" + del model # not used here + + _callbacks: list[Callback] = [RichModelSummary()] + + # Add ModelCheckpoint if it is not in the callbacks list. + has_checkpoint_callback = any(isinstance(c, ModelCheckpoint) for c in self._cache.args["callbacks"]) + if not has_checkpoint_callback: + tile_i, tile_j = self.tile_index + _callbacks.append( + ModelCheckpoint( + dirpath=self._cache.args["default_root_dir"] / "weights" / "lightning", + filename=f"model{tile_i}_{tile_j}", + auto_insert_metric_name=False, + ), + ) + + # Add the post-processor callbacks. Used for thresholding and label calculation. + _callbacks.append(_PostProcessorCallback()) + + # Add the normalization callback if tile level normalization was specified (is not none). + normalization_callback = get_normalization_callback(self.normalization) + if normalization_callback is not None: + _callbacks.append(normalization_callback) + + # Add the thresholding and metrics callbacks in all cases, + # because individual model might still need this for early stop. + _callbacks.append(_ThresholdCallback(self.threshold)) + _callbacks.append(_MetricsCallback(self.task, self.image_metric_names, self.pixel_metric_names)) + + _callbacks.append(TimerCallback()) + + # Combine the callbacks, and update the trainer callbacks. + self._cache.args["callbacks"] = _callbacks + self._cache.args["callbacks"] diff --git a/src/anomalib/pipelines/tiled_ensemble/components/utils/ensemble_tiling.py b/src/anomalib/pipelines/tiled_ensemble/components/utils/ensemble_tiling.py new file mode 100644 index 0000000000..db56f88b47 --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/components/utils/ensemble_tiling.py @@ -0,0 +1,147 @@ +"""Tiler used with ensemble of models.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Sequence +from typing import Any + +from torch import Tensor + +from anomalib.data.base.datamodule import collate_fn +from anomalib.data.utils.tiler import Tiler, compute_new_image_size + + +class EnsembleTiler(Tiler): + """Tile Image into (non)overlapping Patches which are then used for ensemble training. + + Args: + tile_size (int | Sequence): Tile dimension for each patch. + stride (int | Sequence): Stride length between patches. + image_size (int | Sequence): Size of input image that will be tiled. + + Examples: + >>> import torch + >>> tiler = EnsembleTiler(tile_size=256, stride=128, image_size=512) + >>> + >>> # random images, shape: [B, C, H, W] + >>> images = torch.rand(32, 5, 512, 512) + >>> # once tiled, the shape is [tile_count_H, tile_count_W, B, C, tile_H, tile_W] + >>> tiled = tiler.tile(images) + >>> tiled.shape + torch.Size([3, 3, 32, 5, 256, 256]) + + >>> # assemble the tiles back together + >>> untiled = tiler.untile(tiled) + >>> untiled.shape + torch.Size([32, 5, 512, 512]) + """ + + def __init__(self, tile_size: int | Sequence, stride: int | Sequence, image_size: int | Sequence) -> None: + super().__init__( + tile_size=tile_size, + stride=stride, + ) + + # calculate final image size + self.image_size = self.validate_size_type(image_size) + self.input_h, self.input_w = self.image_size + self.resized_h, self.resized_w = compute_new_image_size( + image_size=(self.input_h, self.input_w), + tile_size=(self.tile_size_h, self.tile_size_w), + stride=(self.stride_h, self.stride_w), + ) + + # get number of patches in both dimensions + self.num_patches_h = int((self.resized_h - self.tile_size_h) / self.stride_h) + 1 + self.num_patches_w = int((self.resized_w - self.tile_size_w) / self.stride_w) + 1 + self.num_tiles = self.num_patches_h * self.num_patches_w + + def tile(self, image: Tensor, use_random_tiling: bool = False) -> Tensor: + """Tiles an input image to either overlapping or non-overlapping patches. + + Args: + image (Tensor): Input images. + use_random_tiling (bool): Random tiling, which is part of original tiler but is unused here. + + Returns: + Tensor: Tiles generated from images. + Returned shape: [num_h, num_w, batch, channel, tile_height, tile_width]. + """ + # tiles are returned in order [tile_count * batch, channels, tile_height, tile_width] + combined_tiles = super().tile(image, use_random_tiling) + + # rearrange to [num_h, num_w, batch, channel, tile_height, tile_width] + tiles = combined_tiles.contiguous().view( + self.batch_size, + self.num_patches_h, + self.num_patches_w, + self.num_channels, + self.tile_size_h, + self.tile_size_w, + ) + tiles = tiles.permute(1, 2, 0, 3, 4, 5) + + return tiles # noqa: RET504 + + def untile(self, tiles: Tensor) -> Tensor: + """Reassemble the tiled tensor into image level representation. + + Args: + tiles (Tensor): Tiles in shape: [num_h, num_w, batch, channel, tile_height, tile_width]. + + Returns: + Tensor: Image constructed from input tiles. Shape: [B, C, H, W]. + """ + # tiles have shape [num_h, num_w, batch, channel, tile_height, tile_width] + _, _, batch, channels, tile_size_h, tile_size_w = tiles.shape + + # set tilers batch size as it might have been changed by previous tiling + self.batch_size = batch + + # rearrange the tiles in order [tile_count * batch, channels, tile_height, tile_width] + # the required shape for untiling + tiles = tiles.permute(2, 0, 1, 3, 4, 5) + tiles = tiles.contiguous().view(-1, channels, tile_size_h, tile_size_w) + + untiled = super().untile(tiles) + + return untiled # noqa: RET504 + + +class TileCollater: + """Class serving as collate function to perform tiling on batch of images from Dataloader. + + Args: + tiler (EnsembleTiler): Tiler used to split the images to tiles. + tile_index (tuple[int, int]): Index of tile we want to return. + """ + + def __init__(self, tiler: EnsembleTiler, tile_index: tuple[int, int]) -> None: + self.tiler = tiler + self.tile_index = tile_index + + def __call__(self, batch: list) -> dict[str, Any]: + """Collate batch and tile images + masks from batch. + + Args: + batch (list): Batch of elements from data, also including images. + + Returns: + dict[str, Any]: Collated batch dictionary with tiled images. + """ + # use default collate + coll_batch = collate_fn(batch) + + tiled_images = self.tiler.tile(coll_batch["image"]) + # return only tiles at given index + coll_batch["image"] = tiled_images[self.tile_index] + + if "mask" in coll_batch: + # insert channel (as mask has just one) + tiled_masks = self.tiler.tile(coll_batch["mask"].unsqueeze(1)) + + # return only tiled at given index, squeeze to remove previously added channel + coll_batch["mask"] = tiled_masks[self.tile_index].squeeze(1) + + return coll_batch diff --git a/src/anomalib/pipelines/tiled_ensemble/components/utils/helper_functions.py b/src/anomalib/pipelines/tiled_ensemble/components/utils/helper_functions.py new file mode 100644 index 0000000000..bc1e5f4f55 --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/components/utils/helper_functions.py @@ -0,0 +1,179 @@ +"""Helper functions for the tiled ensemble training.""" + +import json + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +from pathlib import Path + +from jsonargparse import ArgumentParser, Namespace +from lightning import Trainer + +from anomalib.data import AnomalibDataModule, get_datamodule +from anomalib.models import AnomalyModule, get_model +from anomalib.utils.normalization import NormalizationMethod + +from . import NormalizationStage +from .ensemble_engine import TiledEnsembleEngine +from .ensemble_tiling import EnsembleTiler, TileCollater + + +def get_ensemble_datamodule(data_args: dict, tiler: EnsembleTiler, tile_index: tuple[int, int]) -> AnomalibDataModule: + """Get Anomaly Datamodule adjusted for use in ensemble. + + Datamodule collate function gets replaced by TileCollater in order to tile all images before they are passed on. + + Args: + data_args: tiled ensemble data configuration. + tiler (EnsembleTiler): Tiler used to split the images to tiles for use in ensemble. + tile_index (tuple[int, int]): Index of the tile in the split image. + + Returns: + AnomalibDataModule: Anomalib Lightning DataModule + """ + datamodule = get_datamodule(data_args) + # set custom collate function that does the tiling + datamodule.collate_fn = TileCollater(tiler, tile_index) + datamodule.setup() + + return datamodule + + +def get_ensemble_model(model_args: dict, tiler: EnsembleTiler) -> AnomalyModule: + """Get model prepared for ensemble training. + + Args: + model_args: tiled ensemble model configuration. + tiler (EnsembleTiler): tiler used to get tile dimensions. + + Returns: + AnomalyModule: model with input_size setup + """ + model = get_model(model_args) + # set model input size match tile size + model.set_input_size((tiler.tile_size_h, tiler.tile_size_w)) + + return model + + +def get_ensemble_tiler(tiling_args: dict, data_args: dict) -> EnsembleTiler: + """Get tiler used for image tiling and to obtain tile dimensions. + + Args: + tiling_args: tiled ensemble tiling configuration. + data_args: tiled ensemble data configuration. + + Returns: + EnsembleTiler: tiler object. + """ + tiler = EnsembleTiler( + tile_size=tiling_args["tile_size"], + stride=tiling_args["stride"], + image_size=data_args["init_args"]["image_size"], + ) + + return tiler # noqa: RET504 + + +def parse_trainer_kwargs(trainer_args: dict | None) -> Namespace | dict: + """Parse trainer args and instantiate all needed elements. + + Transforms config into kwargs ready for Trainer, including instantiation of callback etc. + + Args: + trainer_args (dict): Trainer args dictionary. + + Returns: + dict: parsed kwargs with instantiated elements. + """ + if not trainer_args: + return {} + + # try to get trainer args, if not present return empty + parser = ArgumentParser() + + parser.add_class_arguments(Trainer, fail_untyped=False, instantiate=False, sub_configs=True) + config = parser.parse_object(trainer_args) + objects = parser.instantiate_classes(config) + + return objects # noqa: RET504 + + +def get_ensemble_engine( + tile_index: tuple[int, int], + accelerator: str, + devices: list[int] | str | int, + root_dir: Path, + normalization_stage: str, + metrics: dict | None = None, + trainer_args: dict | None = None, +) -> TiledEnsembleEngine: + """Prepare engine for ensemble training or prediction. + + This method makes sure correct normalization is used, prepares metrics and additional trainer kwargs.. + + Args: + tile_index (tuple[int, int]): Index of tile that this model processes. + accelerator (str): Accelerator (device) to use. + devices (list[int] | str | int): device IDs used for training. + root_dir (Path): Root directory to save checkpoints, stats and images. + normalization_stage (str): Config dictionary for ensemble post-processing. + metrics (dict): Dict containing pixel and image metrics names. + trainer_args (dict): Trainer args dictionary. Empty dict if not present. + + Returns: + TiledEnsembleEngine: set up engine for ensemble training/prediction. + """ + # if we want tile level normalization we set it here, otherwise it's done later on joined images + if normalization_stage == NormalizationStage.TILE: + normalization = NormalizationMethod.MIN_MAX + else: + normalization = NormalizationMethod.NONE + + # parse additional trainer args and callbacks if present in config + trainer_kwargs = parse_trainer_kwargs(trainer_args) + # remove keys that we already have + trainer_kwargs.pop("accelerator", None) + trainer_kwargs.pop("default_root_dir", None) + trainer_kwargs.pop("devices", None) + + # create engine for specific tile location + engine = TiledEnsembleEngine( + tile_index=tile_index, + normalization=normalization, + accelerator=accelerator, + devices=devices, + default_root_dir=root_dir, + image_metrics=metrics.get("image", None) if metrics else None, + pixel_metrics=metrics.get("pixel", None) if metrics else None, + **trainer_kwargs, + ) + + return engine # noqa: RET504 + + +def get_threshold_values(normalization_stage: NormalizationStage, root_dir: Path) -> tuple[float, float]: + """Get threshold values for image and pixel level predictions. + + If normalization is not used, get values based on statistics obtained from validation set. + If normalization is used, both image and pixel threshold are 0.5 + + Args: + normalization_stage (NormalizationStage): ensemble run args, used to get normalization stage. + root_dir (Path): path to run root where stats file is saved. + + Returns: + tuple[float, float]: image and pixel threshold. + """ + if normalization_stage == NormalizationStage.NONE: + stats_path = root_dir / "weights" / "lightning" / "stats.json" + with stats_path.open("r") as f: + stats = json.load(f) + image_threshold = stats["image_threshold"] + pixel_threshold = stats["pixel_threshold"] + else: + # normalization transforms the scores so that threshold is at 0.5 + image_threshold = 0.5 + pixel_threshold = 0.5 + + return image_threshold, pixel_threshold diff --git a/src/anomalib/pipelines/tiled_ensemble/components/utils/prediction_data.py b/src/anomalib/pipelines/tiled_ensemble/components/utils/prediction_data.py new file mode 100644 index 0000000000..4fe45e9c4a --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/components/utils/prediction_data.py @@ -0,0 +1,45 @@ +"""Classes used to store ensemble predictions.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from torch import Tensor + + +class EnsemblePredictions: + """Basic implementation of EnsemblePredictionData that keeps all predictions in main memory.""" + + def __init__(self) -> None: + super().__init__() + self.all_data: dict[tuple[int, int], list] = {} + + def add_tile_prediction(self, tile_index: tuple[int, int], tile_prediction: list[dict[str, Tensor | list]]) -> None: + """Add tile prediction data at provided index to class dictionary in main memory. + + Args: + tile_index (tuple[int, int]): Index of tile that we are adding in form (row, column). + tile_prediction (list[dict[str, Tensor | list]]): + List of batches containing all predicted data for current tile position. + + """ + self.num_batches = len(tile_prediction) + + self.all_data[tile_index] = tile_prediction + + def get_batch_tiles(self, batch_index: int) -> dict[tuple[int, int], dict]: + """Get all tiles of current batch from class dictionary. + + Called by merging mechanism. + + Args: + batch_index (int): Index of current batch of tiles to be returned. + + Returns: + dict[tuple[int, int], dict]: Dictionary mapping tile index to predicted data, for provided batch index. + """ + batch_data = {} + + for index, batches in self.all_data.items(): + batch_data[index] = batches[batch_index] + + return batch_data diff --git a/src/anomalib/pipelines/tiled_ensemble/components/utils/prediction_merging.py b/src/anomalib/pipelines/tiled_ensemble/components/utils/prediction_merging.py new file mode 100644 index 0000000000..7337cc4ffe --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/components/utils/prediction_merging.py @@ -0,0 +1,167 @@ +"""Class used as mechanism to merge ensemble predictions from each tile into complete whole-image representation.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch import Tensor + +from .ensemble_tiling import EnsembleTiler +from .prediction_data import EnsemblePredictions + + +class PredictionMergingMechanism: + """Class used for merging the data predicted by each separate model of tiled ensemble. + + Tiles are stacked in one tensor and untiled using Ensemble Tiler. + Boxes from tiles are either stacked or generated anew from anomaly map. + Labels are combined with OR operator, meaning one anomalous tile -> anomalous image. + Scores are averaged across all tiles. + + Args: + ensemble_predictions (EnsemblePredictions): Object containing predictions on tile level. + tiler (EnsembleTiler): Tiler used to transform tiles back to image level representation. + + Example: + >>> from anomalib.pipelines.tiled_ensemble.components.utils.ensemble_tiling import EnsembleTiler + >>> from anomalib.pipelines.tiled_ensemble.components.utils.prediction_data import EnsemblePredictions + >>> + >>> tiler = EnsembleTiler(tile_size=256, stride=128, image_size=512) + >>> data = EnsemblePredictions() + >>> merger = PredictionMergingMechanism(data, tiler) + >>> + >>> # we can then start merging procedure for each batch + >>> merger.merge_tile_predictions(0) + """ + + def __init__(self, ensemble_predictions: EnsemblePredictions, tiler: EnsembleTiler) -> None: + assert ensemble_predictions.num_batches > 0, "There should be at least one batch for each tile prediction." + assert (0, 0) in ensemble_predictions.get_batch_tiles( + 0, + ), "Tile prediction dictionary should always have at least one tile" + + self.ensemble_predictions = ensemble_predictions + self.num_batches = self.ensemble_predictions.num_batches + + self.tiler = tiler + + def merge_tiles(self, batch_data: dict, tile_key: str) -> Tensor: + """Merge tiles back into one tensor and perform untiling with tiler. + + Args: + batch_data (dict): Dictionary containing all tile predictions of current batch. + tile_key (str): Key used in prediction dictionary for tiles that we want to merge. + + Returns: + Tensor: Tensor of tiles in original (stitched) shape. + """ + # batch of tiles with index (0, 0) always exists, so we use it to get some basic information + first_tiles = batch_data[0, 0][tile_key] + batch_size = first_tiles.shape[0] + device = first_tiles.device + + if tile_key == "mask": + # in case of ground truth masks, we don't have channels + merged_size = [ + self.tiler.num_patches_h, + self.tiler.num_patches_w, + batch_size, + self.tiler.tile_size_h, + self.tiler.tile_size_w, + ] + else: + # all tiles beside masks also have channels + num_channels = first_tiles.shape[1] + merged_size = [ + self.tiler.num_patches_h, + self.tiler.num_patches_w, + batch_size, + int(num_channels), + self.tiler.tile_size_h, + self.tiler.tile_size_w, + ] + + # create new empty tensor for merged tiles + merged_masks = torch.zeros(size=merged_size, device=device) + + # insert tile into merged tensor at right locations + for (tile_i, tile_j), tile_data in batch_data.items(): + merged_masks[tile_i, tile_j, ...] = tile_data[tile_key] + + if tile_key == "mask": + # add channel as tiler needs it + merged_masks = merged_masks.unsqueeze(3) + + # stitch tiles back into whole, output is [B, C, H, W] + merged_output = self.tiler.untile(merged_masks) + + if tile_key == "mask": + # remove previously added channels + merged_output = merged_output.squeeze(1) + + return merged_output + + def merge_labels_and_scores(self, batch_data: dict) -> dict[str, Tensor]: + """Join scores and their corresponding label predictions from all tiles for each image. + + Label merging is done by rule where one anomalous tile in image results in whole image being anomalous. + Scores are averaged over tiles. + + Args: + batch_data (dict): Dictionary containing all tile predictions of current batch. + + Returns: + dict[str, Tensor]: Dictionary with "pred_labels" and "pred_scores" + """ + # create accumulator with same shape as original + labels = torch.zeros(batch_data[0, 0]["pred_labels"].shape, dtype=torch.bool) + scores = torch.zeros(batch_data[0, 0]["pred_scores"].shape) + + for curr_tile_data in batch_data.values(): + curr_labels = curr_tile_data["pred_labels"] + curr_scores = curr_tile_data["pred_scores"] + + labels = labels.logical_or(curr_labels) + scores += curr_scores + + scores /= self.tiler.num_tiles + + return {"pred_labels": labels, "pred_scores": scores} + + def merge_tile_predictions(self, batch_index: int) -> dict[str, Tensor | list]: + """Join predictions from ensemble into whole image level representation for batch at index batch_index. + + Args: + batch_index (int): Index of current batch. + + Returns: + dict[str, Tensor | list]: List of merged predictions for specified batch. + """ + current_batch_data = self.ensemble_predictions.get_batch_tiles(batch_index) + + # take first tile as base prediction, keep items that are the same over all tiles: + # image_path, label, mask_path + merged_predictions = { + "image_path": current_batch_data[0, 0]["image_path"], + "label": current_batch_data[0, 0]["label"], + } + if "mask_path" in current_batch_data[0, 0]: + merged_predictions["mask_path"] = current_batch_data[0, 0]["mask_path"] + if "boxes" in current_batch_data[0, 0]: + merged_predictions["boxes"] = current_batch_data[0, 0]["boxes"] + + tiled_data = ["image", "mask"] + if "anomaly_maps" in current_batch_data[0, 0]: + tiled_data += ["anomaly_maps", "pred_masks"] + + # merge all tiled data + for t_key in tiled_data: + if t_key in current_batch_data[0, 0]: + merged_predictions[t_key] = self.merge_tiles(current_batch_data, t_key) + + # label and score merging + merged_scores_and_labels = self.merge_labels_and_scores(current_batch_data) + merged_predictions["pred_labels"] = merged_scores_and_labels["pred_labels"] + merged_predictions["pred_scores"] = merged_scores_and_labels["pred_scores"] + + return merged_predictions diff --git a/src/anomalib/pipelines/tiled_ensemble/components/visualization.py b/src/anomalib/pipelines/tiled_ensemble/components/visualization.py new file mode 100644 index 0000000000..1298ece89f --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/components/visualization.py @@ -0,0 +1,125 @@ +"""Tiled ensemble - visualization job.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections.abc import Generator +from pathlib import Path +from typing import Any + +from tqdm import tqdm + +from anomalib import TaskType +from anomalib.data.utils.image import save_image +from anomalib.pipelines.components import Job, JobGenerator +from anomalib.pipelines.tiled_ensemble.components.utils import NormalizationStage +from anomalib.pipelines.types import GATHERED_RESULTS, RUN_RESULTS +from anomalib.utils.visualization import ImageVisualizer + +logger = logging.getLogger(__name__) + + +class VisualizationJob(Job): + """Job for visualization of predictions. + + Args: + predictions (list[Any]): list of image-level predictions. + root_dir (Path): Root directory to save checkpoints, stats and images. + task (TaskType): type of task the predictions represent. + normalize (bool): if predictions need to be normalized + """ + + name = "Visualize" + + def __init__(self, predictions: list[Any], root_dir: Path, task: TaskType, normalize: bool) -> None: + super().__init__() + self.predictions = predictions + self.root_dir = root_dir / "images" + self.task = task + self.normalize = normalize + + def run(self, task_id: int | None = None) -> list[Any]: + """Run job that visualizes all prediction data. + + Args: + task_id: Not used in this case. + + Returns: + list[Any]: Unchanged predictions. + """ + del task_id # not needed here + + visualizer = ImageVisualizer(task=self.task, normalize=self.normalize) + + logger.info("Starting visualization.") + + for data in tqdm(self.predictions, desc="Visualizing"): + for result in visualizer(outputs=data): + # Finally image path is root/defect_type/image_name + if result.file_name is not None: + file_path = Path(result.file_name) + else: + msg = "file_path should exist in returned Visualizer." + raise ValueError(msg) + + root = self.root_dir / file_path.parent.name + filename = file_path.name + + save_image(image=result.image, root=root, filename=filename) + + return self.predictions + + @staticmethod + def collect(results: list[RUN_RESULTS]) -> GATHERED_RESULTS: + """Nothing to collect in this job. + + Returns: + list[Any]: Unchanged list of predictions. + """ + # take the first element as result is list of lists here + return results[0] + + @staticmethod + def save(results: GATHERED_RESULTS) -> None: + """This job doesn't save anything.""" + + +class VisualizationJobGenerator(JobGenerator): + """Generate VisualizationJob. + + Args: + root_dir (Path): Root directory where images will be saved (root/images). + """ + + def __init__(self, root_dir: Path, task: TaskType, normalization_stage: NormalizationStage) -> None: + self.root_dir = root_dir + self.task = task + self.normalize = normalization_stage == NormalizationStage.NONE + + @property + def job_class(self) -> type: + """Return the job class.""" + return VisualizationJob + + def generate_jobs( + self, + args: dict | None = None, + prev_stage_result: list[Any] | None = None, + ) -> Generator[VisualizationJob, None, None]: + """Return a generator producing a single visualization job. + + Args: + args: Ensemble run args. + prev_stage_result (list[Any]): Ensemble predictions from previous step. + + Returns: + Generator[VisualizationJob, None, None]: VisualizationJob generator + """ + del args # args not used here + + if prev_stage_result is not None: + yield VisualizationJob(prev_stage_result, self.root_dir, self.task, self.normalize) + else: + msg = "Visualization job requires tile level predictions from previous step." + raise ValueError(msg) diff --git a/src/anomalib/pipelines/tiled_ensemble/test_pipeline.py b/src/anomalib/pipelines/tiled_ensemble/test_pipeline.py new file mode 100644 index 0000000000..7fdd61e9ff --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/test_pipeline.py @@ -0,0 +1,124 @@ +"""Tiled ensemble test pipeline.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from pathlib import Path + +import torch + +from anomalib.data.utils import TestSplitMode +from anomalib.pipelines.components.base import Pipeline, Runner +from anomalib.pipelines.components.runners import ParallelRunner, SerialRunner +from anomalib.pipelines.tiled_ensemble.components import ( + MergeJobGenerator, + MetricsCalculationJobGenerator, + NormalizationJobGenerator, + PredictJobGenerator, + SmoothingJobGenerator, + ThresholdingJobGenerator, + VisualizationJobGenerator, +) +from anomalib.pipelines.tiled_ensemble.components.utils import NormalizationStage, PredictData, ThresholdStage + +logger = logging.getLogger(__name__) + + +class EvalTiledEnsemble(Pipeline): + """Tiled ensemble evaluation pipeline. + + Args: + root_dir (Path): Path to root dir of run that contains checkpoints. + """ + + def __init__(self, root_dir: Path) -> None: + self.root_dir = Path(root_dir) + + def _setup_runners(self, args: dict) -> list[Runner]: + """Set up the runners for the pipeline. + + This pipeline consists of jobs used to test/evaluate tiled ensemble: + Prediction on test data > merging of predictions > (optional) seam smoothing + > (optional) Normalization > (optional) Thresholding + > Visualisation of predictions > Metrics calculation. + + Returns: + list[Runner]: List of runners executing tiled ensemble testing jobs. + """ + runners: list[Runner] = [] + + if args["data"]["init_args"]["test_split_mode"] == TestSplitMode.NONE: + logger.info("Test split mode set to `none`, skipping test phase.") + return runners + + seed = args["seed"] + accelerator = args["accelerator"] + tiling_args = args["tiling"] + data_args = args["data"] + normalization_stage = NormalizationStage(args["normalization_stage"]) + threshold_stage = ThresholdStage(args["thresholding"]["stage"]) + model_args = args["TrainModels"]["model"] + task = args["data"]["init_args"]["task"] + metrics = args["TrainModels"]["metrics"] + + predict_job_generator = PredictJobGenerator( + PredictData.TEST, + seed=seed, + accelerator=accelerator, + root_dir=self.root_dir, + tiling_args=tiling_args, + data_args=data_args, + model_args=model_args, + normalization_stage=normalization_stage, + ) + # 1. predict using test data + if accelerator == "cuda": + runners.append( + ParallelRunner( + predict_job_generator, + n_jobs=torch.cuda.device_count(), + ), + ) + else: + runners.append( + SerialRunner( + predict_job_generator, + ), + ) + # 2. merge predictions + runners.append(SerialRunner(MergeJobGenerator(tiling_args=tiling_args, data_args=data_args))) + + # 3. (optional) smooth seams + if args["SeamSmoothing"]["apply"]: + runners.append( + SerialRunner( + SmoothingJobGenerator(accelerator=accelerator, tiling_args=tiling_args, data_args=data_args), + ), + ) + + # 4. (optional) normalize + if normalization_stage == NormalizationStage.IMAGE: + runners.append(SerialRunner(NormalizationJobGenerator(self.root_dir))) + # 5. (optional) threshold to get labels from scores + if threshold_stage == ThresholdStage.IMAGE: + runners.append(SerialRunner(ThresholdingJobGenerator(self.root_dir, normalization_stage))) + + # 6. visualize predictions + runners.append( + SerialRunner(VisualizationJobGenerator(self.root_dir, task=task, normalization_stage=normalization_stage)), + ) + # calculate metrics + runners.append( + SerialRunner( + MetricsCalculationJobGenerator( + accelerator=accelerator, + root_dir=self.root_dir, + task=task, + metrics=metrics, + normalization_stage=normalization_stage, + ), + ), + ) + + return runners diff --git a/src/anomalib/pipelines/tiled_ensemble/train_pipeline.py b/src/anomalib/pipelines/tiled_ensemble/train_pipeline.py new file mode 100644 index 0000000000..38e4e34e4b --- /dev/null +++ b/src/anomalib/pipelines/tiled_ensemble/train_pipeline.py @@ -0,0 +1,123 @@ +"""Tiled ensemble training pipeline.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING + +from anomalib.data.utils import ValSplitMode + +if TYPE_CHECKING: + from pathlib import Path + +import logging + +import torch + +from anomalib.pipelines.components.base import Pipeline, Runner +from anomalib.pipelines.components.runners import ParallelRunner, SerialRunner + +from .components import ( + MergeJobGenerator, + PredictJobGenerator, + SmoothingJobGenerator, + StatisticsJobGenerator, + TrainModelJobGenerator, +) +from .components.utils import NormalizationStage, PredictData +from .components.utils.ensemble_engine import TiledEnsembleEngine + +logger = logging.getLogger(__name__) + + +class TrainTiledEnsemble(Pipeline): + """Tiled ensemble training pipeline.""" + + def __init__(self) -> None: + self.root_dir: Path + + def _setup_runners(self, args: dict) -> list[Runner]: + """Setup the runners for the pipeline. + + This pipeline consists of training and validation steps: + Training models > prediction on val data > merging val data > + > (optionally) smoothing seams > calculation of post-processing statistics + + Returns: + list[Runner]: List of runners executing tiled ensemble train + val jobs. + """ + runners: list[Runner] = [] + self.root_dir = TiledEnsembleEngine.setup_ensemble_workspace(args) + + seed = args["seed"] + accelerator = args["accelerator"] + tiling_args = args["tiling"] + data_args = args["data"] + normalization_stage = NormalizationStage(args["normalization_stage"]) + thresholding_method = args["thresholding"]["method"] + model_args = args["TrainModels"]["model"] + + train_job_generator = TrainModelJobGenerator( + seed=seed, + accelerator=accelerator, + root_dir=self.root_dir, + tiling_args=tiling_args, + data_args=data_args, + normalization_stage=normalization_stage, + ) + + predict_job_generator = PredictJobGenerator( + data_source=PredictData.VAL, + seed=seed, + accelerator=accelerator, + root_dir=self.root_dir, + tiling_args=tiling_args, + data_args=data_args, + model_args=model_args, + normalization_stage=normalization_stage, + ) + + # 1. train + if accelerator == "cuda": + runners.append( + ParallelRunner( + train_job_generator, + n_jobs=torch.cuda.device_count(), + ), + ) + else: + runners.append( + SerialRunner( + train_job_generator, + ), + ) + + if data_args["init_args"]["val_split_mode"] == ValSplitMode.NONE: + logger.warning("No validation set provided, skipping statistics calculation.") + return runners + + # 2. predict using validation data + if accelerator == "cuda": + runners.append( + ParallelRunner(predict_job_generator, n_jobs=torch.cuda.device_count()), + ) + else: + runners.append( + SerialRunner(predict_job_generator), + ) + + # 3. merge predictions + runners.append(SerialRunner(MergeJobGenerator(tiling_args=tiling_args, data_args=data_args))) + + # 4. (optional) smooth seams + if args["SeamSmoothing"]["apply"]: + runners.append( + SerialRunner( + SmoothingJobGenerator(accelerator=accelerator, tiling_args=tiling_args, data_args=data_args), + ), + ) + + # 5. calculate statistics used for inference + runners.append(SerialRunner(StatisticsJobGenerator(self.root_dir, thresholding_method))) + + return runners diff --git a/tests/integration/pipelines/test_tiled_ensemble.py b/tests/integration/pipelines/test_tiled_ensemble.py new file mode 100644 index 0000000000..2909311276 --- /dev/null +++ b/tests/integration/pipelines/test_tiled_ensemble.py @@ -0,0 +1,62 @@ +"""Test tiled ensemble training and prediction.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +import pytest +import yaml + +from anomalib.pipelines.tiled_ensemble import EvalTiledEnsemble, TrainTiledEnsemble + + +@pytest.fixture(scope="session") +def get_mock_environment(dataset_path: Path, project_path: Path) -> Path: + """Return mock directory for testing with datapath setup to dummy data.""" + ens_temp_dir = project_path / "ens_tmp" + ens_temp_dir.mkdir(exist_ok=True) + + with Path("tests/integration/pipelines/tiled_ensemble.yaml").open(encoding="utf-8") as file: + config = yaml.safe_load(file) + + # use separate project temp dir to avoid messing with other tests + config["default_root_dir"] = str(ens_temp_dir) + config["data"]["init_args"]["root"] = str(dataset_path / "mvtec") + + with (Path(ens_temp_dir) / "tiled_ensemble.yaml").open("w", encoding="utf-8") as file: + yaml.safe_dump(config, file) + + return Path(ens_temp_dir) + + +def test_train(get_mock_environment: Path, capsys: pytest.CaptureFixture) -> None: + """Test training of the tiled ensemble.""" + train_pipeline = TrainTiledEnsemble() + train_parser = train_pipeline.get_parser() + args = train_parser.parse_args(["--config", str(get_mock_environment / "tiled_ensemble.yaml")]) + train_pipeline.run(args) + # check that no errors were printed -> all stages were successful + out = capsys.readouterr().out + assert not any(line.startswith("There were some errors") for line in out.split("\n")) + + +def test_predict(get_mock_environment: Path, capsys: pytest.CaptureFixture) -> None: + """Test prediction with the tiled ensemble.""" + predict_pipeline = EvalTiledEnsemble(root_dir=get_mock_environment / "Padim" / "MVTec" / "dummy" / "v0") + predict_parser = predict_pipeline.get_parser() + args = predict_parser.parse_args(["--config", str(get_mock_environment / "tiled_ensemble.yaml")]) + predict_pipeline.run(args) + # check that no errors were printed -> all stages were successful + out = capsys.readouterr().out + assert not any(line.startswith("There were some errors") for line in out.split("\n")) + + +def test_visualisation(get_mock_environment: Path) -> None: + """Test that images were produced.""" + assert (get_mock_environment / "Padim/MVTec/dummy/v0/images/bad/000.png").exists() + + +def test_metric_results(get_mock_environment: Path) -> None: + """Test that metrics were saved.""" + assert (get_mock_environment / "Padim/MVTec/dummy/v0/metric_results.csv").exists() diff --git a/tests/integration/pipelines/tiled_ensemble.yaml b/tests/integration/pipelines/tiled_ensemble.yaml new file mode 100644 index 0000000000..8d35be8297 --- /dev/null +++ b/tests/integration/pipelines/tiled_ensemble.yaml @@ -0,0 +1,43 @@ +seed: 42 +accelerator: "cpu" +default_root_dir: "results" + +tiling: + tile_size: [50, 50] + stride: 50 + +normalization_stage: image # on what level we normalize, options: [tile, image, none] +thresholding: + method: F1AdaptiveThreshold # refer to documentation for thresholding methods + stage: image # stage at which we apply threshold, options: [tile, image] + +data: + class_path: anomalib.data.MVTec + init_args: + root: toBeSetup + category: dummy + train_batch_size: 32 + eval_batch_size: 32 + num_workers: 0 + task: segmentation + transform: null + train_transform: null + eval_transform: null + test_split_mode: from_dir + test_split_ratio: 0.2 + val_split_mode: same_as_test + val_split_ratio: 0.5 + image_size: [50, 100] + +SeamSmoothing: + apply: True # if this is applied, area around tile seams are is smoothed + sigma: 2 # sigma of gaussian filter used to smooth this area + width: 0.1 # width factor, multiplied by tile dimension gives the region width around seam which will be smoothed + +TrainModels: + model: + class_path: Padim + + metrics: + pixel: AUROC + image: AUROC diff --git a/tests/unit/pipelines/__init__.py b/tests/unit/pipelines/__init__.py new file mode 100644 index 0000000000..46de40af76 --- /dev/null +++ b/tests/unit/pipelines/__init__.py @@ -0,0 +1,4 @@ +"""Pipeline unit tests.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/unit/pipelines/tiled_ensemble/__init__.py b/tests/unit/pipelines/tiled_ensemble/__init__.py new file mode 100644 index 0000000000..a78a1ad659 --- /dev/null +++ b/tests/unit/pipelines/tiled_ensemble/__init__.py @@ -0,0 +1,4 @@ +"""Tiled ensemble unit tests.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/unit/pipelines/tiled_ensemble/conftest.py b/tests/unit/pipelines/tiled_ensemble/conftest.py new file mode 100644 index 0000000000..b4fad61ebb --- /dev/null +++ b/tests/unit/pipelines/tiled_ensemble/conftest.py @@ -0,0 +1,151 @@ +"""Fixtures that are used in tiled ensemble testing.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest +import torch +import yaml + +from anomalib.data import AnomalibDataModule +from anomalib.models import AnomalyModule +from anomalib.pipelines.tiled_ensemble.components.utils.ensemble_tiling import EnsembleTiler +from anomalib.pipelines.tiled_ensemble.components.utils.helper_functions import ( + get_ensemble_datamodule, + get_ensemble_model, + get_ensemble_tiler, +) +from anomalib.pipelines.tiled_ensemble.components.utils.prediction_data import EnsemblePredictions +from anomalib.pipelines.tiled_ensemble.components.utils.prediction_merging import PredictionMergingMechanism + + +@pytest.fixture(scope="module") +def get_ensemble_config(dataset_path: Path) -> dict: + """Return ensemble dummy config dict with corrected dataset path to dummy temp dir.""" + with Path("tests/unit/pipelines/tiled_ensemble/dummy_config.yaml").open(encoding="utf-8") as file: + config = yaml.safe_load(file) + # dummy dataset + config["data"]["init_args"]["root"] = dataset_path / "mvtec" + + return config + + +@pytest.fixture(scope="module") +def get_tiler(get_ensemble_config: dict) -> EnsembleTiler: + """Return EnsembleTiler object based on test dummy config.""" + config = get_ensemble_config + + return get_ensemble_tiler(config["tiling"], config["data"]) + + +@pytest.fixture(scope="module") +def get_model(get_ensemble_config: dict, get_tiler: EnsembleTiler) -> AnomalyModule: + """Return model prepared for tiled ensemble training.""" + config = get_ensemble_config + tiler = get_tiler + + return get_ensemble_model(config["TrainModels"]["model"], tiler) + + +@pytest.fixture(scope="module") +def get_datamodule(get_ensemble_config: dict, get_tiler: EnsembleTiler) -> AnomalibDataModule: + """Return ensemble datamodule.""" + config = get_ensemble_config + tiler = get_tiler + datamodule = get_ensemble_datamodule(config, tiler, (0, 0)) + datamodule.setup() + + return datamodule + + +@pytest.fixture(scope="module") +def get_tile_predictions(get_datamodule: AnomalibDataModule) -> EnsemblePredictions: + """Return tile predictions inside EnsemblePredictions object.""" + datamodule = get_datamodule + + data = EnsemblePredictions() + + for tile_index in [(0, 0), (0, 1), (1, 0), (1, 1)]: + datamodule.collate_fn.tile_index = tile_index + + tile_prediction = [] + batch = next(iter(datamodule.test_dataloader())) + + # make mock labels and scores + batch["pred_scores"] = torch.rand(batch["label"].shape) + batch["pred_labels"] = batch["pred_scores"] > 0.5 + + # set mock maps to just one channel of image + batch["anomaly_maps"] = batch["image"].clone()[:, 0, :, :].unsqueeze(1) + # set mock pred mask to mask but add channel + batch["pred_masks"] = batch["mask"].clone().unsqueeze(1) + + tile_prediction.append(batch) + + # store to prediction storage object + data.add_tile_prediction(tile_index, tile_prediction) + + return data + + +@pytest.fixture(scope="module") +def get_batch_predictions() -> list[dict]: + """Return mock batched predictions.""" + mock_data = { + "image": torch.rand((5, 3, 100, 100)), + "mask": (torch.rand((5, 100, 100)) > 0.5).type(torch.float32), + "anomaly_maps": torch.rand((5, 1, 100, 100)), + "label": torch.Tensor([0, 1, 1, 0, 1]), + "pred_scores": torch.rand(5), + "pred_labels": torch.ones(5), + "pred_masks": torch.zeros((5, 100, 100)), + } + + return [mock_data, mock_data] + + +@pytest.fixture(scope="module") +def get_merging_mechanism( + get_tile_predictions: EnsemblePredictions, + get_tiler: EnsembleTiler, +) -> PredictionMergingMechanism: + """Return ensemble prediction merging mechanism object.""" + tiler = get_tiler + predictions = get_tile_predictions + return PredictionMergingMechanism(predictions, tiler) + + +@pytest.fixture(scope="module") +def get_mock_stats_dir() -> Path: + """Get temp dir containing statistics.""" + with TemporaryDirectory() as temp_dir: + stats = { + "minmax": { + "anomaly_maps": { + "min": 1.9403648376464844, + "max": 209.91940307617188, + }, + "box_scores": { + "min": 0.5, + "max": 0.45, + }, + "pred_scores": { + "min": 9.390382766723633, + "max": 209.91940307617188, + }, + }, + "image_threshold": 0.1111, + "pixel_threshold": 0.1111, + } + stats_path = Path(temp_dir) / "weights" / "lightning" / "stats.json" + stats_path.parent.mkdir(parents=True) + + # save mock statistics + with stats_path.open("w", encoding="utf-8") as stats_file: + json.dump(stats, stats_file, ensure_ascii=False, indent=4) + + yield Path(temp_dir) diff --git a/tests/unit/pipelines/tiled_ensemble/dummy_config.yaml b/tests/unit/pipelines/tiled_ensemble/dummy_config.yaml new file mode 100644 index 0000000000..fcd4b7c716 --- /dev/null +++ b/tests/unit/pipelines/tiled_ensemble/dummy_config.yaml @@ -0,0 +1,52 @@ +seed: 42 +accelerator: "cpu" +default_root_dir: "results" + +tiling: + tile_size: [50, 50] + stride: 50 + +normalization_stage: image # on what level we normalize, options: [tile, image, none] +thresholding: + method: F1AdaptiveThreshold # refer to documentation for thresholding methods + stage: image # stage at which we apply threshold, options: [tile, image] + +data: + class_path: anomalib.data.MVTec + init_args: + root: toBeSetup + category: dummy + train_batch_size: 32 + eval_batch_size: 32 + num_workers: 0 + task: segmentation + transform: null + train_transform: null + eval_transform: null + test_split_mode: from_dir + test_split_ratio: 0.2 + val_split_mode: same_as_test + val_split_ratio: 0.5 + image_size: [100, 100] + +SeamSmoothing: + apply: True # if this is applied, area around tile seams are is smoothed + sigma: 2 # sigma of gaussian filter used to smooth this area + width: 0.1 # width factor, multiplied by tile dimension gives the region width around seam which will be smoothed + +TrainModels: + model: + class_path: Fastflow + + metrics: + pixel: AUROC + image: AUROC + + trainer: + max_epochs: 1 + callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + patience: 1 + monitor: pixel_AUROC + mode: max diff --git a/tests/unit/pipelines/tiled_ensemble/test_components.py b/tests/unit/pipelines/tiled_ensemble/test_components.py new file mode 100644 index 0000000000..0e3c0dcdd4 --- /dev/null +++ b/tests/unit/pipelines/tiled_ensemble/test_components.py @@ -0,0 +1,387 @@ +"""Test working of tiled ensemble pipeline components.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import copy +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest +import torch + +from anomalib.data import get_datamodule +from anomalib.metrics import F1AdaptiveThreshold, ManualThreshold +from anomalib.pipelines.tiled_ensemble.components import ( + MergeJobGenerator, + MetricsCalculationJobGenerator, + NormalizationJobGenerator, + SmoothingJobGenerator, + StatisticsJobGenerator, + ThresholdingJobGenerator, +) +from anomalib.pipelines.tiled_ensemble.components.metrics_calculation import MetricsCalculationJob +from anomalib.pipelines.tiled_ensemble.components.smoothing import SmoothingJob +from anomalib.pipelines.tiled_ensemble.components.utils import NormalizationStage +from anomalib.pipelines.tiled_ensemble.components.utils.prediction_data import EnsemblePredictions +from anomalib.pipelines.tiled_ensemble.components.utils.prediction_merging import PredictionMergingMechanism + + +class TestMerging: + """Test merging mechanism and merging job.""" + + @staticmethod + def test_tile_merging(get_ensemble_config: dict, get_merging_mechanism: PredictionMergingMechanism) -> None: + """Test tiled data merging.""" + config = get_ensemble_config + merger = get_merging_mechanism + + # prepared original data + datamodule = get_datamodule(config) + datamodule.prepare_data() + datamodule.setup() + original_data = next(iter(datamodule.test_dataloader())) + + batch = merger.ensemble_predictions.get_batch_tiles(0) + + merged_image = merger.merge_tiles(batch, "image") + assert merged_image.equal(original_data["image"]) + + merged_mask = merger.merge_tiles(batch, "mask") + assert merged_mask.equal(original_data["mask"]) + + @staticmethod + def test_label_and_score_merging(get_merging_mechanism: PredictionMergingMechanism) -> None: + """Test label and score merging.""" + merger = get_merging_mechanism + scores = torch.rand(4, 10) + labels = scores > 0.5 + + mock_data = {(0, 0): {}, (0, 1): {}, (1, 0): {}, (1, 1): {}} + + for i, data in enumerate(mock_data.values()): + data["pred_scores"] = scores[i] + data["pred_labels"] = labels[i] + + merged = merger.merge_labels_and_scores(mock_data) + + assert merged["pred_scores"].equal(scores.mean(dim=0)) + + assert merged["pred_labels"].equal(labels.any(dim=0)) + + @staticmethod + def test_merge_job( + get_tile_predictions: EnsemblePredictions, + get_ensemble_config: dict, + get_merging_mechanism: PredictionMergingMechanism, + ) -> None: + """Test merging job execution.""" + config = get_ensemble_config + predictions = copy.deepcopy(get_tile_predictions) + merging_mechanism = get_merging_mechanism + + merging_job_generator = MergeJobGenerator(tiling_args=config["tiling"], data_args=config["data"]) + merging_job = next(merging_job_generator.generate_jobs(prev_stage_result=predictions)) + + merged_direct = merging_mechanism.merge_tile_predictions(0) + merged_with_job = merging_job.run()[0] + + # check that merging by job is same as with the mechanism directly + for key, value in merged_direct.items(): + if isinstance(value, torch.Tensor): + assert merged_with_job[key].equal(value) + elif isinstance(value, list) and isinstance(value[0], torch.Tensor): + # boxes + assert all(j.equal(d) for j, d in zip(merged_with_job[key], value, strict=False)) + else: + assert merged_with_job[key] == value + + +class TestStatsCalculation: + """Test post-processing statistics calculations.""" + + @staticmethod + @pytest.mark.parametrize( + ("threshold_str", "threshold_cls"), + [("F1AdaptiveThreshold", F1AdaptiveThreshold), ("ManualThreshold", ManualThreshold)], + ) + def test_threshold_method(threshold_str: str, threshold_cls: type, get_ensemble_config: dict) -> None: + """Test that correct thresholding method is used.""" + config = copy.deepcopy(get_ensemble_config) + config["thresholding"]["method"] = threshold_str + + stats_job_generator = StatisticsJobGenerator(Path("mock"), threshold_str) + stats_job = next(stats_job_generator.generate_jobs(None, None)) + + assert isinstance(stats_job.image_threshold, threshold_cls) + + @staticmethod + def test_stats_run(project_path: Path) -> None: + """Test execution of statistics calc. job.""" + mock_preds = [ + { + "pred_scores": torch.rand(4), + "label": torch.ones(4), + "anomaly_maps": torch.rand(4, 1, 50, 50), + "mask": torch.ones(4, 1, 50, 50), + }, + ] + + stats_job_generator = StatisticsJobGenerator(project_path, "F1AdaptiveThreshold") + stats_job = next(stats_job_generator.generate_jobs(None, mock_preds)) + + results = stats_job.run() + + assert "minmax" in results + assert "image_threshold" in results + assert "pixel_threshold" in results + + # save as it's removed from results + save_path = results["save_path"] + stats_job.save(results) + assert Path(save_path).exists() + + @staticmethod + @pytest.mark.parametrize( + ("key", "values"), + [ + ("anomaly_maps", [torch.rand(5, 1, 50, 50), torch.rand(5, 1, 50, 50)]), + ("pred_scores", [torch.rand(5), torch.rand(5)]), + ], + ) + def test_minmax(key: str, values: list) -> None: + """Test minmax stats calculation.""" + # add given keys to test all possible sources of minmax + data = [ + {"pred_scores": torch.rand(5), "label": torch.ones(5), key: values[0]}, + {"pred_scores": torch.rand(5), "label": torch.ones(5), key: values[1]}, + ] + + stats_job_generator = StatisticsJobGenerator(Path("mock"), "F1AdaptiveThreshold") + stats_job = next(stats_job_generator.generate_jobs(None, data)) + results = stats_job.run() + + if isinstance(values[0], list): + values[0] = torch.cat(values[0]) + values[1] = torch.cat(values[1]) + values = torch.stack(values) + + assert results["minmax"][key]["min"] == torch.min(values) + assert results["minmax"][key]["max"] == torch.max(values) + + @staticmethod + @pytest.mark.parametrize( + ("labels", "preds", "target_threshold"), + [ + (torch.Tensor([0, 0, 0, 1, 1]), torch.Tensor([2.3, 1.6, 2.6, 7.9, 3.3]), 3.3), # standard case + (torch.Tensor([1, 0, 0, 0]), torch.Tensor([4, 3, 2, 1]), 4), # 100% recall for all thresholds + ], + ) + def test_threshold(labels: torch.Tensor, preds: torch.Tensor, target_threshold: float) -> None: + """Test threshold calculation job.""" + data = [ + { + "label": labels, + "mask": labels, + "pred_scores": preds, + "anomaly_maps": preds, + }, + ] + + stats_job_generator = StatisticsJobGenerator(Path("mock"), "F1AdaptiveThreshold") + stats_job = next(stats_job_generator.generate_jobs(None, data)) + results = stats_job.run() + + assert round(results["image_threshold"], 5) == target_threshold + assert round(results["pixel_threshold"], 5) == target_threshold + + +class TestMetrics: + """Test ensemble metrics.""" + + @pytest.fixture(scope="class") + @staticmethod + def get_ensemble_metrics_job( + get_ensemble_config: dict, + get_batch_predictions: list[dict], + ) -> tuple[MetricsCalculationJob, str]: + """Return Metrics calculation job and path to directory where metrics csv will be saved.""" + config = get_ensemble_config + with TemporaryDirectory() as tmp_dir: + metrics = MetricsCalculationJobGenerator( + config["accelerator"], + root_dir=Path(tmp_dir), + task=config["data"]["init_args"]["task"], + metrics=config["TrainModels"]["metrics"], + normalization_stage=NormalizationStage(config["normalization_stage"]), + ) + + mock_predictions = get_batch_predictions + + return next(metrics.generate_jobs(prev_stage_result=copy.deepcopy(mock_predictions))), tmp_dir + + @staticmethod + def test_metrics_result(get_ensemble_metrics_job: tuple[MetricsCalculationJob, str]) -> None: + """Test metrics result.""" + metrics_job, _ = get_ensemble_metrics_job + + result = metrics_job.run() + + assert "pixel_AUROC" in result + assert "image_AUROC" in result + + @staticmethod + def test_metrics_saving(get_ensemble_metrics_job: tuple[MetricsCalculationJob, str]) -> None: + """Test metrics saving to csv.""" + metrics_job, tmp_dir = get_ensemble_metrics_job + + result = metrics_job.run() + metrics_job.save(result) + assert (Path(tmp_dir) / "metric_results.csv").exists() + + +class TestJoinSmoothing: + """Test JoinSmoothing job responsible for smoothing area at tile seams.""" + + @pytest.fixture(scope="class") + @staticmethod + def get_join_smoothing_job(get_ensemble_config: dict, get_batch_predictions: list[dict]) -> SmoothingJob: + """Make and return SmoothingJob instance.""" + config = get_ensemble_config + job_gen = SmoothingJobGenerator( + accelerator=config["accelerator"], + tiling_args=config["tiling"], + data_args=config["data"], + ) + # copy since smoothing changes data + mock_predictions = copy.deepcopy(get_batch_predictions) + return next(job_gen.generate_jobs(config["SeamSmoothing"], mock_predictions)) + + @staticmethod + def test_mask(get_join_smoothing_job: SmoothingJob) -> None: + """Test seam mask in case where tiles don't overlap.""" + smooth = get_join_smoothing_job + + join_index = smooth.tiler.tile_size_h, smooth.tiler.tile_size_w + + # seam should be covered by True + assert smooth.seam_mask[join_index] + + # non-seam region should be false + assert not smooth.seam_mask[0, 0] + assert not smooth.seam_mask[-1, -1] + + @staticmethod + def test_mask_overlapping(get_ensemble_config: dict, get_batch_predictions: list[dict]) -> None: + """Test seam mask in case where tiles overlap.""" + config = copy.deepcopy(get_ensemble_config) + # tile size = 50, stride = 25 -> overlapping + config["tiling"]["stride"] = 25 + job_gen = SmoothingJobGenerator( + accelerator=config["accelerator"], + tiling_args=config["tiling"], + data_args=config["data"], + ) + mock_predictions = copy.deepcopy(get_batch_predictions) + smooth = next(job_gen.generate_jobs(config["SeamSmoothing"], mock_predictions)) + + join_index = smooth.tiler.stride_h, smooth.tiler.stride_w + + # overlap seam should be covered by True + assert smooth.seam_mask[join_index] + assert smooth.seam_mask[-join_index[0], -join_index[1]] + + # non-seam region should be false + assert not smooth.seam_mask[0, 0] + assert not smooth.seam_mask[-1, -1] + + @staticmethod + def test_smoothing(get_join_smoothing_job: SmoothingJob, get_batch_predictions: list[dict]) -> None: + """Test smoothing job run.""" + original_data = get_batch_predictions + # fixture makes a copy of data + smooth = get_join_smoothing_job + + # take first batch + smoothed = smooth.run()[0] + join_index = smooth.tiler.tile_size_h, smooth.tiler.tile_size_w + + # join sections should be processed + assert not smoothed["anomaly_maps"][:, :, join_index].equal(original_data[0]["anomaly_maps"][:, :, join_index]) + + # non-join section shouldn't be changed + assert smoothed["anomaly_maps"][:, :, 0, 0].equal(original_data[0]["anomaly_maps"][:, :, 0, 0]) + + +def test_normalization(get_batch_predictions: list[dict], project_path: Path) -> None: + """Test normalization step.""" + original_predictions = copy.deepcopy(get_batch_predictions) + + for batch in original_predictions: + batch["anomaly_maps"] *= 100 + batch["pred_scores"] *= 100 + + # # get and save stats using stats job on predictions + stats_job_generator = StatisticsJobGenerator(project_path, "F1AdaptiveThreshold") + stats_job = next(stats_job_generator.generate_jobs(prev_stage_result=original_predictions)) + stats = stats_job.run() + stats_job.save(stats) + + # normalize predictions based on obtained stats + norm_job_generator = NormalizationJobGenerator(root_dir=project_path) + # copy as this changes preds + norm_job = next(norm_job_generator.generate_jobs(prev_stage_result=original_predictions)) + normalized_predictions = norm_job.run() + + for batch in normalized_predictions: + assert (batch["anomaly_maps"] >= 0).all() + assert (batch["anomaly_maps"] <= 1).all() + + assert (batch["pred_scores"] >= 0).all() + assert (batch["pred_scores"] <= 1).all() + + +class TestThresholding: + """Test tiled ensemble thresholding stage.""" + + @pytest.fixture(scope="class") + @staticmethod + def get_threshold_job(get_mock_stats_dir: Path) -> callable: + """Return a function that takes prediction data and runs threshold job.""" + thresh_job_generator = ThresholdingJobGenerator( + root_dir=get_mock_stats_dir, + normalization_stage=NormalizationStage.IMAGE, + ) + + def thresh_helper(preds: dict) -> list | None: + thresh_job = next(thresh_job_generator.generate_jobs(prev_stage_result=preds)) + return thresh_job.run() + + return thresh_helper + + @staticmethod + def test_score_threshold(get_threshold_job: callable) -> None: + """Test anomaly score thresholding.""" + thresholding = get_threshold_job + + data = [{"pred_scores": torch.tensor([0.7, 0.8, 0.1, 0.33, 0.5])}] + + thresholded = thresholding(data)[0] + + assert thresholded["pred_labels"].equal(torch.tensor([True, True, False, False, True])) + + @staticmethod + def test_anomap_threshold(get_threshold_job: callable) -> None: + """Test anomaly map thresholding.""" + thresholding = get_threshold_job + + data = [ + { + "pred_scores": torch.tensor([0.7, 0.8, 0.1, 0.33, 0.5]), + "anomaly_maps": torch.tensor([[0.7, 0.8, 0.1], [0.33, 0.5, 0.1]]), + }, + ] + + thresholded = thresholding(data)[0] + + assert thresholded["pred_masks"].equal(torch.tensor([[True, True, False], [False, True, False]])) diff --git a/tests/unit/pipelines/tiled_ensemble/test_helper_functions.py b/tests/unit/pipelines/tiled_ensemble/test_helper_functions.py new file mode 100644 index 0000000000..06e5864cef --- /dev/null +++ b/tests/unit/pipelines/tiled_ensemble/test_helper_functions.py @@ -0,0 +1,113 @@ +"""Test ensemble helper functions.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +import pytest +from jsonargparse import Namespace +from lightning.pytorch.callbacks import EarlyStopping + +from anomalib.callbacks.normalization import _MinMaxNormalizationCallback +from anomalib.models import AnomalyModule +from anomalib.pipelines.tiled_ensemble.components.utils import NormalizationStage +from anomalib.pipelines.tiled_ensemble.components.utils.ensemble_tiling import EnsembleTiler, TileCollater +from anomalib.pipelines.tiled_ensemble.components.utils.helper_functions import ( + get_ensemble_datamodule, + get_ensemble_engine, + get_ensemble_model, + get_ensemble_tiler, + get_threshold_values, + parse_trainer_kwargs, +) + + +class TestHelperFunctions: + """Test ensemble helper functions.""" + + @staticmethod + def test_ensemble_datamodule(get_ensemble_config: dict, get_tiler: EnsembleTiler) -> None: + """Test that datamodule is created and has correct collate function.""" + config = get_ensemble_config + tiler = get_tiler + datamodule = get_ensemble_datamodule(config, tiler, (0, 0)) + + assert isinstance(datamodule.collate_fn, TileCollater) + + @staticmethod + def test_ensemble_model(get_ensemble_config: dict, get_tiler: EnsembleTiler) -> None: + """Test that model is successfully created with correct input shape.""" + config = get_ensemble_config + tiler = get_tiler + model = get_ensemble_model(config["TrainModels"]["model"], tiler) + + assert model.input_size == tuple(config["tiling"]["tile_size"]) + + @staticmethod + def test_tiler(get_ensemble_config: dict) -> None: + """Test that tiler is successfully instantiated.""" + config = get_ensemble_config + + tiler = get_ensemble_tiler(config["tiling"], config["data"]) + assert isinstance(tiler, EnsembleTiler) + + @staticmethod + def test_trainer_kwargs(get_ensemble_config: dict) -> None: + """Test that objects are correctly constructed from kwargs.""" + config = get_ensemble_config + + objects = parse_trainer_kwargs(config["TrainModels"]["trainer"]) + assert isinstance(objects, Namespace) + # verify that early stopping is parsed and added to callbacks + assert isinstance(objects.callbacks[0], EarlyStopping) + + @staticmethod + @pytest.mark.parametrize( + "normalization_stage", + [NormalizationStage.NONE, NormalizationStage.IMAGE, NormalizationStage.TILE], + ) + def test_threshold_values(normalization_stage: NormalizationStage, get_mock_stats_dir: Path) -> None: + """Test that threshold values are correctly set based on normalization stage.""" + stats_dir = get_mock_stats_dir + + i_thresh, p_thresh = get_threshold_values(normalization_stage, stats_dir) + + if normalization_stage != NormalizationStage.NONE: + # minmax normalization sets thresholds to 0.5 + assert i_thresh == p_thresh == 0.5 + else: + assert i_thresh == p_thresh == 0.1111 + + +class TestEnsembleEngine: + """Test ensemble engine configuration.""" + + @staticmethod + @pytest.mark.parametrize( + "normalization_stage", + [NormalizationStage.NONE, NormalizationStage.IMAGE, NormalizationStage.TILE], + ) + def test_normalisation(normalization_stage: NormalizationStage, get_model: AnomalyModule) -> None: + """Test that normalization callback is correctly initialized.""" + engine = get_ensemble_engine( + tile_index=(0, 0), + accelerator="cpu", + devices="1", + root_dir=Path("mock"), + normalization_stage=normalization_stage, + ) + + engine._setup_anomalib_callbacks(get_model) # noqa: SLF001 + + # verify that only in case of tile level normalization the callback is present + if normalization_stage == NormalizationStage.TILE: + assert any( + isinstance(x, _MinMaxNormalizationCallback) + for x in engine._cache.args["callbacks"] # noqa: SLF001 + ) + else: + assert not any( + isinstance(x, _MinMaxNormalizationCallback) + for x in engine._cache.args["callbacks"] # noqa: SLF001 + ) diff --git a/tests/unit/pipelines/tiled_ensemble/test_prediction_data.py b/tests/unit/pipelines/tiled_ensemble/test_prediction_data.py new file mode 100644 index 0000000000..7185f1e2ca --- /dev/null +++ b/tests/unit/pipelines/tiled_ensemble/test_prediction_data.py @@ -0,0 +1,69 @@ +"""Test tiled prediction storage class.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import copy +from collections.abc import Callable + +import torch +from torch import Tensor + +from anomalib.data import AnomalibDataModule +from anomalib.pipelines.tiled_ensemble.components.utils.prediction_data import EnsemblePredictions + + +class TestPredictionData: + """Test EnsemblePredictions class, used for tiled prediction storage.""" + + @staticmethod + def store_all(data: EnsemblePredictions, datamodule: AnomalibDataModule) -> dict: + """Store the tiled predictions in the EnsemblePredictions object.""" + tile_dict = {} + for tile_index in [(0, 0), (0, 1), (1, 0), (1, 1)]: + datamodule.collate_fn.tile_index = tile_index + + tile_prediction = [] + for batch in iter(datamodule.train_dataloader()): + # set mock maps to just one channel of image + batch["anomaly_maps"] = batch["image"].clone()[:, 0, :, :].unsqueeze(1) + # set mock pred mask to mask but add channel + batch["pred_masks"] = batch["mask"].clone().unsqueeze(1) + tile_prediction.append(batch) + # save original + tile_dict[tile_index] = copy.deepcopy(tile_prediction) + # store to prediction storage object + data.add_tile_prediction(tile_index, tile_prediction) + + return tile_dict + + @staticmethod + def verify_equal(name: str, tile_dict: dict, storage: EnsemblePredictions, eq_funct: Callable) -> bool: + """Verify that all data at same tile index and same batch index matches.""" + batch_num = len(tile_dict[0, 0]) + + for batch_i in range(batch_num): + # batch is dict where key: tile index and val is batched data of that tile + curr_batch = storage.get_batch_tiles(batch_i) + + # go over all indices of current batch of stored data + for tile_index, stored_data_batch in curr_batch.items(): + stored_data = stored_data_batch[name] + # get original data dict at current tile index and batch index + original_data = tile_dict[tile_index][batch_i][name] + if isinstance(original_data, Tensor): + if not eq_funct(original_data, stored_data): + return False + elif original_data != stored_data: + return False + + return True + + def test_prediction_object(self, get_datamodule: AnomalibDataModule) -> None: + """Test prediction storage class.""" + datamodule = get_datamodule + storage = EnsemblePredictions() + original = self.store_all(storage, datamodule) + + for name in original[0, 0][0]: + assert self.verify_equal(name, original, storage, torch.equal), f"{name} doesn't match" diff --git a/tests/unit/pipelines/tiled_ensemble/test_tiler.py b/tests/unit/pipelines/tiled_ensemble/test_tiler.py new file mode 100644 index 0000000000..96b6c0e7bc --- /dev/null +++ b/tests/unit/pipelines/tiled_ensemble/test_tiler.py @@ -0,0 +1,119 @@ +"""Tiling related tests for tiled ensemble.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import copy + +import pytest +import torch + +from anomalib.data import AnomalibDataModule +from anomalib.pipelines.tiled_ensemble.components.utils.helper_functions import get_ensemble_tiler + +tiler_config = { + "tiling": { + "tile_size": 256, + "stride": 256, + }, + "data": {"init_args": {"image_size": 512}}, +} + +tiler_config_overlap = { + "tiling": { + "tile_size": 256, + "stride": 128, + }, + "data": {"init_args": {"image_size": 512}}, +} + + +class TestTiler: + """EnsembleTiler tests.""" + + @staticmethod + @pytest.mark.parametrize( + ("input_shape", "config", "expected_shape"), + [ + (torch.Size([5, 3, 512, 512]), tiler_config, torch.Size([2, 2, 5, 3, 256, 256])), + (torch.Size([5, 3, 512, 512]), tiler_config_overlap, torch.Size([3, 3, 5, 3, 256, 256])), + (torch.Size([5, 3, 500, 500]), tiler_config, torch.Size([2, 2, 5, 3, 256, 256])), + (torch.Size([5, 3, 500, 500]), tiler_config_overlap, torch.Size([3, 3, 5, 3, 256, 256])), + ], + ) + def test_basic_tile_for_ensemble(input_shape: torch.Size, config: dict, expected_shape: torch.Size) -> None: + """Test basic tiling of data.""" + config = copy.deepcopy(config) + config["data"]["init_args"]["image_size"] = input_shape[-1] + tiler = get_ensemble_tiler(config["tiling"], config["data"]) + + images = torch.rand(size=input_shape) + tiled = tiler.tile(images) + + assert tiled.shape == expected_shape + + @staticmethod + @pytest.mark.parametrize( + ("input_shape", "config"), + [ + (torch.Size([5, 3, 512, 512]), tiler_config), + (torch.Size([5, 3, 512, 512]), tiler_config_overlap), + (torch.Size([5, 3, 500, 500]), tiler_config), + (torch.Size([5, 3, 500, 500]), tiler_config_overlap), + ], + ) + def test_basic_tile_reconstruction(input_shape: torch.Size, config: dict) -> None: + """Test basic reconstruction of tiled data.""" + config = copy.deepcopy(config) + config["data"]["init_args"]["image_size"] = input_shape[-1] + + tiler = get_ensemble_tiler(config["tiling"], config["data"]) + + images = torch.rand(size=input_shape) + tiled = tiler.tile(images.clone()) + untiled = tiler.untile(tiled) + + assert images.shape == untiled.shape + assert images.equal(untiled) + + @staticmethod + @pytest.mark.parametrize( + ("input_shape", "config"), + [ + (torch.Size([5, 3, 512, 512]), tiler_config), + (torch.Size([5, 3, 500, 500]), tiler_config), + ], + ) + def test_untile_different_instance(input_shape: torch.Size, config: dict) -> None: + """Test untiling with different Tiler instance.""" + config = copy.deepcopy(config) + config["data"]["init_args"]["image_size"] = input_shape[-1] + tiler_1 = get_ensemble_tiler(config["tiling"], config["data"]) + + tiler_2 = get_ensemble_tiler(config["tiling"], config["data"]) + + images = torch.rand(size=input_shape) + tiled = tiler_1.tile(images.clone()) + + untiled = tiler_2.untile(tiled) + + # untiling should work even with different instance of tiler + assert images.shape == untiled.shape + assert images.equal(untiled) + + +class TestTileCollater: + """Test tile collater.""" + + @staticmethod + def test_collate_tile_shape(get_ensemble_config: dict, get_datamodule: AnomalibDataModule) -> None: + """Test that collate function successfully tiles the image.""" + config = get_ensemble_config + # datamodule with tile collater + datamodule = get_datamodule + + tile_w, tile_h = config["tiling"]["tile_size"] + + batch = next(iter(datamodule.train_dataloader())) + assert batch["image"].shape[1:] == (3, tile_w, tile_h) + assert batch["mask"].shape[1:] == (tile_w, tile_h) diff --git a/tools/tiled_ensemble/ens_config.yaml b/tools/tiled_ensemble/ens_config.yaml new file mode 100644 index 0000000000..2490b22e9a --- /dev/null +++ b/tools/tiled_ensemble/ens_config.yaml @@ -0,0 +1,43 @@ +seed: 42 +accelerator: "gpu" +default_root_dir: "results" + +tiling: + tile_size: [128, 128] + stride: 128 + +normalization_stage: image # on what level we normalize, options: [tile, image, none] +thresholding: + method: F1AdaptiveThreshold # refer to documentation for thresholding methods + stage: image # stage at which we apply threshold, options: [tile, image] + +data: + class_path: anomalib.data.MVTec + init_args: + root: ./datasets/MVTec + category: bottle + train_batch_size: 32 + eval_batch_size: 32 + num_workers: 8 + task: segmentation + transform: null + train_transform: null + eval_transform: null + test_split_mode: from_dir + test_split_ratio: 0.2 + val_split_mode: same_as_test + val_split_ratio: 0.5 + image_size: [256, 256] + +SeamSmoothing: + apply: True # if this is applied, area around tile seams are is smoothed + sigma: 2 # sigma of gaussian filter used to smooth this area + width: 0.1 # width factor, multiplied by tile dimension gives the region width around seam which will be smoothed + +TrainModels: + model: + class_path: Padim + + metrics: + pixel: AUROC + image: AUROC diff --git a/tools/tiled_ensemble/eval.py b/tools/tiled_ensemble/eval.py new file mode 100644 index 0000000000..58be27c25c --- /dev/null +++ b/tools/tiled_ensemble/eval.py @@ -0,0 +1,28 @@ +"""Run tiled ensemble prediction.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +from jsonargparse import ArgumentParser + +from anomalib.pipelines.tiled_ensemble import EvalTiledEnsemble + + +def get_parser() -> ArgumentParser: + """Create a new parser if none is provided.""" + parser = ArgumentParser() + parser.add_argument("--config", type=str | Path, help="Configuration file path.", required=True) + parser.add_argument("--root", type=str | Path, help="Weights file path.", required=True) + + return parser + + +if __name__ == "__main__": + args = get_parser().parse_args() + + print("Running tiled ensemble test pipeline.") + # pass the path to root dir with checkpoints + test_pipeline = EvalTiledEnsemble(args.root) + test_pipeline.run(args) diff --git a/tools/tiled_ensemble/train.py b/tools/tiled_ensemble/train.py new file mode 100644 index 0000000000..8aed47ea0d --- /dev/null +++ b/tools/tiled_ensemble/train.py @@ -0,0 +1,17 @@ +"""Run tiled ensemble training.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from anomalib.pipelines.tiled_ensemble import EvalTiledEnsemble, TrainTiledEnsemble + +if __name__ == "__main__": + print("Running tiled ensemble train pipeline") + train_pipeline = TrainTiledEnsemble() + # run training + train_pipeline.run() + + print("Running tiled ensemble test pipeline.") + # pass the root dir from train run to load checkpoints + test_pipeline = EvalTiledEnsemble(train_pipeline.root_dir) + test_pipeline.run()