Skip to content

Commit fbc3927

Browse files
authored
[CUDA] cuDNN Flash Attention (microsoft#21629)
### Description - [x] Add cuDNN flash attention using cudnn frontend, and enable it in MultiHeadAttention operator. - [x] Support attention mask. - [x] Support attention bias. - [x] Update tests and benchmark script. The cuDNN SDPA is disabled by default. To enable it, need the following: (1) Requires cuDNN 9.3 or newer version installed. (2) Set an environment variable `ORT_ENABLE_CUDNN_FLASH_ATTENTION=1` or set `sdpa_kernel=8` cuda provider option to enable it. (3) Only works on devices with compute capability >= 8.0. Note that some combinations of parameters might be rejected due to limited support of head dimension or sequence lengths. Future Works: (1) FP8 and BF16 APIs. Currently, only API for FP16 are exposed. (2) Add API to support ragged batching (padding removed in inputs). (3) Support other input formats (like QKV_BS3NH). (4) Currently, q are converted to BSNH, k/v are converted to either BSNH or BNSH format. May do some experiment to see whether converting q to BNSH could be better in some case. ### Example Benchmark Results on H100 The following tests are on FP16 MultiHeadAttention operator without attention mask and attention bias. #### Test Setting 1 batch_size | sequence_length | past_sequence_length | num_heads | head_size -- | -- | -- | -- | -- 16 | 256 | 0 | 32 | 128 format | average_latency | tflops | kernel -- | -- | -- | -- Q,K,V (BNSH) | 0.000075 | 229.5 | torch:flash Q,K,V (BNSH) | 0.000119 | 144.8 | torch:efficient Q,K,V (BNSH) | 0.000224 | 76.5 | torch:math Q,K,V (BSNH) | 0.000075 | 227.8 | ort:cudnn Q,K,V (BSNH) | 0.000094 | 182.8 | ort:flash Q,K,V (BSNH) | 0.000138 | 124.7 | ort:efficient Q,K,V (BSNH) | 0.000438 | 39.3 | ort:math Q,KV | 0.000129 | 133.0 | ort:cudnn Q,KV | 0.000151 | 114.1 | ort:flash Q,KV | 0.000194 | 88.5 | ort:efficient QKV | 0.000154 | 111.8 | ort:cudnn QKV | 0.000175 | 98.0 | ort:flash QKV | 0.000217 | 79.0 | ort:efficient #### Test Setting 2 batch_size | sequence_length | past_sequence_length | num_heads | head_size -- | -- | -- | -- | -- 16 | 512 | 0 | 16 | 64 format | average_latency | tflops | kernel -- | -- | -- | -- Q,K,V (BNSH) | 0.000069 | 249.2 | torch:flash Q,K,V (BNSH) | 0.000141 | 121.7 | torch:efficient Q,K,V (BNSH) | 0.000294 | 58.5 | torch:math Q,K,V (BSNH) | 0.000077 | 221.7 | ort:cudnn Q,K,V (BSNH) | 0.000087 | 196.6 | ort:flash Q,K,V (BSNH) | 0.000163 | 105.6 | ort:efficient Q,K,V (BSNH) | 0.000651 | 26.4 | ort:math Q,KV | 0.000103 | 167.1 | ort:cudnn Q,KV | 0.000117 | 146.3 | ort:flash Q,KV | 0.000192 | 89.6 | ort:efficient QKV | 0.000113 | 151.5 | ort:cudnn QKV | 0.000128 | 134.7 | ort:flash QKV | 0.000201 | 85.3 | ort:efficient
1 parent 9f7e19c commit fbc3927

19 files changed

+681
-50
lines changed

cmake/external/cuDNN.cmake

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,5 +107,3 @@ elseif(CUDNN_MAJOR_VERSION EQUAL 9)
107107
CUDNN::cudnn_heuristic
108108
)
109109
endif()
110-
111-
mark_as_advanced(CUDNN_INCLUDE_DIR)

cmake/onnxruntime_rocm_hipify.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ find_package(Python3 COMPONENTS Interpreter REQUIRED)
55

66
# GLOB pattern of file to be excluded
77
set(contrib_ops_excluded_files
8+
"bert/cudnn_fmha/*"
89
"bert/cutlass_fmha/*"
910
"bert/fastertransformer_decoder_attention/*"
1011
"bert/flash_attention/*"

onnxruntime/contrib_ops/cpu/bert/attention_common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ enum AttentionKernelType {
4747
AttentionKernel_TrtFusedCrossAttention,
4848
AttentionKernel_CutlassMemoryEfficientAttention,
4949
AttentionKernel_FlashAttention,
50+
AttentionKernel_CudnnFlashAttention,
5051
AttentionKernel_Default
5152
};
5253

onnxruntime/contrib_ops/cuda/bert/attention.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
246246

247247
constexpr size_t element_size = sizeof(T);
248248
constexpr bool use_fused_cross_attention = false;
249+
constexpr bool use_cudnn_flash_attention = false;
249250
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
250251
parameters.batch_size,
251252
parameters.num_heads,
@@ -258,6 +259,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
258259
use_flash_attention,
259260
use_fused_cross_attention,
260261
use_memory_efficient_attention,
262+
use_cudnn_flash_attention,
261263
false);
262264
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, workSpaceSize, false, context->GetComputeStream());
263265

@@ -294,7 +296,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
294296
data.out_accum = reinterpret_cast<CudaT*>(out_accum_buffer.get());
295297
}
296298

297-
return QkvToContext<CudaT>(device_prop, cublas, context->GetComputeStream(), parameters, data);
299+
cudnnHandle_t cudnn = GetCudnnHandle(context);
300+
return QkvToContext<CudaT>(device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data);
298301
}
299302

300303
} // namespace cuda

onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ limitations under the License.
3737
#include "contrib_ops/cuda/bert/bert_padding.h"
3838
#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
3939
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
40+
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
4041
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
4142
#include "contrib_ops/cuda/bert/attention_impl.h"
4243

@@ -109,6 +110,7 @@ size_t GetAttentionWorkspaceSize(
109110
bool use_flash_attention,
110111
bool use_fused_cross_attention,
111112
bool use_memory_efficient_attention,
113+
bool use_cudnn_flash_attention,
112114
bool no_qkv_workspace) {
113115
// Note that q, k and v might need alignment for fused attention kernels.
114116
const size_t qkv_size = element_size * batch_size * num_heads *
@@ -144,6 +146,10 @@ size_t GetAttentionWorkspaceSize(
144146
return qkv_bytes + 2 * GetSequenceOffsetSize(static_cast<int>(batch_size), true);
145147
}
146148

149+
if (use_cudnn_flash_attention) {
150+
return qkv_bytes;
151+
}
152+
147153
return qkv_bytes + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length,
148154
total_sequence_length);
149155
}
@@ -320,6 +326,68 @@ Status FlashAttention(
320326
}
321327
#endif
322328

329+
template <typename T>
330+
Status CudnnFlashAttention(
331+
cudnnHandle_t cudnn_handle,
332+
Stream* ort_stream,
333+
contrib::AttentionParameters& parameters,
334+
AttentionData<T>& data,
335+
float scale) {
336+
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH ||
337+
data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH ||
338+
data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH);
339+
assert(parameters.mask_type == AttentionMaskType::MASK_NONE ||
340+
parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN);
341+
constexpr bool is_bf16 = false;
342+
343+
T* attention_bias = const_cast<T*>(data.attention_bias);
344+
int* mask_sequence_lengths_kv = const_cast<int*>(data.mask_index);
345+
346+
cudnn_sdpa::run(
347+
data.output,
348+
data.q,
349+
data.k,
350+
data.v,
351+
attention_bias,
352+
nullptr, // (optional) mask_sequence_lengths_q
353+
mask_sequence_lengths_kv, // (optional) mask_sequence_lengths_kv
354+
parameters.batch_size,
355+
parameters.num_heads, // num_heads_q,
356+
parameters.num_heads, // num_heads_kv,
357+
parameters.head_size, // head_size_qk
358+
parameters.v_head_size, // head_size_v
359+
parameters.sequence_length, // sequence_length_q
360+
parameters.total_sequence_length, // sequence_length_kv
361+
scale, // scaling factor applied prior softmax
362+
parameters.is_unidirectional, // causal
363+
is_bf16, // True if bfloat16, otherwise float16
364+
parameters.broadcast_attn_bias_dim_0, // broadcast attention bias dimension 0 or not
365+
parameters.broadcast_attn_bias_dim_1, // broadcast attention bias dimension 1 or not
366+
0, // sliding window length. 0 means no sliding window.
367+
data.qkv_format,
368+
cudnn_handle,
369+
ort_stream,
370+
data.allocator);
371+
372+
return Status::OK();
373+
}
374+
375+
template <>
376+
Status CudnnFlashAttention(
377+
cudnnHandle_t cudnn_handle,
378+
Stream* ort_stream,
379+
contrib::AttentionParameters& parameters,
380+
AttentionData<float>& data,
381+
float scale) {
382+
ORT_UNUSED_PARAMETER(cudnn_handle);
383+
ORT_UNUSED_PARAMETER(ort_stream);
384+
ORT_UNUSED_PARAMETER(parameters);
385+
ORT_UNUSED_PARAMETER(data);
386+
ORT_UNUSED_PARAMETER(scale);
387+
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
388+
"cudnn flash attention does not support float tensor");
389+
}
390+
323391
#if USE_MEMORY_EFFICIENT_ATTENTION
324392
template <typename T>
325393
Status EfficientAttention(
@@ -498,6 +566,7 @@ template <typename T>
498566
Status QkvToContext(
499567
const cudaDeviceProp& device_prop,
500568
cublasHandle_t& cublas,
569+
cudnnHandle_t& cudnn,
501570
Stream* ort_stream,
502571
contrib::AttentionParameters& parameters,
503572
AttentionData<T>& data) {
@@ -512,10 +581,11 @@ Status QkvToContext(
512581
void* fused_runner = data.fused_runner;
513582

514583
// At most one fused kernel is enabled.
515-
assert((int(data.use_flash_attention) +
516-
int(data.use_memory_efficient_attention) +
517-
int(fused_runner != nullptr) +
518-
int(data.fused_cross_attention_kernel != nullptr)) <= 1);
584+
assert((static_cast<int>(data.use_flash_attention) +
585+
static_cast<int>(data.use_memory_efficient_attention) +
586+
static_cast<int>(fused_runner != nullptr) +
587+
static_cast<int>(data.fused_cross_attention_kernel != nullptr) +
588+
static_cast<int>(data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention)) <= 1);
519589

520590
ORT_RETURN_IF_ERROR(PrepareQkv<T>(parameters, data, stream, max_threads_per_block));
521591

@@ -577,6 +647,10 @@ Status QkvToContext(
577647
}
578648
#endif
579649

650+
if (data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
651+
return CudnnFlashAttention(cudnn, ort_stream, parameters, data, scale);
652+
}
653+
580654
#if USE_MEMORY_EFFICIENT_ATTENTION
581655
if (data.use_memory_efficient_attention) {
582656
return EfficientAttention(device_prop, stream, parameters, data, scale);
@@ -594,13 +668,15 @@ template struct AttentionData<half>;
594668
template Status QkvToContext<float>(
595669
const cudaDeviceProp& device_prop,
596670
cublasHandle_t& cublas,
671+
cudnnHandle_t& cudnn,
597672
Stream* ort_stream,
598673
contrib::AttentionParameters& parameters,
599674
AttentionData<float>& data);
600675

601676
template Status QkvToContext<half>(
602677
const cudaDeviceProp& device_prop,
603678
cublasHandle_t& cublas,
679+
cudnnHandle_t& cudnn,
604680
Stream* ort_stream,
605681
contrib::AttentionParameters& parameters,
606682
AttentionData<half>& data);

onnxruntime/contrib_ops/cuda/bert/attention_impl.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <iostream>
1010
#include <mutex>
1111
#include "core/framework/allocator.h"
12+
#include "core/providers/cuda/cuda_common.h"
1213
#include "contrib_ops/cpu/bert/attention_common.h"
1314

1415
namespace onnxruntime {
@@ -54,6 +55,7 @@ size_t GetAttentionWorkspaceSize(
5455
bool use_flash_attention,
5556
bool use_fused_cross_attention,
5657
bool use_memory_efficient_attention,
58+
bool use_cudnn_flash_attention,
5759
bool no_qkv_workspace);
5860

5961
template <typename T>
@@ -104,9 +106,11 @@ struct AttentionData {
104106
size_t workspace_bytes = 0;
105107
bool allow_debug_info = false;
106108

109+
// For MultiHeadAttention only.
110+
AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Default;
111+
AllocatorPtr allocator = nullptr;
107112
bool IsUnfused() const {
108-
return !use_flash_attention && !use_memory_efficient_attention &&
109-
(fused_runner == nullptr) && (fused_cross_attention_kernel == nullptr);
113+
return kernel_type == AttentionKernelType::AttentionKernel_Unfused;
110114
}
111115

112116
void PrintDebugInfo() const {
@@ -139,6 +143,7 @@ template <typename T>
139143
Status QkvToContext(
140144
const cudaDeviceProp& device_prop,
141145
cublasHandle_t& cublas,
146+
cudnnHandle_t& cudnn,
142147
Stream* stream,
143148
contrib::AttentionParameters& parameters,
144149
AttentionData<T>& data);

onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
#include "core/providers/shared_library/provider_api.h"
1010
#include "core/platform/env_var_utils.h"
1111
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
12+
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
1213

1314
using namespace onnxruntime::contrib::attention;
1415

1516
namespace onnxruntime {
16-
void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
17+
void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool check_cudnn_version) {
1718
if (value > 0) {
1819
use_flash_attention_ = (value & static_cast<int>(AttentionBackend::FLASH_ATTENTION)) > 0;
1920
use_efficient_attention_ = (value & static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION)) > 0;
@@ -28,6 +29,7 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
2829
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false);
2930
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false);
3031
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, false);
32+
3133
use_unfused_ = true;
3234
use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableTrtFlashAttention, false);
3335
use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedCrossAttention, false);
@@ -45,6 +47,14 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
4547
kMinSeqLenForEfficientAttentionFp32,
4648
value > 0 ? 0 : kDefaultMinSeqLenForEfficientAttentionFp32);
4749

50+
// Enable cuDNN flash attention only when it is stable (requires cuDNN version >= 9.3.0).
51+
if (use_cudnn_flash_attention_ && check_cudnn_version && !::onnxruntime::cudnn_sdpa::is_stable()) {
52+
use_cudnn_flash_attention_ = false;
53+
if (enable_kernel_debug_info_) {
54+
std::cout << "cuDNN Flash Attention is disabled. Requires cuDNN 9.3 or later." << std::endl;
55+
}
56+
}
57+
4858
if (use_build_flag) {
4959
// Some kernels can be disabled at build time. If they are disabled, we should not use them.
5060
#ifndef USE_FLASH_ATTENTION
@@ -58,9 +68,9 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
5868
}
5969

6070
void AttentionKernelOptions::InitializeOnce(
61-
int sdpa_kernel, bool use_build_flag) {
71+
int sdpa_kernel, bool use_build_flag, bool check_cudnn_version) {
6272
std::call_once(this->initialize_once_flag_, [&]() {
63-
this->Initialize(sdpa_kernel, use_build_flag);
73+
this->Initialize(sdpa_kernel, use_build_flag, check_cudnn_version);
6474
if (this->enable_kernel_debug_info_) {
6575
this->Print();
6676
}

onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ struct AttentionKernelDebugInfo {
2121

2222
class AttentionKernelOptions {
2323
public:
24-
void InitializeOnce(int sdpa_kernel, bool use_build_flag);
24+
void InitializeOnce(int sdpa_kernel, bool use_build_flag, bool check_cudnn_version = false);
2525

2626
bool UseFlashAttention() const { return use_flash_attention_; }
2727
bool UseEfficientAttention() const { return use_efficient_attention_; }
@@ -40,7 +40,7 @@ class AttentionKernelOptions {
4040
protected:
4141
void Print() const;
4242

43-
void Initialize(int value, bool use_build_flag);
43+
void Initialize(int value, bool use_build_flag, bool check_cudnn_version);
4444

4545
private:
4646
bool use_flash_attention_{true};

0 commit comments

Comments
 (0)