Skip to content

Commit 06b01ea

Browse files
yiyixuxuyiyixuxusayakpaul
authored
[ip-adapter] refactor prepare_ip_adapter_image_embeds and skip load image_encoder (huggingface#7016)
* add Co-authored-by: Sayak Paul <[email protected]> --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Sayak Paul <[email protected]>
1 parent f4fc750 commit 06b01ea

25 files changed

+769
-138
lines changed

docs/source/en/using-diffusers/ip_adapter.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,39 @@ export_to_gif(frames, "gummy_bear.gif")
234234
> [!TIP]
235235
> While calling `load_ip_adapter()`, pass `low_cpu_mem_usage=True` to speed up the loading time.
236236
237+
All the pipelines supporting IP-Adapter accept a `ip_adapter_image_embeds` argument. If you need to run the IP-Adapter multiple times with the same image, you can encode the image once and save the embedding to the disk.
238+
239+
```py
240+
image_embeds = pipeline.prepare_ip_adapter_image_embeds(
241+
ip_adapter_image=image,
242+
ip_adapter_image_embeds=None,
243+
device="cuda",
244+
num_images_per_prompt=1,
245+
do_classifier_free_guidance=True,
246+
)
247+
248+
torch.save(image_embeds, "image_embeds.ipadpt")
249+
```
250+
251+
Load the image embedding and pass it to the pipeline as `ip_adapter_image_embeds`
252+
253+
> [!TIP]
254+
> ComfyUI image embeddings for IP-Adapters are fully compatible in Diffusers and should work out-of-box.
255+
256+
```py
257+
image_embeds = torch.load("image_embeds.ipadpt")
258+
images = pipeline(
259+
prompt="a polar bear sitting in a chair drinking a milkshake",
260+
ip_adapter_image_embeds=image_embeds,
261+
negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
262+
num_inference_steps=100,
263+
generator=generator,
264+
).images
265+
```
266+
267+
> [!TIP]
268+
> If you use IP-Adapter with `ip_adapter_image_embedding` instead of `ip_adapter_image`, you can choose not to load an image encoder by passing `image_encoder_folder=None` to `load_ip_adapter()`.
269+
237270
## Specific use cases
238271

239272
IP-Adapter's image prompting and compatibility with other adapters and models makes it a versatile tool for a variety of use cases. This section covers some of the more popular applications of IP-Adapter, and we can't wait to see what you come up with!

examples/community/pipeline_animatediff_controlnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -799,8 +799,10 @@ def __call__(
799799
ip_adapter_image (`PipelineImageInput`, *optional*):
800800
Optional image input to work with IP Adapters.
801801
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
802-
Pre-generated image embeddings for IP-Adapter. If not
803-
provided, embeddings are computed from the `ip_adapter_image` input argument.
802+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
803+
Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
804+
if `do_classifier_free_guidance` is set to `True`.
805+
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
804806
conditioning_frames (`List[PipelineImageInput]`, *optional*):
805807
The ControlNet input condition to provide guidance to the `unet` for generation. If multiple ControlNets
806808
are specified, images must be passed as a list such that each element of the list can be correctly

examples/community/pipeline_animatediff_img2video.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -798,8 +798,10 @@ def __call__(
798798
ip_adapter_image: (`PipelineImageInput`, *optional*):
799799
Optional image input to work with IP Adapters.
800800
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
801-
Pre-generated image embeddings for IP-Adapter. If not
802-
provided, embeddings are computed from the `ip_adapter_image` input argument.
801+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
802+
Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
803+
if `do_classifier_free_guidance` is set to `True`.
804+
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
803805
output_type (`str`, *optional*, defaults to `"pil"`):
804806
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
805807
`np.array`.

src/diffusers/loaders/ip_adapter.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from pathlib import Path
16-
from typing import Dict, List, Union
16+
from typing import Dict, List, Optional, Union
1717

1818
import torch
1919
from huggingface_hub.utils import validate_hf_hub_args
@@ -52,11 +52,12 @@ def load_ip_adapter(
5252
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
5353
subfolder: Union[str, List[str]],
5454
weight_name: Union[str, List[str]],
55+
image_encoder_folder: Optional[str] = "image_encoder",
5556
**kwargs,
5657
):
5758
"""
5859
Parameters:
59-
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
60+
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
6061
Can be either:
6162
6263
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
@@ -65,7 +66,18 @@ def load_ip_adapter(
6566
with [`ModelMixin.save_pretrained`].
6667
- A [torch state
6768
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
68-
69+
subfolder (`str` or `List[str]`):
70+
The subfolder location of a model file within a larger model repository on the Hub or locally.
71+
If a list is passed, it should have the same length as `weight_name`.
72+
weight_name (`str` or `List[str]`):
73+
The name of the weight file to load. If a list is passed, it should have the same length as
74+
`weight_name`.
75+
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
76+
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
77+
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside `subfolder`,
78+
you only need to pass the name of the folder that contains image encoder weights, e.g. `image_encoder_folder="image_encoder"`.
79+
If the image encoder is located in a folder other than `subfolder`, you should pass the path to the folder that contains image encoder weights,
80+
for example, `image_encoder_folder="different_subfolder/image_encoder"`.
6981
cache_dir (`Union[str, os.PathLike]`, *optional*):
7082
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
7183
is not used.
@@ -87,8 +99,6 @@ def load_ip_adapter(
8799
revision (`str`, *optional*, defaults to `"main"`):
88100
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
89101
allowed by Git.
90-
subfolder (`str`, *optional*, defaults to `""`):
91-
The subfolder location of a model file within a larger model repository on the Hub or locally.
92102
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
93103
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
94104
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
@@ -184,16 +194,29 @@ def load_ip_adapter(
184194

185195
# load CLIP image encoder here if it has not been registered to the pipeline yet
186196
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
187-
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
188-
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
189-
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
190-
pretrained_model_name_or_path_or_dict,
191-
subfolder=Path(subfolder, "image_encoder").as_posix(),
192-
low_cpu_mem_usage=low_cpu_mem_usage,
193-
).to(self.device, dtype=self.dtype)
194-
self.register_modules(image_encoder=image_encoder)
197+
if image_encoder_folder is not None:
198+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
199+
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
200+
if image_encoder_folder.count("/") == 0:
201+
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
202+
else:
203+
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
204+
205+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
206+
pretrained_model_name_or_path_or_dict,
207+
subfolder=image_encoder_subfolder,
208+
low_cpu_mem_usage=low_cpu_mem_usage,
209+
).to(self.device, dtype=self.dtype)
210+
self.register_modules(image_encoder=image_encoder)
211+
else:
212+
raise ValueError(
213+
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
214+
)
195215
else:
196-
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
216+
logger.warning(
217+
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
218+
"Use `ip_adapter_image_embedding` to pass pre-geneated image embedding instead."
219+
)
197220

198221
# create feature extractor if it has not been registered to the pipeline yet
199222
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:

src/diffusers/pipelines/animatediff/pipeline_animatediff.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
370370

371371
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
372372
def prepare_ip_adapter_image_embeds(
373-
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
373+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
374374
):
375375
if ip_adapter_image_embeds is None:
376376
if not isinstance(ip_adapter_image, list):
@@ -394,13 +394,23 @@ def prepare_ip_adapter_image_embeds(
394394
[single_negative_image_embeds] * num_images_per_prompt, dim=0
395395
)
396396

397-
if self.do_classifier_free_guidance:
397+
if do_classifier_free_guidance:
398398
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
399399
single_image_embeds = single_image_embeds.to(device)
400400

401401
image_embeds.append(single_image_embeds)
402402
else:
403-
image_embeds = ip_adapter_image_embeds
403+
image_embeds = []
404+
for single_image_embeds in ip_adapter_image_embeds:
405+
if do_classifier_free_guidance:
406+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
407+
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
408+
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
409+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
410+
else:
411+
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
412+
image_embeds.append(single_image_embeds)
413+
404414
return image_embeds
405415

406416
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
@@ -494,6 +504,16 @@ def check_inputs(
494504
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
495505
)
496506

507+
if ip_adapter_image_embeds is not None:
508+
if not isinstance(ip_adapter_image_embeds, list):
509+
raise ValueError(
510+
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
511+
)
512+
elif ip_adapter_image_embeds[0].ndim != 3:
513+
raise ValueError(
514+
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
515+
)
516+
497517
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
498518
def prepare_latents(
499519
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
@@ -612,8 +632,10 @@ def __call__(
612632
ip_adapter_image: (`PipelineImageInput`, *optional*):
613633
Optional image input to work with IP Adapters.
614634
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
615-
Pre-generated image embeddings for IP-Adapter. If not
616-
provided, embeddings are computed from the `ip_adapter_image` input argument.
635+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
636+
Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
637+
if `do_classifier_free_guidance` is set to `True`.
638+
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
617639
output_type (`str`, *optional*, defaults to `"pil"`):
618640
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
619641
`np.array`.
@@ -717,7 +739,11 @@ def __call__(
717739

718740
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
719741
image_embeds = self.prepare_ip_adapter_image_embeds(
720-
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
742+
ip_adapter_image,
743+
ip_adapter_image_embeds,
744+
device,
745+
batch_size * num_videos_per_prompt,
746+
self.do_classifier_free_guidance,
721747
)
722748

723749
# 4. Prepare timesteps

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
448448

449449
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
450450
def prepare_ip_adapter_image_embeds(
451-
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
451+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
452452
):
453453
if ip_adapter_image_embeds is None:
454454
if not isinstance(ip_adapter_image, list):
@@ -472,13 +472,23 @@ def prepare_ip_adapter_image_embeds(
472472
[single_negative_image_embeds] * num_images_per_prompt, dim=0
473473
)
474474

475-
if self.do_classifier_free_guidance:
475+
if do_classifier_free_guidance:
476476
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
477477
single_image_embeds = single_image_embeds.to(device)
478478

479479
image_embeds.append(single_image_embeds)
480480
else:
481-
image_embeds = ip_adapter_image_embeds
481+
image_embeds = []
482+
for single_image_embeds in ip_adapter_image_embeds:
483+
if do_classifier_free_guidance:
484+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
485+
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
486+
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
487+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
488+
else:
489+
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
490+
image_embeds.append(single_image_embeds)
491+
482492
return image_embeds
483493

484494
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
@@ -523,6 +533,8 @@ def check_inputs(
523533
negative_prompt=None,
524534
prompt_embeds=None,
525535
negative_prompt_embeds=None,
536+
ip_adapter_image=None,
537+
ip_adapter_image_embeds=None,
526538
callback_on_step_end_tensor_inputs=None,
527539
):
528540
if strength < 0 or strength > 1:
@@ -567,6 +579,21 @@ def check_inputs(
567579
if video is not None and latents is not None:
568580
raise ValueError("Only one of `video` or `latents` should be provided")
569581

582+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
583+
raise ValueError(
584+
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
585+
)
586+
587+
if ip_adapter_image_embeds is not None:
588+
if not isinstance(ip_adapter_image_embeds, list):
589+
raise ValueError(
590+
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
591+
)
592+
elif ip_adapter_image_embeds[0].ndim != 3:
593+
raise ValueError(
594+
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
595+
)
596+
570597
def get_timesteps(self, num_inference_steps, timesteps, strength, device):
571598
# get the original timestep using init_timestep
572599
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
@@ -765,8 +792,10 @@ def __call__(
765792
ip_adapter_image: (`PipelineImageInput`, *optional*):
766793
Optional image input to work with IP Adapters.
767794
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
768-
Pre-generated image embeddings for IP-Adapter. If not
769-
provided, embeddings are computed from the `ip_adapter_image` input argument.
795+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
796+
Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
797+
if `do_classifier_free_guidance` is set to `True`.
798+
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
770799
output_type (`str`, *optional*, defaults to `"pil"`):
771800
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
772801
`np.array`.
@@ -814,6 +843,8 @@ def __call__(
814843
negative_prompt_embeds=negative_prompt_embeds,
815844
video=video,
816845
latents=latents,
846+
ip_adapter_image=ip_adapter_image,
847+
ip_adapter_image_embeds=ip_adapter_image_embeds,
817848
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
818849
)
819850

@@ -855,7 +886,11 @@ def __call__(
855886

856887
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
857888
image_embeds = self.prepare_ip_adapter_image_embeds(
858-
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
889+
ip_adapter_image,
890+
ip_adapter_image_embeds,
891+
device,
892+
batch_size * num_videos_per_prompt,
893+
self.do_classifier_free_guidance,
859894
)
860895

861896
# 4. Prepare timesteps

0 commit comments

Comments
 (0)