-
Notifications
You must be signed in to change notification settings - Fork 1
/
models.py
70 lines (60 loc) · 2.33 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from arch import *
from utils import *
import torch
import torch.nn as nn
def reparameterize(mu, logvar):
"""
This function applies the reparameterization trick:
z = mu(X) + sigma(X)^0.5 * epsilon, where epsilon ~ N(0,I)
:param mu: mean of x
:param logvar: log variaance of x
:return z: the sampled latent variable
"""
device = mu.device
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std).to(device)
return mu + eps * std
class SoftIntroVAE(nn.Module):
def __init__(self, cdim=3, zdim=512, channels=(64, 128, 256, 512, 512, 512), image_size=256, conditional=False,
cond_dim=10):
super(SoftIntroVAE, self).__init__()
self.zdim = zdim
self.conditional = conditional
self.cond_dim = cond_dim
self.encoder = Encoder(cdim, zdim, channels, image_size, conditional=conditional, cond_dim=cond_dim)
self.decoder = Decoder(cdim, zdim, channels, image_size, conditional=conditional,
conv_input_size=self.encoder.conv_output_size, cond_dim=cond_dim)
def forward(self, x, o_cond=None, deterministic=False):
if self.conditional and o_cond is not None:
mu, logvar = self.encode(x, o_cond=o_cond)
if deterministic:
z = mu
else:
z = reparameterize(mu, logvar)
y = self.decode(z, y_cond=o_cond)
else:
mu, logvar = self.encode(x)
if deterministic:
z = mu
else:
z = reparameterize(mu, logvar)
y = self.decode(z)
return mu, logvar, z, y
def sample(self, z, y_cond=None):
y = self.decode(z, y_cond=y_cond)
return y
def sample_with_noise(self, num_samples=1, device=torch.device("cpu"), y_cond=None):
z = torch.randn(num_samples, self.z_dim).to(device)
return self.decode(z, y_cond=y_cond)
def encode(self, x, o_cond=None):
if self.conditional and o_cond is not None:
mu, logvar = self.encoder(x, o_cond=o_cond)
else:
mu, logvar = self.encoder(x)
return mu, logvar
def decode(self, z, y_cond=None):
if self.conditional and y_cond is not None:
y = self.decoder(z, y_cond=y_cond)
else:
y = self.decoder(z)
return y