-
Notifications
You must be signed in to change notification settings - Fork 2
/
SDS.py
308 lines (252 loc) · 10.1 KB
/
SDS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
import torch
import torch.nn.functional as F
from diffusers import DDIMScheduler, StableDiffusionPipeline
import open_clip
import torchvision.transforms as T
class SDS:
"""
Class to implement the SDS loss function.
"""
def __init__(
self,
sd_version="2.1",
device="cpu",
t_range=[0.02, 0.98],
output_dir="output",
):
"""
Load the Stable Diffusion model and set the parameters.
Args:
sd_version (str): version for stable diffusion model
device (_type_): _description_
"""
# Set the stable diffusion model key based on the version
if sd_version == "2.1":
sd_model_key = "stabilityai/stable-diffusion-2-1-base"
else:
raise NotImplementedError(
f"Stable diffusion version {sd_version} not supported"
)
# Set parameters
self.H = 512 # default height of Stable Diffusion
self.W = 512 # default width of Stable Diffusion
self.num_inference_steps = 50
self.output_dir = output_dir
self.device = device
self.precision_t = torch.float32
# Create model
sd_pipe = StableDiffusionPipeline.from_pretrained(
sd_model_key, torch_dtype=self.precision_t
).to(device)
self.preprocess = T.Resize((self.H, self.W))
self.vae = sd_pipe.vae
self.tokenizer = sd_pipe.tokenizer
self.text_encoder = sd_pipe.text_encoder
self.unet = sd_pipe.unet
self.scheduler = DDIMScheduler.from_pretrained(
sd_model_key, subfolder="scheduler", torch_dtype=self.precision_t
)
del sd_pipe
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(
self.device
) # for convenient access
print(f"[INFO] loaded stable diffusion!")
@torch.no_grad()
def get_text_embeddings(self, prompt):
"""
Get the text embeddings for the prompt.
Args:
prompt (list of string): text prompt to encode.
"""
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
return text_embeddings
def encode_imgs(self, img):
"""
Encode the image to latent representation.
Args:
img (tensor): image to encode. shape (N, 3, H, W), range [0, 1]
Returns:
latents (tensor): latent representation. shape (1, 4, 64, 64)
"""
# check the shape of the image should be 512x512
assert img.shape[-2:] == (512, 512), "Image shape should be 512x512"
img = 2 * img - 1 # [0, 1] => [-1, 1]
img = self.preprocess(img)
posterior = self.vae.encode(img).latent_dist
latents = posterior.sample() * self.vae.config.scaling_factor
return latents
def decode_latents(self, latents):
"""
Decode the latent representation into RGB image.
Args:
latents (tensor): latent representation. shape (1, 4, 64, 64), range [-1, 1]
Returns:
imgs[0] (np.array): decoded image. shape (512, 512, 3), range [0, 255]
"""
latents = 1 / self.vae.config.scaling_factor * latents
imgs = self.vae.decode(latents.type(self.precision_t)).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1) # [-1, 1] => [0, 1]
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() # torch to numpy
imgs = (imgs * 255).round() # [0, 1] => [0, 255]
return imgs[0]
def sds_loss(
self,
latents,
text_embeddings,
text_embeddings_uncond=None,
guidance_scale=100,
grad_scale=1
):
"""
Compute the SDS loss.
Args:
latents (tensor): input latents, shape [1, 4, 64, 64]
text_embeddings (tensor): conditional text embedding (for positive prompt), shape [1, 77, 1024]
text_embeddings_uncond (tensor, optional): unconditional text embedding (for negative prompt), shape [1, 77, 1024]. Defaults to None.
guidance_scale (int, optional): weight scaling for guidance. Defaults to 100.
grad_scale (int, optional): gradient scaling. Defaults to 1.
Returns:
loss (tensor): SDS loss
"""
# sample a timestep ~ U(0.02, 0.98) to avoid very high/low noise level
t = torch.randint(
self.min_step,
self.max_step + 1,
(latents.shape[0],),
dtype=torch.long,
device=self.device,
)
# predict the noise residual with unet, NO grad!
with torch.no_grad():
### YOUR CODE HERE ###
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
# pred noise
noise_pred = self.unet(latents_noisy, t, text_embeddings).sample
if text_embeddings_uncond is not None and guidance_scale != 1:
### YOUR CODE HERE ###
noise_pred_uncond = self.unet(latents_noisy, t, text_embeddings_uncond).sample
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
w = 1 - self.alphas[t]
gradient = w * (noise_pred - noise)
latents_target = latents - grad_scale * gradient
loss = ((latents_target - latents) ** 2).sum()
return loss
def sds_loss_batch(
self,
latents,
text_embeddings,
text_embeddings_uncond=None,
guidance_scale=100,
grad_scale=1
):
loss = 0.
for i in range(len(latents)):
loss += self.sds_loss(latents[i:i+1, ...], text_embeddings, text_embeddings_uncond, guidance_scale, grad_scale)
if len(latents) > 0:
loss /= len(latents)
return loss
class CLIP:
"""
Class to implement the SDS loss function.
"""
def __init__(
self,
device="cpu",
output_dir="output",
):
"""
Load the Stable Diffusion model and set the parameters.
Args:
sd_version (str): version for stable diffusion model
device (_type_): _description_
"""
# Set parameters
self.H = 224 # default height of CLIP
self.W = 224 # default width of CLIP
self.output_dir = output_dir
self.device = device
# Set the open_clip model key based on the version
model, _, _ = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
self.preprocess = T.Compose([T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
T.Resize((self.H, self.W))])
self.tokenizer = open_clip.get_tokenizer('ViT-B-32')
self.model = model.to(device)
print(f"[INFO] loaded OpenClip!")
@torch.no_grad()
def get_text_embeddings(self, prompt):
"""
Get the text embeddings for the prompt.
Args:
prompt (list of string): text prompt to encode.
"""
return self.model.encode_text(self.tokenizer(prompt).to(self.device))
def encode_imgs(self, image):
"""
Encode images to latent representation.
Args:
img (tensor): image to encode. shape (N, 3, H, W), range [0, 1]
Returns:
latents (tensor): latent representation. shape (N, 512)
"""
image = self.preprocess(image)
# Encode the rendered image to latents
image_embeddings = self.model.encode_image(image)
return image_embeddings
def clip_loss(
self,
imgs,
text_embeddings,
text_embeddings_uncond=None
):
"""
Compute the SDS loss.
Args:
imgs (tensor): input latents, shape [N, H, W, 3]
text_embeddings (tensor): conditional text embedding (for positive prompt), shape [1, 77, 1024]
text_embeddings_uncond (tensor, optional): unconditional text embedding (for negative prompt), shape [1, 77, 1024]. Defaults to None.
Returns:
loss (tensor): CLIP loss
"""
image_embeddings = self.encode_imgs(imgs)
# Compute the loss
image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
if text_embeddings_uncond is not None:
text_embeddings_uncond = text_embeddings_uncond / text_embeddings_uncond.norm(dim=-1, keepdim=True)
text_embeddings = torch.cat([text_embeddings, text_embeddings_uncond])
text_probs = (image_embeddings @ text_embeddings.T).mean(0)
loss = -text_probs[0] + text_probs[1:].mean()
else:
text_probs = (image_embeddings @ text_embeddings.T).mean(0)
loss = -text_probs[0]
return loss
def clip_score(
self,
imgs,
text_embeddings
):
"""
Compute the SDS loss.
Args:
imgs (tensor): input latents, shape [N, H, W, 3]
text_embeddings (tensor): conditional text embedding (for positive prompt), shape [1, 77, 1024]
text_embeddings_uncond (tensor, optional): unconditional text embedding (for negative prompt), shape [1, 77, 1024]. Defaults to None.
Returns:
loss (tensor): CLIP loss
"""
image_embeddings = self.encode_imgs(imgs)
# Compute the loss
image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
text_probs = (image_embeddings @ text_embeddings.T).mean(0)
return text_probs