@@ -401,13 +401,13 @@ def get_timesteps(self, num_inference_steps, strength, device):
401401 return timesteps , num_inference_steps - t_start
402402
403403 def prepare_latents (self , init_image , timestep , batch_size , num_images_per_prompt , dtype , device ,
404- generator = None , init_image_latents = None , noise = None ):
404+ generator = None , image_latents = None , noise = None ):
405405 init_image = init_image .to (device = device , dtype = dtype )
406406 init_latent_dist = self .vae .encode (init_image ).latent_dist
407- if init_image_latents == None :
407+ if image_latents == None :
408408 init_latents = init_latent_dist .sample (generator = generator )
409409 else :
410- init_latents = init_latent_dist .sample_from_sample (init_image_latents )
410+ init_latents = init_latent_dist .sample_from_sample (image_latents )
411411 init_latents = 0.18215 * init_latents
412412
413413 if batch_size > init_latents .shape [0 ] and batch_size % init_latents .shape [0 ] == 0 :
@@ -462,7 +462,7 @@ def __call__(
462462 num_images_per_prompt : Optional [int ] = 1 ,
463463 eta : Optional [float ] = 0.0 ,
464464 generator : Optional [torch .Generator ] = None ,
465- init_image_latents : Optional [torch .FloatTensor ] = None ,
465+ image_latents : Optional [torch .FloatTensor ] = None ,
466466 noise : Optional [torch .FloatTensor ] = None ,
467467 output_type : Optional [str ] = "pil" ,
468468 return_dict : bool = True ,
@@ -560,8 +560,8 @@ def __call__(
560560
561561 # 6. Prepare latent variables
562562 latents = self .prepare_latents (
563- init_image , latent_timestep , batch_size , num_images_per_prompt , text_embeddings .dtype , device ,
564- generator , init_image_latents , noise
563+ image , latent_timestep , batch_size , num_images_per_prompt , text_embeddings .dtype , device ,
564+ generator , image_latents , noise
565565 )
566566
567567 # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
0 commit comments