@@ -105,6 +105,10 @@ def __init__(
105105 def set_use_memory_efficient_attention_xformers (
106106 self , use_memory_efficient_attention_xformers : bool , attention_op : Optional [Callable ] = None
107107 ):
108+ is_lora = hasattr (self , "processor" ) and isinstance (
109+ self .processor , (LoRACrossAttnProcessor , LoRAXFormersCrossAttnProcessor )
110+ )
111+
108112 if use_memory_efficient_attention_xformers :
109113 if self .added_kv_proj_dim is not None :
110114 # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
@@ -138,9 +142,28 @@ def set_use_memory_efficient_attention_xformers(
138142 except Exception as e :
139143 raise e
140144
141- processor = XFormersCrossAttnProcessor (attention_op = attention_op )
145+ if is_lora :
146+ processor = LoRAXFormersCrossAttnProcessor (
147+ hidden_size = self .processor .hidden_size ,
148+ cross_attention_dim = self .processor .cross_attention_dim ,
149+ rank = self .processor .rank ,
150+ attention_op = attention_op ,
151+ )
152+ processor .load_state_dict (self .processor .state_dict ())
153+ processor .to (self .processor .to_q_lora .up .weight .device )
154+ else :
155+ processor = XFormersCrossAttnProcessor (attention_op = attention_op )
142156 else :
143- processor = CrossAttnProcessor ()
157+ if is_lora :
158+ processor = LoRACrossAttnProcessor (
159+ hidden_size = self .processor .hidden_size ,
160+ cross_attention_dim = self .processor .cross_attention_dim ,
161+ rank = self .processor .rank ,
162+ )
163+ processor .load_state_dict (self .processor .state_dict ())
164+ processor .to (self .processor .to_q_lora .up .weight .device )
165+ else :
166+ processor = CrossAttnProcessor ()
144167
145168 self .set_processor (processor )
146169
@@ -324,6 +347,10 @@ class LoRACrossAttnProcessor(nn.Module):
324347 def __init__ (self , hidden_size , cross_attention_dim = None , rank = 4 ):
325348 super ().__init__ ()
326349
350+ self .hidden_size = hidden_size
351+ self .cross_attention_dim = cross_attention_dim
352+ self .rank = rank
353+
327354 self .to_q_lora = LoRALinearLayer (hidden_size , hidden_size , rank )
328355 self .to_k_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size , rank )
329356 self .to_v_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size , rank )
@@ -437,9 +464,14 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
437464
438465
439466class LoRAXFormersCrossAttnProcessor (nn .Module ):
440- def __init__ (self , hidden_size , cross_attention_dim , rank = 4 ):
467+ def __init__ (self , hidden_size , cross_attention_dim , rank = 4 , attention_op : Optional [ Callable ] = None ):
441468 super ().__init__ ()
442469
470+ self .hidden_size = hidden_size
471+ self .cross_attention_dim = cross_attention_dim
472+ self .rank = rank
473+ self .attention_op = attention_op
474+
443475 self .to_q_lora = LoRALinearLayer (hidden_size , hidden_size , rank )
444476 self .to_k_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size , rank )
445477 self .to_v_lora = LoRALinearLayer (cross_attention_dim or hidden_size , hidden_size , rank )
@@ -462,7 +494,9 @@ def __call__(
462494 key = attn .head_to_batch_dim (key ).contiguous ()
463495 value = attn .head_to_batch_dim (value ).contiguous ()
464496
465- hidden_states = xformers .ops .memory_efficient_attention (query , key , value , attn_bias = attention_mask )
497+ hidden_states = xformers .ops .memory_efficient_attention (
498+ query , key , value , attn_bias = attention_mask , op = self .attention_op
499+ )
466500 hidden_states = attn .batch_to_head_dim (hidden_states )
467501
468502 # linear proj
@@ -595,4 +629,6 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=
595629 SlicedAttnProcessor ,
596630 CrossAttnAddedKVProcessor ,
597631 SlicedAttnAddedKVProcessor ,
632+ LoRACrossAttnProcessor ,
633+ LoRAXFormersCrossAttnProcessor ,
598634]
0 commit comments