|
13 | 13 | # limitations under the License. |
14 | 14 | import inspect |
15 | 15 | from importlib import import_module |
16 | | -from typing import Callable, Optional, Union |
| 16 | +from typing import Callable, List, Optional, Union |
17 | 17 |
|
18 | 18 | import torch |
19 | 19 | import torch.nn.functional as F |
@@ -2195,42 +2195,78 @@ def __call__( |
2195 | 2195 | hidden_states = attn.batch_to_head_dim(hidden_states) |
2196 | 2196 |
|
2197 | 2197 | if ip_adapter_masks is not None: |
2198 | | - if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: |
| 2198 | + if not isinstance(ip_adapter_masks, List): |
| 2199 | + # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] |
| 2200 | + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) |
| 2201 | + if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): |
2199 | 2202 | raise ValueError( |
2200 | | - " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." |
2201 | | - " Please use `IPAdapterMaskProcessor` to preprocess your mask" |
2202 | | - ) |
2203 | | - if len(ip_adapter_masks) != len(self.scale): |
2204 | | - raise ValueError( |
2205 | | - f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" |
| 2203 | + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " |
| 2204 | + f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " |
| 2205 | + f"({len(ip_hidden_states)})" |
2206 | 2206 | ) |
| 2207 | + else: |
| 2208 | + for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): |
| 2209 | + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: |
| 2210 | + raise ValueError( |
| 2211 | + "Each element of the ip_adapter_masks array should be a tensor with shape " |
| 2212 | + "[1, num_images_for_ip_adapter, height, width]." |
| 2213 | + " Please use `IPAdapterMaskProcessor` to preprocess your mask" |
| 2214 | + ) |
| 2215 | + if mask.shape[1] != ip_state.shape[1]: |
| 2216 | + raise ValueError( |
| 2217 | + f"Number of masks ({mask.shape[1]}) does not match " |
| 2218 | + f"number of ip images ({ip_state.shape[1]}) at index {index}" |
| 2219 | + ) |
| 2220 | + if isinstance(scale, list) and not len(scale) == mask.shape[1]: |
| 2221 | + raise ValueError( |
| 2222 | + f"Number of masks ({mask.shape[1]}) does not match " |
| 2223 | + f"number of scales ({len(scale)}) at index {index}" |
| 2224 | + ) |
2207 | 2225 | else: |
2208 | 2226 | ip_adapter_masks = [None] * len(self.scale) |
2209 | 2227 |
|
2210 | 2228 | # for ip-adapter |
2211 | 2229 | for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( |
2212 | 2230 | ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks |
2213 | 2231 | ): |
2214 | | - ip_key = to_k_ip(current_ip_hidden_states) |
2215 | | - ip_value = to_v_ip(current_ip_hidden_states) |
2216 | | - |
2217 | | - ip_key = attn.head_to_batch_dim(ip_key) |
2218 | | - ip_value = attn.head_to_batch_dim(ip_value) |
| 2232 | + if mask is not None: |
| 2233 | + if not isinstance(scale, list): |
| 2234 | + scale = [scale] |
| 2235 | + |
| 2236 | + current_num_images = mask.shape[1] |
| 2237 | + for i in range(current_num_images): |
| 2238 | + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) |
| 2239 | + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) |
| 2240 | + |
| 2241 | + ip_key = attn.head_to_batch_dim(ip_key) |
| 2242 | + ip_value = attn.head_to_batch_dim(ip_value) |
| 2243 | + |
| 2244 | + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) |
| 2245 | + _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) |
| 2246 | + _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) |
| 2247 | + |
| 2248 | + mask_downsample = IPAdapterMaskProcessor.downsample( |
| 2249 | + mask[:, i, :, :], |
| 2250 | + batch_size, |
| 2251 | + _current_ip_hidden_states.shape[1], |
| 2252 | + _current_ip_hidden_states.shape[2], |
| 2253 | + ) |
2219 | 2254 |
|
2220 | | - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) |
2221 | | - current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) |
2222 | | - current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) |
| 2255 | + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) |
2223 | 2256 |
|
2224 | | - if mask is not None: |
2225 | | - mask_downsample = IPAdapterMaskProcessor.downsample( |
2226 | | - mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] |
2227 | | - ) |
| 2257 | + hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) |
| 2258 | + else: |
| 2259 | + ip_key = to_k_ip(current_ip_hidden_states) |
| 2260 | + ip_value = to_v_ip(current_ip_hidden_states) |
2228 | 2261 |
|
2229 | | - mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) |
| 2262 | + ip_key = attn.head_to_batch_dim(ip_key) |
| 2263 | + ip_value = attn.head_to_batch_dim(ip_value) |
2230 | 2264 |
|
2231 | | - current_ip_hidden_states = current_ip_hidden_states * mask_downsample |
| 2265 | + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) |
| 2266 | + current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) |
| 2267 | + current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) |
2232 | 2268 |
|
2233 | | - hidden_states = hidden_states + scale * current_ip_hidden_states |
| 2269 | + hidden_states = hidden_states + scale * current_ip_hidden_states |
2234 | 2270 |
|
2235 | 2271 | # linear proj |
2236 | 2272 | hidden_states = attn.to_out[0](hidden_states) |
@@ -2369,49 +2405,91 @@ def __call__( |
2369 | 2405 | hidden_states = hidden_states.to(query.dtype) |
2370 | 2406 |
|
2371 | 2407 | if ip_adapter_masks is not None: |
2372 | | - if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: |
2373 | | - raise ValueError( |
2374 | | - " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." |
2375 | | - " Please use `IPAdapterMaskProcessor` to preprocess your mask" |
2376 | | - ) |
2377 | | - if len(ip_adapter_masks) != len(self.scale): |
| 2408 | + if not isinstance(ip_adapter_masks, List): |
| 2409 | + # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] |
| 2410 | + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) |
| 2411 | + if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): |
2378 | 2412 | raise ValueError( |
2379 | | - f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" |
| 2413 | + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " |
| 2414 | + f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " |
| 2415 | + f"({len(ip_hidden_states)})" |
2380 | 2416 | ) |
| 2417 | + else: |
| 2418 | + for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): |
| 2419 | + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: |
| 2420 | + raise ValueError( |
| 2421 | + "Each element of the ip_adapter_masks array should be a tensor with shape " |
| 2422 | + "[1, num_images_for_ip_adapter, height, width]." |
| 2423 | + " Please use `IPAdapterMaskProcessor` to preprocess your mask" |
| 2424 | + ) |
| 2425 | + if mask.shape[1] != ip_state.shape[1]: |
| 2426 | + raise ValueError( |
| 2427 | + f"Number of masks ({mask.shape[1]}) does not match " |
| 2428 | + f"number of ip images ({ip_state.shape[1]}) at index {index}" |
| 2429 | + ) |
| 2430 | + if isinstance(scale, list) and not len(scale) == mask.shape[1]: |
| 2431 | + raise ValueError( |
| 2432 | + f"Number of masks ({mask.shape[1]}) does not match " |
| 2433 | + f"number of scales ({len(scale)}) at index {index}" |
| 2434 | + ) |
2381 | 2435 | else: |
2382 | 2436 | ip_adapter_masks = [None] * len(self.scale) |
2383 | 2437 |
|
2384 | 2438 | # for ip-adapter |
2385 | 2439 | for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( |
2386 | 2440 | ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks |
2387 | 2441 | ): |
2388 | | - ip_key = to_k_ip(current_ip_hidden_states) |
2389 | | - ip_value = to_v_ip(current_ip_hidden_states) |
| 2442 | + if mask is not None: |
| 2443 | + if not isinstance(scale, list): |
| 2444 | + scale = [scale] |
2390 | 2445 |
|
2391 | | - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
2392 | | - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| 2446 | + current_num_images = mask.shape[1] |
| 2447 | + for i in range(current_num_images): |
| 2448 | + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) |
| 2449 | + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) |
2393 | 2450 |
|
2394 | | - # the output of sdp = (batch, num_heads, seq_len, head_dim) |
2395 | | - # TODO: add support for attn.scale when we move to Torch 2.1 |
2396 | | - current_ip_hidden_states = F.scaled_dot_product_attention( |
2397 | | - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False |
2398 | | - ) |
| 2451 | + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| 2452 | + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
2399 | 2453 |
|
2400 | | - current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( |
2401 | | - batch_size, -1, attn.heads * head_dim |
2402 | | - ) |
2403 | | - current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) |
| 2454 | + # the output of sdp = (batch, num_heads, seq_len, head_dim) |
| 2455 | + # TODO: add support for attn.scale when we move to Torch 2.1 |
| 2456 | + _current_ip_hidden_states = F.scaled_dot_product_attention( |
| 2457 | + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False |
| 2458 | + ) |
2404 | 2459 |
|
2405 | | - if mask is not None: |
2406 | | - mask_downsample = IPAdapterMaskProcessor.downsample( |
2407 | | - mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] |
2408 | | - ) |
| 2460 | + _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape( |
| 2461 | + batch_size, -1, attn.heads * head_dim |
| 2462 | + ) |
| 2463 | + _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) |
2409 | 2464 |
|
2410 | | - mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) |
| 2465 | + mask_downsample = IPAdapterMaskProcessor.downsample( |
| 2466 | + mask[:, i, :, :], |
| 2467 | + batch_size, |
| 2468 | + _current_ip_hidden_states.shape[1], |
| 2469 | + _current_ip_hidden_states.shape[2], |
| 2470 | + ) |
2411 | 2471 |
|
2412 | | - current_ip_hidden_states = current_ip_hidden_states * mask_downsample |
| 2472 | + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) |
| 2473 | + hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) |
| 2474 | + else: |
| 2475 | + ip_key = to_k_ip(current_ip_hidden_states) |
| 2476 | + ip_value = to_v_ip(current_ip_hidden_states) |
| 2477 | + |
| 2478 | + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| 2479 | + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| 2480 | + |
| 2481 | + # the output of sdp = (batch, num_heads, seq_len, head_dim) |
| 2482 | + # TODO: add support for attn.scale when we move to Torch 2.1 |
| 2483 | + current_ip_hidden_states = F.scaled_dot_product_attention( |
| 2484 | + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False |
| 2485 | + ) |
| 2486 | + |
| 2487 | + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( |
| 2488 | + batch_size, -1, attn.heads * head_dim |
| 2489 | + ) |
| 2490 | + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) |
2413 | 2491 |
|
2414 | | - hidden_states = hidden_states + scale * current_ip_hidden_states |
| 2492 | + hidden_states = hidden_states + scale * current_ip_hidden_states |
2415 | 2493 |
|
2416 | 2494 | # linear proj |
2417 | 2495 | hidden_states = attn.to_out[0](hidden_states) |
|
0 commit comments