forked from torch/nn
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Normalize.lua
155 lines (137 loc) · 4.49 KB
/
Normalize.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
local Normalize, parent = torch.class('nn.Normalize', 'nn.Module')
function Normalize:__init(p,eps)
parent.__init(self)
assert(p,'p-norm not provided')
assert(p > 0, p..'-norm not supported')
self.p = p
self.eps = eps or 1e-10
end
function Normalize:updateOutput(input)
assert(input:dim() <= 2, 'only 1d layer supported')
local input_size = input:size()
if input:dim() == 1 then
input = input:view(1,-1)
end
self._output = self._output or input.new()
self.norm = self.norm or input.new()
self.buffer = self.buffer or input.new()
self._output:resizeAs(input)
if self.p == math.huge then
-- specialization for the infinity norm
self._indices = self._indices or
(torch.type(self.output) == 'torch.CudaTensor' and
torch.CudaTensor() or torch.LongTensor())
self.buffer:abs(input)
torch.max(self.norm, self._indices, self.buffer, 2)
self.norm:add(self.eps)
else
self.normp = self.normp or input.new()
if self.p % 2 ~= 0 then
self.buffer:abs(input):pow(self.p)
else
self.buffer:pow(input,self.p)
end
self.normp:sum(self.buffer,2):add(self.eps)
self.norm:pow(self.normp,1/self.p)
end
self._output:cdiv(input, self.norm:view(-1,1):expandAs(input))
self.output:view(self._output, input_size)
return self.output
end
function Normalize:updateGradInput(input, gradOutput)
assert(input:dim() <= 2, 'only 1d layer supported')
assert(gradOutput:dim() <= 2, 'only 1d layer supported')
local input_size = input:size()
if input:dim() == 1 then
input = input:view(1,-1)
end
local n = input:size(1) -- batch size
local d = input:size(2) -- dimensionality of vectors
self._gradInput = self._gradInput or input.new()
self.cross = self.cross or input.new()
-- compute diagonal term with gradOutput
self._gradInput:resize(n,d)
if self.p == math.huge then
-- specialization for the inf case
self._gradInput:cmul(self.norm:view(n,1,1):expand(n,d,1),gradOutput)
self.buffer:resizeAs(input):zero()
self.cross:resize(n,1)
self.cross:gather(input,2,self._indices)
self.cross:cdiv(self.norm)
self.buffer:scatter(2,self._indices,self.cross)
else
self._gradInput:cmul(self.normp:view(n,1):expand(n,d), gradOutput)
-- small optimizations for different p
-- buffer = input*|input|^(p-2)
if self.p % 2 ~= 0 then
-- for non-even p, need to add absolute value
if self.p < 2 then
-- add eps to avoid possible division by 0
self.buffer:abs(input):add(self.eps):pow(self.p-2):cmul(input)
else
self.buffer:abs(input):pow(self.p-2):cmul(input)
end
elseif self.p == 2 then
-- special case for p == 2, pow(x,0) = 1
self.buffer:copy(input)
else
-- p is even and > 2, pow(x,p) is always positive
self.buffer:pow(input,self.p-2):cmul(input)
end
end
-- compute cross term in two steps
self.cross:resize(n,1)
-- instead of having a huge temporary matrix (b1*b2),
-- do the computations as b1*(b2*gradOutput). This avoids redundant
-- computation and also a huge buffer of size n*d^2
self.buffer2 = self.buffer2 or input.new() -- nxd
self.buffer2:cmul(input, gradOutput)
self.cross:sum(self.buffer2, 2)
self.buffer:cmul(self.cross:expandAs(self.buffer))
self._gradInput:add(-1, self.buffer)
-- reuse cross buffer for normalization
if self.p == math.huge then
self.cross:cmul(self.norm,self.norm)
else
self.cross:cmul(self.normp,self.norm)
end
self._gradInput:cdiv(self.cross:expand(n,d))
self.gradInput:view(self._gradInput, input_size)
return self.gradInput
end
function Normalize:__tostring__()
local s
-- different prints if the norm is integer
if self.p % 1 == 0 then
s = '%s(%d)'
else
s = '%s(%f)'
end
return string.format(s,torch.type(self),self.p)
end
function Normalize:type(type, tensorCache)
-- torch.max expects a LongTensor as indices, whereas cutorch.max expects a CudaTensor.
if type == 'torch.CudaTensor' then
parent.type(self, type, tensorCache)
else
-- self._indices must be a LongTensor. Setting it to nil temporarily avoids
-- unnecessary memory allocations.
local indices
indices, self._indices = self._indices, nil
parent.type(self, type, tensorCache)
self._indices = indices and indices:long() or nil
end
return self
end
function Normalize:clearState()
nn.utils.clear(self, {
'_output',
'_indices',
'_gradInput',
'buffer',
'norm',
'normp',
'cross',
})
return parent.clearState(self)
end