Skip to content

Commit 001b140

Browse files
asomozasayakpaul
andauthored
[ip-adapter] fix problem using embeds with the plus version of ip adapters (huggingface#7189)
* initial * check_inputs fix to the rest of pipelines * add fix for no cfg too * use of variable --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent f55873b commit 001b140

20 files changed

+240
-100
lines changed

src/diffusers/pipelines/animatediff/pipeline_animatediff.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -400,15 +400,22 @@ def prepare_ip_adapter_image_embeds(
400400

401401
image_embeds.append(single_image_embeds)
402402
else:
403+
repeat_dims = [1]
403404
image_embeds = []
404405
for single_image_embeds in ip_adapter_image_embeds:
405406
if do_classifier_free_guidance:
406407
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
407-
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
408-
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
408+
single_image_embeds = single_image_embeds.repeat(
409+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
410+
)
411+
single_negative_image_embeds = single_negative_image_embeds.repeat(
412+
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
413+
)
409414
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
410415
else:
411-
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
416+
single_image_embeds = single_image_embeds.repeat(
417+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
418+
)
412419
image_embeds.append(single_image_embeds)
413420

414421
return image_embeds
@@ -509,9 +516,9 @@ def check_inputs(
509516
raise ValueError(
510517
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
511518
)
512-
elif ip_adapter_image_embeds[0].ndim != 3:
519+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
513520
raise ValueError(
514-
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
521+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
515522
)
516523

517524
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -478,15 +478,22 @@ def prepare_ip_adapter_image_embeds(
478478

479479
image_embeds.append(single_image_embeds)
480480
else:
481+
repeat_dims = [1]
481482
image_embeds = []
482483
for single_image_embeds in ip_adapter_image_embeds:
483484
if do_classifier_free_guidance:
484485
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
485-
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
486-
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
486+
single_image_embeds = single_image_embeds.repeat(
487+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
488+
)
489+
single_negative_image_embeds = single_negative_image_embeds.repeat(
490+
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
491+
)
487492
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
488493
else:
489-
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
494+
single_image_embeds = single_image_embeds.repeat(
495+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
496+
)
490497
image_embeds.append(single_image_embeds)
491498

492499
return image_embeds
@@ -589,9 +596,9 @@ def check_inputs(
589596
raise ValueError(
590597
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
591598
)
592-
elif ip_adapter_image_embeds[0].ndim != 3:
599+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
593600
raise ValueError(
594-
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
601+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
595602
)
596603

597604
def get_timesteps(self, num_inference_steps, timesteps, strength, device):

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -510,15 +510,22 @@ def prepare_ip_adapter_image_embeds(
510510

511511
image_embeds.append(single_image_embeds)
512512
else:
513+
repeat_dims = [1]
513514
image_embeds = []
514515
for single_image_embeds in ip_adapter_image_embeds:
515516
if do_classifier_free_guidance:
516517
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
517-
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
518-
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
518+
single_image_embeds = single_image_embeds.repeat(
519+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
520+
)
521+
single_negative_image_embeds = single_negative_image_embeds.repeat(
522+
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
523+
)
519524
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
520525
else:
521-
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
526+
single_image_embeds = single_image_embeds.repeat(
527+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
528+
)
522529
image_embeds.append(single_image_embeds)
523530

524531
return image_embeds
@@ -726,9 +733,9 @@ def check_inputs(
726733
raise ValueError(
727734
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
728735
)
729-
elif ip_adapter_image_embeds[0].ndim != 3:
736+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
730737
raise ValueError(
731-
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
738+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
732739
)
733740

734741
def check_image(self, image, prompt, prompt_embeds):

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -503,15 +503,22 @@ def prepare_ip_adapter_image_embeds(
503503

504504
image_embeds.append(single_image_embeds)
505505
else:
506+
repeat_dims = [1]
506507
image_embeds = []
507508
for single_image_embeds in ip_adapter_image_embeds:
508509
if do_classifier_free_guidance:
509510
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
510-
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
511-
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
511+
single_image_embeds = single_image_embeds.repeat(
512+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
513+
)
514+
single_negative_image_embeds = single_negative_image_embeds.repeat(
515+
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
516+
)
512517
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
513518
else:
514-
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
519+
single_image_embeds = single_image_embeds.repeat(
520+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
521+
)
515522
image_embeds.append(single_image_embeds)
516523

517524
return image_embeds
@@ -713,9 +720,9 @@ def check_inputs(
713720
raise ValueError(
714721
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
715722
)
716-
elif ip_adapter_image_embeds[0].ndim != 3:
723+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
717724
raise ValueError(
718-
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
725+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
719726
)
720727

721728
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -628,15 +628,22 @@ def prepare_ip_adapter_image_embeds(
628628

629629
image_embeds.append(single_image_embeds)
630630
else:
631+
repeat_dims = [1]
631632
image_embeds = []
632633
for single_image_embeds in ip_adapter_image_embeds:
633634
if do_classifier_free_guidance:
634635
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
635-
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
636-
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
636+
single_image_embeds = single_image_embeds.repeat(
637+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
638+
)
639+
single_negative_image_embeds = single_negative_image_embeds.repeat(
640+
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
641+
)
637642
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
638643
else:
639-
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
644+
single_image_embeds = single_image_embeds.repeat(
645+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
646+
)
640647
image_embeds.append(single_image_embeds)
641648

642649
return image_embeds
@@ -871,9 +878,9 @@ def check_inputs(
871878
raise ValueError(
872879
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
873880
)
874-
elif ip_adapter_image_embeds[0].ndim != 3:
881+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
875882
raise ValueError(
876-
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
883+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
877884
)
878885

879886
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -537,15 +537,22 @@ def prepare_ip_adapter_image_embeds(
537537

538538
image_embeds.append(single_image_embeds)
539539
else:
540+
repeat_dims = [1]
540541
image_embeds = []
541542
for single_image_embeds in ip_adapter_image_embeds:
542543
if do_classifier_free_guidance:
543544
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
544-
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
545-
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
545+
single_image_embeds = single_image_embeds.repeat(
546+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
547+
)
548+
single_negative_image_embeds = single_negative_image_embeds.repeat(
549+
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
550+
)
546551
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
547552
else:
548-
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
553+
single_image_embeds = single_image_embeds.repeat(
554+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
555+
)
549556
image_embeds.append(single_image_embeds)
550557

551558
return image_embeds
@@ -817,9 +824,9 @@ def check_inputs(
817824
raise ValueError(
818825
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
819826
)
820-
elif ip_adapter_image_embeds[0].ndim != 3:
827+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
821828
raise ValueError(
822-
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
829+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
823830
)
824831

825832
def prepare_control_image(

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -515,15 +515,22 @@ def prepare_ip_adapter_image_embeds(
515515

516516
image_embeds.append(single_image_embeds)
517517
else:
518+
repeat_dims = [1]
518519
image_embeds = []
519520
for single_image_embeds in ip_adapter_image_embeds:
520521
if do_classifier_free_guidance:
521522
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
522-
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
523-
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
523+
single_image_embeds = single_image_embeds.repeat(
524+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
525+
)
526+
single_negative_image_embeds = single_negative_image_embeds.repeat(
527+
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
528+
)
524529
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
525530
else:
526-
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
531+
single_image_embeds = single_image_embeds.repeat(
532+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
533+
)
527534
image_embeds.append(single_image_embeds)
528535

529536
return image_embeds
@@ -730,9 +737,9 @@ def check_inputs(
730737
raise ValueError(
731738
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
732739
)
733-
elif ip_adapter_image_embeds[0].ndim != 3:
740+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
734741
raise ValueError(
735-
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
742+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
736743
)
737744

738745
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -567,15 +567,22 @@ def prepare_ip_adapter_image_embeds(
567567

568568
image_embeds.append(single_image_embeds)
569569
else:
570+
repeat_dims = [1]
570571
image_embeds = []
571572
for single_image_embeds in ip_adapter_image_embeds:
572573
if do_classifier_free_guidance:
573574
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
574-
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
575-
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
575+
single_image_embeds = single_image_embeds.repeat(
576+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
577+
)
578+
single_negative_image_embeds = single_negative_image_embeds.repeat(
579+
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
580+
)
576581
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
577582
else:
578-
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
583+
single_image_embeds = single_image_embeds.repeat(
584+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
585+
)
579586
image_embeds.append(single_image_embeds)
580587

581588
return image_embeds
@@ -794,9 +801,9 @@ def check_inputs(
794801
raise ValueError(
795802
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
796803
)
797-
elif ip_adapter_image_embeds[0].ndim != 3:
804+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
798805
raise ValueError(
799-
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
806+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
800807
)
801808

802809
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image

0 commit comments

Comments
 (0)