-
Notifications
You must be signed in to change notification settings - Fork 21
/
initialization.py
65 lines (53 loc) · 1.75 KB
/
initialization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import numpy as np
import scipy.linalg as la
def henaff_init_(A):
size = A.size(0) // 2
diag = A.new(size).uniform_(-np.pi, np.pi)
return create_diag_(A, diag)
def cayley_init_(A):
size = A.size(0) // 2
diag = A.new(size).uniform_(0., np.pi / 2.)
diag = -torch.sqrt((1. - torch.cos(diag))/(1. + torch.cos(diag)))
return create_diag_(A, diag)
# We include a few more initializations that could be useful for other problems
def haar_init_(A):
""" Haar initialization on SO(n) """
torch.nn.init.orthogonal_(A)
with torch.no_grad():
if A.det() < 0.:
# Go bijectively from O^-(n) to O^+(n) \iso SO(n)
idx = np.random.randint(0, A.size(0))
A[idx] *= -1.
An = la.logm(A.data.cpu().numpy()).real
An = .5 * (An - An.T)
A.copy_(torch.tensor(An))
return A
def haar_diag_init_(A):
""" Block-diagonal skew-symmetric matrix with eigenvalues distributed as those from a Haar """
haar_init_(A)
with torch.no_grad():
An = A.data.cpu().numpy()
eig = la.eigvals(An).imag
eig = eig[::2]
if A.size(0) % 2 == 1:
eig = eig[:-1]
eig = torch.tensor(eig)
return create_diag_(A, eig)
def normal_squeeze_diag_init_(A):
size = A.size(0) // 2
diag = A.new(size).normal_(0, 1).fmod_(np.pi/8.)
return create_diag_(A, diag)
def normal_diag_init_(A):
size = A.size(0) // 2
diag = A.new(size).normal_(0, 1).fmod_(np.pi)
return create_diag_(A, diag)
def create_diag_(A, diag):
n = A.size(0)
diag_z = torch.zeros(n-1)
diag_z[::2] = diag
A_init = torch.diag(diag_z, diagonal=1)
A_init = A_init - A_init.T
with torch.no_grad():
A.copy_(A_init)
return A