Skip to content

Commit fa4f363

Browse files
DmitryUlyanovsoumith
authored andcommitted
Instance norm (pytorch#1283)
* instance norm * fix whitespaces * whitespaces * docs * "C" letter was cyrillic in docs, fixed * remove force_eval, fix non contiguous case
1 parent aab30d4 commit fa4f363

File tree

4 files changed

+244
-2
lines changed

4 files changed

+244
-2
lines changed

docs/source/nn.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,23 @@ Normalization layers
302302
.. autoclass:: BatchNorm3d
303303
:members:
304304

305+
:hidden:`InstanceNorm1d`
306+
~~~~~~~~~~~~~~~~~~~~~
307+
308+
.. autoclass:: InstanceNorm1d
309+
:members:
310+
311+
:hidden:`InstanceNorm2d`
312+
~~~~~~~~~~~~~~~~~~~~~
313+
314+
.. autoclass:: InstanceNorm2d
315+
:members:
316+
317+
:hidden:`InstanceNorm3d`
318+
~~~~~~~~~~~~~~~~~~~~~
319+
320+
.. autoclass:: InstanceNorm3d
321+
:members:
305322

306323
Recurrent layers
307324
----------------------------------

test/test_nn.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,63 @@ def test_Dropout3d(self):
756756
input = torch.Tensor(num_features, b, d, w, h)
757757
self._test_dropout(nn.Dropout3d, input)
758758

759+
def _test_InstanceNorm(self, cls, input):
760+
b, c = input.size(0), input.size(1)
761+
input_var = Variable(input)
762+
763+
IN = cls(c, eps=0)
764+
765+
output = IN(input_var)
766+
out_reshaped = output.transpose(1, 0).contiguous().view(c, -1)
767+
768+
mean = out_reshaped.mean(1)
769+
var = out_reshaped.var(1, unbiased=False)
770+
771+
self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5)
772+
self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5)
773+
774+
# If momentum==1 running_mean/var should be
775+
# equal to mean/var of the input
776+
IN = cls(c, momentum=1, eps=0)
777+
778+
output = IN(input_var)
779+
780+
input_reshaped = input_var.transpose(1, 0).contiguous().view(c, -1)
781+
mean = input_reshaped.mean(1)
782+
783+
input_reshaped = input_var.transpose(1, 0).contiguous().view(c, b, -1)
784+
var = input_reshaped.var(2, unbiased=True)[:, :]
785+
786+
self.assertAlmostEqual(torch.abs(mean.data - IN.running_mean).mean(), 0, delta=1e-5)
787+
self.assertAlmostEqual(torch.abs(var.data.mean(1) - IN.running_var).mean(), 0, delta=1e-5)
788+
789+
def test_InstanceNorm2d(self):
790+
b = random.randint(3, 5)
791+
c = random.randint(1, 5)
792+
w = random.randint(2, 5)
793+
h = random.randint(2, 5)
794+
795+
input = torch.Tensor(b, c, h, w).uniform_()
796+
self._test_InstanceNorm(nn.InstanceNorm2d, input)
797+
798+
def test_InstanceNorm1d(self):
799+
b = random.randint(3, 5)
800+
c = random.randint(1, 5)
801+
d = random.randint(2, 5)
802+
803+
input = torch.Tensor(b, c, d).uniform_()
804+
self._test_InstanceNorm(nn.InstanceNorm1d, input)
805+
806+
def test_InstanceNorm3d(self):
807+
b = random.randint(3, 5)
808+
c = random.randint(1, 5)
809+
w = random.randint(2, 5)
810+
h = random.randint(2, 5)
811+
d = random.randint(2, 5)
812+
813+
input = torch.Tensor(b, c, h, w, d).uniform_()
814+
self._test_InstanceNorm(nn.InstanceNorm3d, input)
815+
759816
def test_pad(self):
760817
inputs = Variable(torch.randn(1, 3, 4, 4), requires_grad=True)
761818
self.assertTrue(gradcheck(lambda x: F.pad(x, (1, 1, 1, 1)), (inputs,)))

torch/nn/modules/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, LPPool2d, AdaptiveMaxPool1d, \
1515
AdaptiveMaxPool2d, AdaptiveAvgPool1d, AdaptiveAvgPool2d
1616
from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d
17+
from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d
1718
from .dropout import Dropout, Dropout2d, Dropout3d
1819
from .padding import ReflectionPad2d, ReplicationPad2d, ReplicationPad3d
1920
from .normalization import CrossMapLRN2d
@@ -36,8 +37,9 @@
3637
'SoftMarginLoss', 'CrossEntropyLoss', 'Container', 'Sequential', 'ModuleList',
3738
'ParameterList', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
3839
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d',
39-
'LPPool2d', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'Dropout', 'Dropout2d',
40-
'Dropout3d', 'ReflectionPad2d', 'ReplicationPad2d', 'ReplicationPad3d', 'CrossMapLRN2d',
40+
'LPPool2d', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d', 'InstanceNorm2d',
41+
'InstanceNorm3d', 'Dropout', 'Dropout2d', 'Dropout3d', 'ReflectionPad2d',
42+
'ReplicationPad2d', 'ReplicationPad3d', 'CrossMapLRN2d',
4143
'Embedding', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell',
4244
'PixelShuffle', 'UpsamplingNearest2d', 'UpsamplingBilinear2d', 'PairwiseDistance',
4345
'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d',

torch/nn/modules/instancenorm.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from .batchnorm import _BatchNorm
2+
from .. import functional as F
3+
4+
5+
class _InstanceNorm(_BatchNorm):
6+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False):
7+
super(_InstanceNorm, self).__init__(
8+
num_features, eps, momentum, affine)
9+
10+
def forward(self, input):
11+
self._check_input_dim(input)
12+
13+
b, c = input.size(0), input.size(1)
14+
15+
# Repeat stored stats and affine transform params
16+
running_mean = self.running_mean.repeat(b)
17+
running_var = self.running_var.repeat(b)
18+
19+
weight, bias = None, None
20+
if self.affine:
21+
weight = self.weight.repeat(b)
22+
bias = self.bias.repeat(b)
23+
24+
# Apply instance norm
25+
input_reshaped = input.contiguous().view(1, b * c, *input.size()[2:])
26+
27+
out = F.batch_norm(
28+
input_reshaped, running_mean, running_var, weight, bias,
29+
self.training, self.momentum, self.eps)
30+
31+
# Reshape back
32+
self.running_mean.copy_(running_mean.view(b, c).mean(0))
33+
self.running_var.copy_(running_var.view(b, c).mean(0))
34+
35+
return out.view(b, c, *input.size()[2:])
36+
37+
def eval(self):
38+
return self
39+
40+
41+
class InstanceNorm1d(_InstanceNorm):
42+
r"""Applies Instance Normalization over a 2d or 3d input that is seen as a mini-batch.
43+
44+
.. math::
45+
46+
y = \frac{x - mean[x]}{ \sqrt{Var[x]} + \epsilon} * gamma + beta
47+
48+
The mean and standard-deviation are calculated per-dimension separately
49+
for each object in a mini-batch. Gamma and beta are learnable parameter vectors
50+
of size C (where C is the input size).
51+
52+
During training, this layer keeps a running estimate of its computed mean
53+
and variance. The running sum is kept with a default momentum of 0.1.
54+
55+
At evaluation time (`.eval()`), the default behaviour of the InstanceNorm module stays the same
56+
i.e. running mean/variance is NOT used for normalization. One can force using stored
57+
mean and variance with `.train(False)` method.
58+
59+
Args:
60+
num_features: num_features from an expected input of size `batch_size x num_features x width`
61+
eps: a value added to the denominator for numerical stability. Default: 1e-5
62+
momentum: the value used for the running_mean and running_var computation. Default: 0.1
63+
affine: a boolean value that when set to true, gives the layer learnable affine parameters.
64+
65+
Shape:
66+
- Input: :math:`(N, C, L)`
67+
- Output: :math:`(N, C, L)` (same shape as input)
68+
69+
Examples:
70+
>>> # With Learnable Parameters
71+
>>> m = nn.InstanceNorm1d(100)
72+
>>> # Without Learnable Parameters
73+
>>> m = nn.InstanceNorm1d(100, affine=False)
74+
>>> input = autograd.Variable(torch.randn(20, 100))
75+
>>> output = m(input)
76+
"""
77+
78+
def _check_input_dim(self, input):
79+
if input.dim() != 3:
80+
raise ValueError('expected 2D or 3D input (got {}D input)'
81+
.format(input.dim()))
82+
super(InstanceNorm1d, self)._check_input_dim(input)
83+
84+
85+
class InstanceNorm2d(_InstanceNorm):
86+
r"""Applies Instance Normalization over a 4d input that is seen as a mini-batch of 3d inputs
87+
.. math::
88+
y = \frac{x - mean[x]}{ \sqrt{Var[x]} + \epsilon} * gamma + beta
89+
The mean and standard-deviation are calculated per-dimension separately
90+
for each object in a mini-batch. Gamma and beta are learnable parameter vectors
91+
of size C (where C is the input size).
92+
93+
During training, this layer keeps a running estimate of its computed mean
94+
and variance. The running sum is kept with a default momentum of 0.1.
95+
96+
At evaluation time (`.eval()`), the default behaviour of the InstanceNorm module stays the same
97+
i.e. running mean/variance is NOT used for normalization. One can force using stored
98+
mean and variance with `.train(False)` method.
99+
100+
Args:
101+
num_features: num_features from an expected input of size batch_size x num_features x height x width
102+
eps: a value added to the denominator for numerical stability. Default: 1e-5
103+
momentum: the value used for the running_mean and running_var computation. Default: 0.1
104+
affine: a boolean value that when set to true, gives the layer learnable affine parameters.
105+
Shape:
106+
- Input: :math:`(N, C, H, W)`
107+
- Output: :math:`(N, C, H, W)` (same shape as input)
108+
Examples:
109+
>>> # With Learnable Parameters
110+
>>> m = nn.InstanceNorm2d(100)
111+
>>> # Without Learnable Parameters
112+
>>> m = nn.InstanceNorm2d(100, affine=False)
113+
>>> input = autograd.Variable(torch.randn(20, 100, 35, 45))
114+
>>> output = m(input)
115+
"""
116+
117+
def _check_input_dim(self, input):
118+
if input.dim() != 4:
119+
raise ValueError('expected 4D input (got {}D input)'
120+
.format(input.dim()))
121+
super(InstanceNorm2d, self)._check_input_dim(input)
122+
123+
124+
class InstanceNorm3d(_InstanceNorm):
125+
r"""Applies Instance Normalization over a 5d input that is seen as a mini-batch of 4d inputs
126+
127+
.. math::
128+
129+
y = \frac{x - mean[x]}{ \sqrt{Var[x]} + \epsilon} * gamma + beta
130+
131+
The mean and standard-deviation are calculated per-dimension separately for each object in a mini-batch.
132+
Gamma and beta are learnable parameter vectors
133+
of size C (where C is the input size).
134+
135+
During training, this layer keeps a running estimate of its computed mean
136+
and variance. The running sum is kept with a default momentum of 0.1.
137+
138+
At evaluation time (`.eval()`), the default behaviour of the InstanceNorm module stays the same
139+
i.e. running mean/variance is NOT used for normalization. One can force using stored
140+
mean and variance with `.train(False)` method.
141+
142+
143+
Args:
144+
num_features: num_features from an expected input of size batch_size x num_features x depth x height x width
145+
eps: a value added to the denominator for numerical stability. Default: 1e-5
146+
momentum: the value used for the running_mean and running_var computation. Default: 0.1
147+
affine: a boolean value that when set to true, gives the layer learnable affine parameters.
148+
149+
Shape:
150+
- Input: :math:`(N, C, D, H, W)`
151+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
152+
153+
Examples:
154+
>>> # With Learnable Parameters
155+
>>> m = nn.InstanceNorm3d(100)
156+
>>> # Without Learnable Parameters
157+
>>> m = nn.InstanceNorm3d(100, affine=False)
158+
>>> input = autograd.Variable(torch.randn(20, 100, 35, 45, 10))
159+
>>> output = m(input)
160+
"""
161+
162+
def _check_input_dim(self, input):
163+
if input.dim() != 5:
164+
raise ValueError('expected 5D input (got {}D input)'
165+
.format(input.dim()))
166+
super(InstanceNorm3d, self)._check_input_dim(input)

0 commit comments

Comments
 (0)