1515from colossalai .core import global_context as gpc
1616from colossalai .logging import disable_existing_loggers , get_dist_logger
1717from colossalai .nn .optimizer .gemini_optimizer import GeminiAdamOptimizer
18- from colossalai .nn .parallel .utils import convert_to_torch_module
19- from colossalai .tensor import ProcessGroup
18+ from colossalai .nn .parallel .utils import get_static_torch_model
2019from colossalai .utils import get_current_device
2120from colossalai .utils .model .colo_init_context import ColoInitContext
2221from diffusers import AutoencoderKL , DDPMScheduler , DiffusionPipeline , UNet2DConditionModel
@@ -356,26 +355,17 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
356355
357356
358357# Gemini + ZeRO DDP
359- def gemini_zero_dpp (model : torch .nn .Module , pg : ProcessGroup , placememt_policy : str = "auto" ):
358+ def gemini_zero_dpp (model : torch .nn .Module , placememt_policy : str = "auto" ):
360359 from colossalai .nn .parallel import GeminiDDP
361360
362361 model = GeminiDDP (
363- model , device = get_current_device (), placement_policy = placememt_policy , pin_memory = True , search_range_mb = 32
362+ model , device = get_current_device (), placement_policy = placememt_policy , pin_memory = True , search_range_mb = 64
364363 )
365364 return model
366365
367366
368367def main (args ):
369- # config for colossalai
370-
371- config = {
372- "BATCH" : args .train_batch_size ,
373- "gradient_accumulation_steps" : args .gradient_accumulation_steps ,
374- "clip_grad_norm" : args .max_grad_norm ,
375- }
376-
377- colossalai .launch_from_torch (config = config )
378- pg = ProcessGroup ()
368+ colossalai .launch_from_torch (config = {})
379369
380370 if args .seed is not None :
381371 gpc .set_seed (args .seed )
@@ -472,7 +462,7 @@ def main(args):
472462 )
473463
474464 logger .info (f"Loading UNet2DConditionModel from { args .pretrained_model_name_or_path } " , ranks = [0 ])
475- with ColoInitContext ():
465+ with ColoInitContext (device = get_current_device () ):
476466 unet = UNet2DConditionModel .from_pretrained (
477467 args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .revision , low_cpu_mem_usage = False
478468 )
@@ -484,12 +474,19 @@ def main(args):
484474 unet .enable_gradient_checkpointing ()
485475
486476 if args .scale_lr :
487- args .learning_rate = args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * 2
477+ args .learning_rate = (
478+ args .learning_rate
479+ * args .gradient_accumulation_steps
480+ * args .train_batch_size
481+ * gpc .get_world_size (ParallelMode .DATA )
482+ )
488483
489- unet = gemini_zero_dpp (unet , pg , args .placement )
484+ unet = gemini_zero_dpp (unet , args .placement )
490485
491486 # config optimizer for colossalai zero
492- optimizer = GeminiAdamOptimizer (unet , lr = args .learning_rate , initial_scale = 2 ** 5 )
487+ optimizer = GeminiAdamOptimizer (
488+ unet , lr = args .learning_rate , initial_scale = 2 ** 5 , clipping_norm = args .max_grad_norm
489+ )
493490
494491 # load noise_scheduler
495492 noise_scheduler = DDPMScheduler .from_pretrained (args .pretrained_model_name_or_path , subfolder = "scheduler" )
@@ -657,10 +654,11 @@ def collate_fn(examples):
657654
658655 if global_step % args .save_steps == 0 :
659656 torch .cuda .synchronize ()
657+ torch_unet = get_static_torch_model (unet )
660658 if gpc .get_local_rank (ParallelMode .DATA ) == 0 :
661659 pipeline = DiffusionPipeline .from_pretrained (
662660 args .pretrained_model_name_or_path ,
663- unet = convert_to_torch_module ( unet ) ,
661+ unet = torch_unet ,
664662 revision = args .revision ,
665663 )
666664 save_path = os .path .join (args .output_dir , f"checkpoint-{ global_step } " )
@@ -670,7 +668,7 @@ def collate_fn(examples):
670668 break
671669
672670 torch .cuda .synchronize ()
673- unet = convert_to_torch_module (unet )
671+ unet = get_static_torch_model (unet )
674672
675673 if gpc .get_local_rank (ParallelMode .DATA ) == 0 :
676674 pipeline = DiffusionPipeline .from_pretrained (
0 commit comments