-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
115 lines (102 loc) · 4 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import torch.nn as nn
import pytorch_lightning as pl
import math
from modules import *
class DiffusionModel(pl.LightningModule):
def __init__(self, in_size, t_range, img_depth):
super().__init__()
self.beta_small = 1e-4
self.beta_large = 0.02
self.t_range = t_range
self.in_size = in_size
bilinear = True
self.inc = DoubleConv(img_depth, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
factor = 2 if bilinear else 1
self.down3 = Down(256, 512 // factor)
self.up1 = Up(512, 256 // factor, bilinear)
self.up2 = Up(256, 128 // factor, bilinear)
self.up3 = Up(128, 64, bilinear)
self.outc = OutConv(64, img_depth)
self.sa1 = SAWrapper(256, 8)
self.sa2 = SAWrapper(256, 4)
self.sa3 = SAWrapper(128, 8)
def pos_encoding(self, t, channels, embed_size):
inv_freq = 1.0 / (
10000
** (torch.arange(0, channels, 2, device=self.device).float() / channels)
)
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
return pos_enc.view(-1, channels, 1, 1).repeat(1, 1, embed_size, embed_size)
def forward(self, x, t):
"""
Model is U-Net with added positional encodings and self-attention layers.
"""
x1 = self.inc(x)
x2 = self.down1(x1) + self.pos_encoding(t, 128, 16)
x3 = self.down2(x2) + self.pos_encoding(t, 256, 8)
x3 = self.sa1(x3)
x4 = self.down3(x3) + self.pos_encoding(t, 256, 4)
x4 = self.sa2(x4)
x = self.up1(x4, x3) + self.pos_encoding(t, 128, 8)
x = self.sa3(x)
x = self.up2(x, x2) + self.pos_encoding(t, 64, 16)
x = self.up3(x, x1) + self.pos_encoding(t, 64, 32)
output = self.outc(x)
return output
def beta(self, t):
return self.beta_small + (t / self.t_range) * (
self.beta_large - self.beta_small
)
def alpha(self, t):
return 1 - self.beta(t)
def alpha_bar(self, t):
return math.prod([self.alpha(j) for j in range(t)])
def get_loss(self, batch, batch_idx):
"""
Corresponds to Algorithm 1 from (Ho et al., 2020).
"""
ts = torch.randint(0, self.t_range, [batch.shape[0]], device=self.device)
noise_imgs = []
epsilons = torch.randn(batch.shape, device=self.device)
for i in range(len(ts)):
a_hat = self.alpha_bar(ts[i])
noise_imgs.append(
(math.sqrt(a_hat) * batch[i]) + (math.sqrt(1 - a_hat) * epsilons[i])
)
noise_imgs = torch.stack(noise_imgs, dim=0)
e_hat = self.forward(noise_imgs, ts.unsqueeze(-1).type(torch.float))
loss = nn.functional.mse_loss(
e_hat.reshape(-1, self.in_size), epsilons.reshape(-1, self.in_size)
)
return loss
def denoise_sample(self, x, t):
"""
Corresponds to the inner loop of Algorithm 2 from (Ho et al., 2020).
"""
with torch.no_grad():
if t > 1:
z = torch.randn(x.shape)
else:
z = 0
e_hat = self.forward(x, t.view(1, 1).repeat(x.shape[0], 1))
pre_scale = 1 / math.sqrt(self.alpha(t))
e_scale = (1 - self.alpha(t)) / math.sqrt(1 - self.alpha_bar(t))
post_sigma = math.sqrt(self.beta(t)) * z
x = pre_scale * (x - e_scale * e_hat) + post_sigma
return x
def training_step(self, batch, batch_idx):
loss = self.get_loss(batch, batch_idx)
self.log("train/loss", loss)
return loss
def validation_step(self, batch, batch_idx):
loss = self.get_loss(batch, batch_idx)
self.log("val/loss", loss)
return
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=2e-4)
return optimizer