Commit fbc3927
authored
[CUDA] cuDNN Flash Attention (microsoft#21629)
### Description
- [x] Add cuDNN flash attention using cudnn frontend, and enable it in
MultiHeadAttention operator.
- [x] Support attention mask.
- [x] Support attention bias.
- [x] Update tests and benchmark script.
The cuDNN SDPA is disabled by default. To enable it, need the following:
(1) Requires cuDNN 9.3 or newer version installed.
(2) Set an environment variable `ORT_ENABLE_CUDNN_FLASH_ATTENTION=1` or
set `sdpa_kernel=8` cuda provider option to enable it.
(3) Only works on devices with compute capability >= 8.0.
Note that some combinations of parameters might be rejected due to
limited support of head dimension or sequence lengths.
Future Works:
(1) FP8 and BF16 APIs. Currently, only API for FP16 are exposed.
(2) Add API to support ragged batching (padding removed in inputs).
(3) Support other input formats (like QKV_BS3NH).
(4) Currently, q are converted to BSNH, k/v are converted to either BSNH
or BNSH format. May do some experiment to see whether converting q to
BNSH could be better in some case.
### Example Benchmark Results on H100
The following tests are on FP16 MultiHeadAttention operator without
attention mask and attention bias.
#### Test Setting 1
batch_size | sequence_length | past_sequence_length | num_heads |
head_size
-- | -- | -- | -- | --
16 | 256 | 0 | 32 | 128
format | average_latency | tflops | kernel
-- | -- | -- | --
Q,K,V (BNSH) | 0.000075 | 229.5 | torch:flash
Q,K,V (BNSH) | 0.000119 | 144.8 | torch:efficient
Q,K,V (BNSH) | 0.000224 | 76.5 | torch:math
Q,K,V (BSNH) | 0.000075 | 227.8 | ort:cudnn
Q,K,V (BSNH) | 0.000094 | 182.8 | ort:flash
Q,K,V (BSNH) | 0.000138 | 124.7 | ort:efficient
Q,K,V (BSNH) | 0.000438 | 39.3 | ort:math
Q,KV | 0.000129 | 133.0 | ort:cudnn
Q,KV | 0.000151 | 114.1 | ort:flash
Q,KV | 0.000194 | 88.5 | ort:efficient
QKV | 0.000154 | 111.8 | ort:cudnn
QKV | 0.000175 | 98.0 | ort:flash
QKV | 0.000217 | 79.0 | ort:efficient
#### Test Setting 2
batch_size | sequence_length | past_sequence_length | num_heads |
head_size
-- | -- | -- | -- | --
16 | 512 | 0 | 16 | 64
format | average_latency | tflops | kernel
-- | -- | -- | --
Q,K,V (BNSH) | 0.000069 | 249.2 | torch:flash
Q,K,V (BNSH) | 0.000141 | 121.7 | torch:efficient
Q,K,V (BNSH) | 0.000294 | 58.5 | torch:math
Q,K,V (BSNH) | 0.000077 | 221.7 | ort:cudnn
Q,K,V (BSNH) | 0.000087 | 196.6 | ort:flash
Q,K,V (BSNH) | 0.000163 | 105.6 | ort:efficient
Q,K,V (BSNH) | 0.000651 | 26.4 | ort:math
Q,KV | 0.000103 | 167.1 | ort:cudnn
Q,KV | 0.000117 | 146.3 | ort:flash
Q,KV | 0.000192 | 89.6 | ort:efficient
QKV | 0.000113 | 151.5 | ort:cudnn
QKV | 0.000128 | 134.7 | ort:flash
QKV | 0.000201 | 85.3 | ort:efficient1 parent 9f7e19c commit fbc3927
File tree
19 files changed
+681
-50
lines changed- cmake
- external
- onnxruntime
- contrib_ops
- cpu/bert
- cuda
- bert
- cudnn_fmha
- quantization
- core/providers/cuda
- test
- contrib_ops
- python/transformers
19 files changed
+681
-50
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
107 | 107 | | |
108 | 108 | | |
109 | 109 | | |
110 | | - | |
111 | | - | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
5 | 5 | | |
6 | 6 | | |
7 | 7 | | |
| 8 | + | |
8 | 9 | | |
9 | 10 | | |
10 | 11 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
47 | 47 | | |
48 | 48 | | |
49 | 49 | | |
| 50 | + | |
50 | 51 | | |
51 | 52 | | |
52 | 53 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
246 | 246 | | |
247 | 247 | | |
248 | 248 | | |
| 249 | + | |
249 | 250 | | |
250 | 251 | | |
251 | 252 | | |
| |||
258 | 259 | | |
259 | 260 | | |
260 | 261 | | |
| 262 | + | |
261 | 263 | | |
262 | 264 | | |
263 | 265 | | |
| |||
294 | 296 | | |
295 | 297 | | |
296 | 298 | | |
297 | | - | |
| 299 | + | |
| 300 | + | |
298 | 301 | | |
299 | 302 | | |
300 | 303 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
37 | 37 | | |
38 | 38 | | |
39 | 39 | | |
| 40 | + | |
40 | 41 | | |
41 | 42 | | |
42 | 43 | | |
| |||
109 | 110 | | |
110 | 111 | | |
111 | 112 | | |
| 113 | + | |
112 | 114 | | |
113 | 115 | | |
114 | 116 | | |
| |||
144 | 146 | | |
145 | 147 | | |
146 | 148 | | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
147 | 153 | | |
148 | 154 | | |
149 | 155 | | |
| |||
320 | 326 | | |
321 | 327 | | |
322 | 328 | | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
323 | 391 | | |
324 | 392 | | |
325 | 393 | | |
| |||
498 | 566 | | |
499 | 567 | | |
500 | 568 | | |
| 569 | + | |
501 | 570 | | |
502 | 571 | | |
503 | 572 | | |
| |||
512 | 581 | | |
513 | 582 | | |
514 | 583 | | |
515 | | - | |
516 | | - | |
517 | | - | |
518 | | - | |
| 584 | + | |
| 585 | + | |
| 586 | + | |
| 587 | + | |
| 588 | + | |
519 | 589 | | |
520 | 590 | | |
521 | 591 | | |
| |||
577 | 647 | | |
578 | 648 | | |
579 | 649 | | |
| 650 | + | |
| 651 | + | |
| 652 | + | |
| 653 | + | |
580 | 654 | | |
581 | 655 | | |
582 | 656 | | |
| |||
594 | 668 | | |
595 | 669 | | |
596 | 670 | | |
| 671 | + | |
597 | 672 | | |
598 | 673 | | |
599 | 674 | | |
600 | 675 | | |
601 | 676 | | |
602 | 677 | | |
603 | 678 | | |
| 679 | + | |
604 | 680 | | |
605 | 681 | | |
606 | 682 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
9 | 9 | | |
10 | 10 | | |
11 | 11 | | |
| 12 | + | |
12 | 13 | | |
13 | 14 | | |
14 | 15 | | |
| |||
54 | 55 | | |
55 | 56 | | |
56 | 57 | | |
| 58 | + | |
57 | 59 | | |
58 | 60 | | |
59 | 61 | | |
| |||
104 | 106 | | |
105 | 107 | | |
106 | 108 | | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
107 | 112 | | |
108 | | - | |
109 | | - | |
| 113 | + | |
110 | 114 | | |
111 | 115 | | |
112 | 116 | | |
| |||
139 | 143 | | |
140 | 144 | | |
141 | 145 | | |
| 146 | + | |
142 | 147 | | |
143 | 148 | | |
144 | 149 | | |
| |||
Lines changed: 13 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
9 | 9 | | |
10 | 10 | | |
11 | 11 | | |
| 12 | + | |
12 | 13 | | |
13 | 14 | | |
14 | 15 | | |
15 | 16 | | |
16 | | - | |
| 17 | + | |
17 | 18 | | |
18 | 19 | | |
19 | 20 | | |
| |||
28 | 29 | | |
29 | 30 | | |
30 | 31 | | |
| 32 | + | |
31 | 33 | | |
32 | 34 | | |
33 | 35 | | |
| |||
45 | 47 | | |
46 | 48 | | |
47 | 49 | | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
48 | 58 | | |
49 | 59 | | |
50 | 60 | | |
| |||
58 | 68 | | |
59 | 69 | | |
60 | 70 | | |
61 | | - | |
| 71 | + | |
62 | 72 | | |
63 | | - | |
| 73 | + | |
64 | 74 | | |
65 | 75 | | |
66 | 76 | | |
| |||
Lines changed: 2 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
24 | | - | |
| 24 | + | |
25 | 25 | | |
26 | 26 | | |
27 | 27 | | |
| |||
40 | 40 | | |
41 | 41 | | |
42 | 42 | | |
43 | | - | |
| 43 | + | |
44 | 44 | | |
45 | 45 | | |
46 | 46 | | |
| |||
0 commit comments