Skip to content

Commit ad38212

Browse files
authored
[CUDA] enable causal in MultiHeadAttention (microsoft#21852)
### Description Enable causal in MultiHeadAttention cuda operator. All formats (Q_K_V_BSNH_BSNH_BSNH, Q_K_V_BSNH_BNSH_BNSH, Q_KV_BSNH_BSN2H and QKV_BSN3H) supports causal for now. Internally, casual will be dispatch to flash attention, efficient attention or unfused attention kernel. ### Motivation and Context Currently, MultiHeadAttention has causal enabled in CPU ep, but not in CUDA ep. It could cause issues in onnx conversion, like some model can run in CPU but not in CUDA. Enable causal in CUDA will reduce the difference of support matrix of CPU/CUDA.
1 parent d9c57ac commit ad38212

File tree

4 files changed

+37
-28
lines changed

4 files changed

+37
-28
lines changed

onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)
4646

4747
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
4848
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
49-
ORT_ENFORCE(!is_unidirectional_,
50-
"MHA support CUDA kernel does not Unidirectional. Consider using Attention or GQA instead.");
5149

5250
kernel_options_ = this->GetAttentionKernelOptions();
5351

@@ -208,13 +206,13 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
208206
bool use_fused_cross_attention =
209207
kernel_type == AttentionKernelType::AttentionKernel_Default &&
210208
!disable_fused_cross_attention_ &&
209+
!is_unidirectional_ &&
211210
nullptr == key_padding_mask &&
212211
nullptr == attention_bias &&
213212
nullptr == past_key && nullptr == present_key &&
214213
(parameters.qkv_format == Q_K_V_BSNH || (parameters.qkv_format == Q_KV_BSNH_BSN2H && bias == nullptr)) &&
215214
parameters.hidden_size == parameters.v_hidden_size &&
216-
has_fused_cross_attention_kernel(sm, parameters.head_size,
217-
parameters.kv_sequence_length);
215+
has_fused_cross_attention_kernel(sm, parameters.head_size, parameters.kv_sequence_length);
218216
if (use_fused_cross_attention) {
219217
if (fused_fp16_cross_attention_kernel_ == nullptr) {
220218
std::call_once(fused_cross_init_once_flag_, [&]() {
@@ -233,20 +231,20 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
233231
bool use_fused_runner =
234232
kernel_type == AttentionKernelType::AttentionKernel_Default &&
235233
!disable_fused_self_attention_ &&
234+
!is_unidirectional_ &&
236235
nullptr == attention_bias &&
237236
(parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) &&
238237
nullptr == past_key && nullptr == present_key &&
239238
is_mask_none_or_1d_k_len &&
240239
parameters.hidden_size == parameters.v_hidden_size &&
241240
parameters.sequence_length == parameters.kv_sequence_length && // self attention only for fused runner
242241
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
243-
enable_trt_flash_attention_, false);
242+
enable_trt_flash_attention_, is_unidirectional_);
244243
if (use_fused_runner) {
245244
// Here we assume that num_heads and head_size does not change for a MultiHeadAttention node.
246245
if (nullptr == fused_fp16_runner_.get()) {
247-
constexpr bool is_unidirectional = false;
248246
std::call_once(fused_fp16_runner_created_, [&]() {
249-
fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional,
247+
fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_,
250248
enable_trt_flash_attention_, parameters.scale);
251249
});
252250
}

onnxruntime/python/tools/transformers/io_binding_helper.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]):
304304
tensor.data_ptr(),
305305
)
306306

307-
def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = False):
307+
def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = True):
308308
"""Bind input tensors and run inference"""
309309
for name, tensor in feed_dict.items():
310310
assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous()
@@ -317,7 +317,6 @@ def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = No
317317
else:
318318
self.bind_input_and_buffer_sharing(name, tensor)
319319

320-
# Synchronization are not needed in most cases unless different streams are used or inputs/outputs are in CPU.
321320
if synchronize:
322321
self.io_binding.synchronize_inputs()
323322
self.ort_session.run_with_iobinding(self.io_binding, run_options)

onnxruntime/test/python/transformers/benchmark_mha.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -587,8 +587,8 @@ def __init__(self, config: MultiHeadAttentionConfig, session_options=None, use_t
587587
self.ort_session = create_session(config, session_options, use_tf32=use_tf32)
588588
self.feed_dict = config.random_inputs()
589589

590-
def infer(self):
591-
return self.ort_session.infer(self.feed_dict)
590+
def infer(self, run_options=None, synchronize=True):
591+
return self.ort_session.infer(self.feed_dict, run_options=run_options, synchronize=synchronize)
592592

593593

594594
def measure_latency(cuda_session: CudaSession, input_dict):
@@ -1356,7 +1356,6 @@ def _parse_arguments():
13561356
args.repeats = 10000 if args.use_gpu else 100
13571357

13581358
if args.use_gpu:
1359-
assert args.torch or not args.causal, "no causal cuda kernel in MHA op"
13601359
assert torch.cuda.is_available()
13611360
if not args.torch:
13621361
assert "CUDAExecutionProvider" in get_available_providers()

onnxruntime/test/python/transformers/test_mha.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,22 @@ def get_bias_support(format: InputFormats):
6868
raise RuntimeError(f"Unknown format: {format}")
6969

7070

71+
def get_causal_support(format: InputFormats):
72+
if format == InputFormats.Q_K_V_BSNH_BSNH_BSNH:
73+
return [True, False]
74+
75+
if format == InputFormats.Q_K_V_BSNH_BNSH_BNSH:
76+
return [True, False]
77+
78+
if format == InputFormats.Q_KV_BSNH_BSN2H:
79+
return [True, False]
80+
81+
if format == InputFormats.QKV_BSN3H:
82+
return [True, False]
83+
84+
raise RuntimeError(f"Unknown format: {format}")
85+
86+
7187
def get_atten_bias_support():
7288
atten_bias_options = [
7389
# (has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1)
@@ -215,7 +231,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool):
215231
for num_heads in heads:
216232
for head_size in head_sizes:
217233
for format in formats:
218-
for causal in [True, False]:
234+
for causal in get_causal_support(format):
219235
for mask_format in mask_formats:
220236
for has_bias in get_bias_support(format):
221237
for (
@@ -256,8 +272,8 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool):
256272
has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[
257273
i % len(atten_bias_options)
258274
]
259-
for causal in [True, False]:
260-
for format in formats:
275+
for format in formats:
276+
for causal in get_causal_support(format):
261277
for has_bias in get_bias_support(format):
262278
config = MultiHeadAttentionConfig(
263279
batch_size=batch_size,
@@ -308,7 +324,7 @@ def kv_cache_test_cases(provider: str, comprehensive: bool):
308324
for num_heads in heads:
309325
for head_size in head_sizes:
310326
for format in formats:
311-
for causal in [True, False]:
327+
for causal in get_causal_support(format):
312328
for has_past_input in [True, False]:
313329
for mask_format in mask_formats:
314330
for has_bias in get_bias_support(format):
@@ -353,8 +369,8 @@ def kv_cache_test_cases(provider: str, comprehensive: bool):
353369
has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[
354370
i % len(atten_bias_options)
355371
]
356-
for causal in [True, False]:
357-
for format in formats:
372+
for format in formats:
373+
for causal in get_causal_support(format):
358374
for has_past_input in [True, False]:
359375
for has_bias in get_bias_support(format):
360376
sequence_length = 1 if has_past_input else past_sequence_length
@@ -397,7 +413,7 @@ def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool):
397413
device, dtype, formats = get_provider_support_info(provider, False)
398414

399415
for format in formats:
400-
for causal in [True, False]:
416+
for causal in get_causal_support(format):
401417
for num_heads in heads:
402418
for head_size in head_sizes:
403419
configs = [] # list of configurations to run in parallel
@@ -437,7 +453,7 @@ def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool):
437453
device, dtype, formats = get_provider_support_info(provider, True)
438454

439455
for format in formats:
440-
for causal in [True, False]:
456+
for causal in get_causal_support(format):
441457
for num_heads in heads:
442458
for head_size in head_sizes:
443459
configs = []
@@ -494,12 +510,8 @@ def parity_check_mha(
494510
rtol=1e-3,
495511
atol=1e-3,
496512
):
497-
# CUDA kernel does not support causal so skip such test cases.
498-
if config.causal and config.provider == "CUDAExecutionProvider":
499-
return
500-
501513
ort_mha = OrtMultiHeadAttention(config, use_tf32=False)
502-
ort_outputs = ort_mha.infer()
514+
ort_outputs = ort_mha.infer(synchronize=True)
503515
out = ort_outputs["output"]
504516
out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size))
505517

@@ -602,9 +614,6 @@ def parity_check_mha_multi_threading(
602614
):
603615
# Use the first config to create a session, which is shared by all configs to run in parallel.
604616
config = test_inputs[0]["config"]
605-
# For now, MHA CUDA kernel does not support causal so skip such test cases.
606-
if config.causal and config.provider == "CUDAExecutionProvider":
607-
return None
608617

609618
# Some kernel does not support certain input format.
610619
if attention_kernel not in [
@@ -784,6 +793,10 @@ def run_mha_cpu(self):
784793

785794
def run_mha_cuda_multi_threading(self, attention_kernel):
786795
for configs in multi_thread_test_cases("CUDAExecutionProvider", comprehensive_mode):
796+
if configs and configs[0].causal and (SdpaKernel.TRT_CAUSAL_ATTENTION & attention_kernel != 0):
797+
# TRT fused causal is disabled by default so skip the test of causal for multi-threading.
798+
continue
799+
787800
test_inputs = []
788801
for config in configs:
789802
ort_inputs = config.random_inputs()

0 commit comments

Comments
 (0)