@@ -122,15 +122,15 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
122
122
mat2 [start_idx :end_idx ],
123
123
out = out
124
124
)
125
+ torch .xpu .synchronize (input .device )
125
126
else :
126
127
return original_torch_bmm (input , mat2 , out = out )
127
- torch .xpu .synchronize (input .device )
128
128
return hidden_states
129
129
130
130
original_scaled_dot_product_attention = torch .nn .functional .scaled_dot_product_attention
131
- def scaled_dot_product_attention_32_bit (query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False ):
131
+ def scaled_dot_product_attention_32_bit (query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False , ** kwargs ):
132
132
if query .device .type != "xpu" :
133
- return original_scaled_dot_product_attention (query , key , value , attn_mask = attn_mask , dropout_p = dropout_p , is_causal = is_causal )
133
+ return original_scaled_dot_product_attention (query , key , value , attn_mask = attn_mask , dropout_p = dropout_p , is_causal = is_causal , ** kwargs )
134
134
do_split , do_split_2 , do_split_3 , split_slice_size , split_2_slice_size , split_3_slice_size = find_sdpa_slice_sizes (query .shape , query .element_size ())
135
135
136
136
# Slice SDPA
@@ -153,25 +153,25 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
153
153
key [start_idx :end_idx , start_idx_2 :end_idx_2 , start_idx_3 :end_idx_3 ],
154
154
value [start_idx :end_idx , start_idx_2 :end_idx_2 , start_idx_3 :end_idx_3 ],
155
155
attn_mask = attn_mask [start_idx :end_idx , start_idx_2 :end_idx_2 , start_idx_3 :end_idx_3 ] if attn_mask is not None else attn_mask ,
156
- dropout_p = dropout_p , is_causal = is_causal
156
+ dropout_p = dropout_p , is_causal = is_causal , ** kwargs
157
157
)
158
158
else :
159
159
hidden_states [start_idx :end_idx , start_idx_2 :end_idx_2 ] = original_scaled_dot_product_attention (
160
160
query [start_idx :end_idx , start_idx_2 :end_idx_2 ],
161
161
key [start_idx :end_idx , start_idx_2 :end_idx_2 ],
162
162
value [start_idx :end_idx , start_idx_2 :end_idx_2 ],
163
163
attn_mask = attn_mask [start_idx :end_idx , start_idx_2 :end_idx_2 ] if attn_mask is not None else attn_mask ,
164
- dropout_p = dropout_p , is_causal = is_causal
164
+ dropout_p = dropout_p , is_causal = is_causal , ** kwargs
165
165
)
166
166
else :
167
167
hidden_states [start_idx :end_idx ] = original_scaled_dot_product_attention (
168
168
query [start_idx :end_idx ],
169
169
key [start_idx :end_idx ],
170
170
value [start_idx :end_idx ],
171
171
attn_mask = attn_mask [start_idx :end_idx ] if attn_mask is not None else attn_mask ,
172
- dropout_p = dropout_p , is_causal = is_causal
172
+ dropout_p = dropout_p , is_causal = is_causal , ** kwargs
173
173
)
174
+ torch .xpu .synchronize (query .device )
174
175
else :
175
- return original_scaled_dot_product_attention (query , key , value , attn_mask = attn_mask , dropout_p = dropout_p , is_causal = is_causal )
176
- torch .xpu .synchronize (query .device )
176
+ return original_scaled_dot_product_attention (query , key , value , attn_mask = attn_mask , dropout_p = dropout_p , is_causal = is_causal , ** kwargs )
177
177
return hidden_states
0 commit comments