Skip to content

Commit cd91fc0

Browse files
Re-add xformers enable to UNet2DCondition (huggingface#1627)
* finish * fix * Update tests/models/test_models_unet_2d.py * style Co-authored-by: Anton Lozhkov <[email protected]>
1 parent ff65c2d commit cd91fc0

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

src/diffusers/modeling_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,39 @@ def disable_gradient_checkpointing(self):
188188
if self._supports_gradient_checkpointing:
189189
self.apply(partial(self._set_gradient_checkpointing, value=False))
190190

191+
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
192+
# Recursively walk through all the children.
193+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
194+
# gets the message
195+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
196+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
197+
module.set_use_memory_efficient_attention_xformers(valid)
198+
199+
for child in module.children():
200+
fn_recursive_set_mem_eff(child)
201+
202+
for module in self.children():
203+
if isinstance(module, torch.nn.Module):
204+
fn_recursive_set_mem_eff(module)
205+
206+
def enable_xformers_memory_efficient_attention(self):
207+
r"""
208+
Enable memory efficient attention as implemented in xformers.
209+
210+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
211+
time. Speed up at training time is not guaranteed.
212+
213+
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
214+
is used.
215+
"""
216+
self.set_use_memory_efficient_attention_xformers(True)
217+
218+
def disable_xformers_memory_efficient_attention(self):
219+
r"""
220+
Disable memory efficient attention as implemented in xformers.
221+
"""
222+
self.set_use_memory_efficient_attention_xformers(False)
223+
191224
def save_pretrained(
192225
self,
193226
save_directory: Union[str, os.PathLike],

tests/models/test_models_unet_2d.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
torch_all_close,
3131
torch_device,
3232
)
33+
from diffusers.utils.import_utils import is_xformers_available
3334
from parameterized import parameterized
3435

3536
from ..test_modeling_common import ModelTesterMixin
@@ -255,6 +256,20 @@ def prepare_init_args_and_inputs_for_common(self):
255256
inputs_dict = self.dummy_input
256257
return init_dict, inputs_dict
257258

259+
@unittest.skipIf(
260+
torch_device != "cuda" or not is_xformers_available(),
261+
reason="XFormers attention is only available with CUDA and `xformers` installed",
262+
)
263+
def test_xformers_enable_works(self):
264+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
265+
model = self.model_class(**init_dict)
266+
267+
model.enable_xformers_memory_efficient_attention()
268+
269+
assert (
270+
model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers
271+
), "xformers is not enabled"
272+
258273
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
259274
def test_gradient_checkpointing(self):
260275
# enable deterministic behavior for gradient checkpointing

0 commit comments

Comments
 (0)