-
Notifications
You must be signed in to change notification settings - Fork 1
/
safe_svd_jax.py
72 lines (52 loc) · 1.59 KB
/
safe_svd_jax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import jax
import jax.numpy as np
# credits https://github.com/wangleiphy/tensorgrad/blob/master/tensornets/adlib/svd.py
def safe_inverse(x, epsilon=1E-8):
#epsilon=epsilon*torch.max(torch.abs(x))
return x/(x**2 + epsilon)
@jax.custom_jvp
def svd(A):
u,s,vh=np.linalg.svd(A,full_matrices=True)
return u,s,vh
assert False, 'jvp not implemented'
# @svd.defjvp
# def svd_jvp(primals, tangents):
# print('foo')
# print([x.shape for x in primals])
# print([x.shape for x in tangents])
# A, = primals
# u,s,vh=svd(A)
# du,ds,dvh, = tangents
# M=u.shape[0]
# N=vh.shape[0]
# NS=s.shape[0]
# Sinv=safe_inverse(s)
# #assert Sinv.isfinite().all()
# F = (s - s[:, None])
# F = safe_inverse(F)
# F.diagonal().fill_(0)
# #assert F.isfinite().all()
# G = (s + s[:, None])
# G = safe_inverse(G)
# G.diagonal().fill_(0)
# #assert G.isfinite().all()
# udu=u.t()@du
# vdv=vh@dvh.t().conj()
# #assert dS.isfinite().all()
# su = (F+G)*(udu-udu.t().conj())/2
# sv = (F-G)*(vdv-vdv.t().conj())/2
# dA = u @ (su + sv + np.diag(ds)) @ vh
# #assert dA.isfinite().all()
# if (M>NS):
# dA = dA + (np.eye(M) - u@u.t().conj()) @ (du*Sinv) @ vh
# if (N>NS):
# dA = dA + (u*Sinv) @ dvh @ (np.eye(N) - vh.t().conj()@vh)
# #assert dA.isfinite().all()
# return dA
if __name__=='__main__':
from jax.test_util import check_grads
print('testing autograd')
A=jax.random.normal(jax.random.PRNGKey(0),shape=(20,16))
check_grads(svd, (A,), order=1)
print('passed')
__all__=['svd']