Skip to content

Commit bce65cd

Browse files
[refactor] make set_attention_slice recursive (huggingface#1532)
* make attn slice recursive * remove set_attention_slice from blocks * fix copies * make enable_attention_slicing base class method of DiffusionPipeline * fix set_attention_slice * fix set_attention_slice * fix copies * add tests * up * up * up * update * up * uP Co-authored-by: Patrick von Platen <[email protected]>
1 parent e289998 commit bce65cd

20 files changed

+292
-609
lines changed

src/diffusers/models/attention.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,6 @@ def __init__(
174174
self.norm_out = nn.LayerNorm(inner_dim)
175175
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
176176

177-
def _set_attention_slice(self, slice_size):
178-
for block in self.transformer_blocks:
179-
block._set_attention_slice(slice_size)
180-
181177
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
182178
"""
183179
Args:
@@ -448,10 +444,6 @@ def __init__(
448444
f" correctly and a GPU is available: {e}"
449445
)
450446

451-
def _set_attention_slice(self, slice_size):
452-
self.attn1._slice_size = slice_size
453-
self.attn2._slice_size = slice_size
454-
455447
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
456448
if not is_xformers_available():
457449
print("Here is how to install it")
@@ -534,6 +526,7 @@ def __init__(
534526
# for slice_size > 0 the attention score computation
535527
# is split across the batch axis to save memory
536528
# You can set slice_size with `set_attention_slice`
529+
self.sliceable_head_dim = heads
537530
self._slice_size = None
538531
self._use_memory_efficient_attention_xformers = False
539532

@@ -559,6 +552,12 @@ def reshape_batch_dim_to_heads(self, tensor):
559552
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
560553
return tensor
561554

555+
def set_attention_slice(self, slice_size):
556+
if slice_size is not None and slice_size > self.sliceable_head_dim:
557+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
558+
559+
self._slice_size = slice_size
560+
562561
def forward(self, hidden_states, context=None, mask=None):
563562
batch_size, sequence_length, _ = hidden_states.shape
564563

src/diffusers/models/unet_2d_blocks.py

-53
Original file line numberDiff line numberDiff line change
@@ -401,23 +401,6 @@ def __init__(
401401
self.attentions = nn.ModuleList(attentions)
402402
self.resnets = nn.ModuleList(resnets)
403403

404-
def set_attention_slice(self, slice_size):
405-
head_dims = self.attn_num_head_channels
406-
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
407-
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
408-
raise ValueError(
409-
f"Make sure slice_size {slice_size} is a common divisor of "
410-
f"the number of heads used in cross_attention: {head_dims}"
411-
)
412-
if slice_size is not None and slice_size > min(head_dims):
413-
raise ValueError(
414-
f"slice_size {slice_size} has to be smaller or equal to "
415-
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
416-
)
417-
418-
for attn in self.attentions:
419-
attn._set_attention_slice(slice_size)
420-
421404
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
422405
hidden_states = self.resnets[0](hidden_states, temb)
423406
for attn, resnet in zip(self.attentions, self.resnets[1:]):
@@ -595,23 +578,6 @@ def __init__(
595578

596579
self.gradient_checkpointing = False
597580

598-
def set_attention_slice(self, slice_size):
599-
head_dims = self.attn_num_head_channels
600-
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
601-
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
602-
raise ValueError(
603-
f"Make sure slice_size {slice_size} is a common divisor of "
604-
f"the number of heads used in cross_attention: {head_dims}"
605-
)
606-
if slice_size is not None and slice_size > min(head_dims):
607-
raise ValueError(
608-
f"slice_size {slice_size} has to be smaller or equal to "
609-
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
610-
)
611-
612-
for attn in self.attentions:
613-
attn._set_attention_slice(slice_size)
614-
615581
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
616582
output_states = ()
617583

@@ -1190,25 +1156,6 @@ def __init__(
11901156

11911157
self.gradient_checkpointing = False
11921158

1193-
def set_attention_slice(self, slice_size):
1194-
head_dims = self.attn_num_head_channels
1195-
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
1196-
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
1197-
raise ValueError(
1198-
f"Make sure slice_size {slice_size} is a common divisor of "
1199-
f"the number of heads used in cross_attention: {head_dims}"
1200-
)
1201-
if slice_size is not None and slice_size > min(head_dims):
1202-
raise ValueError(
1203-
f"slice_size {slice_size} has to be smaller or equal to "
1204-
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
1205-
)
1206-
1207-
for attn in self.attentions:
1208-
attn._set_attention_slice(slice_size)
1209-
1210-
self.gradient_checkpointing = False
1211-
12121159
def forward(
12131160
self,
12141161
hidden_states,

src/diffusers/models/unet_2d_condition.py

+61-20
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import dataclass
15-
from typing import Optional, Tuple, Union
15+
from typing import List, Optional, Tuple, Union
1616

1717
import torch
1818
import torch.nn as nn
@@ -229,28 +229,69 @@ def __init__(
229229
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
230230

231231
def set_attention_slice(self, slice_size):
232-
head_dims = self.config.attention_head_dim
233-
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
234-
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
235-
raise ValueError(
236-
f"Make sure slice_size {slice_size} is a common divisor of "
237-
f"the number of heads used in cross_attention: {head_dims}"
238-
)
239-
if slice_size is not None and slice_size > min(head_dims):
240-
raise ValueError(
241-
f"slice_size {slice_size} has to be smaller or equal to "
242-
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
243-
)
232+
r"""
233+
Enable sliced attention computation.
234+
235+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
236+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
237+
238+
Args:
239+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
240+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
241+
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
242+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
243+
must be a multiple of `slice_size`.
244+
"""
245+
sliceable_head_dims = []
246+
247+
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
248+
if hasattr(module, "set_attention_slice"):
249+
sliceable_head_dims.append(module.sliceable_head_dim)
250+
251+
for child in module.children():
252+
fn_recursive_retrieve_slicable_dims(child)
253+
254+
# retrieve number of attention layers
255+
for module in self.children():
256+
fn_recursive_retrieve_slicable_dims(module)
244257

245-
for block in self.down_blocks:
246-
if hasattr(block, "attentions") and block.attentions is not None:
247-
block.set_attention_slice(slice_size)
258+
num_slicable_layers = len(sliceable_head_dims)
248259

249-
self.mid_block.set_attention_slice(slice_size)
260+
if slice_size == "auto":
261+
# half the attention head size is usually a good trade-off between
262+
# speed and memory
263+
slice_size = [dim // 2 for dim in sliceable_head_dims]
264+
elif slice_size == "max":
265+
# make smallest slice possible
266+
slice_size = num_slicable_layers * [1]
267+
268+
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
269+
270+
if len(slice_size) != len(sliceable_head_dims):
271+
raise ValueError(
272+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
273+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
274+
)
250275

251-
for block in self.up_blocks:
252-
if hasattr(block, "attentions") and block.attentions is not None:
253-
block.set_attention_slice(slice_size)
276+
for i in range(len(slice_size)):
277+
size = slice_size[i]
278+
dim = sliceable_head_dims[i]
279+
if size is not None and size > dim:
280+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
281+
282+
# Recursively walk through all the children.
283+
# Any children which exposes the set_attention_slice method
284+
# gets the message
285+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
286+
if hasattr(module, "set_attention_slice"):
287+
module.set_attention_slice(slice_size.pop())
288+
289+
for child in module.children():
290+
fn_recursive_set_attention_slice(child, slice_size)
291+
292+
reversed_slice_size = list(reversed(slice_size))
293+
for module in self.children():
294+
fn_recursive_set_attention_slice(module, reversed_slice_size)
254295

255296
def _set_gradient_checkpointing(self, module, value=False):
256297
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):

src/diffusers/pipeline_utils.py

+31
Original file line numberDiff line numberDiff line change
@@ -839,3 +839,34 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
839839
module = getattr(self, module_name)
840840
if isinstance(module, torch.nn.Module):
841841
fn_recursive_set_mem_eff(module)
842+
843+
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
844+
r"""
845+
Enable sliced attention computation.
846+
847+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
848+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
849+
850+
Args:
851+
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
852+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
853+
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
854+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
855+
must be a multiple of `slice_size`.
856+
"""
857+
self.set_attention_slice(slice_size)
858+
859+
def disable_attention_slicing(self):
860+
r"""
861+
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
862+
back to computing attention in one step.
863+
"""
864+
# set slice_size = `None` to disable `attention slicing`
865+
self.enable_attention_slicing(None)
866+
867+
def set_attention_slice(self, slice_size: Optional[int]):
868+
module_names, _, _ = self.extract_init_dict(dict(self.config))
869+
for module_name in module_names:
870+
module = getattr(self, module_name)
871+
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):
872+
module.set_attention_slice(slice_size)

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

-32
Original file line numberDiff line numberDiff line change
@@ -166,38 +166,6 @@ def __init__(
166166
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
167167
self.register_to_config(requires_safety_checker=requires_safety_checker)
168168

169-
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
170-
r"""
171-
Enable sliced attention computation.
172-
173-
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
174-
in several steps. This is useful to save some memory in exchange for a small speed decrease.
175-
176-
Args:
177-
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
178-
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
179-
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
180-
`attention_head_dim` must be a multiple of `slice_size`.
181-
"""
182-
if slice_size == "auto":
183-
if isinstance(self.unet.config.attention_head_dim, int):
184-
# half the attention head size is usually a good trade-off between
185-
# speed and memory
186-
slice_size = self.unet.config.attention_head_dim // 2
187-
else:
188-
# if `attention_head_dim` is a list, take the smallest head size
189-
slice_size = min(self.unet.config.attention_head_dim)
190-
191-
self.unet.set_attention_slice(slice_size)
192-
193-
def disable_attention_slicing(self):
194-
r"""
195-
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
196-
back to computing attention in one step.
197-
"""
198-
# set slice_size = `None` to disable `attention slicing`
199-
self.enable_attention_slicing(None)
200-
201169
def enable_vae_slicing(self):
202170
r"""
203171
Enable sliced VAE decoding.

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

-32
Original file line numberDiff line numberDiff line change
@@ -179,38 +179,6 @@ def __init__(
179179
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
180180
self.register_to_config(requires_safety_checker=requires_safety_checker)
181181

182-
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
183-
r"""
184-
Enable sliced attention computation.
185-
186-
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
187-
in several steps. This is useful to save some memory in exchange for a small speed decrease.
188-
189-
Args:
190-
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
191-
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
192-
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
193-
`attention_head_dim` must be a multiple of `slice_size`.
194-
"""
195-
if slice_size == "auto":
196-
if isinstance(self.unet.config.attention_head_dim, int):
197-
# half the attention head size is usually a good trade-off between
198-
# speed and memory
199-
slice_size = self.unet.config.attention_head_dim // 2
200-
else:
201-
# if `attention_head_dim` is a list, take the smallest head size
202-
slice_size = min(self.unet.config.attention_head_dim)
203-
204-
self.unet.set_attention_slice(slice_size)
205-
206-
def disable_attention_slicing(self):
207-
r"""
208-
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
209-
back to computing attention in one step.
210-
"""
211-
# set slice_size = `None` to disable `attention slicing`
212-
self.enable_attention_slicing(None)
213-
214182
def enable_sequential_cpu_offload(self, gpu_id=0):
215183
r"""
216184
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,

src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py

-34
Original file line numberDiff line numberDiff line change
@@ -209,40 +209,6 @@ def __init__(
209209
)
210210
self.register_to_config(requires_safety_checker=requires_safety_checker)
211211

212-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
213-
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
214-
r"""
215-
Enable sliced attention computation.
216-
217-
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
218-
in several steps. This is useful to save some memory in exchange for a small speed decrease.
219-
220-
Args:
221-
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
222-
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
223-
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
224-
`attention_head_dim` must be a multiple of `slice_size`.
225-
"""
226-
if slice_size == "auto":
227-
if isinstance(self.unet.config.attention_head_dim, int):
228-
# half the attention head size is usually a good trade-off between
229-
# speed and memory
230-
slice_size = self.unet.config.attention_head_dim // 2
231-
else:
232-
# if `attention_head_dim` is a list, take the smallest head size
233-
slice_size = min(self.unet.config.attention_head_dim)
234-
235-
self.unet.set_attention_slice(slice_size)
236-
237-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
238-
def disable_attention_slicing(self):
239-
r"""
240-
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
241-
back to computing attention in one step.
242-
"""
243-
# set slice_size = `None` to disable `attention slicing`
244-
self.enable_attention_slicing(None)
245-
246212
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
247213
def enable_sequential_cpu_offload(self, gpu_id=0):
248214
r"""

0 commit comments

Comments
 (0)