-
Notifications
You must be signed in to change notification settings - Fork 148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A mistake on gradient penalty #40
Comments
You're right. The gradient penalty is now stable and the generator loss isn't changing so fast. I prefer to just use tensor.norm though.
|
Tried your code, and yeah, it has artifacts too. I wonder what's wrong. |
Took a look at the official WGAN GP code and it doesn't have batch-norm. So I removed it and the artifacts are gone. It doesn't seem to be better than before the adjustment though. EDIT: Actually batch norm is only not used for MNIST, but unless I disable it for others, I'm still getting artifacts. |
https://github.com/Zeleni9/pytorch-wgan/blob/master/models/wgan_gradient_penalty.py#L324
grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.lambda_term
The shape of
gradients
is [batch_size, 1, 32, 32], the abovenorm(2, dim=1)
is actually not behaving as you may want (and returns a tensor with shape [batch_size, 32, 32]). This is actually enforcinggradients
to be closer to 1 element-wise. So the gradient and the lipschitz are much bigger than 1.A potential fix is
grad_penalty = (((gradients.view(gradients.shape[0], -1) ** 2).sum(dim=1).sqrt() - 1) ** 2).mean() * self.lambda_term
Correct me if I am wrong. Thx.
The text was updated successfully, but these errors were encountered: