Skip to content

Commit db7b7bd

Browse files
authored
[Train unconditional] Unwrap model before EMA (huggingface#1469)
1 parent 6a0a312 commit db7b7bd

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,12 @@ def transforms(examples):
320320

321321
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
322322

323-
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
323+
ema_model = EMAModel(
324+
accelerator.unwrap_model(model),
325+
inv_gamma=args.ema_inv_gamma,
326+
power=args.ema_power,
327+
max_value=args.ema_max_decay,
328+
)
324329

325330
# Handle the repository creation
326331
if accelerator.is_main_process:

0 commit comments

Comments
 (0)