Skip to content

Conversation

@yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Feb 18, 2024

fix #6925

allow pass ip_adapter_image_embeds directly and skip loading image_encoder

import torch
from diffusers import StableDiffusionXLPipeline, AutoencoderKL, DPMSolverMultistepScheduler
from diffusers.utils.testing_utils import load_pt

vae = AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix",
    torch_dtype=torch.float16,
).to("cuda")

pipeline = StableDiffusionXLPipeline.from_pretrained(
    'stabilityai/stable-diffusion-xl-base-1.0',
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
    vae=vae
).to('cuda')

pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline.scheduler.config.use_karras_sigmas = True


pipeline.load_ip_adapter(
    "h94/IP-Adapter",
    subfolder="sdxl_models",
    weight_name="ip-adapter_sdxl_vit-h.safetensors",
    image_encoder_folder=None,
)
pipeline.set_ip_adapter_scale(0.6)

print(f" pipeline.image_encoder: {pipeline.image_encoder}")

prompt = "a horse, highly detailed, 4k, professional"
negative_prompt="blurry"

# diffusers embeds [(2,2,1024)], cfg 
image_embeds =  load_pt("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/diffusers_style_test.ipadpt")
generator = torch.Generator(device="cpu").manual_seed(33)
image = pipeline(
    prompt=prompt,
    ip_adapter_image_embeds=image_embeds,
    negative_prompt=negative_prompt,
    guidance_scale=7.5,
    num_inference_steps=20,
    num_images_per_prompt=2,
    generator=generator,
).images

image[0].save("yiyi_test_4_out_diffusers_cfg.png")

# diffusers embeds, no cfg 
image_embeds_no_cfg = [single_image_embeds[1] for single_image_embeds in image_embeds]
generator = torch.Generator(device="cpu").manual_seed(33)
image = pipeline(
    prompt=prompt,
    ip_adapter_image_embeds=image_embeds_no_cfg,
    negative_prompt=negative_prompt,
    guidance_scale=0,
    num_inference_steps=20,
    num_images_per_prompt=2,
    generator=generator,
).images

image[0].save("yiyi_test_4_out_diffusers_no_cfg.png")


# comfyui embeds (2,2,1024), cfg 
image_embeds =  load_pt("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/comfyui_style_test.ipadpt").to(device="cuda", dtype=torch.float16)
image_embeds, negative_image_embeds = image_embeds.chunk(2)

generator = torch.Generator(device="cpu").manual_seed(33)
image = pipeline(
    prompt=prompt,
    ip_adapter_image_embeds=[torch.cat([negative_image_embeds, image_embeds], dim=0)],
    negative_prompt=negative_prompt,
    guidance_scale=7.5,
    num_inference_steps=20,
    num_images_per_prompt=2,
    generator=generator,
).images

image[0].save("yiyi_test_4_out_comfy_cfg.png")

# comfyui embeds, no cfg 
generator = torch.Generator(device="cpu").manual_seed(33)
image = pipeline(
    prompt=prompt,
    ip_adapter_image_embeds=[image_embeds],
    negative_prompt=negative_prompt,
    guidance_scale=0,
    num_inference_steps=20,
    num_images_per_prompt=2,
    generator=generator,
).images

image[0].save("yiyi_test_4_out_comfy_no_cfg.png")

we also allow specifying a different subfolder for image_encoder_folder

e.g. if the ip-adapter checkpoint needs to use image_encoder that's not from the same subfolder, you do not need to load it explicitly as described here in the doc https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters#ip-adapter-plus you can specify the image_encoder_folder instead

from diffusers import AutoPipelineForText2Image
import torch
from diffusers.utils import load_image

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16
)

pipeline.load_ip_adapter(
    "h94/IP-Adapter", 
    subfolder="sdxl_models", 
    weight_name="ip-adapter-plus_sdxl_vit-h.safetensors", 
    image_encoder_folder="models/image_encoder")
pipeline.enable_model_cpu_offload()


image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
generator = torch.Generator(device="cpu").manual_seed(33)
images = pipeline(
    prompt='best quality, high quality, wearing sunglasses', 
    ip_adapter_image=image,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
    num_inference_steps=50,
    generator=generator,
).images[0]
images.save("yiyi_test_out.png")

@yiyixuxu yiyixuxu changed the title [ip-adapter] allow pass ip_hidden_states directly and skip load image_encoder [ip-adapter] allow pass ip_hidden_states directly and skip load image_encoder Feb 18, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator Author

cc @asomoza feel free to give a review:)

@yiyixuxu yiyixuxu requested a review from sayakpaul February 19, 2024 03:49
@sayakpaul
Copy link
Member

If you could also describe how was "diffusers_style_test.ipadpt" generated, I think that'd be a useful reference for the community.

if image_encoder_folder is not None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
if image_encoder_folder.count("/") == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧠

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking nice!

I think it'd be nice to add:

@asomoza
Copy link
Member

asomoza commented Feb 19, 2024

Nice that this also solves the problem with the different image encoders for SDXL.

I'm still testing and everything works except when I try to load the embeddings from comfyui, I get a shape error. I'm trying to figure out where is the difference between the image embeddings form diffusers and the ones from comfyui.

edit: This is not a problem with this PR though, probably need some kind of conversion.

@asomoza
Copy link
Member

asomoza commented Feb 19, 2024

If you could also describe how was "diffusers_style_test.ipadpt" generated, I think that'd be a useful reference for the community.

Just did a torch.save after the prepare_ip_adapter_image_embeds with your example in #6868:

image_embeds = prepare_ip_adapter_image_embeds(
    unet=pipeline.unet,
    image_encoder=pipeline.image_encoder,
    feature_extractor=pipeline.feature_extractor,
    ip_adapter_image=[[image_one, image_two]],
    do_classifier_free_guidance=True,
    device="cuda",
    num_images_per_prompt=1,
)

torch.save(image_embeds, "diffusers_style_test.ipadpt")

@asomoza
Copy link
Member

asomoza commented Feb 19, 2024

I found another issue while doing some tests, I didn't notice it before since I don't use multiple images per prompt, but the code right now expects the embeddings to match this argument, for example in the SDXL pipeline:

single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)

This ties the embeddings to the number of images per prompt which is not ideal unless you keep track of this.

I see two solutions:

1.- Document this and leave it to the users
2.- Move this part of the code to after loading the embeddings so they will be the same regardless of the number of images per prompt.

I'll wait to see your resolution since this makes it hard to reuse the embeddings or make them compatible with other apps.

As an example, if I use the same code that generates the horse but with the num_images_per_prompt=4 I get this result:

20240219121344

@sayakpaul
Copy link
Member

Thanks for reporting!

Move this part of the code to after loading the embeddings so they will be the same regardless of the number of images per prompt.

I think we should catch these things early and report to the users so that they can call the pipeline with the appropriate values. We'd want the num_images_per_prompt argument to be consistent with how it's generally treated across the library, IMO. WDYT? Also @yiyixuxu? I think this way, we wouldn't have to document it separately.

As an example, if I use the same code that generates the horse but with the num_images_per_prompt=4 I get this result:

What is the expected result?

@asomoza
Copy link
Member

asomoza commented Feb 19, 2024

I think we should catch these things early and report to the users so that they can call the pipeline with the appropriate values. We'd want the num_images_per_prompt argument to be consistent with how it's generally treated across the library, IMO. WDYT? Also @yiyixuxu? I think this way, we wouldn't have to document it separately.

AFAIK there's no way to catch this, this doesn't produce any errors is just that the final result is different and kind of worse than the original.

The only way to make it work as intended, is if you match the number of images per prompt with the embeddings, so you'll need to also remember this value, put it in the filename or save it as a metadata in the file, but this also makes it only usable for that specific case.

This is not a critical issue, specially for people that just use one image per prompt, but with it, I don't see the need to put any more effort into making it compatible with saved embeddings from other apps or libraries.

What is the expected result?

The same as the first image in this PR:

20240219011826

@sayakpaul
Copy link
Member

I think it does warrant a deeper investigation then why it’s the case. We should fix the root cause here IMO.

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Feb 20, 2024

@asomoza @sayakpaul

about the other issue

I see two solutions:
1.- Document this and leave it to the users
2.- Move this part of the code to after loading the embeddings so they will be the same regardless of the number of images per prompt.

i think solution 2 is an easy answer, no? any downside to it that I missed?

@asomoza
Copy link
Member

asomoza commented Feb 20, 2024

i think solution 2 is an easy answer, no? any downside to it that I missed?

yeah, for me it is the solution I would choose but sometimes there's an underlying reason in diffusers that I don't know which prevents it.

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Feb 20, 2024

@asomoza

ok, I will make changes based on solution 2. I think it's fine because that's consistent with prompt_embeds (we do not expect prompt_embeds to match num_images_per_prompt )

prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)

but we will wait @sayakpaul back next week to make a final decision on what we do

@asomoza
Copy link
Member

asomoza commented Feb 20, 2024

@yiyixuxu @sayakpaul

There's another parameter that gets saved with the image embeddings and ties it up, the do_classifier_free_guidance which means that the embeddings will only work if they match with what is saved (CFG > 0 or CFG < 1)

So if we're going to move the num_images_per_prompt we should also move the check for do_classifier_free_guidance to after saving/loading the embeddings. This is also consistent with prompt_embeds

@asomoza
Copy link
Member

asomoza commented Feb 20, 2024

I can now make my saved embeddings work with diffusers and comfyui depending on which point in my code I save the embeddings, but right now there isn't a solution to make both of them compatible.

I was wrong about the image projection being done before the saving though, so the comfyui embeds works passing them with ip_adapter_image_embeds, I do mine with ip_hidden_states and works ok too.

I can load and run the comfyui embeds like this:

comfy_image_embeds = torch.load("comfyui_style_test.ipadpt")
embeds = torch.unbind(comfy_image_embeds )

single_image_embeds = embeds[0]
single_negative_image_embeds = embeds[1]

single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)

image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])

but there's no way to make the diffusers embeds work with comfyui with all the arguments saved with the embeddings.

@yiyixuxu
Copy link
Collaborator Author

@asomoza
I don't quiet understand the last comment
is this an example of comfyui embedding? https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/comfyui_style_test.ipadpt?download=true

it is a 2 x 2 x 1024 tensor, I would assume it is a the embedding after image_encoder and before the projection layer, no?

)

if self.do_classifier_free_guidance:
if do_classifier_free_guidance:
Copy link
Collaborator Author

@yiyixuxu yiyixuxu Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding do_classifier_free_guidance as an argument to prepare_ip_adapter_imabe_embeds so we can use this pipeline method to save image embeddings

image_embeds = pipeline.prepare_ip_adapter_image_embeds(
    ip_adapter_image=image,
    ip_adapter_image_embeds=None,
    device="cuda",
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
)

torch.save(image_embeds, "image_embeds.ipadpt")

@yiyixuxu yiyixuxu changed the title [ip-adapter] allow pass ip_hidden_states directly and skip load image_encoder [ip-adapter] refactor prepare_ip_adapter_image_embeds and skip load image_encoder Feb 29, 2024
> [!TIP]
> While calling `load_ip_adapter()`, pass `low_cpu_mem_usage=True` to speed up the loading time.
All the pipelines supporting IP-Adapter accept a `ip_adapter_image_embeds` argument. If you need to run the IP-Adapter multiple times with the same image, you can encode the image once and save the embedding to the disk.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @stevhliu here for awareness

I added a section to ip-adapter guide here. Let me know if you have any comments. If editing in a separate PR is easier, feel free to do so!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another very good use case for ip_adapter_image_embeds is probably the multi-ip-adapter https://huggingface.co/docs/diffusers/main/en/using-diffusers/ip_adapter#multi-ip-adapter

a common practice is to use a folder of 10+ images for styling, and you would use the same styling images everywhere to create a consistent style, so it would be nice to create an image embedding for these style images, so you don't have to load a bunch of same images from a folder and encode them each time

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should definitely add that example motivating the use case. WDYT @asomoza?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll edit it in a separate PR, and I can also make a mention of ip_adapter_image_embeds in the multi IP-Adapter section 🙂

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should definitely add that example motivating the use case. WDYT @asomoza?

yeah this is specially helpful when you use a lot of images and multiple ip adapters, you just need to save the embeddings making it a lot easier to replicate and saves a lot of space if you use high quality images.

I'll try to do one with a style and a character and see how it goes, but to see the real potential of this we'll also need controlnet and ip adapter masking so the best use case would be a full demo with all of this incorporated.

@yiyixuxu
Copy link
Collaborator Author

finally finishing up this PR now. I refactored some more feel free to give a final review

cc @sayakpaul @asomoza

> ComfyUI image embeddings are fully compatible with IP-Adapter in diffusers and will work out-of-box.
```py
image_embeds = torch.load("image_embeds.ipadpt")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't know where this is coming from. Let's include a snippet to download that and explicitly mention that it's coming from ComfyUI.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking pretty solid. I left a couple of suggestions to the docs. I reviewed the changes made to ip_adapter.py and the changes in prepare_ip_adapter_image_embeds() and check_inputs from SDXL pipeline script. I think the rest of the pipelines share these changes?

I thought we were also supporting passing the image embedding projection as well. Are we not doing so?

Comment on lines +507 to +515
if ip_adapter_image_embeds is not None:
if not isinstance(ip_adapter_image_embeds, list):
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need any checks on the shapes to conform to what's needed for classifier-free guidance?

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Mar 1, 2024

@sayakpaul

I thought we were also supporting passing the image embedding projection as well. Are we not doing so?

so it turns out comfyUI embedding is created before the image projection layer - so we don't need to support passing the projection output directly anymore since it is too small an use case

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

IP Adapter Image Embeds - Compatibility with ComfyUI and other libraries/apps

7 participants