Skip to content

Commit 8e4733b

Browse files
authored
Only test for xformers when enabling them huggingface#1773 (huggingface#1776)
* only check for xformers when xformers are enabled * only test for xformers when enabling them
1 parent 847daf2 commit 8e4733b

File tree

1 file changed

+47
-45
lines changed

1 file changed

+47
-45
lines changed

src/diffusers/models/attention.py

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -288,28 +288,29 @@ def __init__(
288288
self._use_memory_efficient_attention_xformers = False
289289

290290
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
291-
if not is_xformers_available():
292-
raise ModuleNotFoundError(
293-
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
294-
" xformers",
295-
name="xformers",
296-
)
297-
elif not torch.cuda.is_available():
298-
raise ValueError(
299-
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
300-
" available for GPU "
301-
)
302-
else:
303-
try:
304-
# Make sure we can run the memory efficient attention
305-
_ = xformers.ops.memory_efficient_attention(
306-
torch.randn((1, 2, 40), device="cuda"),
307-
torch.randn((1, 2, 40), device="cuda"),
308-
torch.randn((1, 2, 40), device="cuda"),
291+
if use_memory_efficient_attention_xformers:
292+
if not is_xformers_available():
293+
raise ModuleNotFoundError(
294+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
295+
" xformers",
296+
name="xformers",
309297
)
310-
except Exception as e:
311-
raise e
312-
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
298+
elif not torch.cuda.is_available():
299+
raise ValueError(
300+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
301+
" available for GPU "
302+
)
303+
else:
304+
try:
305+
# Make sure we can run the memory efficient attention
306+
_ = xformers.ops.memory_efficient_attention(
307+
torch.randn((1, 2, 40), device="cuda"),
308+
torch.randn((1, 2, 40), device="cuda"),
309+
torch.randn((1, 2, 40), device="cuda"),
310+
)
311+
except Exception as e:
312+
raise e
313+
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
313314

314315
def reshape_heads_to_batch_dim(self, tensor):
315316
batch_size, seq_len, dim = tensor.shape
@@ -450,31 +451,32 @@ def __init__(
450451
self.norm3 = nn.LayerNorm(dim)
451452

452453
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
453-
if not is_xformers_available():
454-
print("Here is how to install it")
455-
raise ModuleNotFoundError(
456-
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
457-
" xformers",
458-
name="xformers",
459-
)
460-
elif not torch.cuda.is_available():
461-
raise ValueError(
462-
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
463-
" available for GPU "
464-
)
465-
else:
466-
try:
467-
# Make sure we can run the memory efficient attention
468-
_ = xformers.ops.memory_efficient_attention(
469-
torch.randn((1, 2, 40), device="cuda"),
470-
torch.randn((1, 2, 40), device="cuda"),
471-
torch.randn((1, 2, 40), device="cuda"),
454+
if use_memory_efficient_attention_xformers:
455+
if not is_xformers_available():
456+
print("Here is how to install it")
457+
raise ModuleNotFoundError(
458+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
459+
" xformers",
460+
name="xformers",
472461
)
473-
except Exception as e:
474-
raise e
475-
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
476-
if self.attn2 is not None:
477-
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
462+
elif not torch.cuda.is_available():
463+
raise ValueError(
464+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
465+
" available for GPU "
466+
)
467+
else:
468+
try:
469+
# Make sure we can run the memory efficient attention
470+
_ = xformers.ops.memory_efficient_attention(
471+
torch.randn((1, 2, 40), device="cuda"),
472+
torch.randn((1, 2, 40), device="cuda"),
473+
torch.randn((1, 2, 40), device="cuda"),
474+
)
475+
except Exception as e:
476+
raise e
477+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
478+
if self.attn2 is not None:
479+
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
478480

479481
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
480482
# 1. Self-Attention

0 commit comments

Comments
 (0)