@@ -923,44 +923,47 @@ def main(args):
923923 if global_step >= args .max_train_steps :
924924 break
925925
926- if args .validation_prompt is not None and epoch % args .validation_epochs == 0 :
927- logger .info (
928- f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
929- f" { args .validation_prompt } ."
930- )
931- # create pipeline
932- pipeline = DiffusionPipeline .from_pretrained (
933- args .pretrained_model_name_or_path ,
934- unet = accelerator .unwrap_model (unet ),
935- text_encoder = accelerator .unwrap_model (text_encoder ),
936- revision = args .revision ,
937- torch_dtype = weight_dtype ,
938- )
939- pipeline .scheduler = DPMSolverMultistepScheduler .from_config (pipeline .scheduler .config )
940- pipeline = pipeline .to (accelerator .device )
941- pipeline .set_progress_bar_config (disable = True )
942-
943- # run inference
944- generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed )
945- prompt = args .num_validation_images * [args .validation_prompt ]
946- images = pipeline (prompt , num_inference_steps = 25 , generator = generator ).images
947-
948- for tracker in accelerator .trackers :
949- if tracker .name == "tensorboard" :
950- np_images = np .stack ([np .asarray (img ) for img in images ])
951- tracker .writer .add_images ("validation" , np_images , epoch , dataformats = "NHWC" )
952- if tracker .name == "wandb" :
953- tracker .log (
954- {
955- "validation" : [
956- wandb .Image (image , caption = f"{ i } : { args .validation_prompt } " )
957- for i , image in enumerate (images )
958- ]
959- }
960- )
961-
962- del pipeline
963- torch .cuda .empty_cache ()
926+ if accelerator .is_main_process :
927+ if args .validation_prompt is not None and epoch % args .validation_epochs == 0 :
928+ logger .info (
929+ f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
930+ f" { args .validation_prompt } ."
931+ )
932+ # create pipeline
933+ pipeline = DiffusionPipeline .from_pretrained (
934+ args .pretrained_model_name_or_path ,
935+ unet = accelerator .unwrap_model (unet ),
936+ text_encoder = accelerator .unwrap_model (text_encoder ),
937+ revision = args .revision ,
938+ torch_dtype = weight_dtype ,
939+ )
940+ pipeline .scheduler = DPMSolverMultistepScheduler .from_config (pipeline .scheduler .config )
941+ pipeline = pipeline .to (accelerator .device )
942+ pipeline .set_progress_bar_config (disable = True )
943+
944+ # run inference
945+ generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed )
946+ images = [
947+ pipeline (args .validation_prompt , num_inference_steps = 25 , generator = generator ).images [0 ]
948+ for _ in range (args .num_validation_images )
949+ ]
950+
951+ for tracker in accelerator .trackers :
952+ if tracker .name == "tensorboard" :
953+ np_images = np .stack ([np .asarray (img ) for img in images ])
954+ tracker .writer .add_images ("validation" , np_images , epoch , dataformats = "NHWC" )
955+ if tracker .name == "wandb" :
956+ tracker .log (
957+ {
958+ "validation" : [
959+ wandb .Image (image , caption = f"{ i } : { args .validation_prompt } " )
960+ for i , image in enumerate (images )
961+ ]
962+ }
963+ )
964+
965+ del pipeline
966+ torch .cuda .empty_cache ()
964967
965968 # Save the lora layers
966969 accelerator .wait_for_everyone ()
@@ -982,8 +985,10 @@ def main(args):
982985 # run inference
983986 if args .validation_prompt and args .num_validation_images > 0 :
984987 generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed ) if args .seed else None
985- prompt = args .num_validation_images * [args .validation_prompt ]
986- images = pipeline (prompt , num_inference_steps = 25 , generator = generator ).images
988+ images = [
989+ pipeline (args .validation_prompt , num_inference_steps = 25 , generator = generator ).images [0 ]
990+ for _ in range (args .num_validation_images )
991+ ]
987992
988993 for tracker in accelerator .trackers :
989994 if tracker .name == "tensorboard" :
0 commit comments