Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A mistake on gradient penalty #40

Open
doub7e opened this issue May 27, 2022 · 4 comments
Open

A mistake on gradient penalty #40

doub7e opened this issue May 27, 2022 · 4 comments

Comments

@doub7e
Copy link

doub7e commented May 27, 2022

https://github.com/Zeleni9/pytorch-wgan/blob/master/models/wgan_gradient_penalty.py#L324

grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.lambda_term

The shape of gradients is [batch_size, 1, 32, 32], the above norm(2, dim=1) is actually not behaving as you may want (and returns a tensor with shape [batch_size, 32, 32]). This is actually enforcing gradients to be closer to 1 element-wise. So the gradient and the lipschitz are much bigger than 1.

A potential fix is
grad_penalty = (((gradients.view(gradients.shape[0], -1) ** 2).sum(dim=1).sqrt() - 1) ** 2).mean() * self.lambda_term

Correct me if I am wrong. Thx.

@R-N
Copy link

R-N commented Nov 24, 2023

You're right. The gradient penalty is now stable and the generator loss isn't changing so fast. I prefer to just use tensor.norm though.

grad_norm = gradients.view(gradients.shape[0], -1).norm(2, dim=-1)
gradient_penalty = ((grad_norm - 1) ** 2).mean()

@R-N
Copy link

R-N commented Nov 24, 2023

Ok, well, this adjustment actually adds artifacts to the generated images? Not sure why

With adjustment:
img_generatori_iter_800

Without:
img_generatori_iter_800 (1)

@R-N
Copy link

R-N commented Nov 24, 2023

Tried your code, and yeah, it has artifacts too. I wonder what's wrong.
Perhaps this is one of those "if it's not broken don't fix it" moments.

@R-N
Copy link

R-N commented Nov 24, 2023

Took a look at the official WGAN GP code and it doesn't have batch-norm. So I removed it and the artifacts are gone. It doesn't seem to be better than before the adjustment though.

EDIT: Actually batch norm is only not used for MNIST, but unless I disable it for others, I'm still getting artifacts.

img_generatori_iter_800 (2)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants