@@ -36,7 +36,7 @@ def _unsqueeze_ft(tensor):
3636
3737
3838class _SynchronizedBatchNorm (_BatchNorm ):
39- def __init__ (self , num_features , eps = 1e-5 , momentum = 0.1 , affine = True ):
39+ def __init__ (self , num_features , eps = 1e-5 , momentum = 0.001 , affine = True ):
4040 super (_SynchronizedBatchNorm , self ).__init__ (num_features , eps = eps , momentum = momentum , affine = affine )
4141
4242 self ._sync_master = SyncMaster (self ._data_parallel_master )
@@ -45,6 +45,14 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
4545 self ._parallel_id = None
4646 self ._slave_pipe = None
4747
48+ # customed batch norm statistics
49+ self ._iter = 1
50+ self ._moving_average_fraction = 1. - momentum
51+ self .register_buffer ('_tmp_running_mean' , torch .zeros (self .num_features ))
52+ self .register_buffer ('_tmp_running_var' , torch .ones (self .num_features ))
53+ self ._tmp_running_mean = self .running_mean .clone ()
54+ self ._tmp_running_var = self .running_var .clone ()
55+
4856 def forward (self , input ):
4957 # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
5058 if not (self ._is_parallel and self .training ):
@@ -108,6 +116,10 @@ def _data_parallel_master(self, intermediates):
108116
109117 return outputs
110118
119+ def _add_weighted (self , dest , delta , alpha = 1 , beta = 1 , bias = 0 ):
120+ """return *dest* by `dest := dest*alpha + delta*beta + bias`"""
121+ return dest * alpha + delta * beta + bias
122+
111123 def _compute_mean_std (self , sum_ , ssum , size ):
112124 """Compute the mean and standard-deviation with sum and square-sum. This method
113125 also maintains the moving average on the master device."""
@@ -117,8 +129,12 @@ def _compute_mean_std(self, sum_, ssum, size):
117129 unbias_var = sumvar / (size - 1 )
118130 bias_var = sumvar / size
119131
120- self .running_mean = (1 - self .momentum ) * self .running_mean + self .momentum * mean .data
121- self .running_var = (1 - self .momentum ) * self .running_var + self .momentum * unbias_var .data
132+ self ._tmp_running_mean = self ._add_weighted (self ._tmp_running_mean , mean .data , alpha = self ._moving_average_fraction )
133+ self ._tmp_running_var = self ._add_weighted (self ._tmp_running_var , unbias_var .data , alpha = self ._moving_average_fraction )
134+ self ._iter = self ._add_weighted (self ._iter , 1 , alpha = self ._moving_average_fraction )
135+
136+ self .running_mean = self ._tmp_running_mean / self ._iter
137+ self .running_var = self ._tmp_running_var / self ._iter
122138
123139 return mean , bias_var .clamp (self .eps ) ** - 0.5
124140
0 commit comments