5252from diffusers .optimization import get_scheduler
5353from diffusers .utils import check_min_version , is_wandb_available
5454from diffusers .utils .import_utils import is_xformers_available
55+ from diffusers .utils .torch_utils import randn_tensor
5556
5657
5758if is_wandb_available ():
@@ -114,16 +115,17 @@ def log_validation(
114115
115116 pipeline_args = {}
116117
117- if text_encoder is not None :
118- pipeline_args ["text_encoder" ] = accelerator .unwrap_model (text_encoder )
119-
120118 if vae is not None :
121119 pipeline_args ["vae" ] = vae
122120
121+ if text_encoder is not None :
122+ text_encoder = accelerator .unwrap_model (text_encoder )
123+
123124 # create pipeline (note: unet and vae are loaded again in float32)
124125 pipeline = DiffusionPipeline .from_pretrained (
125126 args .pretrained_model_name_or_path ,
126127 tokenizer = tokenizer ,
128+ text_encoder = text_encoder ,
127129 unet = accelerator .unwrap_model (unet ),
128130 revision = args .revision ,
129131 torch_dtype = weight_dtype ,
@@ -156,10 +158,16 @@ def log_validation(
156158 # run inference
157159 generator = None if args .seed is None else torch .Generator (device = accelerator .device ).manual_seed (args .seed )
158160 images = []
159- for _ in range (args .num_validation_images ):
160- with torch .autocast ("cuda" ):
161- image = pipeline (** pipeline_args , num_inference_steps = 25 , generator = generator ).images [0 ]
162- images .append (image )
161+ if args .validation_images is None :
162+ for _ in range (args .num_validation_images ):
163+ with torch .autocast ("cuda" ):
164+ image = pipeline (** pipeline_args , num_inference_steps = 25 , generator = generator ).images [0 ]
165+ images .append (image )
166+ else :
167+ for image in args .validation_images :
168+ image = Image .open (image )
169+ image = pipeline (** pipeline_args , image = image , generator = generator ).images [0 ]
170+ images .append (image )
163171
164172 for tracker in accelerator .trackers :
165173 if tracker .name == "tensorboard" :
@@ -525,6 +533,19 @@ def parse_args(input_args=None):
525533 parser .add_argument (
526534 "--skip_save_text_encoder" , action = "store_true" , required = False , help = "Set to not save text encoder"
527535 )
536+ parser .add_argument (
537+ "--validation_images" ,
538+ required = False ,
539+ default = None ,
540+ nargs = "+" ,
541+ help = "Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution." ,
542+ )
543+ parser .add_argument (
544+ "--class_labels_conditioning" ,
545+ required = False ,
546+ default = None ,
547+ help = "The optional `class_label` conditioning to pass to the unet, available values are `timesteps`." ,
548+ )
528549
529550 if input_args is not None :
530551 args = parser .parse_args (input_args )
@@ -1169,7 +1190,7 @@ def compute_text_embeddings(prompt):
11691190 )
11701191 else :
11711192 noise = torch .randn_like (model_input )
1172- bsz = model_input .shape [ 0 ]
1193+ bsz , channels , height , width = model_input .shape
11731194 # Sample a random timestep for each image
11741195 timesteps = torch .randint (
11751196 0 , noise_scheduler .config .num_train_timesteps , (bsz ,), device = model_input .device
@@ -1191,8 +1212,24 @@ def compute_text_embeddings(prompt):
11911212 text_encoder_use_attention_mask = args .text_encoder_use_attention_mask ,
11921213 )
11931214
1215+ if unet .config .in_channels > channels :
1216+ needed_additional_channels = unet .config .in_channels - channels
1217+ additional_latents = randn_tensor (
1218+ (bsz , needed_additional_channels , height , width ),
1219+ device = noisy_model_input .device ,
1220+ dtype = noisy_model_input .dtype ,
1221+ )
1222+ noisy_model_input = torch .cat ([additional_latents , noisy_model_input ], dim = 1 )
1223+
1224+ if args .class_labels_conditioning == "timesteps" :
1225+ class_labels = timesteps
1226+ else :
1227+ class_labels = None
1228+
11941229 # Predict the noise residual
1195- model_pred = unet (noisy_model_input , timesteps , encoder_hidden_states ).sample
1230+ model_pred = unet (
1231+ noisy_model_input , timesteps , encoder_hidden_states , class_labels = class_labels
1232+ ).sample
11961233
11971234 if model_pred .shape [1 ] == 6 :
11981235 model_pred , _ = torch .chunk (model_pred , 2 , dim = 1 )
0 commit comments