@@ -621,6 +621,12 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None
621621 key = self .reshape_heads_to_batch_dim (key )
622622 value = self .reshape_heads_to_batch_dim (value )
623623
624+ if attention_mask is not None :
625+ if attention_mask .shape [- 1 ] != query .shape [1 ]:
626+ target_length = query .shape [1 ]
627+ attention_mask = F .pad (attention_mask , (0 , target_length ), value = 0.0 )
628+ attention_mask = attention_mask .repeat_interleave (self .heads , dim = 0 )
629+
624630 # attention, what we cannot get enough of
625631 if self ._use_memory_efficient_attention_xformers :
626632 hidden_states = self ._memory_efficient_attention_xformers (query , key , value , attention_mask )
@@ -630,7 +636,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None
630636 if self ._slice_size is None or query .shape [0 ] // self ._slice_size == 1 :
631637 hidden_states = self ._attention (query , key , value , attention_mask )
632638 else :
633- hidden_states = self ._sliced_attention (query , key , value , sequence_length , dim )
639+ hidden_states = self ._sliced_attention (query , key , value , sequence_length , dim , attention_mask )
634640
635641 # linear proj
636642 hidden_states = self .to_out [0 ](hidden_states )
@@ -653,11 +659,6 @@ def _attention(self, query, key, value, attention_mask=None):
653659 )
654660
655661 if attention_mask is not None :
656- if attention_mask .shape != attention_scores .shape :
657- target_length = query .shape [1 ]
658- attention_mask = F .pad (attention_mask , (0 , target_length ), value = 0.0 )
659- attention_mask = attention_mask .repeat_interleave (self .heads , dim = 0 )
660-
661662 attention_scores = attention_scores + attention_mask
662663
663664 if self .upcast_softmax :
@@ -675,7 +676,7 @@ def _attention(self, query, key, value, attention_mask=None):
675676 hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
676677 return hidden_states
677678
678- def _sliced_attention (self , query , key , value , sequence_length , dim ):
679+ def _sliced_attention (self , query , key , value , sequence_length , dim , attention_mask ):
679680 batch_size_attention = query .shape [0 ]
680681 hidden_states = torch .zeros (
681682 (batch_size_attention , sequence_length , dim // self .heads ), device = query .device , dtype = query .dtype
@@ -699,6 +700,13 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
699700 beta = 0 ,
700701 alpha = self .scale ,
701702 )
703+
704+ if attention_mask is not None :
705+ attn_slice = attn_slice + attention_mask [start_idx :end_idx ]
706+
707+ if self .upcast_softmax :
708+ attn_slice = attn_slice .float ()
709+
702710 attn_slice = attn_slice .softmax (dim = - 1 )
703711
704712 # cast back to the original dtype
@@ -716,7 +724,7 @@ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask
716724 query = query .contiguous ()
717725 key = key .contiguous ()
718726 value = value .contiguous ()
719- hidden_states = xformers .ops .memory_efficient_attention (query , key , value , attn_bias = None )
727+ hidden_states = xformers .ops .memory_efficient_attention (query , key , value , attn_bias = attention_mask )
720728 hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
721729 return hidden_states
722730
0 commit comments