-
Notifications
You must be signed in to change notification settings - Fork 9
/
gmm.py
73 lines (59 loc) · 2.26 KB
/
gmm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn
import numpy as np
from utils import cosine_schedule
class GMMDataset(torch.utils.data.Dataset):
def __init__(self, samples):
self.samples = samples
def __len__(self):
return self.samples.shape[0]
def __getitem__(self, idx):
return self.samples[idx, :], torch.ones(1).to(self.samples.device)
class GMM(nn.Module):
""" Gaussian Mixture Models
N.B.: covariance is assumed to be diagonal
"""
def __init__(self, w, mu, sigma):
"""
p(x) = sum_i w[i] N(mu[i], sigma[i]^2 * I)
config:
w: shape K X 1, mixture coefficients, must sum to 1
mu: shape K X D, mean
sigma: shape K X D, (diagonal) variance
"""
super().__init__()
self.register_buffer('w', w)
self.register_buffer('mu', mu)
self.register_buffer('sigma', sigma)
self.K = w.shape[0]
self.D = mu.shape[1]
@torch.no_grad()
def log_gaussian(self, x, mu, sigma):
""" log density of single (diagonal-covariance) multivariate Gaussian"""
return -0.5 * ((x - mu)**2 / sigma**2).sum(dim=1) - 0.5 * (
self.D * np.log(2 * np.pi) + torch.log(torch.prod(sigma**2)))
@torch.no_grad()
def log_prob(self, x):
return torch.logsumexp(
torch.stack([
torch.log(self.w[kk]) +
self.log_gaussian(x, self.mu[kk], self.sigma[kk])
for kk in range(self.K)
]), 0)
@torch.no_grad()
def sampling(self, num_samples):
m = torch.distributions.Categorical(self.w)
idx = m.sample((num_samples,))
return self.mu[idx, :] + torch.randn(num_samples, self.D).to(
self.w.device) * self.sigma[idx, :]
@torch.no_grad()
def langevin_sampling(self, x, num_steps=10, eta=1.0e+0, is_anneal=False):
eta_list = cosine_schedule(eta_max=eta, T=num_steps)
for ii in range(num_steps):
eta_ii = eta_list[ii] if is_anneal else eta
x = x.detach()
x.requires_grad = True
eng = -self.log_prob(x).sum()
grad = torch.autograd.grad(eng, x)[0]
x = x - eta_ii * grad + torch.randn_like(x) * np.sqrt(eta_ii * 2)
return x.detach()