Skip to content

Commit 1b95f7e

Browse files
authored
Merge pull request #1131 from huihuifan/weightnorm
adding weight norm container
2 parents c3cdec6 + 28d67b0 commit 1b95f7e

File tree

4 files changed

+252
-26
lines changed

4 files changed

+252
-26
lines changed

WeightNorm.lua

+165
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
-- Weight Normalization
2+
-- https://arxiv.org/pdf/1602.07868v3.pdf
3+
local WeightNorm, parent = torch.class("nn.WeightNorm", "nn.Container")
4+
5+
function WeightNorm:__init(module, outputDim)
6+
-- this container will apply Weight Normalization to any module it wraps
7+
-- it accepts parameter ``outputDim`` that represents the dimension of the output of the weight
8+
-- if outputDim is not 1, the container will transpose the weight
9+
-- if the weight is not 2D, the container will view the weight into a 2D shape
10+
-- that is nOut x (nIn x kw x dw x ...)
11+
12+
parent.__init(self)
13+
assert(module.weight)
14+
15+
if module.bias then
16+
self.bias = module.bias
17+
self.gradBias = module.gradBias
18+
end
19+
self.gradWeight = module.gradWeight
20+
self.weight = module.weight
21+
22+
self.outputDim = outputDim or 1
23+
24+
-- track the non-output weight dimensions
25+
self.otherDims = 1
26+
for i = 1, self.weight:dim() do
27+
if i ~= self.outputDim then
28+
self.otherDims = self.otherDims * self.weight:size(i)
29+
end
30+
end
31+
32+
-- view size for weight norm 2D calculations
33+
self.viewIn = torch.LongStorage({self.weight:size(self.outputDim), self.otherDims})
34+
35+
-- view size back to original weight
36+
self.viewOut = self.weight:size()
37+
38+
-- bubble outputDim size up to the front
39+
for i = self.outputDim - 1, 1, -1 do
40+
self.viewOut[i], self.viewOut[i + 1] = self.viewOut[i + 1], self.viewOut[i]
41+
end
42+
43+
-- weight is reparametrized to decouple the length from the direction
44+
-- such that w = g * ( v / ||v|| )
45+
self.v = torch.Tensor(self.viewIn[1], self.viewIn[2])
46+
self.g = torch.Tensor(self.viewIn[1])
47+
48+
self._norm = torch.Tensor(self.viewIn[1])
49+
self._scale = torch.Tensor(self.viewIn[1])
50+
51+
-- gradient of g
52+
self.gradG = torch.Tensor(self.viewIn[1]):zero()
53+
-- gradient of v
54+
self.gradV = torch.Tensor(self.viewIn)
55+
56+
self.modules[1] = module
57+
self:resetInit()
58+
end
59+
60+
function WeightNorm:permuteIn(inpt)
61+
local ans = inpt
62+
for i = self.outputDim - 1, 1, -1 do
63+
ans = ans:transpose(i, i+1)
64+
end
65+
return ans
66+
end
67+
68+
function WeightNorm:permuteOut(inpt)
69+
local ans = inpt
70+
for i = 1, self.outputDim - 1 do
71+
ans = ans:transpose(i, i+1)
72+
end
73+
return ans
74+
end
75+
76+
function WeightNorm:resetInit(inputSize, outputSize)
77+
self.v:normal(0, math.sqrt(2/self.viewIn[2]))
78+
self.g:norm(self.v, 2, 2)
79+
if self.bias then
80+
self.bias:zero()
81+
end
82+
end
83+
84+
function WeightNorm:updateOutput(input)
85+
-- view to 2D when weight norm container operates
86+
self.gradV:copy(self:permuteIn(self.weight))
87+
self.gradV = self.gradV:view(self.viewIn)
88+
89+
-- ||w||
90+
self._norm:norm(self.v, 2, 2):pow(2):add(10e-5):sqrt()
91+
-- g * w / ||w||
92+
self.gradV:copy(self.v)
93+
self._scale:copy(self.g):cdiv(self._norm)
94+
self.gradV:cmul(self._scale:view(self.viewIn[1], 1)
95+
:expand(self.viewIn[1], self.viewIn[2]))
96+
97+
-- otherwise maintain size of original module weight
98+
self.gradV = self.gradV:view(self.viewOut)
99+
100+
self.weight:copy(self:permuteOut(self.gradV))
101+
self.output:set(self.modules[1]:updateOutput(input))
102+
return self.output
103+
end
104+
105+
function WeightNorm:accGradParameters(input, gradOutput, scale)
106+
scale = scale or 1
107+
self.modules[1]:accGradParameters(input, gradOutput, scale)
108+
109+
self.weight:copy(self:permuteIn(self.weight))
110+
self.gradV:copy(self:permuteIn(self.gradWeight))
111+
self.weight = self.weight:view(self.viewIn)
112+
113+
local norm = self._norm:view(self.viewIn[1], 1):expand(self.viewIn[1], self.viewIn[2])
114+
local scale = self._scale:view(self.viewIn[1], 1):expand(self.viewIn[1], self.viewIn[2])
115+
116+
-- dL / dw * (w / ||w||)
117+
self.weight:copy(self.gradV)
118+
self.weight:cmul(self.v):cdiv(norm)
119+
self.gradG:sum(self.weight, 2)
120+
121+
-- dL / dw * g / ||w||
122+
self.gradV:cmul(scale)
123+
124+
-- dL / dg * (w * g / ||w||^2)
125+
self.weight:copy(self.v):cmul(scale):cdiv(norm)
126+
self.weight:cmul(self.gradG:view(self.viewIn[1], 1)
127+
:expand(self.viewIn[1], self.viewIn[2]))
128+
129+
-- dL / dv update
130+
self.gradV:add(-1, self.weight)
131+
132+
self.gradV = self.gradV:view(self.viewOut)
133+
self.weight = self.weight:view(self.viewOut)
134+
self.gradWeight:copy(self:permuteOut(self.gradV))
135+
end
136+
137+
function WeightNorm:updateGradInput(input, gradOutput)
138+
self.gradInput:set(self.modules[1]:updateGradInput(input, gradOutput))
139+
return self.gradInput
140+
end
141+
142+
function WeightNorm:zeroGradParameters()
143+
self.modules[1]:zeroGradParameters()
144+
self.gradV:zero()
145+
self.gradG:zero()
146+
end
147+
148+
function WeightNorm:updateParameters(lr)
149+
self.modules[1]:updateParameters(lr)
150+
self.g:add(-lr, self.gradG)
151+
self.v:add(-lr, self.gradV)
152+
end
153+
154+
function WeightNorm:parameters()
155+
if self.bias then
156+
return {self.v, self.g, self.bias}, {self.gradV, self.gradG, self.gradBias}
157+
else
158+
return {self.v, self.g}, {self.gradV, self.gradG}
159+
end
160+
end
161+
162+
function WeightNorm:__tostring__()
163+
local str = 'nn.WeightNorm [' .. tostring(self.modules[1]) .. ']'
164+
return str
165+
end

doc/containers.md

+37-26
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ Complex neural networks are easily built using container classes:
88
* [Concat](#nn.Concat) : concatenates in one layer several modules along dimension `dim` ;
99
* [DepthConcat](#nn.DepthConcat) : like Concat, but adds zero-padding when non-`dim` sizes don't match;
1010
* [Bottle](#nn.Bottle) : allows any dimensionality input be forwarded through a module ;
11-
11+
1212
See also the [Table Containers](#nn.TableContainers) for manipulating tables of [Tensors](https://github.com/torch/torch7/blob/master/doc/tensor.md).
1313

1414
<a name="nn.Container"></a>
1515
## Container ##
1616

1717
This is an abstract [Module](module.md#nn.Module) class which declares methods defined in all containers.
18-
It reimplements many of the Module methods such that calls are propagated to the
18+
It reimplements many of the Module methods such that calls are propagated to the
1919
contained modules. For example, a call to [zeroGradParameters](module.md#nn.Module.zeroGradParameters)
2020
will be propagated to all contained modules.
2121

@@ -37,7 +37,7 @@ Returns the number of contained modules.
3737
Sequential provides a means to plug layers together
3838
in a feed-forward fully connected manner.
3939

40-
E.g.
40+
E.g.
4141
creating a one hidden-layer multi-layer perceptron is thus just as easy as:
4242
```lua
4343
mlp = nn.Sequential()
@@ -104,17 +104,17 @@ nn.Sequential {
104104

105105
`module` = `Parallel(inputDimension,outputDimension)`
106106

107-
Creates a container module that applies its `ith` child module to the `ith` slice of the input Tensor by using [select](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-selectdim-index)
107+
Creates a container module that applies its `ith` child module to the `ith` slice of the input Tensor by using [select](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-selectdim-index)
108108
on dimension `inputDimension`. It concatenates the results of its contained modules together along dimension `outputDimension`.
109109

110110
Example:
111111
```lua
112112
mlp = nn.Parallel(2,1); -- Parallel container will associate a module to each slice of dimension 2
113113
-- (column space), and concatenate the outputs over the 1st dimension.
114-
114+
115115
mlp:add(nn.Linear(10,3)); -- Linear module (input 10, output 3), applied on 1st slice of dimension 2
116116
mlp:add(nn.Linear(10,2)) -- Linear module (input 10, output 2), applied on 2nd slice of dimension 2
117-
117+
118118
-- After going through the Linear module the outputs are
119119
-- concatenated along the unique dimension, to form 1D Tensor
120120
> mlp:forward(torch.randn(10,2)) -- of size 5.
@@ -131,8 +131,8 @@ A more complicated example:
131131

132132
mlp = nn.Sequential();
133133
c = nn.Parallel(1,2) -- Parallel container will associate a module to each slice of dimension 1
134-
-- (row space), and concatenate the outputs over the 2nd dimension.
135-
134+
-- (row space), and concatenate the outputs over the 2nd dimension.
135+
136136
for i=1,10 do -- Add 10 Linear+Reshape modules in parallel (input = 3, output = 2x1)
137137
local t=nn.Sequential()
138138
t:add(nn.Linear(3,2)) -- Linear module (input = 3, output = 2)
@@ -165,7 +165,7 @@ for i = 1, 10000 do -- Train for a few iterations
165165
local err = criterion:forward(pred,y)
166166
local gradCriterion = criterion:backward(pred,y);
167167
mlp:zeroGradParameters();
168-
mlp:backward(x, gradCriterion);
168+
mlp:backward(x, gradCriterion);
169169
mlp:updateParameters(0.01);
170170
print(err)
171171
end
@@ -209,16 +209,16 @@ module = nn.DepthConcat(dim)
209209
DepthConcat concatenates the output of one layer of "parallel" modules along the
210210
provided dimension `dim`: they take the same inputs, and their output is
211211
concatenated. For dimensions other than `dim` having different sizes,
212-
the smaller tensors are copied in the center of the output tensor,
212+
the smaller tensors are copied in the center of the output tensor,
213213
effectively padding the borders with zeros.
214214

215-
The module is particularly useful for concatenating the output of [Convolutions](convolution.md)
216-
along the depth dimension (i.e. `nOutputFrame`).
217-
This is used to implement the *DepthConcat* layer
215+
The module is particularly useful for concatenating the output of [Convolutions](convolution.md)
216+
along the depth dimension (i.e. `nOutputFrame`).
217+
This is used to implement the *DepthConcat* layer
218218
of the [Going deeper with convolutions](http://arxiv.org/pdf/1409.4842v1.pdf) article.
219-
The normal [Concat](#nn.Concat) Module can't be used since the spatial
220-
dimensions (height and width) of the output Tensors requiring concatenation
221-
may have different values. To deal with this, the output uses the largest
219+
The normal [Concat](#nn.Concat) Module can't be used since the spatial
220+
dimensions (height and width) of the output Tensors requiring concatenation
221+
may have different values. To deal with this, the output uses the largest
222222
spatial dimensions and adds zero-padding around the smaller Tensors.
223223
```lua
224224
inputSize = 3
@@ -231,7 +231,7 @@ mlp:add(nn.SpatialConvolutionMM(inputSize, outputSize, 3, 3))
231231
mlp:add(nn.SpatialConvolutionMM(inputSize, outputSize, 4, 4))
232232

233233
> print(mlp:forward(input))
234-
(1,.,.) =
234+
(1,.,.) =
235235
-0.2874 0.6255 1.1122 0.4768 0.9863 -0.2201 -0.1516
236236
0.2779 0.9295 1.1944 0.4457 1.1470 0.9693 0.1654
237237
-0.5769 -0.4730 0.3283 0.6729 1.3574 -0.6610 0.0265
@@ -240,7 +240,7 @@ mlp:add(nn.SpatialConvolutionMM(inputSize, outputSize, 4, 4))
240240
0.4147 0.5062 0.6251 0.4374 0.3252 0.3478 0.0046
241241
0.7845 -0.0902 0.3499 0.0342 1.0706 -0.0605 0.5525
242242

243-
(2,.,.) =
243+
(2,.,.) =
244244
-0.7351 -0.9327 -0.3092 -1.3395 -0.4596 -0.6377 -0.5097
245245
-0.2406 -0.2617 -0.3400 -0.4339 -0.3648 0.1539 -0.2961
246246
-0.7124 -1.2228 -0.2632 0.1690 0.4836 -0.9469 -0.7003
@@ -249,7 +249,7 @@ mlp:add(nn.SpatialConvolutionMM(inputSize, outputSize, 4, 4))
249249
-0.3086 -0.0298 -0.2031 0.1026 -0.5785 -0.3275 -0.1630
250250
0.0596 -0.6097 0.1443 -0.8603 -0.2774 -0.4506 -0.5367
251251

252-
(3,.,.) =
252+
(3,.,.) =
253253
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
254254
0.0000 -0.7326 0.3544 0.1821 0.4796 1.0164 0.0000
255255
0.0000 -0.9195 -0.0567 -0.1947 0.0169 0.1924 0.0000
@@ -258,7 +258,7 @@ mlp:add(nn.SpatialConvolutionMM(inputSize, outputSize, 4, 4))
258258
0.0000 -0.1911 0.2912 0.5092 0.2955 0.7171 0.0000
259259
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
260260

261-
(4,.,.) =
261+
(4,.,.) =
262262
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
263263
0.0000 -0.8263 0.3646 0.6750 0.2062 0.2785 0.0000
264264
0.0000 -0.7572 0.0432 -0.0821 0.4871 1.9506 0.0000
@@ -267,7 +267,7 @@ mlp:add(nn.SpatialConvolutionMM(inputSize, outputSize, 4, 4))
267267
0.0000 0.2570 0.4694 -0.1262 0.5602 0.0821 0.0000
268268
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
269269

270-
(5,.,.) =
270+
(5,.,.) =
271271
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
272272
0.0000 0.3158 0.4389 -0.0485 -0.2179 0.0000 0.0000
273273
0.0000 0.1966 0.6185 -0.9563 -0.3365 0.0000 0.0000
@@ -276,7 +276,7 @@ mlp:add(nn.SpatialConvolutionMM(inputSize, outputSize, 4, 4))
276276
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
277277
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
278278

279-
(6,.,.) =
279+
(6,.,.) =
280280
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
281281
0.0000 1.1148 0.2324 -0.1093 0.5024 0.0000 0.0000
282282
0.0000 -0.2624 -0.5863 0.3444 0.3506 0.0000 0.0000
@@ -286,11 +286,11 @@ mlp:add(nn.SpatialConvolutionMM(inputSize, outputSize, 4, 4))
286286
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
287287
[torch.DoubleTensor of dimension 6x7x7]
288288
```
289-
Note how the last 2 of 6 filter maps have 1 column of zero-padding
290-
on the left and top, as well as 2 on the right and bottom.
289+
Note how the last 2 of 6 filter maps have 1 column of zero-padding
290+
on the left and top, as well as 2 on the right and bottom.
291291
This is inevitable when the component
292-
module output tensors non-`dim` sizes aren't all odd or even.
293-
Such that in order to keep the mappings aligned, one need
292+
module output tensors non-`dim` sizes aren't all odd or even.
293+
Such that in order to keep the mappings aligned, one need
294294
only ensure that these be all odd (or even).
295295

296296
<a name="nn.Bottle"></a>
@@ -323,6 +323,17 @@ mlp = nn.Bottle(nn.Linear(10, 2))
323323
[torch.LongStorage of size 4]
324324
```
325325

326+
<a name="nn.WeightNorm"></a>
327+
## Weight Normalization
328+
329+
```lua
330+
module = nn.WeightNorm(module)
331+
```
332+
333+
WeightNorm implements the reparametrization presented in [Weight Normalization](https://arxiv.org/pdf/1602.07868v3.pdf), which decouples the length of neural network weight vectors from their direction. The weight vectors `w` is determined instead by parameters `g` and `v` such that `w = g * v / ||v||`, where `||v||` is the euclidean norm of vector v. This container can wrap nn layers with weights.
334+
335+
It accepts a parameter ``outputDim`` that represents the output dimension of the module weight it wraps, which defaults to 1. If the outputDim is not 1, the container will transpose the weight appropriately. If the module weight is not 2D, the container will view the weight into an appropriate 2D shape based on the outputDim specified by the user.
336+
326337
<a name="nn.TableContainers"></a>
327338
## Table Containers ##
328339
While the above containers are used for manipulating input [Tensors](https://github.com/torch/torch7/blob/master/doc/tensor.md), table containers are used for manipulating tables :

init.lua

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ require('nn.Parallel')
1616
require('nn.Sequential')
1717
require('nn.DepthConcat')
1818
require('nn.Bottle')
19+
require('nn.WeightNorm')
1920

2021
require('nn.Linear')
2122
require('nn.Bilinear')

0 commit comments

Comments
 (0)