Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training #3

Open
Bozorgtabar opened this issue Jan 2, 2018 · 4 comments
Open

Training #3

Bozorgtabar opened this issue Jan 2, 2018 · 4 comments

Comments

@Bozorgtabar
Copy link

I think you must train the discriminator first and then train the generator.
For each step of the generator training , you must repeat the discriminator training for 5 times

@goldkim92
Copy link
Owner

Thanks for the comment. I also found that I should change my code to have data augmentation and learning rate decay.

@Bozorgtabar
Copy link
Author

Great!
When you want to change your code?
Please update me

@goldkim92
Copy link
Owner

Changed the code right now.
Before revising the code, the result was not very good enough to show at README. But right now, after changing a few things you told me, I'm looking forward to have a nice result. Thanks.

@Bozorgtabar
Copy link
Author

One more thing, I noticed that the discriminator and generator losses are not correct.
Please take a look at Pytorch implementation:

================== Train D ==================

        # Real images (CelebA)
        out_real, out_cls = self.D(real_x1)
        out_cls1 = out_cls[:, :self.c_dim]      # celebA part
        d_loss_real = - torch.mean(out_real)
        d_loss_cls = F.binary_cross_entropy_with_logits(out_cls1, real_label1, size_average=False) / real_x1.size(0)

        # Real images (RaFD)
        out_real, out_cls = self.D(real_x2)
        out_cls2 = out_cls[:, self.c_dim:]      # rafd part
        d_loss_real += - torch.mean(out_real)
        d_loss_cls += F.cross_entropy(out_cls2, real_label2)

        # Compute classification accuracy of the discriminator
        if (i+1) % self.log_step == 0:
            accuracies = self.compute_accuracy(out_cls1, real_label1, 'CelebA')
            log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
            print('Classification Acc (Black/Blond/Brown/Gender/Aged): ')
            print(log)
            accuracies = self.compute_accuracy(out_cls2, real_label2, 'RaFD')
            log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
            print('Classification Acc (8 emotional expressions): ')
            print(log)

        # Fake images (CelebA)
        fake_c = torch.cat([fake_c1, zero1, mask1], dim=1)
        fake_x1 = self.G(real_x1, fake_c)
        fake_x1 = Variable(fake_x1.data)
        out_fake, _ = self.D(fake_x1)
        d_loss_fake = torch.mean(out_fake)

        # Fake images (RaFD)
        fake_c = torch.cat([zero2, fake_c2, mask2], dim=1)
        fake_x2 = self.G(real_x2, fake_c)
        out_fake, _ = self.D(fake_x2)
        d_loss_fake += torch.mean(out_fake)

        # Backward + Optimize
        d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
        self.reset_grad()
        d_loss.backward()
        self.d_optimizer.step()

        # Compute gradient penalty
        if (i+1) % 2 == 0:
            real_x = real_x1
            fake_x = fake_x1
        else:
            real_x = real_x2
            fake_x = fake_x2

        alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
        interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
        out, out_cls = self.D(interpolated)

        if (i+1) % 2 == 0:
            out_cls = out_cls[:, :self.c_dim]  # CelebA
        else:
            out_cls = out_cls[:, self.c_dim:]  # RaFD

        grad = torch.autograd.grad(outputs=out,
                                   inputs=interpolated,
                                   grad_outputs=torch.ones(out.size()).cuda(),
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        grad = grad.view(grad.size(0), -1)
        grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
        d_loss_gp = torch.mean((grad_l2norm - 1)**2)

        # Backward + Optimize
        d_loss = self.lambda_gp * d_loss_gp
        self.reset_grad()
        d_loss.backward()
        self.d_optimizer.step()

        # Logging
        loss = {}
        loss['D/loss_real'] = d_loss_real.data[0]
        loss['D/loss_fake'] = d_loss_fake.data[0]
        loss['D/loss_cls'] = d_loss_cls.data[0]
        loss['D/loss_gp'] = d_loss_gp.data[0]

For example Real_A in your script is supposed to be (real_x, fake_c)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants