Skip to content

Commit 8e46d97

Browse files
Add missing restore() EMA call in train SDXL script (huggingface#7599)
* Restore unet params back to normal from EMA when validation call is finished * empty commit --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 7e808e7 commit 8e46d97

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,6 +1252,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
12521252
del pipeline
12531253
torch.cuda.empty_cache()
12541254

1255+
if args.use_ema:
1256+
# Switch back to the original UNet parameters.
1257+
ema_unet.restore(unet.parameters())
1258+
12551259
accelerator.wait_for_everyone()
12561260
if accelerator.is_main_process:
12571261
unet = unwrap_model(unet)

0 commit comments

Comments
 (0)