Commit ad38212
authored
[CUDA] enable causal in MultiHeadAttention (microsoft#21852)
### Description
Enable causal in MultiHeadAttention cuda operator.
All formats (Q_K_V_BSNH_BSNH_BSNH, Q_K_V_BSNH_BNSH_BNSH, Q_KV_BSNH_BSN2H
and QKV_BSN3H) supports causal for now. Internally, casual will be
dispatch to flash attention, efficient attention or unfused attention
kernel.
### Motivation and Context
Currently, MultiHeadAttention has causal enabled in CPU ep, but not in
CUDA ep. It could cause issues in onnx conversion, like some model can
run in CPU but not in CUDA. Enable causal in CUDA will reduce the
difference of support matrix of CPU/CUDA.1 parent d9c57ac commit ad38212
File tree
4 files changed
+37
-28
lines changed- onnxruntime
- contrib_ops/cuda/bert
- python/tools/transformers
- test/python/transformers
4 files changed
+37
-28
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
46 | 46 | | |
47 | 47 | | |
48 | 48 | | |
49 | | - | |
50 | | - | |
51 | 49 | | |
52 | 50 | | |
53 | 51 | | |
| |||
208 | 206 | | |
209 | 207 | | |
210 | 208 | | |
| 209 | + | |
211 | 210 | | |
212 | 211 | | |
213 | 212 | | |
214 | 213 | | |
215 | 214 | | |
216 | | - | |
217 | | - | |
| 215 | + | |
218 | 216 | | |
219 | 217 | | |
220 | 218 | | |
| |||
233 | 231 | | |
234 | 232 | | |
235 | 233 | | |
| 234 | + | |
236 | 235 | | |
237 | 236 | | |
238 | 237 | | |
239 | 238 | | |
240 | 239 | | |
241 | 240 | | |
242 | 241 | | |
243 | | - | |
| 242 | + | |
244 | 243 | | |
245 | 244 | | |
246 | 245 | | |
247 | | - | |
248 | 246 | | |
249 | | - | |
| 247 | + | |
250 | 248 | | |
251 | 249 | | |
252 | 250 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
304 | 304 | | |
305 | 305 | | |
306 | 306 | | |
307 | | - | |
| 307 | + | |
308 | 308 | | |
309 | 309 | | |
310 | 310 | | |
| |||
317 | 317 | | |
318 | 318 | | |
319 | 319 | | |
320 | | - | |
321 | 320 | | |
322 | 321 | | |
323 | 322 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
587 | 587 | | |
588 | 588 | | |
589 | 589 | | |
590 | | - | |
591 | | - | |
| 590 | + | |
| 591 | + | |
592 | 592 | | |
593 | 593 | | |
594 | 594 | | |
| |||
1356 | 1356 | | |
1357 | 1357 | | |
1358 | 1358 | | |
1359 | | - | |
1360 | 1359 | | |
1361 | 1360 | | |
1362 | 1361 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
68 | 68 | | |
69 | 69 | | |
70 | 70 | | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
71 | 87 | | |
72 | 88 | | |
73 | 89 | | |
| |||
215 | 231 | | |
216 | 232 | | |
217 | 233 | | |
218 | | - | |
| 234 | + | |
219 | 235 | | |
220 | 236 | | |
221 | 237 | | |
| |||
256 | 272 | | |
257 | 273 | | |
258 | 274 | | |
259 | | - | |
260 | | - | |
| 275 | + | |
| 276 | + | |
261 | 277 | | |
262 | 278 | | |
263 | 279 | | |
| |||
308 | 324 | | |
309 | 325 | | |
310 | 326 | | |
311 | | - | |
| 327 | + | |
312 | 328 | | |
313 | 329 | | |
314 | 330 | | |
| |||
353 | 369 | | |
354 | 370 | | |
355 | 371 | | |
356 | | - | |
357 | | - | |
| 372 | + | |
| 373 | + | |
358 | 374 | | |
359 | 375 | | |
360 | 376 | | |
| |||
397 | 413 | | |
398 | 414 | | |
399 | 415 | | |
400 | | - | |
| 416 | + | |
401 | 417 | | |
402 | 418 | | |
403 | 419 | | |
| |||
437 | 453 | | |
438 | 454 | | |
439 | 455 | | |
440 | | - | |
| 456 | + | |
441 | 457 | | |
442 | 458 | | |
443 | 459 | | |
| |||
494 | 510 | | |
495 | 511 | | |
496 | 512 | | |
497 | | - | |
498 | | - | |
499 | | - | |
500 | | - | |
501 | 513 | | |
502 | | - | |
| 514 | + | |
503 | 515 | | |
504 | 516 | | |
505 | 517 | | |
| |||
602 | 614 | | |
603 | 615 | | |
604 | 616 | | |
605 | | - | |
606 | | - | |
607 | | - | |
608 | 617 | | |
609 | 618 | | |
610 | 619 | | |
| |||
784 | 793 | | |
785 | 794 | | |
786 | 795 | | |
| 796 | + | |
| 797 | + | |
| 798 | + | |
| 799 | + | |
787 | 800 | | |
788 | 801 | | |
789 | 802 | | |
| |||
0 commit comments