Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIMO] speed up #2379

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

jpcbertoldo
Copy link
Contributor

@jpcbertoldo jpcbertoldo commented Oct 20, 2024

📝 Description

without numba, aupimo became annoyingly slow for full resolution test, so this idea show speed it up by removing unnecessayr computation

this parameter is what mostly makes it so inefficient (num_thresholds = 300_000)

num_thresholds: int = 300_000,

it has to be so big to make sure that there will be enough points in the AUC integration within the integration range

the current implementation thresholds the anomaly score maps from their min to max value, which is the inefficient because we only need it in a much smaller range

strategy to improve it:
- use binary search to find the thresholds corresponding to the fpr integration bounds
- compute the binary classification curves within those bounds
- decrease the number of thresholds from 300_000 to 300

  • review tutorial notebooks

✨ Changes

Select what type of change your PR is:

  • 🐞 Bug fix (non-breaking change which fixes an issue)
  • 🔨 Refactor (non-breaking change which refactors the code base)
  • [] 🚀 New feature (non-breaking change which adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📚 Documentation update
  • 🔒 Security update

✅ Checklist

Before you submit your pull request, please make sure you have completed the following steps:

  • 📋 I have summarized my changes in the CHANGELOG and followed the guidelines for my type of change (skip for minor changes, documentation updates, and test enhancements).
  • 📚 I have made the necessary updates to the documentation (if applicable).
  • 🧪 I have written tests that support my changes and prove that my fix is effective or my feature works (if applicable).

For more information about code review checklists, see the Code Review Checklist.

Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com>
Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com>
@jpcbertoldo jpcbertoldo marked this pull request as ready for review October 21, 2024 22:33
Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com>
Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com>
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@@ -68,15 +69,18 @@ def pimo_curves(
_validate.has_at_least_one_normal_image(masks)

image_classes = images_classes_from_masks(masks)
anomaly_maps_normal_images = anomaly_maps[image_classes == 0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
anomaly_maps_normal_images = anomaly_maps[image_classes == 0]
normal_anomaly_maps = anomaly_maps[image_classes == 0]

@@ -276,6 +275,73 @@ def aupimo_scores(
# =========================================== AUX ===========================================


def _binary_search_threshold_at_fpr_target(
anomaly_maps_normals: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
anomaly_maps_normals: torch.Tensor,
normal_anomaly_maps: torch.Tensor,

Copy link
Contributor

@djdameln djdameln left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is a nice optimization. Just out of curiosity, do you have some numbers on the amount of speed up is achieved by this change?

fpr_target: float | torch.Tensor,
maximum_iterations: int = 300,
) -> float:
"""Binary search of threshold that achieves the given shared FPR level.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to add a more detailed description to this docstring, explaining why we need to apply the binary search and how it is performed. This would be useful for future reference.

@samet-akcay
Copy link
Contributor

Looks like one of the pimo notebook tests are failing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants