@@ -288,28 +288,29 @@ def __init__(
288288 self ._use_memory_efficient_attention_xformers = False
289289
290290 def set_use_memory_efficient_attention_xformers (self , use_memory_efficient_attention_xformers : bool ):
291- if not is_xformers_available ():
292- raise ModuleNotFoundError (
293- "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
294- " xformers" ,
295- name = "xformers" ,
296- )
297- elif not torch .cuda .is_available ():
298- raise ValueError (
299- "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
300- " available for GPU "
301- )
302- else :
303- try :
304- # Make sure we can run the memory efficient attention
305- _ = xformers .ops .memory_efficient_attention (
306- torch .randn ((1 , 2 , 40 ), device = "cuda" ),
307- torch .randn ((1 , 2 , 40 ), device = "cuda" ),
308- torch .randn ((1 , 2 , 40 ), device = "cuda" ),
291+ if use_memory_efficient_attention_xformers :
292+ if not is_xformers_available ():
293+ raise ModuleNotFoundError (
294+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
295+ " xformers" ,
296+ name = "xformers" ,
309297 )
310- except Exception as e :
311- raise e
312- self ._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
298+ elif not torch .cuda .is_available ():
299+ raise ValueError (
300+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
301+ " available for GPU "
302+ )
303+ else :
304+ try :
305+ # Make sure we can run the memory efficient attention
306+ _ = xformers .ops .memory_efficient_attention (
307+ torch .randn ((1 , 2 , 40 ), device = "cuda" ),
308+ torch .randn ((1 , 2 , 40 ), device = "cuda" ),
309+ torch .randn ((1 , 2 , 40 ), device = "cuda" ),
310+ )
311+ except Exception as e :
312+ raise e
313+ self ._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
313314
314315 def reshape_heads_to_batch_dim (self , tensor ):
315316 batch_size , seq_len , dim = tensor .shape
@@ -450,31 +451,32 @@ def __init__(
450451 self .norm3 = nn .LayerNorm (dim )
451452
452453 def set_use_memory_efficient_attention_xformers (self , use_memory_efficient_attention_xformers : bool ):
453- if not is_xformers_available ():
454- print ("Here is how to install it" )
455- raise ModuleNotFoundError (
456- "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
457- " xformers" ,
458- name = "xformers" ,
459- )
460- elif not torch .cuda .is_available ():
461- raise ValueError (
462- "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
463- " available for GPU "
464- )
465- else :
466- try :
467- # Make sure we can run the memory efficient attention
468- _ = xformers .ops .memory_efficient_attention (
469- torch .randn ((1 , 2 , 40 ), device = "cuda" ),
470- torch .randn ((1 , 2 , 40 ), device = "cuda" ),
471- torch .randn ((1 , 2 , 40 ), device = "cuda" ),
454+ if use_memory_efficient_attention_xformers :
455+ if not is_xformers_available ():
456+ print ("Here is how to install it" )
457+ raise ModuleNotFoundError (
458+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
459+ " xformers" ,
460+ name = "xformers" ,
472461 )
473- except Exception as e :
474- raise e
475- self .attn1 ._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
476- if self .attn2 is not None :
477- self .attn2 ._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
462+ elif not torch .cuda .is_available ():
463+ raise ValueError (
464+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
465+ " available for GPU "
466+ )
467+ else :
468+ try :
469+ # Make sure we can run the memory efficient attention
470+ _ = xformers .ops .memory_efficient_attention (
471+ torch .randn ((1 , 2 , 40 ), device = "cuda" ),
472+ torch .randn ((1 , 2 , 40 ), device = "cuda" ),
473+ torch .randn ((1 , 2 , 40 ), device = "cuda" ),
474+ )
475+ except Exception as e :
476+ raise e
477+ self .attn1 ._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
478+ if self .attn2 is not None :
479+ self .attn2 ._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
478480
479481 def forward (self , hidden_states , encoder_hidden_states = None , timestep = None , attention_mask = None ):
480482 # 1. Self-Attention
0 commit comments