@@ -224,11 +224,13 @@ def __init__(
224224 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if getattr (self , "vae" , None ) else 8
225225 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
226226 # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
227- self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor * 2 )
228- latent_channels = self .vae .config .latent_channels if getattr (self , "vae" , None ) else 16
227+ self .latent_channels = self .vae .config .latent_channels if getattr (self , "vae" , None ) else 16
228+ self .image_processor = VaeImageProcessor (
229+ vae_scale_factor = self .vae_scale_factor * 2 , vae_latent_channels = self .latent_channels
230+ )
229231 self .mask_processor = VaeImageProcessor (
230232 vae_scale_factor = self .vae_scale_factor * 2 ,
231- vae_latent_channels = latent_channels ,
233+ vae_latent_channels = self . latent_channels ,
232234 do_normalize = False ,
233235 do_binarize = True ,
234236 do_convert_grayscale = True ,
@@ -493,10 +495,38 @@ def encode_prompt(
493495
494496 return prompt_embeds , pooled_prompt_embeds , text_ids
495497
498+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
499+ def _encode_vae_image (self , image : torch .Tensor , generator : torch .Generator ):
500+ if isinstance (generator , list ):
501+ image_latents = [
502+ retrieve_latents (self .vae .encode (image [i : i + 1 ]), generator = generator [i ])
503+ for i in range (image .shape [0 ])
504+ ]
505+ image_latents = torch .cat (image_latents , dim = 0 )
506+ else :
507+ image_latents = retrieve_latents (self .vae .encode (image ), generator = generator )
508+
509+ image_latents = (image_latents - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
510+
511+ return image_latents
512+
513+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
514+ def get_timesteps (self , num_inference_steps , strength , device ):
515+ # get the original timestep using init_timestep
516+ init_timestep = min (num_inference_steps * strength , num_inference_steps )
517+
518+ t_start = int (max (num_inference_steps - init_timestep , 0 ))
519+ timesteps = self .scheduler .timesteps [t_start * self .scheduler .order :]
520+ if hasattr (self .scheduler , "set_begin_index" ):
521+ self .scheduler .set_begin_index (t_start * self .scheduler .order )
522+
523+ return timesteps , num_inference_steps - t_start
524+
496525 def check_inputs (
497526 self ,
498527 prompt ,
499528 prompt_2 ,
529+ strength ,
500530 height ,
501531 width ,
502532 prompt_embeds = None ,
@@ -507,6 +537,9 @@ def check_inputs(
507537 mask_image = None ,
508538 masked_image_latents = None ,
509539 ):
540+ if strength < 0 or strength > 1 :
541+ raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
542+
510543 if height % (self .vae_scale_factor * 2 ) != 0 or width % (self .vae_scale_factor * 2 ) != 0 :
511544 logger .warning (
512545 f"`height` and `width` have to be divisible by { self .vae_scale_factor * 2 } but are { height } and { width } . Dimensions will be resized accordingly"
@@ -624,9 +657,11 @@ def disable_vae_tiling(self):
624657 """
625658 self .vae .disable_tiling ()
626659
627- # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline .prepare_latents
660+ # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline .prepare_latents
628661 def prepare_latents (
629662 self ,
663+ image ,
664+ timestep ,
630665 batch_size ,
631666 num_channels_latents ,
632667 height ,
@@ -636,28 +671,41 @@ def prepare_latents(
636671 generator ,
637672 latents = None ,
638673 ):
674+ if isinstance (generator , list ) and len (generator ) != batch_size :
675+ raise ValueError (
676+ f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
677+ f" size of { batch_size } . Make sure the batch size matches the length of the generators."
678+ )
679+
639680 # VAE applies 8x compression on images but we must also account for packing which requires
640681 # latent height and width to be divisible by 2.
641682 height = 2 * (int (height ) // (self .vae_scale_factor * 2 ))
642683 width = 2 * (int (width ) // (self .vae_scale_factor * 2 ))
643-
644684 shape = (batch_size , num_channels_latents , height , width )
685+ latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
645686
646687 if latents is not None :
647- latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
648688 return latents .to (device = device , dtype = dtype ), latent_image_ids
649689
650- if isinstance (generator , list ) and len (generator ) != batch_size :
690+ image = image .to (device = device , dtype = dtype )
691+ if image .shape [1 ] != self .latent_channels :
692+ image_latents = self ._encode_vae_image (image = image , generator = generator )
693+ else :
694+ image_latents = image
695+ if batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] == 0 :
696+ # expand init_latents for batch_size
697+ additional_image_per_prompt = batch_size // image_latents .shape [0 ]
698+ image_latents = torch .cat ([image_latents ] * additional_image_per_prompt , dim = 0 )
699+ elif batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] != 0 :
651700 raise ValueError (
652- f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
653- f" size of { batch_size } . Make sure the batch size matches the length of the generators."
701+ f"Cannot duplicate `image` of batch size { image_latents .shape [0 ]} to { batch_size } text prompts."
654702 )
703+ else :
704+ image_latents = torch .cat ([image_latents ], dim = 0 )
655705
656- latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
706+ noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
707+ latents = self .scheduler .scale_noise (image_latents , timestep , noise )
657708 latents = self ._pack_latents (latents , batch_size , num_channels_latents , height , width )
658-
659- latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
660-
661709 return latents , latent_image_ids
662710
663711 @property
@@ -687,6 +735,7 @@ def __call__(
687735 masked_image_latents : Optional [torch .FloatTensor ] = None ,
688736 height : Optional [int ] = None ,
689737 width : Optional [int ] = None ,
738+ strength : float = 1.0 ,
690739 num_inference_steps : int = 50 ,
691740 sigmas : Optional [List [float ]] = None ,
692741 guidance_scale : float = 30.0 ,
@@ -731,6 +780,12 @@ def __call__(
731780 The height in pixels of the generated image. This is set to 1024 by default for the best results.
732781 width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
733782 The width in pixels of the generated image. This is set to 1024 by default for the best results.
783+ strength (`float`, *optional*, defaults to 1.0):
784+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
785+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
786+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
787+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
788+ essentially ignores `image`.
734789 num_inference_steps (`int`, *optional*, defaults to 50):
735790 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
736791 expense of slower inference.
@@ -794,6 +849,7 @@ def __call__(
794849 self .check_inputs (
795850 prompt ,
796851 prompt_2 ,
852+ strength ,
797853 height ,
798854 width ,
799855 prompt_embeds = prompt_embeds ,
@@ -809,6 +865,9 @@ def __call__(
809865 self ._joint_attention_kwargs = joint_attention_kwargs
810866 self ._interrupt = False
811867
868+ init_image = self .image_processor .preprocess (image , height = height , width = width )
869+ init_image = init_image .to (dtype = torch .float32 )
870+
812871 # 2. Define call parameters
813872 if prompt is not None and isinstance (prompt , str ):
814873 batch_size = 1
@@ -838,9 +897,37 @@ def __call__(
838897 lora_scale = lora_scale ,
839898 )
840899
841- # 4. Prepare latent variables
900+ # 4. Prepare timesteps
901+ sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps ) if sigmas is None else sigmas
902+ image_seq_len = (int (height ) // self .vae_scale_factor // 2 ) * (int (width ) // self .vae_scale_factor // 2 )
903+ mu = calculate_shift (
904+ image_seq_len ,
905+ self .scheduler .config .get ("base_image_seq_len" , 256 ),
906+ self .scheduler .config .get ("max_image_seq_len" , 4096 ),
907+ self .scheduler .config .get ("base_shift" , 0.5 ),
908+ self .scheduler .config .get ("max_shift" , 1.15 ),
909+ )
910+ timesteps , num_inference_steps = retrieve_timesteps (
911+ self .scheduler ,
912+ num_inference_steps ,
913+ device ,
914+ sigmas = sigmas ,
915+ mu = mu ,
916+ )
917+ timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , strength , device )
918+
919+ if num_inference_steps < 1 :
920+ raise ValueError (
921+ f"After adjusting the num_inference_steps by strength parameter: { strength } , the number of pipeline"
922+ f"steps is { num_inference_steps } which is < 1 and not appropriate for this pipeline."
923+ )
924+ latent_timestep = timesteps [:1 ].repeat (batch_size * num_images_per_prompt )
925+
926+ # 5. Prepare latent variables
842927 num_channels_latents = self .vae .config .latent_channels
843928 latents , latent_image_ids = self .prepare_latents (
929+ init_image ,
930+ latent_timestep ,
844931 batch_size * num_images_per_prompt ,
845932 num_channels_latents ,
846933 height ,
@@ -851,17 +938,16 @@ def __call__(
851938 latents ,
852939 )
853940
854- # 5 . Prepare mask and masked image latents
941+ # 6 . Prepare mask and masked image latents
855942 if masked_image_latents is not None :
856943 masked_image_latents = masked_image_latents .to (latents .device )
857944 else :
858- image = self .image_processor .preprocess (image , height = height , width = width )
859945 mask_image = self .mask_processor .preprocess (mask_image , height = height , width = width )
860946
861- masked_image = image * (1 - mask_image )
947+ masked_image = init_image * (1 - mask_image )
862948 masked_image = masked_image .to (device = device , dtype = prompt_embeds .dtype )
863949
864- height , width = image .shape [- 2 :]
950+ height , width = init_image .shape [- 2 :]
865951 mask , masked_image_latents = self .prepare_mask_latents (
866952 mask_image ,
867953 masked_image ,
@@ -876,23 +962,6 @@ def __call__(
876962 )
877963 masked_image_latents = torch .cat ((masked_image_latents , mask ), dim = - 1 )
878964
879- # 6. Prepare timesteps
880- sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps ) if sigmas is None else sigmas
881- image_seq_len = latents .shape [1 ]
882- mu = calculate_shift (
883- image_seq_len ,
884- self .scheduler .config .get ("base_image_seq_len" , 256 ),
885- self .scheduler .config .get ("max_image_seq_len" , 4096 ),
886- self .scheduler .config .get ("base_shift" , 0.5 ),
887- self .scheduler .config .get ("max_shift" , 1.15 ),
888- )
889- timesteps , num_inference_steps = retrieve_timesteps (
890- self .scheduler ,
891- num_inference_steps ,
892- device ,
893- sigmas = sigmas ,
894- mu = mu ,
895- )
896965 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
897966 self ._num_timesteps = len (timesteps )
898967
0 commit comments