-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement DeepFool in PyTorch #1090
base: master
Are you sure you want to change the base?
Conversation
@iArunava what do you think? |
perturbations = torch.clamp(x + perturbations, clip_min, clip_max) - x | ||
perturbations /= (1 + overshoot) | ||
|
||
# perturbations *= (1 + overshoot) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this code commented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, that shouldn't be there actually. I'll remove it
eps / norm | ||
) | ||
eta *= factor | ||
eta = torch.renorm(eta, p=norm, dim=0, maxnorm=eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this avoid div by zero? This is important for numerical stability in our case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In what case do you imagine div by zero could be an issue? I've tested and it works fine if the norm of eta is zero and fine if eps is 0
from cleverhans.future.torch.utils import clip_eta | ||
|
||
|
||
def deepfool(model_fn, x, clip_min=-np.inf, clip_max=np.inf, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be clear, are the advantages of this implementation that it can handle batches of inputs and early stop if all batch inputs have been misclassified?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does early stopping separately for each element of the batch, which mirrors the original DeepFool implementation. Like the first element of the batch could be early stopped after 2 iterations and the second could be early stopped after 10. I think the other implementation would keep iterating the whole batch until all of them had succeeded.
Other than that I think the pure torch implementation is a plus for being a lot faster on GPU.
I know somebody else has an open PR for DeepFool in PyTorch but I think my implementation is better in a few ways:
Let me know what you think!