Skip to content

Commit a816a87

Browse files
blefaudeuxBenjamin Lefaudeux
and
Benjamin Lefaudeux
authored
[refactor] Making the xformers mem-efficient attention activation recursive (huggingface#1493)
* Moving the mem efficiient attention activation to the top + recursive * black, too bad there's no pre-commit ? Co-authored-by: Benjamin Lefaudeux <[email protected]>
1 parent f21415d commit a816a87

21 files changed

+37
-366
lines changed

examples/community/lpw_stable_diffusion.py

-18
Original file line numberDiff line numberDiff line change
@@ -488,24 +488,6 @@ def __init__(
488488
feature_extractor=feature_extractor,
489489
)
490490

491-
def enable_xformers_memory_efficient_attention(self):
492-
r"""
493-
Enable memory efficient attention as implemented in xformers.
494-
495-
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
496-
time. Speed up at training time is not guaranteed.
497-
498-
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
499-
is used.
500-
"""
501-
self.unet.set_use_memory_efficient_attention_xformers(True)
502-
503-
def disable_xformers_memory_efficient_attention(self):
504-
r"""
505-
Disable memory efficient attention as implemented in xformers.
506-
"""
507-
self.unet.set_use_memory_efficient_attention_xformers(False)
508-
509491
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
510492
r"""
511493
Enable sliced attention computation.

examples/community/sd_text2img_k_diffusion.py

-18
Original file line numberDiff line numberDiff line change
@@ -106,24 +106,6 @@ def set_sampler(self, scheduler_type: str):
106106
sampling = getattr(library, "sampling")
107107
self.sampler = getattr(sampling, scheduler_type)
108108

109-
def enable_xformers_memory_efficient_attention(self):
110-
r"""
111-
Enable memory efficient attention as implemented in xformers.
112-
113-
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
114-
time. Speed up at training time is not guaranteed.
115-
116-
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
117-
is used.
118-
"""
119-
self.unet.set_use_memory_efficient_attention_xformers(True)
120-
121-
def disable_xformers_memory_efficient_attention(self):
122-
r"""
123-
Disable memory efficient attention as implemented in xformers.
124-
"""
125-
self.unet.set_use_memory_efficient_attention_xformers(False)
126-
127109
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
128110
r"""
129111
Enable sliced attention computation.

examples/community/text_inpainting.py

-18
Original file line numberDiff line numberDiff line change
@@ -183,24 +183,6 @@ def _execution_device(self):
183183
return torch.device(module._hf_hook.execution_device)
184184
return self.device
185185

186-
def enable_xformers_memory_efficient_attention(self):
187-
r"""
188-
Enable memory efficient attention as implemented in xformers.
189-
190-
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
191-
time. Speed up at training time is not guaranteed.
192-
193-
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
194-
is used.
195-
"""
196-
self.unet.set_use_memory_efficient_attention_xformers(True)
197-
198-
def disable_xformers_memory_efficient_attention(self):
199-
r"""
200-
Disable memory efficient attention as implemented in xformers.
201-
"""
202-
self.unet.set_use_memory_efficient_attention_xformers(False)
203-
204186
@torch.no_grad()
205187
def __call__(
206188
self,

src/diffusers/models/attention.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,6 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, retu
246246

247247
return Transformer2DModelOutput(sample=output)
248248

249-
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
250-
for block in self.transformer_blocks:
251-
block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
252-
253249

254250
class AttentionBlock(nn.Module):
255251
"""
@@ -414,7 +410,7 @@ def __init__(
414410
# if xformers is installed try to use memory_efficient_attention by default
415411
if is_xformers_available():
416412
try:
417-
self._set_use_memory_efficient_attention_xformers(True)
413+
self.set_use_memory_efficient_attention_xformers(True)
418414
except Exception as e:
419415
warnings.warn(
420416
"Could not enable memory efficient attention. Make sure xformers is installed"
@@ -425,7 +421,7 @@ def _set_attention_slice(self, slice_size):
425421
self.attn1._slice_size = slice_size
426422
self.attn2._slice_size = slice_size
427423

428-
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
424+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
429425
if not is_xformers_available():
430426
print("Here is how to install it")
431427
raise ModuleNotFoundError(
@@ -835,11 +831,3 @@ def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_di
835831
return (output_states,)
836832

837833
return Transformer2DModelOutput(sample=output_states)
838-
839-
def _set_attention_slice(self, slice_size):
840-
for transformer in self.transformers:
841-
transformer._set_attention_slice(slice_size)
842-
843-
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
844-
for transformer in self.transformers:
845-
transformer._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

src/diffusers/models/unet_2d_blocks.py

-12
Original file line numberDiff line numberDiff line change
@@ -418,10 +418,6 @@ def set_attention_slice(self, slice_size):
418418
for attn in self.attentions:
419419
attn._set_attention_slice(slice_size)
420420

421-
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
422-
for attn in self.attentions:
423-
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
424-
425421
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
426422
hidden_states = self.resnets[0](hidden_states, temb)
427423
for attn, resnet in zip(self.attentions, self.resnets[1:]):
@@ -616,10 +612,6 @@ def set_attention_slice(self, slice_size):
616612
for attn in self.attentions:
617613
attn._set_attention_slice(slice_size)
618614

619-
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
620-
for attn in self.attentions:
621-
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
622-
623615
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
624616
output_states = ()
625617

@@ -1217,10 +1209,6 @@ def set_attention_slice(self, slice_size):
12171209

12181210
self.gradient_checkpointing = False
12191211

1220-
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
1221-
for attn in self.attentions:
1222-
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
1223-
12241212
def forward(
12251213
self,
12261214
hidden_states,

src/diffusers/models/unet_2d_condition.py

-11
Original file line numberDiff line numberDiff line change
@@ -252,17 +252,6 @@ def set_attention_slice(self, slice_size):
252252
if hasattr(block, "attentions") and block.attentions is not None:
253253
block.set_attention_slice(slice_size)
254254

255-
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
256-
for block in self.down_blocks:
257-
if hasattr(block, "attentions") and block.attentions is not None:
258-
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
259-
260-
self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
261-
262-
for block in self.up_blocks:
263-
if hasattr(block, "attentions") and block.attentions is not None:
264-
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
265-
266255
def _set_gradient_checkpointing(self, module, value=False):
267256
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
268257
module.gradient_checkpointing = value

src/diffusers/pipeline_utils.py

+35
Original file line numberDiff line numberDiff line change
@@ -789,3 +789,38 @@ def progress_bar(self, iterable=None, total=None):
789789

790790
def set_progress_bar_config(self, **kwargs):
791791
self._progress_bar_config = kwargs
792+
793+
def enable_xformers_memory_efficient_attention(self):
794+
r"""
795+
Enable memory efficient attention as implemented in xformers.
796+
797+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
798+
time. Speed up at training time is not guaranteed.
799+
800+
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
801+
is used.
802+
"""
803+
self.set_use_memory_efficient_attention_xformers(True)
804+
805+
def disable_xformers_memory_efficient_attention(self):
806+
r"""
807+
Disable memory efficient attention as implemented in xformers.
808+
"""
809+
self.set_use_memory_efficient_attention_xformers(False)
810+
811+
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
812+
# Recursively walk through all the children.
813+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
814+
# gets the message
815+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
816+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
817+
module.set_use_memory_efficient_attention_xformers(valid)
818+
819+
for child in module.children():
820+
fn_recursive_set_mem_eff(child)
821+
822+
module_names, _, _ = self.extract_init_dict(dict(self.config))
823+
for module_name in module_names:
824+
module = getattr(self, module_name)
825+
if isinstance(module, torch.nn.Module):
826+
fn_recursive_set_mem_eff(module)

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

-18
Original file line numberDiff line numberDiff line change
@@ -166,24 +166,6 @@ def __init__(
166166
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
167167
self.register_to_config(requires_safety_checker=requires_safety_checker)
168168

169-
def enable_xformers_memory_efficient_attention(self):
170-
r"""
171-
Enable memory efficient attention as implemented in xformers.
172-
173-
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
174-
time. Speed up at training time is not guaranteed.
175-
176-
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
177-
is used.
178-
"""
179-
self.unet.set_use_memory_efficient_attention_xformers(True)
180-
181-
def disable_xformers_memory_efficient_attention(self):
182-
r"""
183-
Disable memory efficient attention as implemented in xformers.
184-
"""
185-
self.unet.set_use_memory_efficient_attention_xformers(False)
186-
187169
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
188170
r"""
189171
Enable sliced attention computation.

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

-18
Original file line numberDiff line numberDiff line change
@@ -251,24 +251,6 @@ def _execution_device(self):
251251
return torch.device(module._hf_hook.execution_device)
252252
return self.device
253253

254-
def enable_xformers_memory_efficient_attention(self):
255-
r"""
256-
Enable memory efficient attention as implemented in xformers.
257-
258-
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
259-
time. Speed up at training time is not guaranteed.
260-
261-
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
262-
is used.
263-
"""
264-
self.unet.set_use_memory_efficient_attention_xformers(True)
265-
266-
def disable_xformers_memory_efficient_attention(self):
267-
r"""
268-
Disable memory efficient attention as implemented in xformers.
269-
"""
270-
self.unet.set_use_memory_efficient_attention_xformers(False)
271-
272254
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
273255
r"""
274256
Encodes the prompt into text encoder hidden states.

src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py

-20
Original file line numberDiff line numberDiff line change
@@ -285,26 +285,6 @@ def _execution_device(self):
285285
return torch.device(module._hf_hook.execution_device)
286286
return self.device
287287

288-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
289-
def enable_xformers_memory_efficient_attention(self):
290-
r"""
291-
Enable memory efficient attention as implemented in xformers.
292-
293-
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
294-
time. Speed up at training time is not guaranteed.
295-
296-
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
297-
is used.
298-
"""
299-
self.unet.set_use_memory_efficient_attention_xformers(True)
300-
301-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
302-
def disable_xformers_memory_efficient_attention(self):
303-
r"""
304-
Disable memory efficient attention as implemented in xformers.
305-
"""
306-
self.unet.set_use_memory_efficient_attention_xformers(False)
307-
308288
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
309289
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
310290
r"""

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

-18
Original file line numberDiff line numberDiff line change
@@ -165,24 +165,6 @@ def __init__(
165165
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
166166
self.register_to_config(requires_safety_checker=requires_safety_checker)
167167

168-
def enable_xformers_memory_efficient_attention(self):
169-
r"""
170-
Enable memory efficient attention as implemented in xformers.
171-
172-
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
173-
time. Speed up at training time is not guaranteed.
174-
175-
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
176-
is used.
177-
"""
178-
self.unet.set_use_memory_efficient_attention_xformers(True)
179-
180-
def disable_xformers_memory_efficient_attention(self):
181-
r"""
182-
Disable memory efficient attention as implemented in xformers.
183-
"""
184-
self.unet.set_use_memory_efficient_attention_xformers(False)
185-
186168
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
187169
r"""
188170
Enable sliced attention computation.

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py

-20
Original file line numberDiff line numberDiff line change
@@ -134,26 +134,6 @@ def __init__(
134134
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
135135
self.register_to_config(requires_safety_checker=requires_safety_checker)
136136

137-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
138-
def enable_xformers_memory_efficient_attention(self):
139-
r"""
140-
Enable memory efficient attention as implemented in xformers.
141-
142-
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
143-
time. Speed up at training time is not guaranteed.
144-
145-
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
146-
is used.
147-
"""
148-
self.unet.set_use_memory_efficient_attention_xformers(True)
149-
150-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
151-
def disable_xformers_memory_efficient_attention(self):
152-
r"""
153-
Disable memory efficient attention as implemented in xformers.
154-
"""
155-
self.unet.set_use_memory_efficient_attention_xformers(False)
156-
157137
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
158138
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
159139
r"""

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

-20
Original file line numberDiff line numberDiff line change
@@ -254,26 +254,6 @@ def _execution_device(self):
254254
return torch.device(module._hf_hook.execution_device)
255255
return self.device
256256

257-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
258-
def enable_xformers_memory_efficient_attention(self):
259-
r"""
260-
Enable memory efficient attention as implemented in xformers.
261-
262-
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
263-
time. Speed up at training time is not guaranteed.
264-
265-
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
266-
is used.
267-
"""
268-
self.unet.set_use_memory_efficient_attention_xformers(True)
269-
270-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
271-
def disable_xformers_memory_efficient_attention(self):
272-
r"""
273-
Disable memory efficient attention as implemented in xformers.
274-
"""
275-
self.unet.set_use_memory_efficient_attention_xformers(False)
276-
277257
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
278258
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
279259
r"""

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

-20
Original file line numberDiff line numberDiff line change
@@ -300,26 +300,6 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
300300
# fix by only offloading self.safety_checker for now
301301
cpu_offload(self.safety_checker.vision_model, device)
302302

303-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
304-
def enable_xformers_memory_efficient_attention(self):
305-
r"""
306-
Enable memory efficient attention as implemented in xformers.
307-
308-
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
309-
time. Speed up at training time is not guaranteed.
310-
311-
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
312-
is used.
313-
"""
314-
self.unet.set_use_memory_efficient_attention_xformers(True)
315-
316-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
317-
def disable_xformers_memory_efficient_attention(self):
318-
r"""
319-
Disable memory efficient attention as implemented in xformers.
320-
"""
321-
self.unet.set_use_memory_efficient_attention_xformers(False)
322-
323303
@property
324304
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
325305
def _execution_device(self):

0 commit comments

Comments
 (0)