forked from junyanz/pytorch-CycleGAN-and-pix2pix
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss_over_time.py
57 lines (54 loc) · 2.27 KB
/
loss_over_time.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from eval_batch import EXP_MODEL_PAIRS
from tqdm import tqdm
import numpy as np
import re
import matplotlib.pyplot as plt
if __name__ == "__main__":
for experiment, model in tqdm(EXP_MODEL_PAIRS):
path = f'checkpoints/{experiment}/loss_log.txt'
epoch = 0
g_gan = []
g_l1 = []
d_real = []
d_fake = []
with open(path) as fp:
tmp_g_gan = []
tmp_g_l1 = []
tmp_d_real = []
tmp_d_fake = []
for line in fp:
try:
epoch_curr = int(re.search('epoch: (.+?),', line).group(1))
tmp_g_gan.append(float(re.search('G_GAN: (.+?) G_L1:', line).group(1)))
tmp_g_l1.append(float(re.search('G_L1: (.+?) D_real:', line).group(1)))
tmp_d_real.append(float(re.search('D_real: (.+?) D_fake:', line).group(1)))
tmp_d_fake.append(float(re.search('D_fake: (.+?)\n', line).group(1)))
if epoch_curr != epoch:
epoch = epoch_curr
g_gan.append(sum(tmp_g_gan)/len(tmp_g_gan))
tmp_g_gan = []
g_l1.append(sum(tmp_g_l1)/len(tmp_g_l1))
tmp_g_l1 = []
d_real.append(sum(tmp_d_real)/len(tmp_d_real))
tmp_d_real = []
d_fake.append(sum(tmp_d_fake)/len(tmp_d_fake))
tmp_d_fake = []
except AttributeError:
continue
x = np.arange(0, 200, 1)
fig, ax = plt.subplots()
line1, = ax.plot(x, g_gan, color="black", label='GAN')
line2, = ax.plot(x, g_l1, color="blue", label='L1')
ax.legend(handles=[line1, line2])
plt.title(f"Generator Losses - {model}")
plt.xlabel("Epoch")
plt.ylabel("Value")
fig.savefig(f'G_{experiment}.png')
fig, ax = plt.subplots()
line1, = ax.plot(x, d_real, color="black", label='D_REAL')
line2, = ax.plot(x, d_fake, color="blue", label='D_FAKE')
ax.legend(handles=[line1, line2])
plt.title(f"Discriminator Losses - {model}")
plt.xlabel("Epoch")
plt.ylabel("Value")
fig.savefig(f'D_{experiment}.png')