Skip to content

Commit a6e2c1f

Browse files
authored
Fix ema decay (huggingface#1868)
* Fix ema decay and clarify nomenclature. * Rename var.
1 parent b28ab30 commit a6e2c1f

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

examples/text_to_image/train_text_to_image.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -278,24 +278,19 @@ def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
278278
self.decay = decay
279279
self.optimization_step = 0
280280

281-
def get_decay(self, optimization_step):
282-
"""
283-
Compute the decay factor for the exponential moving average.
284-
"""
285-
value = (1 + optimization_step) / (10 + optimization_step)
286-
return 1 - min(self.decay, value)
287-
288281
@torch.no_grad()
289282
def step(self, parameters):
290283
parameters = list(parameters)
291284

292285
self.optimization_step += 1
293-
self.decay = self.get_decay(self.optimization_step)
286+
287+
# Compute the decay factor for the exponential moving average.
288+
value = (1 + self.optimization_step) / (10 + self.optimization_step)
289+
one_minus_decay = 1 - min(self.decay, value)
294290

295291
for s_param, param in zip(self.shadow_params, parameters):
296292
if param.requires_grad:
297-
tmp = self.decay * (s_param - param)
298-
s_param.sub_(tmp)
293+
s_param.sub_(one_minus_decay * (s_param - param))
299294
else:
300295
s_param.copy_(param)
301296

0 commit comments

Comments
 (0)