Skip to content

Commit db3f45a

Browse files
author
Tete Xiao
committed
add new bn statistics & results
1 parent 56b0058 commit db3f45a

File tree

5 files changed

+33
-12
lines changed

5 files changed

+33
-12
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ IMPORTANT: We use our self-trained base model on ImageNet. The model takes the i
6767
</tr>
6868
<tr>
6969
<td rowspan="2">ResNet-50_dilated8 + psp_bilinear_deepsup</td>
70-
<td>No</td><td>40.60</td><td>79.66</td><td>60.13</td>
70+
<td>No</td><td>41.26</td><td>79.73</td><td>60.50</td>
7171
<td rowspan="2">33.4 hours</td>
7272
</tr>
7373
<tr>
74-
<td>Yes</td><td>41.31</td><td>80.14</td><td>60.73</td>
74+
<td>Yes</td><td>42.04</td><td>80.23</td><td>61.14</td>
7575
</tr>
7676
<tr>
7777
<td>ResNet-101_dilated8 + c1_bilinear_deepsup</td>

lib/nn/modules/batchnorm.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _unsqueeze_ft(tensor):
3636

3737

3838
class _SynchronizedBatchNorm(_BatchNorm):
39-
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
39+
def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True):
4040
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
4141

4242
self._sync_master = SyncMaster(self._data_parallel_master)
@@ -45,6 +45,14 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
4545
self._parallel_id = None
4646
self._slave_pipe = None
4747

48+
# customed batch norm statistics
49+
self._iter = 1
50+
self._moving_average_fraction = 1. - momentum
51+
self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features))
52+
self.register_buffer('_tmp_running_var', torch.ones(self.num_features))
53+
self._tmp_running_mean = self.running_mean.clone()
54+
self._tmp_running_var = self.running_var.clone()
55+
4856
def forward(self, input):
4957
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
5058
if not (self._is_parallel and self.training):
@@ -108,6 +116,10 @@ def _data_parallel_master(self, intermediates):
108116

109117
return outputs
110118

119+
def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0):
120+
"""return *dest* by `dest := dest*alpha + delta*beta + bias`"""
121+
return dest * alpha + delta * beta + bias
122+
111123
def _compute_mean_std(self, sum_, ssum, size):
112124
"""Compute the mean and standard-deviation with sum and square-sum. This method
113125
also maintains the moving average on the master device."""
@@ -117,8 +129,12 @@ def _compute_mean_std(self, sum_, ssum, size):
117129
unbias_var = sumvar / (size - 1)
118130
bias_var = sumvar / size
119131

120-
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
121-
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
132+
self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction)
133+
self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction)
134+
self._iter = self._add_weighted(self._iter, 1, alpha=self._moving_average_fraction)
135+
136+
self.running_mean = self._tmp_running_mean / self._iter
137+
self.running_var = self._tmp_running_var / self._iter
122138

123139
return mean, bias_var.clamp(self.eps) ** -0.5
124140

models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def build_encoder(self, arch='resnet50_dilated8', fc_dim=512, weights=''):
105105
if len(weights) > 0:
106106
print('Loading weights for net_encoder')
107107
net_encoder.load_state_dict(
108-
torch.load(weights, map_location=lambda storage, loc: storage))
108+
torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
109109
return net_encoder
110110

111111
def build_decoder(self, arch='psp_bilinear_deepsup',
@@ -138,7 +138,7 @@ def build_decoder(self, arch='psp_bilinear_deepsup',
138138
if len(weights) > 0:
139139
print('Loading weights for net_decoder')
140140
net_decoder.load_state_dict(
141-
torch.load(weights, map_location=lambda storage, loc: storage))
141+
torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
142142
return net_decoder
143143

144144

resnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def resnet50(pretrained=False, **kwargs):
195195
"""
196196
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
197197
if pretrained:
198-
model.load_state_dict(load_url(model_urls['resnet50']))
198+
model.load_state_dict(load_url(model_urls['resnet50']), strict=False)
199199
return model
200200

201201

@@ -207,7 +207,7 @@ def resnet101(pretrained=False, **kwargs):
207207
"""
208208
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
209209
if pretrained:
210-
model.load_state_dict(load_url(model_urls['resnet101']))
210+
model.load_state_dict(load_url(model_urls['resnet101']), strict=False)
211211
return model
212212

213213
# def resnet152(pretrained=False, **kwargs):

train.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,14 @@ def checkpoint(nets, history, args, epoch_num):
8383
dict_encoder = net_encoder.state_dict()
8484
dict_decoder = net_decoder.state_dict()
8585

86+
dict_encoder_save = {k: v for k, v in dict_encoder.items() if not (k.endswith('_tmp_running_mean') or k.endswith('tmp_running_var'))}
87+
dict_decoder_save = {k: v for k, v in dict_decoder.items() if not (k.endswith('_tmp_running_mean') or k.endswith('tmp_running_var'))}
88+
8689
torch.save(history,
8790
'{}/history_{}'.format(args.ckpt, suffix_latest))
88-
torch.save(dict_encoder,
91+
torch.save(dict_encoder_save,
8992
'{}/encoder_{}'.format(args.ckpt, suffix_latest))
90-
torch.save(dict_decoder,
93+
torch.save(dict_decoder_save,
9194
'{}/decoder_{}'.format(args.ckpt, suffix_latest))
9295

9396

@@ -174,7 +177,7 @@ def main(args):
174177
# Main loop
175178
history = {'train': {'epoch': [], 'loss': [], 'acc': []}}
176179

177-
for epoch in range(1, args.num_epoch + 1):
180+
for epoch in range(args.start_epoch, args.num_epoch + 1):
178181
train(segmentation_module, iterator_train, optimizers, history, epoch, args)
179182

180183
# checkpointing
@@ -214,6 +217,8 @@ def main(args):
214217
help='input batch size')
215218
parser.add_argument('--num_epoch', default=20, type=int,
216219
help='epochs to train for')
220+
parser.add_argument('--start_epoch', default=1, type=int,
221+
help='epoch to start training. useful if continue from a checkpoint')
217222
parser.add_argument('--epoch_iters', default=5000, type=int,
218223
help='iterations of each epoch (irrelevant to batch size)')
219224
parser.add_argument('--optim', default='SGD', help='optimizer')

0 commit comments

Comments
 (0)