-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LwM - no gradient in attention distillation loss #37
Comments
Hi @fszatkowski, I remember discussing about this approach and the gradients/loss before, so maybe we missed something. A first change that has not been pushed yet into main is the one found in this commit. And we also had some discussion about it in this issue. Could you check if any of those help/tackle the issue? My first guess is that maybe one does not want to propagate the gradients through the gradcam, but instead generate a loss with the attention maps that is backpropagated through the resnet model (the same weights as the CE-loss modifies). However, it's been a while, so if these links do not help, let me know and we'll see if we can dive into it again and fix it. When checking the experiments, LwM is better than LwF by a significant margin. That seems to indicate some difference between both methods is indeed happening. But it could also be some other reason. Or maybe we introduced the error when cleaning the code for public release. Let me know if that helped. I'm quite interested in solving this if it is indeed an issue! |
The commit you linked changes Here you use detach on both gradients and activations: Lines 255 to 261 in e9d816c
I think attention loss function is okay, but since it's computed on two variables without gradients it simply adds a scalar with 0 derivative to loss and when you call loss.backward in training loop the part that comes from attention loss doesn't backpropagate. I work on slightly modified fork of FACIL so I might be having different results than the version in this repository. Could you please run the LwM code twice with different coefficients for attention map loss (for example 0 and some other value). I think you will get the same final results regardless of the attention loss coefficient. |
I see, the way |
You are correct, it seems like that loss is not having an effect indeed. There are no gradients updated, and therefore changing the parameter has no effect and brings the method towards LwF. I'll need to check some of the older dev branches to see when did we introduce the bug (or forgot to update the method with the fix), since the older spreadsheet files from the original experiments do show a difference when changing the gamma. Thanks for the help! |
I simply tried removing class GradCAM:
...
def __enter__(self):
# register hooks to collect activations and gradients
def forward_hook(module, input, output):
if self.retain_graph:
self.activations = output
else:
self.activations = output.detach()
def backward_hook(module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
...
def __call__(self, input, class_indices=None, return_outputs=False, adapt_bn=False):
...
with torch.no_grad() if not self.retain_graph else contextlib.suppress():
weights = F.adaptive_avg_pool2d(self.gradients, 1)
att_map = (weights * self.activations).sum(dim=1, keepdim=True)
att_map = F.relu(att_map)
del self.activations
del self.gradients
return (att_map, model_output) if return_outputs else att_map Then in the training loop: ...
attmap_old, outputs_old = gradcam_old(images, return_outputs=True)
with GradCAM(self.model, self.gradcam_layer, retain_graph=True) as gradcam:
attmap = gradcam(images) # this use eval() pass
... But the results I got with it were far from the scores from your paper, so I think this is still not working as supposed. |
Hi, when experimenting with LwM in FACIL I noticed that the method behaves the same regardless of the choice of
gamma
parameter that controls attention distillation loss. Upon closer investigation, I noticed that during training attention maps returned by GradCAM have no grad, as you can check yourself with the debugger in this line:FACIL/src/approach/lwm.py
Line 126 in e9d816c
When we later use attention maps to compute attention distillation loss this loss has no gradient and it's contribution to the gradient update is ignored. Therefore, LwM in FACIL basically does LwF with extra unused computation.
I think the issue is in class
GradCAM
in line 226, where the activations are detached, and later in line 255 which disables gradients when computing attention maps. I think this class should have the option to preserve gradients when computing attention maps and trigger this option for a forward pass of the current net. Then the attention maps for current net will haverequires_grad=True
and consequently attention loss of will contribute to weight updates.The text was updated successfully, but these errors were encountered: