@@ -185,17 +185,23 @@ def get_attention_scores(self, query, key, attention_mask=None):
185185 query = query .float ()
186186 key = key .float ()
187187
188+ if attention_mask is None :
189+ baddbmm_input = torch .empty (
190+ query .shape [0 ], query .shape [1 ], key .shape [1 ], dtype = query .dtype , device = query .device
191+ )
192+ beta = 0
193+ else :
194+ baddbmm_input = attention_mask
195+ beta = 1
196+
188197 attention_scores = torch .baddbmm (
189- torch . empty ( query . shape [ 0 ], query . shape [ 1 ], key . shape [ 1 ], dtype = query . dtype , device = query . device ) ,
198+ baddbmm_input ,
190199 query ,
191200 key .transpose (- 1 , - 2 ),
192- beta = 0 ,
201+ beta = beta ,
193202 alpha = self .scale ,
194203 )
195204
196- if attention_mask is not None :
197- attention_scores = attention_scores + attention_mask
198-
199205 if self .upcast_softmax :
200206 attention_scores = attention_scores .float ()
201207
@@ -228,11 +234,12 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
228234 attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length )
229235
230236 query = attn .to_q (hidden_states )
231- query = attn .head_to_batch_dim (query )
232237
233238 encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
234239 key = attn .to_k (encoder_hidden_states )
235240 value = attn .to_v (encoder_hidden_states )
241+
242+ query = attn .head_to_batch_dim (query )
236243 key = attn .head_to_batch_dim (key )
237244 value = attn .head_to_batch_dim (value )
238245
0 commit comments