@@ -76,20 +76,11 @@ def image_grid(imgs, rows, cols):
7676 return grid
7777
7878
79- def log_validation (controlnet , controlnet_params , tokenizer , args , rng , weight_dtype ):
80- logger .info ("Running validation... " )
79+ def log_validation (pipeline , pipeline_params , controlnet_params , tokenizer , args , rng , weight_dtype ):
80+ logger .info ("Running validation..." )
8181
82- pipeline , params = FlaxStableDiffusionControlNetPipeline .from_pretrained (
83- args .pretrained_model_name_or_path ,
84- tokenizer = tokenizer ,
85- controlnet = controlnet ,
86- safety_checker = None ,
87- dtype = weight_dtype ,
88- revision = args .revision ,
89- from_pt = args .from_pt ,
90- )
91- params = jax_utils .replicate (params )
92- params ["controlnet" ] = controlnet_params
82+ pipeline_params = pipeline_params .copy ()
83+ pipeline_params ["controlnet" ] = controlnet_params
9384
9485 num_samples = jax .device_count ()
9586 prng_seed = jax .random .split (rng , jax .device_count ())
@@ -121,7 +112,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d
121112 images = pipeline (
122113 prompt_ids = prompt_ids ,
123114 image = processed_image ,
124- params = params ,
115+ params = pipeline_params ,
125116 prng_seed = prng_seed ,
126117 num_inference_steps = 50 ,
127118 jit = True ,
@@ -176,6 +167,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
176167- text-to-image
177168- diffusers
178169- controlnet
170+ - jax-diffusers-event
179171inference: true
180172---
181173 """
@@ -800,6 +792,17 @@ def main():
800792 ]:
801793 controlnet_params [key ] = unet_params [key ]
802794
795+ pipeline , pipeline_params = FlaxStableDiffusionControlNetPipeline .from_pretrained (
796+ args .pretrained_model_name_or_path ,
797+ tokenizer = tokenizer ,
798+ controlnet = controlnet ,
799+ safety_checker = None ,
800+ dtype = weight_dtype ,
801+ revision = args .revision ,
802+ from_pt = args .from_pt ,
803+ )
804+ pipeline_params = jax_utils .replicate (pipeline_params )
805+
803806 # Optimization
804807 if args .scale_lr :
805808 args .learning_rate = args .learning_rate * total_train_batch_size
@@ -1073,7 +1076,7 @@ def l2(xs):
10731076 and global_step % args .validation_steps == 0
10741077 and jax .process_index () == 0
10751078 ):
1076- _ = log_validation (controlnet , state .params , tokenizer , args , validation_rng , weight_dtype )
1079+ _ = log_validation (pipeline , pipeline_params , state .params , tokenizer , args , validation_rng , weight_dtype )
10771080
10781081 if global_step % args .logging_steps == 0 and jax .process_index () == 0 :
10791082 if args .report_to == "wandb" :
@@ -1105,7 +1108,7 @@ def l2(xs):
11051108 if args .validation_prompt is not None :
11061109 if args .profile_validation :
11071110 jax .profiler .start_trace (args .output_dir )
1108- image_logs = log_validation (controlnet , state .params , tokenizer , args , validation_rng , weight_dtype )
1111+ image_logs = log_validation (pipeline , pipeline_params , state .params , tokenizer , args , validation_rng , weight_dtype )
11091112 if args .profile_validation :
11101113 jax .profiler .stop_trace ()
11111114 else :
0 commit comments