3434)
3535from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
3636from ...utils .torch_utils import randn_tensor
37+ from ..free_init_utils import FreeInitMixin
3738from ..pipeline_utils import DiffusionPipeline
3839from .pipeline_output import AnimateDiffPipelineOutput
3940
@@ -163,7 +164,9 @@ def retrieve_timesteps(
163164 return timesteps , num_inference_steps
164165
165166
166- class AnimateDiffVideoToVideoPipeline (DiffusionPipeline , TextualInversionLoaderMixin , IPAdapterMixin , LoraLoaderMixin ):
167+ class AnimateDiffVideoToVideoPipeline (
168+ DiffusionPipeline , TextualInversionLoaderMixin , IPAdapterMixin , LoraLoaderMixin , FreeInitMixin
169+ ):
167170 r"""
168171 Pipeline for video-to-video generation.
169172
@@ -193,7 +196,7 @@ class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderM
193196 """
194197
195198 model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
196- _optional_components = ["feature_extractor" , "image_encoder" ]
199+ _optional_components = ["feature_extractor" , "image_encoder" , "motion_adapter" ]
197200 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
198201
199202 def __init__ (
@@ -215,7 +218,8 @@ def __init__(
215218 image_encoder : CLIPVisionModelWithProjection = None ,
216219 ):
217220 super ().__init__ ()
218- unet = UNetMotionModel .from_unet2d (unet , motion_adapter )
221+ if isinstance (unet , UNet2DConditionModel ):
222+ unet = UNetMotionModel .from_unet2d (unet , motion_adapter )
219223
220224 self .register_modules (
221225 vae = vae ,
@@ -584,12 +588,12 @@ def check_inputs(
584588 if video is not None and latents is not None :
585589 raise ValueError ("Only one of `video` or `latents` should be provided" )
586590
587- def get_timesteps (self , num_inference_steps , strength , device ):
591+ def get_timesteps (self , num_inference_steps , timesteps , strength , device ):
588592 # get the original timestep using init_timestep
589593 init_timestep = min (int (num_inference_steps * strength ), num_inference_steps )
590594
591595 t_start = max (num_inference_steps - init_timestep , 0 )
592- timesteps = self . scheduler . timesteps [t_start * self .scheduler .order :]
596+ timesteps = timesteps [t_start * self .scheduler .order :]
593597
594598 return timesteps , num_inference_steps - t_start
595599
@@ -876,9 +880,8 @@ def __call__(
876880
877881 # 4. Prepare timesteps
878882 timesteps , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , timesteps )
879- timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , strength , device )
883+ timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , timesteps , strength , device )
880884 latent_timestep = timesteps [:1 ].repeat (batch_size * num_videos_per_prompt )
881- self ._num_timesteps = len (timesteps )
882885
883886 # 5. Prepare latent variables
884887 num_channels_latents = self .unet .config .in_channels
@@ -901,42 +904,55 @@ def __call__(
901904 # 7. Add image embeds for IP-Adapter
902905 added_cond_kwargs = {"image_embeds" : image_embeds } if ip_adapter_image is not None else None
903906
904- # 8. Denoising loop
905- num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
906- with self .progress_bar (total = num_inference_steps ) as progress_bar :
907- for i , t in enumerate (timesteps ):
908- # expand the latents if we are doing classifier free guidance
909- latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
910- latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
911-
912- # predict the noise residual
913- noise_pred = self .unet (
914- latent_model_input ,
915- t ,
916- encoder_hidden_states = prompt_embeds ,
917- cross_attention_kwargs = self .cross_attention_kwargs ,
918- added_cond_kwargs = added_cond_kwargs ,
919- ).sample
920-
921- # perform guidance
922- if self .do_classifier_free_guidance :
923- noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
924- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
925-
926- # compute the previous noisy sample x_t -> x_t-1
927- latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs ).prev_sample
928-
929- if callback_on_step_end is not None :
930- callback_kwargs = {}
931- for k in callback_on_step_end_tensor_inputs :
932- callback_kwargs [k ] = locals ()[k ]
933- callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
934-
935- latents = callback_outputs .pop ("latents" , latents )
936- prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
937- negative_prompt_embeds = callback_outputs .pop ("negative_prompt_embeds" , negative_prompt_embeds )
938-
939- progress_bar .update ()
907+ num_free_init_iters = self ._free_init_num_iters if self .free_init_enabled else 1
908+ for free_init_iter in range (num_free_init_iters ):
909+ if self .free_init_enabled :
910+ latents , timesteps = self ._apply_free_init (
911+ latents , free_init_iter , num_inference_steps , device , latents .dtype , generator
912+ )
913+ num_inference_steps = len (timesteps )
914+ # make sure to readjust timesteps based on strength
915+ timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , timesteps , strength , device )
916+
917+ self ._num_timesteps = len (timesteps )
918+ num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
919+ # 8. Denoising loop
920+ with self .progress_bar (total = num_inference_steps ) as progress_bar :
921+ for i , t in enumerate (timesteps ):
922+ # expand the latents if we are doing classifier free guidance
923+ latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
924+ latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
925+
926+ # predict the noise residual
927+ noise_pred = self .unet (
928+ latent_model_input ,
929+ t ,
930+ encoder_hidden_states = prompt_embeds ,
931+ cross_attention_kwargs = self .cross_attention_kwargs ,
932+ added_cond_kwargs = added_cond_kwargs ,
933+ ).sample
934+
935+ # perform guidance
936+ if self .do_classifier_free_guidance :
937+ noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
938+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
939+
940+ # compute the previous noisy sample x_t -> x_t-1
941+ latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs ).prev_sample
942+
943+ if callback_on_step_end is not None :
944+ callback_kwargs = {}
945+ for k in callback_on_step_end_tensor_inputs :
946+ callback_kwargs [k ] = locals ()[k ]
947+ callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
948+
949+ latents = callback_outputs .pop ("latents" , latents )
950+ prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
951+ negative_prompt_embeds = callback_outputs .pop ("negative_prompt_embeds" , negative_prompt_embeds )
952+
953+ # call the callback, if provided
954+ if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
955+ progress_bar .update ()
940956
941957 if output_type == "latent" :
942958 return AnimateDiffPipelineOutput (frames = latents )
0 commit comments