Skip to content

Commit 9f7b2cf

Browse files
juancopi81yiyixuxu
andauthored
Support of ip-adapter to the StableDiffusionControlNetInpaintPipeline (huggingface#5887)
* Change pipeline_controlnet_inpaint.py to add ip-adapter support. Changes are similar to those in pipeline_controlnet * Change tests for the StableDiffusionControlNetInpaintPipeline by adding image_encoder: None * Update src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py Co-authored-by: YiYi Xu <[email protected]> --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent 895c4b7 commit 9f7b2cf

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
import PIL.Image
2222
import torch
2323
import torch.nn.functional as F
24-
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
24+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
2525

2626
from ...image_processor import PipelineImageInput, VaeImageProcessor
27-
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27+
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
2828
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
2929
from ...models.lora import adjust_lora_scale_text_encoder
3030
from ...schedulers import KarrasDiffusionSchedulers
@@ -241,7 +241,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
241241

242242

243243
class StableDiffusionControlNetInpaintPipeline(
244-
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
244+
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
245245
):
246246
r"""
247247
Pipeline for image inpainting using Stable Diffusion with ControlNet guidance.
@@ -251,6 +251,7 @@ class StableDiffusionControlNetInpaintPipeline(
251251
252252
The pipeline also inherits the following loading methods:
253253
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
254+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
254255
255256
<Tip>
256257
@@ -288,7 +289,7 @@ class StableDiffusionControlNetInpaintPipeline(
288289
"""
289290

290291
model_cpu_offload_seq = "text_encoder->unet->vae"
291-
_optional_components = ["safety_checker", "feature_extractor"]
292+
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
292293
_exclude_from_cpu_offload = ["safety_checker"]
293294
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
294295

@@ -302,6 +303,7 @@ def __init__(
302303
scheduler: KarrasDiffusionSchedulers,
303304
safety_checker: StableDiffusionSafetyChecker,
304305
feature_extractor: CLIPImageProcessor,
306+
image_encoder: CLIPVisionModelWithProjection = None,
305307
requires_safety_checker: bool = True,
306308
):
307309
super().__init__()
@@ -334,6 +336,7 @@ def __init__(
334336
scheduler=scheduler,
335337
safety_checker=safety_checker,
336338
feature_extractor=feature_extractor,
339+
image_encoder=image_encoder,
337340
)
338341
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
339342
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
@@ -593,6 +596,20 @@ def encode_prompt(
593596

594597
return prompt_embeds, negative_prompt_embeds
595598

599+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
600+
def encode_image(self, image, device, num_images_per_prompt):
601+
dtype = next(self.image_encoder.parameters()).dtype
602+
603+
if not isinstance(image, torch.Tensor):
604+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
605+
606+
image = image.to(device=device, dtype=dtype)
607+
image_embeds = self.image_encoder(image).image_embeds
608+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
609+
610+
uncond_image_embeds = torch.zeros_like(image_embeds)
611+
return image_embeds, uncond_image_embeds
612+
596613
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
597614
def run_safety_checker(self, image, device, dtype):
598615
if self.safety_checker is None:
@@ -1053,6 +1070,7 @@ def __call__(
10531070
latents: Optional[torch.FloatTensor] = None,
10541071
prompt_embeds: Optional[torch.FloatTensor] = None,
10551072
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1073+
ip_adapter_image: Optional[PipelineImageInput] = None,
10561074
output_type: Optional[str] = "pil",
10571075
return_dict: bool = True,
10581076
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -1131,6 +1149,7 @@ def __call__(
11311149
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
11321150
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
11331151
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1152+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
11341153
output_type (`str`, *optional*, defaults to `"pil"`):
11351154
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
11361155
return_dict (`bool`, *optional*, defaults to `True`):
@@ -1264,6 +1283,11 @@ def __call__(
12641283
if self.do_classifier_free_guidance:
12651284
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
12661285

1286+
if ip_adapter_image is not None:
1287+
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
1288+
if self.do_classifier_free_guidance:
1289+
image_embeds = torch.cat([negative_image_embeds, image_embeds])
1290+
12671291
# 4. Prepare image
12681292
if isinstance(controlnet, ControlNetModel):
12691293
control_image = self.prepare_control_image(
@@ -1299,7 +1323,7 @@ def __call__(
12991323
else:
13001324
assert False
13011325

1302-
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
1326+
# 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width
13031327
init_image = self.image_processor.preprocess(image, height=height, width=width)
13041328
init_image = init_image.to(dtype=torch.float32)
13051329

@@ -1360,7 +1384,10 @@ def __call__(
13601384
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
13611385
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
13621386

1363-
# 7.1 Create tensor stating which controlnets to keep
1387+
# 7.1 Add image embeds for IP-Adapter
1388+
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1389+
1390+
# 7.2 Create tensor stating which controlnets to keep
13641391
controlnet_keep = []
13651392
for i in range(len(timesteps)):
13661393
keeps = [
@@ -1423,6 +1450,7 @@ def __call__(
14231450
cross_attention_kwargs=self.cross_attention_kwargs,
14241451
down_block_additional_residuals=down_block_res_samples,
14251452
mid_block_additional_residual=mid_block_res_sample,
1453+
added_cond_kwargs=added_cond_kwargs,
14261454
return_dict=False,
14271455
)[0]
14281456

tests/pipelines/controlnet/test_controlnet_inpaint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def get_dummy_components(self):
132132
"tokenizer": tokenizer,
133133
"safety_checker": None,
134134
"feature_extractor": None,
135+
"image_encoder": None,
135136
}
136137
return components
137138

@@ -248,6 +249,7 @@ def get_dummy_components(self):
248249
"tokenizer": tokenizer,
249250
"safety_checker": None,
250251
"feature_extractor": None,
252+
"image_encoder": None,
251253
}
252254
return components
253255

@@ -342,6 +344,7 @@ def init_weights(m):
342344
"tokenizer": tokenizer,
343345
"safety_checker": None,
344346
"feature_extractor": None,
347+
"image_encoder": None,
345348
}
346349
return components
347350

0 commit comments

Comments
 (0)