|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | from dataclasses import dataclass
|
15 |
| -from typing import Optional, Tuple, Union |
| 15 | +from typing import List, Optional, Tuple, Union |
16 | 16 |
|
17 | 17 | import torch
|
18 | 18 | import torch.nn as nn
|
@@ -229,28 +229,69 @@ def __init__(
|
229 | 229 | self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
230 | 230 |
|
231 | 231 | def set_attention_slice(self, slice_size):
|
232 |
| - head_dims = self.config.attention_head_dim |
233 |
| - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims |
234 |
| - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): |
235 |
| - raise ValueError( |
236 |
| - f"Make sure slice_size {slice_size} is a common divisor of " |
237 |
| - f"the number of heads used in cross_attention: {head_dims}" |
238 |
| - ) |
239 |
| - if slice_size is not None and slice_size > min(head_dims): |
240 |
| - raise ValueError( |
241 |
| - f"slice_size {slice_size} has to be smaller or equal to " |
242 |
| - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" |
243 |
| - ) |
| 232 | + r""" |
| 233 | + Enable sliced attention computation. |
| 234 | +
|
| 235 | + When this option is enabled, the attention module will split the input tensor in slices, to compute attention |
| 236 | + in several steps. This is useful to save some memory in exchange for a small speed decrease. |
| 237 | +
|
| 238 | + Args: |
| 239 | + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): |
| 240 | + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If |
| 241 | + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is |
| 242 | + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` |
| 243 | + must be a multiple of `slice_size`. |
| 244 | + """ |
| 245 | + sliceable_head_dims = [] |
| 246 | + |
| 247 | + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): |
| 248 | + if hasattr(module, "set_attention_slice"): |
| 249 | + sliceable_head_dims.append(module.sliceable_head_dim) |
| 250 | + |
| 251 | + for child in module.children(): |
| 252 | + fn_recursive_retrieve_slicable_dims(child) |
| 253 | + |
| 254 | + # retrieve number of attention layers |
| 255 | + for module in self.children(): |
| 256 | + fn_recursive_retrieve_slicable_dims(module) |
244 | 257 |
|
245 |
| - for block in self.down_blocks: |
246 |
| - if hasattr(block, "attentions") and block.attentions is not None: |
247 |
| - block.set_attention_slice(slice_size) |
| 258 | + num_slicable_layers = len(sliceable_head_dims) |
248 | 259 |
|
249 |
| - self.mid_block.set_attention_slice(slice_size) |
| 260 | + if slice_size == "auto": |
| 261 | + # half the attention head size is usually a good trade-off between |
| 262 | + # speed and memory |
| 263 | + slice_size = [dim // 2 for dim in sliceable_head_dims] |
| 264 | + elif slice_size == "max": |
| 265 | + # make smallest slice possible |
| 266 | + slice_size = num_slicable_layers * [1] |
| 267 | + |
| 268 | + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size |
| 269 | + |
| 270 | + if len(slice_size) != len(sliceable_head_dims): |
| 271 | + raise ValueError( |
| 272 | + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" |
| 273 | + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." |
| 274 | + ) |
250 | 275 |
|
251 |
| - for block in self.up_blocks: |
252 |
| - if hasattr(block, "attentions") and block.attentions is not None: |
253 |
| - block.set_attention_slice(slice_size) |
| 276 | + for i in range(len(slice_size)): |
| 277 | + size = slice_size[i] |
| 278 | + dim = sliceable_head_dims[i] |
| 279 | + if size is not None and size > dim: |
| 280 | + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") |
| 281 | + |
| 282 | + # Recursively walk through all the children. |
| 283 | + # Any children which exposes the set_attention_slice method |
| 284 | + # gets the message |
| 285 | + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): |
| 286 | + if hasattr(module, "set_attention_slice"): |
| 287 | + module.set_attention_slice(slice_size.pop()) |
| 288 | + |
| 289 | + for child in module.children(): |
| 290 | + fn_recursive_set_attention_slice(child, slice_size) |
| 291 | + |
| 292 | + reversed_slice_size = list(reversed(slice_size)) |
| 293 | + for module in self.children(): |
| 294 | + fn_recursive_set_attention_slice(module, reversed_slice_size) |
254 | 295 |
|
255 | 296 | def _set_gradient_checkpointing(self, module, value=False):
|
256 | 297 | if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
|
|
0 commit comments