11import argparse
2- import copy
32import logging
43import math
54import os
65import random
76from pathlib import Path
8- from typing import Optional
7+ from typing import Iterable , Optional
98
109import numpy as np
1110import torch
@@ -234,25 +233,17 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
234233}
235234
236235
236+ # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
237237class EMAModel :
238238 """
239239 Exponential Moving Average of models weights
240240 """
241241
242- def __init__ (
243- self ,
244- model ,
245- decay = 0.9999 ,
246- device = None ,
247- ):
248- self .averaged_model = copy .deepcopy (model ).eval ()
249- self .averaged_model .requires_grad_ (False )
242+ def __init__ (self , parameters : Iterable [torch .nn .Parameter ], decay = 0.9999 ):
243+ parameters = list (parameters )
244+ self .shadow_params = [p .clone ().detach () for p in parameters ]
250245
251246 self .decay = decay
252-
253- if device is not None :
254- self .averaged_model = self .averaged_model .to (device = device )
255-
256247 self .optimization_step = 0
257248
258249 def get_decay (self , optimization_step ):
@@ -263,34 +254,47 @@ def get_decay(self, optimization_step):
263254 return 1 - min (self .decay , value )
264255
265256 @torch .no_grad ()
266- def step (self , new_model ):
267- ema_state_dict = self . averaged_model . state_dict ( )
257+ def step (self , parameters ):
258+ parameters = list ( parameters )
268259
269260 self .optimization_step += 1
270261 self .decay = self .get_decay (self .optimization_step )
271262
272- for key , param in new_model .named_parameters ():
273- if isinstance (param , dict ):
274- continue
275- try :
276- ema_param = ema_state_dict [key ]
277- except KeyError :
278- ema_param = param .float ().clone () if param .ndim == 1 else copy .deepcopy (param )
279- ema_state_dict [key ] = ema_param
280-
281- param = param .clone ().detach ().to (ema_param .dtype ).to (ema_param .device )
282-
263+ for s_param , param in zip (self .shadow_params , parameters ):
283264 if param .requires_grad :
284- ema_state_dict [key ].sub_ (self .decay * (ema_param - param ))
265+ tmp = self .decay * (s_param - param )
266+ s_param .sub_ (tmp )
285267 else :
286- ema_state_dict [key ].copy_ (param )
287-
288- for key , param in new_model .named_buffers ():
289- ema_state_dict [key ] = param
268+ s_param .copy_ (param )
290269
291- self .averaged_model .load_state_dict (ema_state_dict , strict = False )
292270 torch .cuda .empty_cache ()
293271
272+ def copy_to (self , parameters : Iterable [torch .nn .Parameter ]) -> None :
273+ """
274+ Copy current averaged parameters into given collection of parameters.
275+
276+ Args:
277+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
278+ updated with the stored moving averages. If `None`, the
279+ parameters with which this `ExponentialMovingAverage` was
280+ initialized will be used.
281+ """
282+ parameters = list (parameters )
283+ for s_param , param in zip (self .shadow_params , parameters ):
284+ param .data .copy_ (s_param .data )
285+
286+ def to (self , device = None , dtype = None ) -> None :
287+ r"""Move internal buffers of the ExponentialMovingAverage to `device`.
288+
289+ Args:
290+ device: like `device` argument to `torch.Tensor.to`
291+ """
292+ # .to() on the tensors handles None correctly
293+ self .shadow_params = [
294+ p .to (device = device , dtype = dtype ) if p .is_floating_point () else p .to (device = device )
295+ for p in self .shadow_params
296+ ]
297+
294298
295299def main ():
296300 args = parse_args ()
@@ -336,9 +340,6 @@ def main():
336340 vae = AutoencoderKL .from_pretrained (args .pretrained_model_name_or_path , subfolder = "vae" )
337341 unet = UNet2DConditionModel .from_pretrained (args .pretrained_model_name_or_path , subfolder = "unet" )
338342
339- if args .use_ema :
340- ema_unet = EMAModel (unet )
341-
342343 # Freeze vae and text_encoder
343344 vae .requires_grad_ (False )
344345 text_encoder .requires_grad_ (False )
@@ -510,8 +511,9 @@ def collate_fn(examples):
510511 text_encoder .to (accelerator .device , dtype = weight_dtype )
511512 vae .to (accelerator .device , dtype = weight_dtype )
512513
513- # Move the ema_unet to gpu.
514- ema_unet .averaged_model .to (accelerator .device )
514+ # Create EMA for the unet.
515+ if args .use_ema :
516+ ema_unet = EMAModel (unet .parameters ())
515517
516518 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
517519 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
@@ -583,7 +585,7 @@ def collate_fn(examples):
583585 # Checks if the accelerator has performed an optimization step behind the scenes
584586 if accelerator .sync_gradients :
585587 if args .use_ema :
586- ema_unet .step (unet )
588+ ema_unet .step (unet . parameters () )
587589 progress_bar .update (1 )
588590 global_step += 1
589591 accelerator .log ({"train_loss" : train_loss }, step = global_step )
@@ -598,10 +600,14 @@ def collate_fn(examples):
598600 # Create the pipeline using the trained modules and save it.
599601 accelerator .wait_for_everyone ()
600602 if accelerator .is_main_process :
603+ unet = accelerator .unwrap_model (unet )
604+ if args .use_ema :
605+ ema_unet .copy_to (unet .parameters ())
606+
601607 pipeline = StableDiffusionPipeline (
602608 text_encoder = text_encoder ,
603609 vae = vae ,
604- unet = accelerator . unwrap_model ( ema_unet . averaged_model if args . use_ema else unet ) ,
610+ unet = unet ,
605611 tokenizer = tokenizer ,
606612 scheduler = PNDMScheduler (
607613 beta_start = 0.00085 , beta_end = 0.012 , beta_schedule = "scaled_linear" , skip_prk_steps = True
0 commit comments