Skip to content

Commit a0cf607

Browse files
fabioriganosayakpaulyiyixuxu
authored
Multi-image masking for single IP Adapter (huggingface#7499)
* Support multiimage masking --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent a341b53 commit a0cf607

File tree

3 files changed

+208
-50
lines changed

3 files changed

+208
-50
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 128 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import inspect
1515
from importlib import import_module
16-
from typing import Callable, Optional, Union
16+
from typing import Callable, List, Optional, Union
1717

1818
import torch
1919
import torch.nn.functional as F
@@ -2195,42 +2195,78 @@ def __call__(
21952195
hidden_states = attn.batch_to_head_dim(hidden_states)
21962196

21972197
if ip_adapter_masks is not None:
2198-
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
2198+
if not isinstance(ip_adapter_masks, List):
2199+
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
2200+
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
2201+
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
21992202
raise ValueError(
2200-
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
2201-
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
2202-
)
2203-
if len(ip_adapter_masks) != len(self.scale):
2204-
raise ValueError(
2205-
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
2203+
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
2204+
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
2205+
f"({len(ip_hidden_states)})"
22062206
)
2207+
else:
2208+
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
2209+
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
2210+
raise ValueError(
2211+
"Each element of the ip_adapter_masks array should be a tensor with shape "
2212+
"[1, num_images_for_ip_adapter, height, width]."
2213+
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
2214+
)
2215+
if mask.shape[1] != ip_state.shape[1]:
2216+
raise ValueError(
2217+
f"Number of masks ({mask.shape[1]}) does not match "
2218+
f"number of ip images ({ip_state.shape[1]}) at index {index}"
2219+
)
2220+
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
2221+
raise ValueError(
2222+
f"Number of masks ({mask.shape[1]}) does not match "
2223+
f"number of scales ({len(scale)}) at index {index}"
2224+
)
22072225
else:
22082226
ip_adapter_masks = [None] * len(self.scale)
22092227

22102228
# for ip-adapter
22112229
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
22122230
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
22132231
):
2214-
ip_key = to_k_ip(current_ip_hidden_states)
2215-
ip_value = to_v_ip(current_ip_hidden_states)
2216-
2217-
ip_key = attn.head_to_batch_dim(ip_key)
2218-
ip_value = attn.head_to_batch_dim(ip_value)
2232+
if mask is not None:
2233+
if not isinstance(scale, list):
2234+
scale = [scale]
2235+
2236+
current_num_images = mask.shape[1]
2237+
for i in range(current_num_images):
2238+
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
2239+
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
2240+
2241+
ip_key = attn.head_to_batch_dim(ip_key)
2242+
ip_value = attn.head_to_batch_dim(ip_value)
2243+
2244+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
2245+
_current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
2246+
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
2247+
2248+
mask_downsample = IPAdapterMaskProcessor.downsample(
2249+
mask[:, i, :, :],
2250+
batch_size,
2251+
_current_ip_hidden_states.shape[1],
2252+
_current_ip_hidden_states.shape[2],
2253+
)
22192254

2220-
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
2221-
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
2222-
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
2255+
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
22232256

2224-
if mask is not None:
2225-
mask_downsample = IPAdapterMaskProcessor.downsample(
2226-
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
2227-
)
2257+
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
2258+
else:
2259+
ip_key = to_k_ip(current_ip_hidden_states)
2260+
ip_value = to_v_ip(current_ip_hidden_states)
22282261

2229-
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
2262+
ip_key = attn.head_to_batch_dim(ip_key)
2263+
ip_value = attn.head_to_batch_dim(ip_value)
22302264

2231-
current_ip_hidden_states = current_ip_hidden_states * mask_downsample
2265+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
2266+
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
2267+
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
22322268

2233-
hidden_states = hidden_states + scale * current_ip_hidden_states
2269+
hidden_states = hidden_states + scale * current_ip_hidden_states
22342270

22352271
# linear proj
22362272
hidden_states = attn.to_out[0](hidden_states)
@@ -2369,49 +2405,91 @@ def __call__(
23692405
hidden_states = hidden_states.to(query.dtype)
23702406

23712407
if ip_adapter_masks is not None:
2372-
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
2373-
raise ValueError(
2374-
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
2375-
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
2376-
)
2377-
if len(ip_adapter_masks) != len(self.scale):
2408+
if not isinstance(ip_adapter_masks, List):
2409+
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
2410+
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
2411+
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
23782412
raise ValueError(
2379-
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
2413+
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
2414+
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
2415+
f"({len(ip_hidden_states)})"
23802416
)
2417+
else:
2418+
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
2419+
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
2420+
raise ValueError(
2421+
"Each element of the ip_adapter_masks array should be a tensor with shape "
2422+
"[1, num_images_for_ip_adapter, height, width]."
2423+
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
2424+
)
2425+
if mask.shape[1] != ip_state.shape[1]:
2426+
raise ValueError(
2427+
f"Number of masks ({mask.shape[1]}) does not match "
2428+
f"number of ip images ({ip_state.shape[1]}) at index {index}"
2429+
)
2430+
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
2431+
raise ValueError(
2432+
f"Number of masks ({mask.shape[1]}) does not match "
2433+
f"number of scales ({len(scale)}) at index {index}"
2434+
)
23812435
else:
23822436
ip_adapter_masks = [None] * len(self.scale)
23832437

23842438
# for ip-adapter
23852439
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
23862440
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
23872441
):
2388-
ip_key = to_k_ip(current_ip_hidden_states)
2389-
ip_value = to_v_ip(current_ip_hidden_states)
2442+
if mask is not None:
2443+
if not isinstance(scale, list):
2444+
scale = [scale]
23902445

2391-
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2392-
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2446+
current_num_images = mask.shape[1]
2447+
for i in range(current_num_images):
2448+
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
2449+
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
23932450

2394-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2395-
# TODO: add support for attn.scale when we move to Torch 2.1
2396-
current_ip_hidden_states = F.scaled_dot_product_attention(
2397-
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2398-
)
2451+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2452+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
23992453

2400-
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
2401-
batch_size, -1, attn.heads * head_dim
2402-
)
2403-
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
2454+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2455+
# TODO: add support for attn.scale when we move to Torch 2.1
2456+
_current_ip_hidden_states = F.scaled_dot_product_attention(
2457+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2458+
)
24042459

2405-
if mask is not None:
2406-
mask_downsample = IPAdapterMaskProcessor.downsample(
2407-
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
2408-
)
2460+
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
2461+
batch_size, -1, attn.heads * head_dim
2462+
)
2463+
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
24092464

2410-
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
2465+
mask_downsample = IPAdapterMaskProcessor.downsample(
2466+
mask[:, i, :, :],
2467+
batch_size,
2468+
_current_ip_hidden_states.shape[1],
2469+
_current_ip_hidden_states.shape[2],
2470+
)
24112471

2412-
current_ip_hidden_states = current_ip_hidden_states * mask_downsample
2472+
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
2473+
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
2474+
else:
2475+
ip_key = to_k_ip(current_ip_hidden_states)
2476+
ip_value = to_v_ip(current_ip_hidden_states)
2477+
2478+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2479+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2480+
2481+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2482+
# TODO: add support for attn.scale when we move to Torch 2.1
2483+
current_ip_hidden_states = F.scaled_dot_product_attention(
2484+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2485+
)
2486+
2487+
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
2488+
batch_size, -1, attn.heads * head_dim
2489+
)
2490+
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
24132491

2414-
hidden_states = hidden_states + scale * current_ip_hidden_states
2492+
hidden_states = hidden_states + scale * current_ip_hidden_states
24152493

24162494
# linear proj
24172495
hidden_states = attn.to_out[0](hidden_states)

tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,3 +544,33 @@ def test_ip_adapter_multiple_masks(self):
544544

545545
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
546546
assert max_diff < 5e-4
547+
548+
def test_ip_adapter_multiple_masks_one_adapter(self):
549+
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
550+
pipeline = StableDiffusionXLPipeline.from_pretrained(
551+
"stabilityai/stable-diffusion-xl-base-1.0",
552+
image_encoder=image_encoder,
553+
torch_dtype=self.dtype,
554+
)
555+
pipeline.enable_model_cpu_offload()
556+
pipeline.load_ip_adapter(
557+
"h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"]
558+
)
559+
pipeline.set_ip_adapter_scale([[0.7, 0.7]])
560+
561+
inputs = self.get_dummy_inputs(for_masks=True)
562+
masks = inputs["cross_attention_kwargs"]["ip_adapter_masks"]
563+
processor = IPAdapterMaskProcessor()
564+
masks = processor.preprocess(masks)
565+
masks = masks.reshape(1, masks.shape[0], masks.shape[2], masks.shape[3])
566+
inputs["cross_attention_kwargs"]["ip_adapter_masks"] = [masks]
567+
ip_images = inputs["ip_adapter_image"]
568+
inputs["ip_adapter_image"] = [[image[0] for image in ip_images]]
569+
images = pipeline(**inputs).images
570+
image_slice = images[0, :3, :3, -1].flatten()
571+
expected_slice = np.array(
572+
[0.79474676, 0.7977683, 0.8013954, 0.7988008, 0.7970615, 0.8029355, 0.80614823, 0.8050743, 0.80627424]
573+
)
574+
575+
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
576+
assert max_diff < 5e-4

tests/pipelines/test_pipelines_common.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,11 @@ def test_pipeline_signature(self):
238238
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
239239
return torch.randn((2, 1, cross_attention_dim), device=torch_device)
240240

241+
def _get_dummy_masks(self, input_size: int = 64):
242+
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
243+
_masks[0, :, :, : int(input_size / 2)] = 1
244+
return _masks
245+
241246
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
242247
parameters = inspect.signature(self.pipeline_class.__call__).parameters
243248
if "image" in parameters.keys() and "strength" in parameters.keys():
@@ -365,6 +370,51 @@ def test_ip_adapter_cfg(self, expected_max_diff: float = 1e-4):
365370

366371
assert out_cfg.shape == out_no_cfg.shape
367372

373+
def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4):
374+
components = self.get_dummy_components()
375+
pipe = self.pipeline_class(**components).to(torch_device)
376+
pipe.set_progress_bar_config(disable=None)
377+
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
378+
sample_size = pipe.unet.config.get("sample_size", 32)
379+
block_out_channels = pipe.vae.config.get("block_out_channels", [128, 256, 512, 512])
380+
input_size = sample_size * (2 ** (len(block_out_channels) - 1))
381+
382+
# forward pass without ip adapter
383+
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
384+
output_without_adapter = pipe(**inputs)[0]
385+
output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten()
386+
387+
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
388+
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
389+
390+
# forward pass with single ip adapter and masks, but scale=0 which should have no effect
391+
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
392+
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
393+
inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]}
394+
pipe.set_ip_adapter_scale(0.0)
395+
output_without_adapter_scale = pipe(**inputs)[0]
396+
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
397+
398+
# forward pass with single ip adapter and masks, but with scale of adapter weights
399+
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
400+
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
401+
inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]}
402+
pipe.set_ip_adapter_scale(42.0)
403+
output_with_adapter_scale = pipe(**inputs)[0]
404+
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
405+
406+
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
407+
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
408+
409+
self.assertLess(
410+
max_diff_without_adapter_scale,
411+
expected_max_diff,
412+
"Output without ip-adapter must be same as normal inference",
413+
)
414+
self.assertGreater(
415+
max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
416+
)
417+
368418

369419
class PipelineLatentTesterMixin:
370420
"""

0 commit comments

Comments
 (0)