Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LwM - no gradient in attention distillation loss #37

Open
fszatkowski opened this issue May 16, 2023 · 5 comments
Open

LwM - no gradient in attention distillation loss #37

fszatkowski opened this issue May 16, 2023 · 5 comments

Comments

@fszatkowski
Copy link

Hi, when experimenting with LwM in FACIL I noticed that the method behaves the same regardless of the choice of gamma parameter that controls attention distillation loss. Upon closer investigation, I noticed that during training attention maps returned by GradCAM have no grad, as you can check yourself with the debugger in this line:

attmap = gradcam(images) # this use eval() pass

When we later use attention maps to compute attention distillation loss this loss has no gradient and it's contribution to the gradient update is ignored. Therefore, LwM in FACIL basically does LwF with extra unused computation.

I think the issue is in class GradCAM in line 226, where the activations are detached, and later in line 255 which disables gradients when computing attention maps. I think this class should have the option to preserve gradients when computing attention maps and trigger this option for a forward pass of the current net. Then the attention maps for current net will have requires_grad=Trueand consequently attention loss of will contribute to weight updates.

@mmasana
Copy link
Owner

mmasana commented May 26, 2023

Hi @fszatkowski,

I remember discussing about this approach and the gradients/loss before, so maybe we missed something. A first change that has not been pushed yet into main is the one found in this commit. And we also had some discussion about it in this issue.

Could you check if any of those help/tackle the issue? My first guess is that maybe one does not want to propagate the gradients through the gradcam, but instead generate a loss with the attention maps that is backpropagated through the resnet model (the same weights as the CE-loss modifies). However, it's been a while, so if these links do not help, let me know and we'll see if we can dive into it again and fix it.

When checking the experiments, LwM is better than LwF by a significant margin. That seems to indicate some difference between both methods is indeed happening. But it could also be some other reason. Or maybe we introduced the error when cleaning the code for public release.

Let me know if that helped. I'm quite interested in solving this if it is indeed an issue!

@fszatkowski
Copy link
Author

The commit you linked changes torch.norm to torch.nn.functional.normalize, but I don't think it helps with the main issue being no gradient in the attention loss. I think since the hooks in GradCAM detach both gradients and activations, the attention maps computed later on have no gradients and there is no way to backpropagate anything through the loss that is computed based on attention maps.

Here you use detach on both gradients and activations:
https://github.com/mmasana/FACIL/blob/e9d816c0c649db91bde1568300a8ba3045651ffd/src/approach/lwm.py#LL223C1-L229C53
So when GradCAM is called, it returns the attention maps that contain no gradients:

FACIL/src/approach/lwm.py

Lines 255 to 261 in e9d816c

with torch.no_grad():
weights = F.adaptive_avg_pool2d(self.gradients, 1)
att_map = (weights * self.activations).sum(dim=1, keepdim=True)
att_map = F.relu(att_map)
del self.activations
del self.gradients
return (att_map, model_output) if return_outputs else att_map

I think attention loss function is okay, but since it's computed on two variables without gradients it simply adds a scalar with 0 derivative to loss and when you call loss.backward in training loop the part that comes from attention loss doesn't backpropagate.

I work on slightly modified fork of FACIL so I might be having different results than the version in this repository. Could you please run the LwM code twice with different coefficients for attention map loss (for example 0 and some other value). I think you will get the same final results regardless of the attention loss coefficient.

@mmasana
Copy link
Owner

mmasana commented May 31, 2023

I see, the way .detach() is called, could indeed block the gradients from updating. I'll first try to reproduce what you propose with the --gamma parameter to check it out.

@mmasana
Copy link
Owner

mmasana commented Jun 5, 2023

You are correct, it seems like that loss is not having an effect indeed. There are no gradients updated, and therefore changing the parameter has no effect and brings the method towards LwF. I'll need to check some of the older dev branches to see when did we introduce the bug (or forgot to update the method with the fix), since the older spreadsheet files from the original experiments do show a difference when changing the gamma.

Thanks for the help!
If you happen to already have a hotfix for the issue, please do propose it to speed things up.

@fszatkowski
Copy link
Author

fszatkowski commented Jun 5, 2023

I simply tried removing activations.detach() call from hooks and making torch.no_grad()in GradCAM pass conditional:

class GradCAM:
    ...
    def __enter__(self):
        # register hooks to collect activations and gradients
        def forward_hook(module, input, output):
            if self.retain_graph:
                self.activations = output
            else:
                self.activations = output.detach()

        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()
    ...

    def __call__(self, input, class_indices=None, return_outputs=False, adapt_bn=False):
        ...
        with torch.no_grad() if not self.retain_graph else contextlib.suppress():
            weights = F.adaptive_avg_pool2d(self.gradients, 1)
            att_map = (weights * self.activations).sum(dim=1, keepdim=True)
            att_map = F.relu(att_map)
            del self.activations
            del self.gradients
            return (att_map, model_output) if return_outputs else att_map

Then in the training loop:

                ...
                attmap_old, outputs_old = gradcam_old(images, return_outputs=True)
                with GradCAM(self.model, self.gradcam_layer, retain_graph=True) as gradcam:
                    attmap = gradcam(images)  # this use eval() pass
                ...

But the results I got with it were far from the scores from your paper, so I think this is still not working as supposed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants