diff --git a/aot_build_utils/generate.py b/aot_build_utils/generate.py index d4c21a8b6..5daa47c08 100644 --- a/aot_build_utils/generate.py +++ b/aot_build_utils/generate.py @@ -24,6 +24,7 @@ generate_batch_paged_decode_inst, generate_batch_paged_prefill_inst, generate_batch_ragged_prefill_inst, + generate_pod_inst, generate_single_decode_inst, generate_single_prefill_inst, ) @@ -250,11 +251,69 @@ def write_if_different(path: Path, content: str) -> None: f"f16qk_{bool(use_fp16_qk_reduction)}" ) + # POD files + pod_uris = [] + for ( + head_dim, + pos_encoding_mode, + use_fp16_qk_reduction, + mask_mode_p, + mask_mode_d, + idtype, + ) in product( + head_dims, + pos_encoding_modes, + use_fp16_qk_reductions, + mask_modes, + mask_modes, # mask_mode_d can be different from mask_mode_p + idtypes, + ): + for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list( + product(prefill_dtypes, fp8_dtypes) + ): + fname = f"pod_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_maskp_{mask_mode_p}_maskd_{mask_mode_d}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu" + content = generate_pod_inst.get_cu_file_str( + head_dim, # head_dim_qk + head_dim, # head_dim_vo + pos_encoding_mode, + use_fp16_qk_reduction, + mask_mode_p, + mask_mode_d, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out + idtype, + ) + write_if_different(path / fname, content) + + for sliding_window_p in [True, False]: + for sliding_window_d in [True, False]: + for logits_soft_cap_p in [True, False]: + for logits_soft_cap_d in [True, False]: + if ( + mask_mode_p == 0 and mask_mode_d == 0 + ): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris + pod_uris.append( + f"pod_with_kv_cache_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_q}_" + f"dtype_idx_{idtype}_" + f"head_dim_qk_{head_dim}_" + f"head_dim_vo_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_p_{sliding_window_p}_" + f"use_swa_d_{sliding_window_d}_" + f"use_logits_cap_p_{logits_soft_cap_p}_" + f"use_logits_cap_d_{logits_soft_cap_d}_" + f"f16qk_{bool(use_fp16_qk_reduction)}" + ) + return ( single_decode_uris + batch_decode_uris + single_prefill_uris + batch_prefill_uris + + pod_uris ) diff --git a/aot_build_utils/generate_pod_inst.py b/aot_build_utils/generate_pod_inst.py new file mode 100644 index 000000000..53e061c09 --- /dev/null +++ b/aot_build_utils/generate_pod_inst.py @@ -0,0 +1,121 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import re +import sys +from pathlib import Path + +from .literal_map import ( + dtype_literal, + idtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, +) + + +def get_cu_file_str( + head_dim_qk, + head_dim_vo, + pos_encoding_mode, + use_fp16_qk_reduction, + mask_mode_p, + mask_mode_d, + dtype_q, + dtype_kv, + dtype_out, + idtype, +): + cta_tile_q_choice = [128, 64, 16] + cta_tile_q_d = 16 + + def get_insts(attention_variant_p, attention_variant_d, dtype_out): + return "\n".join( + [ + """template cudaError_t PODWithKVCacheTensorDispatched<{head_dim_qk}, {head_dim_vo}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode_p}, {cta_tile_q_p}, {cta_tile_q_d}, {mask_mode_d}, {attention_variant_p}, {attention_variant_d}, PrefillParams, DecodeParams>( + PrefillParams prefill_params, DecodeParams decode_params, + {dtype_out}* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream); + """.format( + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, + pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], + use_fp16_qk_reduction=use_fp16_qk_reduction, + mask_mode_p=mask_mode_literal[int(mask_mode_p)], + cta_tile_q_p=cta_tile_q_p, + cta_tile_q_d=cta_tile_q_d, + mask_mode_d=mask_mode_literal[int(mask_mode_d)], + attention_variant_p=attention_variant_p, + attention_variant_d=attention_variant_d, + dtype_out=dtype_out, + ) + for cta_tile_q_p in cta_tile_q_choice + ] + ) + + use_custom_mask_p = "true" if int(mask_mode_p) == 2 else "false" + use_custom_mask_d = "true" if int(mask_mode_d) == 2 else "false" + dtype_q = dtype_literal[dtype_q] + dtype_kv = dtype_literal[dtype_kv] + dtype_out = dtype_literal[dtype_out] + idtype = idtype_literal[idtype] + + content = f"""#include +#include "pod_config.inc" + +namespace flashinfer {{ + +constexpr auto use_custom_mask_p = MaskMode::kNone == MaskMode::kCustom; +constexpr auto use_custom_mask_d = MaskMode::kNone == MaskMode::kCustom; + +using PrefillParams = BatchPrefillPagedParams<{dtype_q}, {dtype_kv}, {dtype_out}>; +using DecodeParams = BatchPrefillPagedParams<{dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>; + +using AttentionVariant1_P = DefaultAttention; +using AttentionVariant1_D = DefaultAttention; + +{get_insts("AttentionVariant1_P", "AttentionVariant1_D", dtype_out)} + +using AttentionVariant2_P = DefaultAttention; +using AttentionVariant2_D = DefaultAttention; + +{get_insts("AttentionVariant2_P", "AttentionVariant2_D", dtype_out)} + +using AttentionVariant3_P = DefaultAttention; +using AttentionVariant3_D = DefaultAttention; + +{get_insts("AttentionVariant3_P", "AttentionVariant3_D", dtype_out)} + +using AttentionVariant4_P = DefaultAttention; +using AttentionVariant4_D = DefaultAttention; + +{get_insts("AttentionVariant4_P", "AttentionVariant4_D", dtype_out)} + +}}""" + return content + + +if __name__ == "__main__": + pattern = ( + r"pod_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_" + r"fp16qkred_([a-z]+)_maskp_([0-9]+)_maskd_([0-9]+)_" + r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" + ) + compiled_pattern = re.compile(pattern) + path = Path(sys.argv[1]) + fname = path.name + match = compiled_pattern.match(fname) + + with open(path, "w") as f: + f.write(get_cu_file_str(*match.groups())) diff --git a/benchmarks/bench_mixed_attention.py b/benchmarks/bench_pod_attention.py similarity index 66% rename from benchmarks/bench_mixed_attention.py rename to benchmarks/bench_pod_attention.py index f581628b9..6312ce812 100644 --- a/benchmarks/bench_mixed_attention.py +++ b/benchmarks/bench_pod_attention.py @@ -17,48 +17,67 @@ def run_bench( device=0, causal=True, ): - # POD Attention only supports page size = 1 due to use of single prefill kernel - page_block_size = 1 + # if page size > 1, prefill kv len must be divisible by page size to ensure + # an identical workload as in BatchAttention + page_size = 1 seq_lens = torch.tensor(d_kv_lens + p_kv_lens, dtype=torch.int32) q_lens = torch.tensor(d_qo_lens + p_qo_lens, dtype=torch.int32) - seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int() - d_seq_lens_blocks = ( - torch.tensor(d_kv_lens, dtype=torch.int32) / page_block_size - ).int() + seq_lens_blocks = torch.ceil(seq_lens / page_size).int() + p_seq_lens = torch.tensor(p_kv_lens, dtype=torch.int32) / page_size + d_seq_lens = (torch.tensor(d_kv_lens, dtype=torch.int32) / page_size).int() - q_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int() + # General params + qo_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int() kv_indptr = torch.cat( [torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0 ).int() - d_q_indptr = torch.cat( - [torch.tensor([0]), torch.cumsum(torch.tensor(d_qo_lens), 0)], dim=0 - ).int() - d_kv_indptr = torch.cat( - [torch.tensor([0]), torch.cumsum(d_seq_lens_blocks, 0)], dim=0 - ).int() - num_blocks = kv_indptr[-1].item() - - q = torch.rand(q_indptr[-1].item(), num_qo_heads, head_dim).to( + num_pages = kv_indptr[-1].item() + q = torch.rand(qo_indptr[-1].item(), num_qo_heads, head_dim).to( device, dtype=torch.bfloat16 ) - kv_data = torch.randn(num_blocks, 2, page_block_size, num_kv_heads, head_dim).to( + kv_data = torch.randn(num_pages, 2, page_size, num_kv_heads, head_dim).to( device, dtype=torch.bfloat16 ) + # Prefill params + seq_lens_blocks_p = torch.ceil( + torch.tensor(p_kv_lens, dtype=torch.int32) / page_size + ).int() + qo_indptr_p = torch.cat( + [torch.tensor([0]), torch.cumsum(torch.tensor(p_qo_lens), 0)], dim=0 + ).int() + kv_indptr_p = torch.cat( + [torch.tensor([0]), torch.cumsum(p_seq_lens, 0)], dim=0 + ).int() + num_pages_p = seq_lens_blocks_p[-1].item() + kv_indices_p = torch.arange(num_pages_p, device=device, dtype=torch.int32) + last_page_len_p = (p_seq_lens - 1) % page_size + 1 + + # Decode params + + qo_indptr_d = torch.cat( + [torch.tensor([0]), torch.cumsum(torch.tensor(d_qo_lens), 0)], dim=0 + ).int() + kv_indptr_d = torch.cat( + [torch.tensor([0]), torch.cumsum(d_seq_lens, 0)], dim=0 + ).int() + num_pages_d = kv_indptr_d[-1].item() + kv_indices_d = torch.arange(num_pages_d, device=device, dtype=torch.int32) + last_page_len_d = (d_seq_lens - 1) % page_size + 1 + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) kv_layout = "NHD" - wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout=kv_layout, backend="fa2", ) - last_page_len = (seq_lens - 1) % page_block_size + 1 + last_page_len = (seq_lens - 1) % page_size + 1 wrapper_old.plan( - q_indptr.to(device), + qo_indptr.to(device), kv_indptr.to(device), - torch.arange(num_blocks).int().to(device), + torch.arange(num_pages, dtype=torch.int32, device=device), last_page_len, num_qo_heads, num_kv_heads, @@ -72,28 +91,28 @@ def run_bench( ms_old = do_bench(lambda: wrapper_old.run(q, kv_data)) if len(p_kv_lens) == 1: - q_d = q[: d_q_indptr[-1]] - kv_d = kv_data[: d_kv_indptr[-1]].unbind(1) - q_p = q[d_q_indptr[-1] :] - k_p, v_p = kv_data[d_kv_indptr[-1] :].unbind(1) - k_p, v_p = k_p.squeeze(1), v_p.squeeze(1) - kv_indices_d = torch.arange( - 0, d_kv_indptr[-1], device=device, dtype=torch.int32 - ) + q_d = q[: qo_indptr_d[-1]] + q_p = q[qo_indptr_d[-1] :] + kv_indices_d = torch.arange(0, num_pages_d, device=device, dtype=torch.int32) - last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1 + last_page_len_d = (d_seq_lens - 1) % page_size + 1 wrapper_pod = flashinfer.PODWithPagedKVCacheWrapper( workspace_buffer, kv_layout=kv_layout, ) wrapper_pod.plan( - d_kv_indptr.to(device), + qo_indptr_p.to(device), + kv_indptr_p.to(device), + kv_indices_p.to(device), + last_page_len_p, + qo_indptr_d.to(device), + kv_indptr_d.to(device), kv_indices_d.to(device), - last_page_len=last_page_len_d, + last_page_len_d, num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, - page_size=page_block_size, + page_size=page_size, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, ) @@ -113,29 +132,47 @@ def run_bench( ms_pod = do_bench( lambda: wrapper_pod.run( q_p, - k_p, - v_p, q_d, - kv_d, + paged_kv_cache=kv_data, causal_p=causal, causal_d=causal, ) ) + # Persistent attention + wrapper = flashinfer.BatchAttention(kv_layout="NHD") + wrapper.plan( + qo_indptr.to(device), + kv_indptr.to(device), + torch.arange(num_pages, dtype=torch.int32, device=device), + seq_lens.to(device), + num_qo_heads, + num_kv_heads, + head_dim, + head_dim, + page_size, + causal=causal, + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, + ) + ms_persistent = do_bench(lambda: wrapper.run(q, kv_data)) - print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms") - if len(p_kv_lens) == 1: - print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms") total_bytes = ( q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size() ) + print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms") + if len(p_kv_lens) == 1: + print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms") + print(f"Elapsed time (Persistent Attention): {ms_persistent:.2f} ms") + print(f"Loading memory size (MB): {total_bytes / (1024**2):.2f} MB") bandwidth_old_gb_s = total_bytes / (ms_old * 1e-3) / (1024**3) - + bandwidth_new_gb_s = total_bytes / (ms_persistent * 1e-3) / (1024**3) print(f"Memory bandwidth (Batched Prefill): {bandwidth_old_gb_s:.2f} GB/s") if len(p_kv_lens) == 1: bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3) print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s") + print(f"Memory bandwidth (Persistent Attention): {bandwidth_new_gb_s:.2f} GB/s") if __name__ == "__main__": diff --git a/csrc/batch_attention.cu b/csrc/batch_attention.cu index 86a7fd287..3d6aac8c3 100644 --- a/csrc/batch_attention.cu +++ b/csrc/batch_attention.cu @@ -154,7 +154,7 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_indptr_offset); params[i].merge_o_indices = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_o_indices_offset); - params[i].num_packed_qo_len = + params[i].num_to_merge_qo_len = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.num_qo_len_offset); params[i].num_kv_heads = num_kv_heads; diff --git a/csrc/batch_attention_customize_config.jinja b/csrc/batch_attention_customize_config.jinja index 6bc85b067..116bfc3d8 100644 --- a/csrc/batch_attention_customize_config.jinja +++ b/csrc/batch_attention_customize_config.jinja @@ -80,7 +80,7 @@ struct PersistentParams { // for state reduction IdType* merge_indptr; IdType* merge_o_indices; - IdType* num_packed_qo_len; + IdType* num_to_merge_qo_len; uint32_t num_kv_heads; uint_fastdiv gqa_group_size; diff --git a/csrc/batch_prefill.cu b/csrc/batch_prefill.cu index a51fc7f56..94588b393 100644 --- a/csrc/batch_prefill.cu +++ b/csrc/batch_prefill.cu @@ -47,7 +47,7 @@ at::Tensor BatchPrefillWithKVCachePlan( at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, - int64_t head_dim_vo, bool causal) { + int64_t head_dim_vo) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = diff --git a/csrc/flashinfer_ops.cu b/csrc/flashinfer_ops.cu index 526ad969a..e305bfade 100644 --- a/csrc/flashinfer_ops.cu +++ b/csrc/flashinfer_ops.cu @@ -123,23 +123,25 @@ void BatchPrefillWithPagedKVCacheRun( int64_t window_left BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS); //========== pod-attention ========= -void pod_with_kv_cache_tensor( +void PODWithKVCacheTensorRun( + // Shared params + at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, + at::Tensor plan_info_vec, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, // Prefill params - at::Tensor q_p, at::Tensor k_p, at::Tensor v_p, at::Tensor tmp_p, at::Tensor o_p, + at::Tensor q_p, at::Tensor paged_k_p, at::Tensor paged_v_p, std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, int64_t window_left_p, std::optional maybe_custom_mask_p, std::optional maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, // Decode params - at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, - at::Tensor plan_info_vec, at::Tensor q_d, at::Tensor paged_k_cache_d, - at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, at::Tensor paged_kv_indptr_d, - at::Tensor paged_kv_indices_d, at::Tensor paged_kv_last_page_len_d, at::Tensor o_d, - std::optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, - int64_t window_left, std::optional maybe_custom_mask_d, - std::optional maybe_mask_indptr_d, std::optional maybe_alibi_slopes_d, - double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, - bool enable_pdl); + at::Tensor q_d, at::Tensor paged_k_cache_d, at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, + at::Tensor paged_kv_indptr_d, at::Tensor paged_kv_indices_d, + at::Tensor paged_kv_last_page_len_d, std::optional maybe_lse_d, + int64_t mask_mode_code_d, int64_t layout_d, int64_t window_left_d, + std::optional maybe_custom_mask_d, std::optional maybe_mask_indptr_d, + std::optional maybe_alibi_slopes_d, double logits_soft_cap_d, double sm_scale_d, + double rope_rcp_scale_d, double rope_rcp_theta_d, bool enable_pdl); //========== quantization ========== void packbits(at::Tensor x, const std::string& bitorder, at::Tensor y); @@ -280,7 +282,7 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { // pod-attention // Temporarily disabled because we don't generate the implementation yet. - // m.def("pod_with_kv_cache_tensor", pod_with_kv_cache_tensor); + // m.def("PODWithKVCacheTensor", PODWithKVCacheTensorRun); // quantization // GPU packbits operator diff --git a/csrc/pod.cu b/csrc/pod.cu index fabde1be7..b655647fa 100644 --- a/csrc/pod.cu +++ b/csrc/pod.cu @@ -23,12 +23,10 @@ namespace flashinfer { template -cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, - typename PrefillParams::DTypeO* tmp, - DecodeParams decode_params, + bool USE_FP16_QK_REDUCTION, MaskMode MASK_MODE_P, uint32_t CTA_TILE_Q_P, + uint32_t CTA_TILE_Q_D, MaskMode MASK_MODE_D, typename PrefillAttentionVariant, + typename DecodeAttentionVariant, typename PrefillParams, typename DecodeParams> +cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, DecodeParams decode_params, typename DecodeParams::DTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream); @@ -36,135 +34,156 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, using namespace flashinfer; -void pod_with_kv_cache_tensor( +at::Tensor PODWithKVCachePlan(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr_p, + at::Tensor kv_indptr_p, at::Tensor kv_len_arr_p, + uint32_t total_num_rows_p, uint32_t batch_size_p, + at::Tensor qo_indptr_d, at::Tensor kv_indptr_d, + uint32_t total_num_rows_d, uint32_t batch_size_d, + uint32_t num_qo_heads_p, uint32_t num_kv_heads, uint32_t head_dim_qk, + uint32_t head_dim_vo, uint32_t page_size, bool enable_cuda_graph) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + + PODPlanInfo plan_info; + + const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device()); + const cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); + cudaError_t status = + PODPlan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, qo_indptr_p.data_ptr(), + kv_indptr_p.data_ptr(), total_num_rows_p, batch_size_p, + qo_indptr_d.data_ptr(), kv_indptr_d.data_ptr(), + total_num_rows_d, batch_size_d, num_qo_heads_p, num_kv_heads, head_dim_qk, + head_dim_vo, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); + + TORCH_CHECK(status == cudaSuccess, + "Failed to plan prefill with error: ", cudaGetErrorString(status)); + + return vec_to_tensor(plan_info.ToVector()); +} + +void PODWithKVCacheTensorRun( + // Shared params + at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, + at::Tensor plan_info_vec, at::Tensor paged_k_cache, at::Tensor paged_v_cache, + at::Tensor qo_indptr, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, + at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional maybe_lse, + int64_t layout, // Prefill params - at::Tensor q_p, at::Tensor k_p, at::Tensor v_p, at::Tensor tmp_p, at::Tensor o_p, - std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, - int64_t window_left_p, std::optional maybe_custom_mask_p, - std::optional maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p, - double rope_rcp_scale_p, double rope_rcp_theta_p, + at::Tensor q_p, int64_t mask_mode_code_p, int64_t window_left_p, + std::optional maybe_custom_mask_p, std::optional maybe_alibi_slopes_p, + double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, // Decode params - at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, - at::Tensor plan_info_vec, at::Tensor q_d, at::Tensor paged_k_cache_d, - at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, at::Tensor paged_kv_indptr_d, - at::Tensor paged_kv_indices_d, at::Tensor paged_kv_last_page_len_d, at::Tensor o_d, - std::optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, - int64_t window_left_d, std::optional maybe_custom_mask_d, - std::optional maybe_mask_indptr_d, std::optional maybe_alibi_slopes_d, - double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, - bool enable_pdl) { + at::Tensor q_d, int64_t mask_mode_code_d, int64_t window_left_d, + std::optional maybe_custom_mask_d, std::optional maybe_mask_indptr_d, + std::optional maybe_alibi_slopes_d, double logits_soft_cap_d, double sm_scale_d, + double rope_rcp_scale_d, double rope_rcp_theta_d, bool enable_pdl) { + PODPlanInfo plan_info; + plan_info.FromVector(tensor_to_vec(plan_info_vec)); + auto device = q_d.device(); + uint32_t batch_size = paged_kv_indptr.size(0) - 1; + void* float_buffer_ptr = static_cast(float_workspace_buffer_d.data_ptr()); + void* int_buffer_ptr = static_cast(int_workspace_buffer_d.data_ptr()); + // get kv_cache_strides + const int64_t* kv_cache_strides = nullptr; + auto k_strides = paged_k_cache.strides(); + auto v_strides = paged_v_cache.strides(); + TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical"); + kv_cache_strides = k_strides.data(); + // Prefill setup - unsigned int head_dim_qk = q_p.size(2); - unsigned int kv_len_p, qo_len_p, num_kv_heads, num_qo_heads; - QKVLayout kv_layout_p = static_cast(layout_p); - qo_len_p = q_p.size(0); - num_qo_heads = q_p.size(1); - uint32_t q_stride_n_p = q_p.stride(0), q_stride_h_p = q_p.stride(1), k_stride_n_p, k_stride_h_p, - v_stride_n_p, v_stride_h_p; - if (kv_layout_p == QKVLayout::kNHD) { - kv_len_p = k_p.size(0); - num_kv_heads = k_p.size(1); - k_stride_n_p = k_p.stride(0); - k_stride_h_p = k_p.stride(1); - v_stride_n_p = v_p.stride(0); - v_stride_h_p = v_p.stride(1); - } else { - kv_len_p = k_p.size(1); - num_kv_heads = k_p.size(0); - k_stride_h_p = k_p.stride(0); - k_stride_n_p = k_p.stride(1); - v_stride_h_p = v_p.stride(0); - v_stride_n_p = v_p.stride(1); - } - if (maybe_lse_p) { - const auto& lse = *maybe_lse_p; - TORCH_CHECK(lse.size(0) == qo_len_p, lse.size(0), q_p.size(0)); - TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q_p.size(1)); + uint32_t head_dim_qk = q_p.size(2); + uint32_t qo_len, num_qo_heads_p; + QKVLayout kv_layout = static_cast(layout); + qo_len = q_p.size(0) + q_d.size(0); + num_qo_heads_p = q_p.size(1); + uint32_t q_stride_n_p = q_p.stride(0), q_stride_h_p = q_p.stride(1); + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == qo_len, lse.size(0), qo_len); + TORCH_CHECK(lse.size(1) == num_qo_heads_p, lse.size(1), q_p.size(1)); } const MaskMode mask_mode_p = static_cast(mask_mode_code_p); auto q_scalar_type = q_p.scalar_type(); - auto kv_scalar_type = k_p.scalar_type(); // Decode setup (Tensor decode = batched prefill) - PrefillPlanInfo plan_info; - plan_info.FromVector(tensor_to_vec(plan_info_vec)); - QKVLayout kv_layout_d = static_cast(layout_d); - auto device = q_d.device(); - int64_t batch_size = paged_kv_indptr_d.size(0) - 1; - int64_t num_qo_heads_d = q_d.size(1); - - TORCH_CHECK(num_qo_heads == num_qo_heads_d, + uint32_t num_qo_heads = q_d.size(1); + TORCH_CHECK(num_qo_heads_p == num_qo_heads, "POD currently requires same # Query heads for prefill and decode"); - int64_t num_kv_heads_d, page_size_d; - uint32_t head_dim_qk_d = q_d.size(2); - if (kv_layout_d == QKVLayout::kHND) { - num_kv_heads_d = paged_k_cache_d.size(1); - page_size_d = paged_k_cache_d.size(2); + uint32_t num_kv_heads_d, num_kv_heads, page_size; + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_k_cache.size(1); + num_kv_heads_d = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); } else { - page_size_d = paged_k_cache_d.size(1); - num_kv_heads_d = paged_k_cache_d.size(2); + num_kv_heads = paged_k_cache.size(2); + num_kv_heads_d = paged_k_cache.size(2); + page_size = paged_k_cache.size(1); } TORCH_CHECK(num_kv_heads == num_kv_heads_d, "POD currently requires same # KV heads for prefill and decode; Prefill: ", num_kv_heads, ", Decode: ", num_kv_heads_d); - if (maybe_lse_d) { - const auto& lse = *maybe_lse_d; - TORCH_CHECK(lse.size(0) == q_d.size(0), lse.size(0), q_d.size(0)); - TORCH_CHECK(lse.size(1) == q_d.size(1), lse.size(1), q_d.size(1)); - } - - void* float_buffer_ptr = static_cast(float_workspace_buffer_d.data_ptr()); - void* int_buffer_ptr = static_cast(int_workspace_buffer_d.data_ptr()); - const MaskMode mask_mode_d = static_cast(mask_mode_code_d); - auto q_scalar_type_d = q_d.scalar_type(); - auto kv_scalar_type_d = paged_k_cache_d.scalar_type(); // get q_stride_n and q_stride_h const auto q_stride_n_d = q_d.stride(0); const auto q_stride_h_d = q_d.stride(1); - // get kv_cache_strides - const int64_t* kv_cache_strides_d = nullptr; - auto k_strides_d = paged_k_cache_d.strides(); - auto v_strides_d = paged_v_cache_d.strides(); - TORCH_CHECK(k_strides_d == v_strides_d, "k/v strides must be identical"); - kv_cache_strides_d = k_strides_d.data(); - const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer_d.device()); const cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); DISPATCH_context( MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, USE_SLIDING_WINDOW_P, USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, HEAD_DIM_VO, batch_size, kv_layout, + static_cast(paged_k_cache.data_ptr()), + static_cast(paged_v_cache.data_ptr()), kv_cache_strides, + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); PrefillParams prefill_params; { // Make params a reference to prefill_params to set values PrefillParams& params = prefill_params; params.q = static_cast(q_p.data_ptr()); - params.k = static_cast(k_p.data_ptr()); - params.v = static_cast(v_p.data_ptr()); - params.o = static_cast(o_p.data_ptr()); - params.lse = maybe_lse_p ? static_cast(maybe_lse_p->data_ptr()) : nullptr; - params.num_qo_heads = num_qo_heads; - params.num_kv_heads = num_kv_heads; - params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); - params.qo_len = qo_len_p; - params.kv_len = kv_len_p; + params.paged_kv = paged_kv; + params.q_indptr = static_cast(qo_indptr.data_ptr()); + params.o = static_cast(o.data_ptr()); + params.lse = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; + params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); params.q_stride_n = q_stride_n_p; params.q_stride_h = q_stride_h_p; - params.k_stride_n = k_stride_n_p; - params.k_stride_h = k_stride_h_p; - params.v_stride_n = v_stride_n_p; - params.v_stride_h = v_stride_h_p; - params.window_left = window_left_p; - params.partition_kv = false; + params.num_kv_heads = num_kv_heads; + params.num_qo_heads = num_qo_heads; + params.request_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.request_indices_offset); + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.kv_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_tile_indices_offset); + params.o_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.o_indptr_offset); + if (plan_info.split_kv) { + params.merge_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_indptr_offset); + if (plan_info.enable_cuda_graph) { + params.block_valid_mask = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.block_valid_mask_offset); + } + } + params.kv_chunk_size_ptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_chunk_size_ptr_offset_p); + params.padded_batch_size = plan_info.padded_batch_size_p; params.maybe_custom_mask = maybe_custom_mask_p ? static_cast(maybe_custom_mask_p->data_ptr()) : nullptr; @@ -175,6 +194,18 @@ void pod_with_kv_cache_tensor( params.sm_scale = sm_scale_p; params.rope_rcp_scale = rope_rcp_scale_p; params.rope_rcp_theta = rope_rcp_theta_p; + params.max_total_num_rows = plan_info.total_num_rows; + if (plan_info.enable_cuda_graph) { + params.total_num_rows = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.total_num_rows_offset); + } + params.partition_kv = plan_info.split_kv; + if (plan_info.split_kv) { + if (plan_info.enable_cuda_graph) { + params.block_valid_mask = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.block_valid_mask_offset); + } + } } DecodeParams decode_params; @@ -183,35 +214,35 @@ void pod_with_kv_cache_tensor( { DecodeParams& params = decode_params; params.q = static_cast(q_d.data_ptr()); - paged_kv_t paged_kv( - num_kv_heads, page_size_d, HEAD_DIM_VO, batch_size, kv_layout_d, - static_cast(paged_k_cache_d.data_ptr()), - static_cast(paged_v_cache_d.data_ptr()), kv_cache_strides_d, - static_cast(paged_kv_indices_d.data_ptr()), - static_cast(paged_kv_indptr_d.data_ptr()), - static_cast(paged_kv_last_page_len_d.data_ptr())); params.paged_kv = paged_kv; - params.q_indptr = static_cast(qo_indptr_d.data_ptr()); - params.o = static_cast(o_d.data_ptr()); - - params.lse = maybe_lse_d ? static_cast(maybe_lse_d->data_ptr()) : nullptr; - params.num_qo_heads = num_qo_heads; + params.q_indptr = static_cast(qo_indptr.data_ptr()); + params.o = static_cast(o.data_ptr()); + params.lse = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); params.q_stride_n = q_stride_n_d; params.q_stride_h = q_stride_h_d; params.window_left = window_left_d; + params.num_kv_heads = num_kv_heads; + params.num_qo_heads = num_qo_heads; - params.request_indices = nullptr; - params.qo_tile_indices = nullptr; - params.kv_tile_indices = nullptr; - params.merge_indptr = nullptr; - params.o_indptr = nullptr; - params.kv_chunk_size_ptr = nullptr; - params.block_valid_mask = nullptr; - params.total_num_rows = nullptr; - params.max_total_num_rows = 0; - params.padded_batch_size = 0; - params.partition_kv = false; + params.request_indices = prefill_params.request_indices; + params.qo_tile_indices = prefill_params.qo_tile_indices; + params.kv_tile_indices = prefill_params.kv_tile_indices; + params.o_indptr = prefill_params.o_indptr; + params.kv_chunk_size_ptr = prefill_params.kv_chunk_size_ptr; + + params.partition_kv = plan_info.split_kv; + if (plan_info.split_kv) { + params.merge_indptr = prefill_params.merge_indptr; + // These should be assigned from plan info, not from prefill_params + tmp_v = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.v_offset); + tmp_s = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.s_offset); + if (plan_info.enable_cuda_graph) { + params.block_valid_mask = prefill_params.block_valid_mask; + } + } + params.padded_batch_size = plan_info.padded_batch_size_d; + params.max_total_num_rows = plan_info.total_num_rows; params.maybe_mask_indptr = maybe_mask_indptr_d ? static_cast(maybe_mask_indptr_d->data_ptr()) @@ -224,30 +255,8 @@ void pod_with_kv_cache_tensor( params.rope_rcp_scale = rope_rcp_scale_d; params.rope_rcp_theta = rope_rcp_theta_d; - params.request_indices = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.request_indices_offset); - params.qo_tile_indices = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); - params.kv_tile_indices = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_tile_indices_offset); - params.o_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.o_indptr_offset); - params.kv_chunk_size_ptr = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_chunk_size_ptr_offset); - if (plan_info.split_kv) { - params.merge_indptr = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_indptr_offset); - tmp_v = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.v_offset); - tmp_s = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.s_offset); - if (plan_info.enable_cuda_graph) { - params.block_valid_mask = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.block_valid_mask_offset); - } - } - params.padded_batch_size = plan_info.padded_batch_size; - params.max_total_num_rows = plan_info.total_num_rows; if (plan_info.enable_cuda_graph) { - params.total_num_rows = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.total_num_rows_offset); + params.total_num_rows = prefill_params.total_num_rows; } } @@ -260,12 +269,13 @@ void pod_with_kv_cache_tensor( DefaultAttention; // DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { - constexpr size_t CTA_TILE_Q = 16; + constexpr size_t CTA_TILE_Q_P = plan_info.cta_tile_q_p; + constexpr size_t CTA_TILE_Q_D = plan_info.cta_tile_q_d; cudaError_t status = flashinfer::PODWithKVCacheTensorDispatched< HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, MASK_MODE_P, - CTA_TILE_Q, MASK_MODE_D, PrefillAttentionVariant, DecodeAttentionVariant>( - prefill_params, static_cast(tmp_p.data_ptr()), decode_params, tmp_v, tmp_s, - enable_pdl, stream); + CTA_TILE_Q_P, CTA_TILE_Q_D, MASK_MODE_D, PrefillAttentionVariant, + DecodeAttentionVariant>(prefill_params, decode_params, tmp_v, tmp_s, enable_pdl, + stream); TORCH_CHECK(status == cudaSuccess, "PODWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); //}); diff --git a/csrc/pod_config.inc b/csrc/pod_config.inc index 16306651a..3c95f4286 100644 --- a/csrc/pod_config.inc +++ b/csrc/pod_config.inc @@ -30,7 +30,7 @@ using namespace flashinfer; return DISPATCH_BOOL(window_left_d > -1, USE_SLIDING_WINDOW_D, [&] { \ return DISPATCH_BOOL(false, USE_LOGITS_SOFT_CAP, [&] { \ using IdType = int32_t; \ - using PrefillParams = SinglePrefillParams;\ + using PrefillParams = BatchPrefillPagedParams;\ using DecodeParams = BatchPrefillPagedParams; \ __VA_ARGS__(); \ diff --git a/csrc/pod_customize_config.jinja b/csrc/pod_customize_config.jinja index b4c56a0e8..ed66afb4b 100644 --- a/csrc/pod_customize_config.jinja +++ b/csrc/pod_customize_config.jinja @@ -30,7 +30,7 @@ constexpr auto USE_SLIDING_WINDOW_D = {{ use_sliding_window_d }}; constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; constexpr bool USE_LOGITS_SOFT_CAP = false; -using PrefillParams = SinglePrefillParams; +using PrefillParams = BatchPrefillPagedParams; using DecodeParams = BatchPrefillPagedParams; #define DISPATCH_context(MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, \ diff --git a/csrc/pod_jit_pybind.cu b/csrc/pod_jit_pybind.cu index 2e8d47bf2..db5e44d2c 100644 --- a/csrc/pod_jit_pybind.cu +++ b/csrc/pod_jit_pybind.cu @@ -16,25 +16,24 @@ #include "pod_config.inc" #include "pytorch_extension_utils.h" -void pod_with_kv_cache_tensor( +void PODWithKVCacheTensorRun( + // Shared params + at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, + at::Tensor plan_info_vec, at::Tensor paged_k_cache, at::Tensor paged_v_cache, + at::Tensor qo_indptr, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, + at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional maybe_lse, + int64_t layout, // Prefill params - at::Tensor q_p, at::Tensor k_p, at::Tensor v_p, at::Tensor tmp_p, at::Tensor o_p, - std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, - int64_t window_left_p, std::optional maybe_custom_mask_p, - std::optional maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p, - double rope_rcp_scale_p, double rope_rcp_theta_p, + at::Tensor q_p, int64_t mask_mode_code_p, int64_t window_left_p, + std::optional maybe_custom_mask_p, std::optional maybe_alibi_slopes_p, + double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, // Decode params - at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, - at::Tensor plan_info_vec, at::Tensor q_d, at::Tensor paged_k_cache_d, - at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, at::Tensor paged_kv_indptr_d, - at::Tensor paged_kv_indices_d, at::Tensor paged_kv_last_page_len_d, at::Tensor o_d, - std::optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, - int64_t window_left_d, std::optional maybe_custom_mask_d, - std::optional maybe_mask_indptr_d, std::optional maybe_alibi_slopes_d, - double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, - bool enable_pdl); + at::Tensor q_d, int64_t mask_mode_code_d, int64_t window_left_d, + std::optional maybe_custom_mask_d, std::optional maybe_mask_indptr_d, + std::optional maybe_alibi_slopes_d, double logits_soft_cap_d, double sm_scale_d, + double rope_rcp_scale_d, double rope_rcp_theta_d, bool enable_pdl); TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { // Batch-request prefill attention with KV-Cache operator - m.def("pod_with_kv_cache_tensor", pod_with_kv_cache_tensor); + m.def("PODWithKVCacheTensor", PODWithKVCacheTensorRun); } diff --git a/csrc/pod_kernel_inst.jinja b/csrc/pod_kernel_inst.jinja index e6938d58b..19b86efbd 100644 --- a/csrc/pod_kernel_inst.jinja +++ b/csrc/pod_kernel_inst.jinja @@ -13,20 +13,25 @@ #include "pod_config.inc" -using namespace flashinfer; - namespace flashinfer { + +using PrefillParams = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>; +using DecodeParams = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ idtype }}>; + +constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; + constexpr auto use_custom_mask_p = {{ mask_mode_p }} == MaskMode::kCustom; constexpr auto use_custom_mask_d = {{ mask_mode_d }} == MaskMode::kCustom; -// Not sure about the below declaration -constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; +using PrefillAttentionVariant = DefaultAttention; +using DecodeAttentionVariant = DefaultAttention; + +{% for cta_tile_q_p in [16, 64, 128] %} template cudaError_t PODWithKVCacheTensorDispatched< - {{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE, - {{ use_fp16_qk_reduction }}, {{ mask_mode_p }}, 16, - {{ mask_mode_d }}, {{ variant_name_p }}, - {{ variant_name_d }}, PrefillParams, DecodeParams>( - PrefillParams prefill_params, {{ dtype_o }}* tmp, - DecodeParams decode_params, {{ dtype_o }}* tmp_v, - float *tmp_s, bool enable_pdl, cudaStream_t stream); -}; + {{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE, {{ use_fp16_qk_reduction }}, {{ mask_mode_p }}, {{ cta_tile_q_p }}, 16, {{ mask_mode_d }}, + PrefillAttentionVariant, DecodeAttentionVariant, PrefillParams, DecodeParams>( + PrefillParams prefill_params, DecodeParams decode_params, + {{ dtype_o }}* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream); +{% endfor %} + +}; // namespace flashinfer diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 1cb685bcd..d9af24d74 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -948,7 +948,6 @@ def plan( self.is_cuda_graph_enabled, head_dim, head_dim, - False, # causal ) else: if self._jit_module is not None: diff --git a/flashinfer/pod.py b/flashinfer/pod.py index 49b2847a0..ebb331b73 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -44,7 +44,7 @@ @functools.cache def get_pod_module(*args): module = gen_pod_module(*args).build_and_load() - return SimpleNamespace(run_tensor=module.pod_with_kv_cache_tensor.default) + return SimpleNamespace(run_tensor=module.PODWithKVCacheTensor.default) class PODWithPagedKVCacheWrapper: @@ -65,7 +65,7 @@ class PODWithPagedKVCacheWrapper: >>> page_size = 16 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") - >>> decode_wrapper = flashinfer.PODWithPagedKVCacheWrapper( + >>> wrapper = flashinfer.PODWithPagedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 @@ -83,7 +83,7 @@ class PODWithPagedKVCacheWrapper: ... ) for _ in range(num_layers) ... ] >>> # create auxiliary data structures for batch decode attention - >>> decode_wrapper.plan( + >>> wrapper.plan( ... kv_page_indptr, ... kv_page_indices, ... kv_last_page_len, @@ -118,9 +118,12 @@ def __init__( float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, + qo_indptr_buffer: Optional[torch.Tensor] = None, paged_kv_indptr_buffer: Optional[torch.Tensor] = None, paged_kv_indices_buffer: Optional[torch.Tensor] = None, paged_kv_last_page_len_buffer: Optional[torch.Tensor] = None, + custom_mask_buf_p: Optional[torch.Tensor] = None, + mask_indptr_buf_p: Optional[torch.Tensor] = None, jit_args: Optional[List[Any]] = None, ) -> None: r"""Constructor of :class:`PODWithPagedKVCacheWrapper`. @@ -140,19 +143,24 @@ def __init__( auxiliary data structures will be stored as the provided buffers. The ``batch_size`` cannot change during the lifecycle of this wrapper when CUDAGraph is enabled. - indptr_buffer : Optional[torch.Tensor] - The user reserved buffer on GPU to store the indptr of the paged kv cache, the size + qo_indptr_buffer: Optional[torch.Tensor] + The user reserved buffer to store the ``qo_indptr`` array, the size of the buffer + should be ``[batch_size + 1]``. + This argument is only effective when ``use_cuda_graph`` is ``True``. + + paged_kv_indptr_buffer: Optional[torch.Tensor] + The user reserved buffer on GPU to store the indptr of the prefill paged kv cache, the size of the buffer should be ``[batch_size + 1]``. Only needed when ``use_cuda_graph`` is ``True``. - indices_buffer : Optional[torch.Tensor] - The user reserved buffer on GPU to store the page indices of the paged kv cache, + paged_kv_indices_buffer: Optional[torch.Tensor] + The user reserved buffer on GPU to store the page indices of the prefill paged kv cache, should be large enough to store the maximum number of page indices (``max_num_pages``) during the lifecycle of this wrapper. Only needed when ``use_cuda_graph`` is ``True``. - last_page_len_buffer : Optional[torch.Tensor] - The user reserved buffer on GPU to store the number of entries in the last page, the + paged_kv_last_page_len_buffer: Optional[torch.Tensor] + The user reserved buffer on GPU to store the number of entries in the last page for prefill, the size of the buffer should be ``[batch_size]``. Only needed when ``use_cuda_graph`` is ``True``. @@ -176,10 +184,14 @@ def __init__( # Override options. Only tensor core version is performant. use_tensor_cores = True self._jit_module = None + assert ( + custom_mask_buf_p is None and mask_indptr_buf_p is None + ), "custom_mask_buf_p and mask_indptr_buf_p are not supported yet" self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device + self._qo_indptr_buf = qo_indptr_buffer self._int_workspace_buffer = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) @@ -191,22 +203,36 @@ def __init__( ) if use_cuda_graph: - if not torch.is_tensor(paged_kv_indptr_buffer): + if not torch.is_tensor(qo_indptr_buffer): + raise ValueError( + "qo_indptr_buffer should be a torch.Tensor in CUDA graph mode" + ) + if not torch.is_tensor(paged_kv_indptr_buffer) or not torch.is_tensor( + paged_kv_indptr_buffer + ): raise ValueError( "paged_kv_indptr_buffer should be a torch.Tensor in cudagraph mode" ) - if not torch.is_tensor(paged_kv_indices_buffer): + if not torch.is_tensor(paged_kv_indices_buffer) or not torch.is_tensor( + paged_kv_indices_buffer + ): raise ValueError( "paged_kv_indices_buffer should be a torch.Tensor in cudagraph mode" ) - if not torch.is_tensor(paged_kv_last_page_len_buffer): + if not torch.is_tensor( + paged_kv_last_page_len_buffer + ) or not torch.is_tensor(paged_kv_last_page_len_buffer): raise ValueError( "paged_kv_last_page_len_buffer should be a torch.Tensor in cudagraph mode" ) self._fixed_batch_size = len(paged_kv_last_page_len_buffer) if len(paged_kv_indptr_buffer) != self._fixed_batch_size + 1: raise ValueError( - "The size of paged_kv_indptr_buffer should be batch_size + 1" + "The length of paged_kv_indptr_buffer_p should be batch_size + 1" + ) + if len(paged_kv_last_page_len_buffer) != self._fixed_batch_size: + raise ValueError( + "The length of paged_kv_last_page_len_buffer_p should be batch_size" ) else: self._fixed_batch_size = 0 @@ -216,7 +242,6 @@ def __init__( self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buffer self._use_tensor_cores = use_tensor_cores self._use_cuda_graph = use_cuda_graph - if use_cuda_graph: # NOTE(Zihao): if once created, no need to update it in plan/run self._qo_indptr_buf = torch.arange( @@ -255,9 +280,13 @@ def reset_workspace_buffer( def plan( self, - indptr: torch.Tensor, - indices: torch.Tensor, - last_page_len: torch.Tensor, + qo_indptr_p: torch.Tensor, + kv_indptr_p: torch.Tensor, + kv_indices_p: torch.Tensor, + last_page_len_p: torch.Tensor, + kv_indptr_d: torch.Tensor, + kv_indices_d: torch.Tensor, + last_page_len_d: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, @@ -276,12 +305,21 @@ def plan( Parameters ---------- - indptr : torch.Tensor - The indptr of the paged kv cache, shape: ``[batch_size + 1]`` - indices : torch.Tensor - The page indices of the paged kv cache, shape: ``[qo_indptr[-1]]`` - last_page_len : torch.Tensor - The number of entries in the last page of each request in the paged kv + qo_indptr_p: torch.Tensor + The indptr of the query/output tensor for prefill, shape: ``[batch_size + 1]``. + kv_indptr_p: torch.Tensor + The indptr of the paged kv cache for prefill, shape: ``[batch_size + 1]``. + kv_indices_p: torch.Tensor + The page indices of the paged kv cache for prefill, shape: ``[kv_indptr[-1]]``. + last_page_len_p : torch.Tensor + The number of entries in the last page of each request in the kv + cache, shape: ``[batch_size]`` + kv_indptr_d : torch.Tensor + The indptr of the paged kv cache for decode, shape: ``[batch_size + 1]`` + kv_indices_d : torch.Tensor + The page indices of the paged kv cache for decode, shape: ``[kv_indptr[-1]]`` + last_page_len_d : torch.Tensor + The number of entries in the last page of each request in the kv cache, shape: ``[batch_size]`` num_qo_heads : int The number of query/output heads @@ -324,46 +362,73 @@ def plan( """ # Logits soft cap is not supported currently logits_soft_cap = False - batch_size = len(last_page_len) + batch_size_p = len(last_page_len_p) + batch_size_d = len(last_page_len_d) + batch_size = batch_size_p + batch_size_d if logits_soft_cap is None: logits_soft_cap = 0.0 - qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") + qo_indptr_host_p = qo_indptr_p.to("cpu", non_blocking=True) + qo_indptr_host_d = _get_range_buf(batch_size_d + 1, "cpu") if self.is_cuda_graph_enabled: if batch_size != self._fixed_batch_size: raise ValueError( "The batch size should be fixed in cudagraph mode, the runtime batch size {} " " mismatches the batch size set during initialization {}".format( - batch_size, self._fixed_batch_size + batch_size_d, self._fixed_batch_size ) ) - if len(indices) > len(self._paged_kv_indices_buf): + if len(kv_indices_d) + len(kv_indices_p) > len(self._paged_kv_indices_buf): raise ValueError( "The size of indices should be less than or equal to the allocated buffer" ) - self._paged_kv_indptr_buf.copy_(indptr, non_blocking=non_blocking) - self._paged_kv_last_page_len_buf.copy_( - last_page_len, non_blocking=non_blocking + self._paged_kv_indptr_buf[: batch_size_p + 1].copy_( + kv_indptr_p, non_blocking=non_blocking + ) + self._paged_kv_indptr_buf[ + batch_size_p + 1 : batch_size_p + batch_size_d + 2 + ].copy_( + kv_indptr_d, + non_blocking=(kv_indptr_d.device == self.device) and non_blocking, + ) + self._paged_kv_last_page_len_buf[:batch_size_p].copy_( + last_page_len_p, non_blocking=non_blocking + ) + self._paged_kv_last_page_len_buf[ + batch_size_p : batch_size_p + batch_size_d + ].copy_( + last_page_len_d, + non_blocking=(last_page_len_d.device == self.device) and non_blocking, ) - self._paged_kv_indices_buf[: len(indices)].copy_( - indices, non_blocking=(indices.device == self.device) and non_blocking + self._paged_kv_indices_buf[: len(kv_indices_p)].copy_( + kv_indices_p, + non_blocking=(kv_indices_p.device == self.device) and non_blocking, + ) + self._paged_kv_indices_buf[ + len(kv_indices_p) : len(kv_indices_p) + len(kv_indices_d) + ].copy_( + kv_indices_d, + non_blocking=(kv_indices_d.device == self.device) and non_blocking, ) else: - self._paged_kv_indptr_buf = indptr.to( - self.device, non_blocking=non_blocking + to_device = lambda x: x.to(self.device, non_blocking=non_blocking) + self._qo_indptr_buf = torch.cat( + [to_device(qo_indptr_p), to_device(qo_indptr_host_d)] ) - self._paged_kv_indices_buf = indices.to( - self.device, non_blocking=non_blocking + self._paged_kv_indptr_buf = torch.cat( + [to_device(kv_indptr_p), to_device(kv_indptr_d)] ) - self._paged_kv_last_page_len_buf = last_page_len.to( - self.device, non_blocking=non_blocking + self._paged_kv_indices_buf = torch.cat( + [to_device(kv_indices_p), to_device(kv_indices_d)] ) - self._qo_indptr_buf = qo_indptr_host.to( - self.device, non_blocking=non_blocking + self._paged_kv_last_page_len_buf = torch.cat( + [to_device(last_page_len_p), to_device(last_page_len_d)] ) - indptr_host = indptr.to("cpu") - last_page_len_host = last_page_len.to("cpu") + kv_indptr_host_p = kv_indptr_p.to("cpu", non_blocking=True) + kv_indptr_host_d = kv_indptr_d.to("cpu", non_blocking=True) + last_page_len_host_p = last_page_len_p.to("cpu", non_blocking=True) + last_page_len_host_d = last_page_len_d.to("cpu", non_blocking=True) if data_type is not None: if q_data_type is None: @@ -378,7 +443,13 @@ def plan( self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type - kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size) + torch.cuda.synchronize() + kv_lens_arr_host_p = get_seq_lens( + kv_indptr_host_p, last_page_len_host_p, page_size + ) + kv_lens_arr_host_d = get_seq_lens( + kv_indptr_host_d, last_page_len_host_d, page_size + ) if self._jit_module is not None: self._cached_module = self._jit_module else: @@ -387,7 +458,7 @@ def plan( q_data_type, kv_data_type, q_data_type, - indptr.dtype, + kv_indptr_d.dtype, head_dim, # head_dim_qk head_dim, # head_dim_vo PosEncodingMode[pos_encoding_mode].value, @@ -399,21 +470,25 @@ def plan( self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, - qo_indptr_host, - indptr_host, - kv_lens_arr_host, - batch_size, # total_num_rows - batch_size, + qo_indptr_host_p, + kv_indptr_host_p, + kv_lens_arr_host_p, + batch_size_p, + qo_indptr_host_p[-1], # total_num_rows_p + qo_indptr_host_d, + kv_indptr_host_d, + kv_lens_arr_host_d, + batch_size_d, + qo_indptr_host_d[-1], # total_num_rows_d num_qo_heads, num_kv_heads, - page_size, - self.is_cuda_graph_enabled, head_dim, head_dim, - False, # causal + page_size, + self.is_cuda_graph_enabled, ) - self._indptr_type = indptr.dtype + self._indptr_type = kv_indptr_d.dtype self._pos_encoding_mode = pos_encoding_mode self._window_left = window_left self._logits_soft_cap = logits_soft_cap @@ -427,26 +502,21 @@ def run( self, # Main params (prefill and decode) q_p: torch.Tensor, - k_p: torch.Tensor, - v_p: torch.Tensor, q_d: torch.Tensor, - paged_kv_cache_d: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], # Prefill options custom_mask_p: Optional[torch.Tensor] = None, packed_custom_mask_p: Optional[torch.Tensor] = None, causal_p: bool = False, - kv_layout_p: str = "NHD", pos_encoding_mode_p: str = "NONE", sm_scale_p: Optional[float] = None, window_left_p: int = -1, rope_scale_p: Optional[float] = None, rope_theta_p: Optional[float] = None, - return_lse_p: bool = False, # Decode options custom_mask_d: Optional[torch.Tensor] = None, packed_custom_mask_d: Optional[torch.Tensor] = None, causal_d: bool = False, - kv_layout_d: str = "NHD", pos_encoding_mode_d: str = "NONE", sm_scale_d: Optional[float] = None, window_left_d: int = -1, @@ -455,7 +525,7 @@ def run( q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, - return_lse_d: bool = False, + return_lse: bool = False, use_fp16_qk_reduction: bool = False, enable_pdl: Optional[bool] = None, *args, @@ -469,8 +539,6 @@ def run( logits_soft_cap_d = None # Prefill setup _check_pos_encoding_mode(pos_encoding_mode_p) - _check_kv_layout(kv_layout_p) - tmp_p = _get_cache_buf("pod_with_kv_cache_tmp", 32 * 1024 * 1024, q_p.device) if logits_soft_cap_p is None: logits_soft_cap_p = 0.0 if sm_scale_p is None: @@ -493,18 +561,27 @@ def run( else: mask_mode_p = MaskMode.NON_CAUSAL.value - lse_p = None - if return_lse_p: - lse_p = torch.empty( - (q_p.size(0), q_p.size(1)), dtype=torch.float32, device=q_p.device + lse = None + if return_lse: + lse = torch.empty( + (q_p.size(0) + q_d.size(0), q_p.size(1)), + dtype=torch.float32, + device=q_p.device, ) - - out_p = torch.empty_like(q_p) + qo_len_p, num_qo_heads, head_dim = q_p.shape + qo_len_d, _, _ = q_d.shape + out = torch.empty( + qo_len_p + qo_len_d, + num_qo_heads, + head_dim, + device=q_p.device, + dtype=q_p.dtype, + ) # Decode setup - k_cache_d, v_cache_d = _unpack_paged_kv_cache(paged_kv_cache_d, self._kv_layout) + k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout) _check_cached_qkv_data_type( - q_d, k_cache_d, self._cached_q_data_type, self._cached_kv_data_type + q_d, k_cache, self._cached_q_data_type, self._cached_kv_data_type ) # TODO_AK: Where are these coming from? pos_encoding_mode_d = self._pos_encoding_mode @@ -529,17 +606,10 @@ def run( if rope_theta_d is None: rope_theta_d = 1e4 - lse_d = None - if return_lse_d: - lse_d = torch.empty( - (q_d.size(0), q_d.size(1)), dtype=torch.float32, device=q_d.device - ) - out_d = torch.empty_like(q_d) - module_getter = get_pod_module( # Prefill params q_p.dtype, - k_p.dtype, + k_cache.dtype, q_p.dtype, q_p.shape[-1], PosEncodingMode[pos_encoding_mode_p].value, @@ -558,15 +628,22 @@ def run( logits_soft_cap_d > 0, # use_logits_soft_cap ) module_getter.run_tensor( + # Shared params + self._float_workspace_buffer, + self._int_workspace_buffer, + self._plan_info, + k_cache, + v_cache, + self._qo_indptr_buf, # contains both prefill and decode indptr + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, + self._paged_kv_last_page_len_buf, + out, + lse, + TensorLayout[self._kv_layout].value, # Prefill params q_p, - k_p, - v_p, - tmp_p, - out_p, - lse_p, mask_mode_p, - TensorLayout[kv_layout_p].value, window_left_p, packed_custom_mask_p, _get_cache_alibi_slopes_buf(q_p.shape[1], q_p.device), @@ -575,20 +652,8 @@ def run( 1.0 / rope_scale_p, 1.0 / rope_theta_p, # Decode params - self._float_workspace_buffer, - self._int_workspace_buffer, - self._plan_info, q_d, - k_cache_d, - v_cache_d, - self._qo_indptr_buf, - self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, - self._paged_kv_last_page_len_buf, - out_d, - lse_d, MaskMode.NON_CAUSAL.value, - TensorLayout[self._kv_layout].value, window_left_d, None, # packed_custom_mask None, # mask_indptr_buf @@ -601,9 +666,9 @@ def run( ) if v_scale is not None: - out_d *= v_scale + out *= v_scale - return (out_p, out_d) + return out[:qo_len_p], out[qo_len_p:] def end_forward(self) -> None: r"""Warning: this function is deprecated and has no effect.""" diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index f35ab2cc3..6c5068716 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1305,7 +1305,7 @@ def plan( paged_kv_indptr : torch.Tensor The indptr of the paged kv-cache, shape: ``[batch_size + 1]``. paged_kv_indices : torch.Tensor - The page indices of the paged kv-cache, shape: ``[qo_indptr[-1]]``. + The page indices of the paged kv-cache, shape: ``[kv_indptr[-1]]``. paged_kv_last_page_len : torch.Tensor The number of entries in the last page of each request in the paged kv-cache, shape: ``[batch_size]``. @@ -1428,9 +1428,11 @@ def plan( self._max_item_len_ptr = max_item_len_ptr # NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors - qo_indptr_host = qo_indptr.to("cpu") - paged_kv_indptr_host = paged_kv_indptr.to("cpu") - paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu") + qo_indptr_host = qo_indptr.to("cpu", non_blocking=True) + paged_kv_indptr_host = paged_kv_indptr.to("cpu", non_blocking=True) + paged_kv_last_page_len_host = paged_kv_last_page_len.to( + "cpu", non_blocking=True + ) kv_lens_arr_host = get_seq_lens( paged_kv_indptr_host, paged_kv_last_page_len_host, page_size ) @@ -1438,6 +1440,7 @@ def plan( kv_lens_arr_host, non_blocking=non_blocking ) + torch.cuda.synchronize() total_num_rows = qo_indptr_host[-1] if self.is_cuda_graph_enabled: @@ -1576,7 +1579,6 @@ def plan( self.is_cuda_graph_enabled, head_dim_qk, head_dim_vo, - causal, ) self._causal = causal @@ -2350,7 +2352,6 @@ def plan( self.is_cuda_graph_enabled, head_dim_qk, head_dim_vo, - causal, ) self._causal = causal diff --git a/include/flashinfer/attention/cascade.cuh b/include/flashinfer/attention/cascade.cuh index 8fb5e6b91..1ef2f3a34 100644 --- a/include/flashinfer/attention/cascade.cuh +++ b/include/flashinfer/attention/cascade.cuh @@ -354,11 +354,11 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa * \tparam DTypeO The data type of v_merged. * \param V The partial v of index sets. (nnz, h, d) * \param S The logsumexp value of index sets. (nnz, h) - * \param indptr The start offsets of each position in the variable length array. + * \param merge_indptr The start offsets of each position in the variable length array. * \param v_merged The merged v of index sets union. (n, h, d) * \param s_merged The merged logsumexp value of index sets union. (n, h) * \param max_seq_len The maximum sequence length supported by the kernel. - * \param seq_len_ptr The current sequence length (number of positions populated in indptr). + * \param seq_len_ptr The current sequence length (number of positions populated in merge_indptr). * \param num_heads The number of heads of v. * \param head_dim The dimension of each head. * \note s are logsumexp values with base 2. @@ -366,9 +366,9 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa template __global__ void PersistentVariableLengthMergeStatesKernel( - DTypeIn* __restrict__ V, float* __restrict__ S, IdType* indptr, DTypeO* __restrict__ v_merged, - float* __restrict__ s_merged, uint32_t max_seq_len, uint32_t* __restrict__ seq_len_ptr, - uint32_t num_heads) { + DTypeIn* __restrict__ V, float* __restrict__ S, IdType* merge_indptr, + DTypeO* __restrict__ v_merged, float* __restrict__ s_merged, uint32_t max_seq_len, + uint32_t* __restrict__ seq_len_ptr, uint32_t num_heads) { uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t cta_id = blockIdx.x; uint32_t num_ctas = gridDim.x; @@ -389,7 +389,7 @@ __global__ void PersistentVariableLengthMergeStatesKernel( uint32_t pos = i / num_heads; uint32_t head_idx = i % num_heads; state_t st; - const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos]; + const uint32_t num_index_sets = merge_indptr[pos + 1] - merge_indptr[pos]; if (num_index_sets == 0) { vec_t v; @@ -403,10 +403,10 @@ __global__ void PersistentVariableLengthMergeStatesKernel( if (num_index_sets == 1) { vec_t v; - v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); + v.cast_load(V + (merge_indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); if (s_merged != nullptr) { - s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx]; + s_merged[pos * num_heads + head_idx] = S[merge_indptr[pos] * num_heads + head_idx]; } continue; } @@ -415,7 +415,8 @@ __global__ void PersistentVariableLengthMergeStatesKernel( for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { cp_async::pred_load( v_smem + (iter * bdy + ty) * head_dim + tx * vec_size, - V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size, + V + ((merge_indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + + tx * vec_size, (iter * bdy + ty) < num_index_sets); cp_async::commit_group(); } @@ -424,7 +425,7 @@ __global__ void PersistentVariableLengthMergeStatesKernel( if (iter % bdx == 0) { s_smem[ty * bdx + tx] = iter * bdy + (ty * bdx + tx) < num_index_sets - ? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx] + ? S[(merge_indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx] : 0.f; __syncthreads(); } @@ -440,7 +441,7 @@ __global__ void PersistentVariableLengthMergeStatesKernel( cp_async::pred_load( v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size, V + - ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) * + ((merge_indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size, (iter + num_smem_stages) * bdy + ty < num_index_sets); @@ -465,11 +466,9 @@ __global__ void PersistentVariableLengthMergeStatesKernel( template -__global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ V, IdType* indptr, - DTypeO* __restrict__ v_sum, - uint32_t max_seq_len, - uint32_t* __restrict__ seq_len_ptr, - uint32_t num_heads) { +__global__ void PersistentVariableLengthAttentionSumKernel( + DTypeIn* __restrict__ V, IdType* merge_indptr, DTypeO* __restrict__ v_sum, uint32_t max_seq_len, + uint32_t* __restrict__ seq_len_ptr, uint32_t num_heads) { uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t cta_id = blockIdx.x; uint32_t num_ctas = gridDim.x; @@ -489,7 +488,7 @@ __global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ for (uint32_t i = cta_id; i < seq_len * num_heads; i += num_ctas) { uint32_t pos = i / num_heads; uint32_t head_idx = i % num_heads; - const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos]; + const uint32_t num_index_sets = merge_indptr[pos + 1] - merge_indptr[pos]; if (num_index_sets == 0) { vec_t v; @@ -500,7 +499,7 @@ __global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ if (num_index_sets == 1) { vec_t v; - v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); + v.cast_load(V + (merge_indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); v.store(v_sum + (pos * num_heads + head_idx) * head_dim + tx * vec_size); continue; } @@ -509,7 +508,8 @@ __global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { cp_async::pred_load( v_smem + (iter * bdy + ty) * head_dim + tx * vec_size, - V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size, + V + ((merge_indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + + tx * vec_size, (iter * bdy + ty) < num_index_sets); cp_async::commit_group(); } @@ -529,7 +529,7 @@ __global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ cp_async::pred_load( v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size, V + - ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) * + ((merge_indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size, (iter + num_smem_stages) * bdy + ty < num_index_sets); @@ -679,7 +679,7 @@ cudaError_t AttentionSum(DTypeIn* v, DTypeO* v_sum, uint32_t num_index_sets, uin } template -cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTypeO* v_merged, +cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* merge_indptr, DTypeO* v_merged, float* s_merged, uint32_t max_seq_len, uint32_t* seq_len, uint32_t num_heads, uint32_t head_dim, bool enable_pdl, cudaStream_t stream = nullptr) { @@ -705,7 +705,8 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp dim3 nblks(num_sms * num_blocks_per_sm); dim3 nthrs(bdx, bdy); - void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &max_seq_len, &seq_len, &num_heads}; + void* args[] = {&v, &s, &merge_indptr, &v_merged, + &s_merged, &max_seq_len, &seq_len, &num_heads}; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -721,8 +722,8 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp config.blockDim = nthrs; config.dynamicSmemBytes = smem_size; config.stream = stream; - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, v, s, indptr, v_merged, s_merged, - max_seq_len, seq_len, num_heads)); + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, v, s, merge_indptr, v_merged, + s_merged, max_seq_len, seq_len, num_heads)); } else { FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } @@ -731,7 +732,7 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp } template -cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum, +cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* merge_indptr, DTypeO* v_sum, uint32_t max_seq_len, uint32_t* seq_len, uint32_t num_heads, uint32_t head_dim, bool enable_pdl, cudaStream_t stream = nullptr) { @@ -756,7 +757,7 @@ cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum dim3 nblks(num_sms * num_blocks_per_sm); dim3 nthrs(bdx, bdy); - void* args[] = {&v, &indptr, &v_sum, &max_seq_len, &seq_len, &num_heads}; + void* args[] = {&v, &merge_indptr, &v_sum, &max_seq_len, &seq_len, &num_heads}; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -772,8 +773,8 @@ cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum config.blockDim = nthrs; config.dynamicSmemBytes = smem_size; config.stream = stream; - FLASHINFER_CUDA_CALL( - cudaLaunchKernelEx(&config, kernel, v, indptr, v_sum, max_seq_len, seq_len, num_heads)); + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, v, merge_indptr, v_sum, max_seq_len, + seq_len, num_heads)); } else { FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } diff --git a/include/flashinfer/attention/persistent.cuh b/include/flashinfer/attention/persistent.cuh index 8c00aa579..3256bc14a 100644 --- a/include/flashinfer/attention/persistent.cuh +++ b/include/flashinfer/attention/persistent.cuh @@ -437,7 +437,7 @@ struct BlockBatchReductionPersistent { static __device__ __forceinline__ void Run( typename KTraits::DTypeIn* __restrict__ V, typename KTraits::DTypeO* __restrict__ v_merged, float* __restrict__ S, float* __restrict__ s_merged, - const typename KTraits::IdType num_packed_qo_len, const uint_fastdiv gqa_group_size, + const typename KTraits::IdType num_to_merge_qo_len, const uint_fastdiv gqa_group_size, const uint32_t num_kv_heads, const typename KTraits::IdType* indptr, const typename KTraits::IdType* o_indices, uint8_t* smem PROFILER_CLOSURE_FUNC_PARAMS) { using DTypeIn = typename KTraits::DTypeIn; @@ -465,10 +465,10 @@ struct BlockBatchReductionPersistent { float* s_smem = (float*)(smem + num_warps * num_smem_stages * bdy * head_dim * sizeof(DTypeIn) + warp_idx * 32 * sizeof(float)); - // V: [num_packed_qo_len x num_kv_tiles, num_kv_heads, head_dim] + // V: [num_to_merge_qo_len x num_kv_tiles, num_kv_heads, head_dim] // v_merged: [qo_len, num_kv_heads, gqa_group_size, head_dim] #pragma unroll 1 - for (uint32_t i = worker_id; i < num_packed_qo_len * num_kv_heads; i += num_workers) { + for (uint32_t i = worker_id; i < num_to_merge_qo_len * num_kv_heads; i += num_workers) { PROFILER_EVENT_START(profiler_closure, PersistentProfileEventType::kReduction); // remap workload diff --git a/include/flashinfer/attention/persistent_template.cuh b/include/flashinfer/attention/persistent_template.cuh index 3bd2331b3..3c367cc8f 100644 --- a/include/flashinfer/attention/persistent_template.cuh +++ b/include/flashinfer/attention/persistent_template.cuh @@ -81,7 +81,7 @@ __global__ __launch_bounds__( grid.sync(); BlockReductionRunner::Run(params_1.partial_o, params_1.final_o, params_1.partial_lse, - params_1.final_lse, *(params_1.num_packed_qo_len), + params_1.final_lse, *(params_1.num_to_merge_qo_len), params_1.gqa_group_size, params_1.num_kv_heads, params_1.merge_indptr, params_1.merge_o_indices, smem); #else @@ -90,7 +90,7 @@ __global__ __launch_bounds__( grid.sync(); BlockReductionRunner::Run(params_1.partial_o, params_1.final_o, params_1.partial_lse, - params_1.final_lse, *(params_1.num_packed_qo_len), + params_1.final_lse, *(params_1.num_to_merge_qo_len), params_1.gqa_group_size, params_1.num_kv_heads, params_1.merge_indptr, params_1.merge_o_indices, smem, profiler_closure); #endif diff --git a/include/flashinfer/attention/pod.cuh b/include/flashinfer/attention/pod.cuh index f705cd11e..d7cb66a27 100644 --- a/include/flashinfer/attention/pod.cuh +++ b/include/flashinfer/attention/pod.cuh @@ -34,29 +34,25 @@ enum Operation { DECODE = 1, }; -template +template __global__ __launch_bounds__(std::max( KTraits_P::NUM_THREADS, - KTraits_D::NUM_THREADS)) void PODWithKVCacheTensorKernel(const uint32_t xsize, - const __grid_constant__ PrefillParams + KTraits_D::NUM_THREADS)) void PODWithKVCacheTensorKernel(const __grid_constant__ PrefillParams prefill_params, const __grid_constant__ DecodeParams decode_params, int* tbAssign) { extern __shared__ uint8_t smem[]; + const uint32_t num_kv_heads = prefill_params.num_kv_heads; // PREFILL VARS - const uint32_t num_kv_heads_p = prefill_params.num_kv_heads; - const uint32_t num_chunks = prefill_params.partition_kv; - const uint32_t qo_len = prefill_params.qo_len; + const uint32_t padded_bsize_p = prefill_params.padded_batch_size; // DECODE VARS - const uint32_t padded_bsize = decode_params.padded_batch_size; - const uint32_t num_kv_heads_d = decode_params.paged_kv.num_heads; + const uint32_t padded_bsize_d = decode_params.padded_batch_size; // THREADBLOCKS - const uint32_t prefill_blocks = num_kv_heads_p * xsize * (PartitionKV_P ? num_chunks : 1); - const uint32_t decode_blocks = padded_bsize * num_kv_heads_d; + const uint32_t prefill_blocks = padded_bsize_p * num_kv_heads; + const uint32_t decode_blocks = padded_bsize_d * num_kv_heads; int op; int linear_bid; @@ -110,72 +106,59 @@ __global__ __launch_bounds__(std::max( op = !op; linear_bid = atomicAdd(&tbAssign[num_SMs + 0], 1); } - // Write the blockId and operation to shared memory + // Write the global blockId and operation to shared memory ((int*)smem)[0] = linear_bid; ((int*)smem)[1] = op; } - // Sync to wait for dynamic scheduler to finish + // Sync to wait for dynamic scheduler to write to smem __syncthreads(); // Fetch from shared memory the assigned blockId and operation. linear_bid = ((int*)smem)[0]; op = ((int*)smem)[1]; // Sync to force all threads to wait - __syncthreads(); + // __syncthreads(); if (op == PREFILL) { const uint32_t linear_tid = threadIdx.x; // Return if threadId exceeds number of threads for this op if (linear_tid >= 32 * KTraits_P::NUM_WARPS_Q * KTraits_P::NUM_WARPS_KV) return; + if (linear_bid >= prefill_blocks) return; const dim3 tid = dim3(linear_tid % 32, (linear_tid / 32) % KTraits_P::NUM_WARPS_Q, (linear_tid / 32) / KTraits_P::NUM_WARPS_Q); - // dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), 1, num_kv_heads); - // dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), num_chunks, num_kv_heads); - // BlockID exceeds limit - if (linear_bid >= prefill_blocks) return; - - const uint32_t bx = linear_bid % xsize; auto& smem_storage = reinterpret_cast(smem); - // Not partition_kv - if constexpr (!PartitionKV_P) { - const uint32_t chunk_idx = 0; - const uint32_t kv_head_idx = linear_bid / xsize; - SinglePrefillWithKVCacheDevice(prefill_params, smem_storage, tid, bx, chunk_idx, - kv_head_idx, 1, num_kv_heads_p); - } else { - const uint32_t chunk_idx = (linear_bid / xsize) % num_chunks; - const uint32_t kv_head_idx = linear_bid / (xsize * num_chunks); - SinglePrefillWithKVCacheDevice(prefill_params, smem_storage, tid, bx, chunk_idx, - kv_head_idx, num_chunks, num_kv_heads_p); - } - } else /* OP == DECODE */ { - auto& smem_storage = reinterpret_cast(smem); - // dim3 nblks_d(padded_batch_size_d, 1, num_kv_heads); - if (linear_bid >= decode_blocks) return; + const uint32_t bx = linear_bid % padded_bsize_p; + const uint32_t kv_head_idx = linear_bid / padded_bsize_p; - const uint32_t bx = linear_bid % padded_bsize; - const uint32_t kv_head_idx = linear_bid / padded_bsize; + BatchPrefillWithPagedKVCacheDevice(prefill_params, smem_storage, tid, bx, + kv_head_idx, num_kv_heads); - // dim3 nthrs_d(32, NUM_WARPS_Q_D, NUM_WARPS_KV_D); + } else /* OP == DECODE */ { const uint32_t linear_tid = threadIdx.x; // Return if threadId exceeds number of threads for this op if (linear_tid >= 32 * KTraits_D::NUM_WARPS_Q * KTraits_D::NUM_WARPS_KV) return; + if (linear_bid >= decode_blocks) return; const dim3 tid = dim3(linear_tid % 32, (linear_tid / 32) % KTraits_D::NUM_WARPS_Q, (linear_tid / 32) / KTraits_D::NUM_WARPS_Q); + auto& smem_storage = reinterpret_cast(smem); + // dim3 nblks_d(padded_batch_size_d, 1, num_kv_heads); + const uint32_t bx = linear_bid % padded_bsize_d; + const uint32_t kv_head_idx = linear_bid / padded_bsize_d; + + // dim3 nthrs_d(32, NUM_WARPS_Q_D, NUM_WARPS_KV_D); + // Decode is faster with tensor cores, which are usually not saturated by prefill BatchPrefillWithPagedKVCacheDevice(decode_params, smem_storage, tid, bx, kv_head_idx, - num_kv_heads_d); + num_kv_heads); } } template -cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, - typename PrefillParams::DTypeO* tmp_p, - DecodeParams decode_params, + bool USE_FP16_QK_REDUCTION, MaskMode MASK_MODE_P, uint32_t CTA_TILE_Q_P, + uint32_t CTA_TILE_Q_D, MaskMode MASK_MODE_D, typename PrefillAttentionVariant, + typename DecodeAttentionVariant, typename PrefillParams, typename DecodeParams> +cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, DecodeParams decode_params, typename DecodeParams::DTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream) { static_assert(std::is_same::value); @@ -191,42 +174,10 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, using DTypeO_P = typename PrefillParams::DTypeO; const uint32_t num_qo_heads = prefill_params.num_qo_heads; const uint32_t num_kv_heads = prefill_params.num_kv_heads; - const uint32_t qo_len = prefill_params.qo_len; - const uint32_t kv_len = prefill_params.kv_len; - if (kv_len < qo_len && MASK_MODE_P == MaskMode::kCausal) { - std::ostringstream err_msg; - err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be greater than or equal " - "to qo_len, got kv_len" - << kv_len << " and qo_len " << qo_len; - FLASHINFER_ERROR(err_msg.str()); - } - const uint32_t group_size = num_qo_heads / num_kv_heads; - const uint_fastdiv group_size_fastdiv(group_size); constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; - uint32_t cta_tile_q_p = 0; - int64_t unpacked_qo_len = qo_len * group_size; - if (unpacked_qo_len > 64 && HEAD_DIM_VO < 256) { - cta_tile_q_p = 128; - } else { - auto compute_capacity = GetCudaComputeCapability(); - if (compute_capacity.first >= 8) { - // Ampere or newer - if (unpacked_qo_len > 16) { - // avg_packed_qo_len <= 64 - cta_tile_q_p = 64; - } else { - // avg_packed_qo_len <= 16 - cta_tile_q_p = 16; - } - } else { - // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout - cta_tile_q_p = 64; - } - } - // Decode vars setup using DTypeQ_D = typename DecodeParams::DTypeQ; using DTypeKV_D = typename DecodeParams::DTypeKV; @@ -269,209 +220,172 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, NUM_MMA_Q_D * NUM_WARPS_Q_D) / (2 * NUM_WARPS_KV_D); - DISPATCH_CTA_TILE_Q(cta_tile_q_p, CTA_TILE_Q_P, { - constexpr uint32_t NUM_WARPS_Q_P = get_num_warps_q(CTA_TILE_Q_P); - constexpr uint32_t NUM_WARPS_KV_P = get_num_warps_kv(CTA_TILE_Q_P); - constexpr uint32_t NUM_MMA_Q_P = get_num_mma_q(CTA_TILE_Q_P); - - using DTypeQKAccum_P = - typename std::conditional, half, - float>::type; - - // we expect each sm execute two threadblocks - // TODO(Zihao): fix the following computation - const int num_ctas_per_sm_p = - max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ_P) * 16) ? 2 : 1; - const int max_smem_per_threadblock_p = max_smem_per_sm / num_ctas_per_sm_p; - - constexpr uint32_t max_num_mma_kv_reg_p = - (HEAD_DIM_VO >= 128 && NUM_MMA_Q_P == 2 && - POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !USE_FP16_QK_REDUCTION) - ? 2 - : (8 / NUM_MMA_Q_P); - // TODO(Zihao): fix the following computation - const uint32_t max_num_mma_kv_smem_p = - (max_smem_per_threadblock_p / (16 * HEAD_DIM_QK * sizeof(DTypeQ_P)) - - NUM_MMA_Q_P * NUM_WARPS_Q_P) / - (2 * NUM_WARPS_KV_P); - - // control NUM_MMA_KV for maximum warp occupancy - DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_p, max_num_mma_kv_reg_p), NUM_MMA_KV_P, { - using KTraits_P = - KernelTraits; - - if constexpr (KTraits_P::IsInvalid()) { - // Invalid configuration, skip - std::ostringstream err_msg; - err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_P - << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO - << " NUM_MMA_KV=" << NUM_MMA_KV_P << " NUM_WARPS_Q=" << NUM_WARPS_Q_P - << " NUM_WARPS_KV=" << NUM_WARPS_KV_P - << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" - " and report the issue to the developers."; - FLASHINFER_ERROR(err_msg.str()); - } else { - // Decode stuff - // TODO: Is there a way to avoid this nested dispatch? - DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_d, max_num_mma_kv_reg_d), NUM_MMA_KV_D, { - using KTraits_D = - KernelTraits; - if constexpr (KTraits_D::IsInvalid()) { - // Invalid configuration, skip - std::ostringstream err_msg; - err_msg - << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_D - << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO - << " NUM_MMA_KV=" << NUM_MMA_KV_D << " NUM_WARPS_Q=" << NUM_WARPS_Q_D - << " NUM_WARPS_KV=" << NUM_WARPS_KV_D - << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" - " and report the issue to the developers."; - FLASHINFER_ERROR(err_msg.str()); - } else { - // End decode stuff - constexpr uint32_t num_threads_p = (NUM_WARPS_Q_P * NUM_WARPS_KV_P) * WARP_SIZE; - size_t smem_size_p = sizeof(typename KTraits_P::SharedStorage); - size_t smem_size_d = sizeof(typename KTraits_D::SharedStorage); - - auto kernel = - PODWithKVCacheTensorKernel; - // Prefill: decide num_splits for split-kv - int num_blocks_per_sm = 0; - int num_sm = 0; - FLASHINFER_CUDA_CALL( - cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - // FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &num_blocks_per_sm, kernel, num_threads_p, smem_size_p)); - // Above function returns 0 for some reason, so we use a workaround - num_blocks_per_sm = std::max( - 1, std::min((int)(max_smem_per_sm / smem_size_p), (int)(256 / num_threads_p))); - uint32_t max_num_kv_chunks = - (num_blocks_per_sm * num_sm) / - (num_kv_heads * ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q)); - uint32_t num_chunks; - if (max_num_kv_chunks > 0) { - uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256); - num_chunks = ceil_div(kv_len, chunk_size); - } else { - num_chunks = 0; - } + constexpr uint32_t NUM_WARPS_Q_P = get_num_warps_q(CTA_TILE_Q_P); + constexpr uint32_t NUM_WARPS_KV_P = get_num_warps_kv(CTA_TILE_Q_P); + constexpr uint32_t NUM_MMA_Q_P = get_num_mma_q(CTA_TILE_Q_P); - // Setup new prefill params if (not) split - auto o_p = prefill_params.o; - auto lse_p = prefill_params.lse; - float* tmp_lse = (float*)(tmp_p + num_chunks * qo_len * num_qo_heads * HEAD_DIM_VO); - if (num_chunks <= 1 || tmp_p == nullptr) { - // Enough parallelism, do not split-kv - prefill_params.partition_kv = 0; - kernel = PODWithKVCacheTensorKernel; - } else { - // Use cooperative groups to increase occupancy - prefill_params.partition_kv = num_chunks; - prefill_params.o = tmp_p; - prefill_params.lse = tmp_lse; - kernel = PODWithKVCacheTensorKernel; - } + using DTypeQKAccum_P = + typename std::conditional, half, + float>::type; - // Setup new decode params if (not) split - auto o_d = decode_params.o; - auto lse_d = decode_params.lse; - if (tmp_v == nullptr) { - // do not partition kv - decode_params.partition_kv = false; - } else { - decode_params.partition_kv = true; - decode_params.o = tmp_v; - decode_params.lse = tmp_s; - } - uint32_t xsize = ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q); - int nblks_p(xsize * (prefill_params.partition_kv ? prefill_params.partition_kv : 1) * - num_kv_heads); - int nthrs_p(32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P); - - int nblks_d(padded_batch_size_d * 1 * num_kv_heads); - int nthrs_d(32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D); - - // ******* Select final combined sizes here ******* / - size_t smem_size = max(smem_size_p, smem_size_d); - int nblks = nblks_p + nblks_d; - int nthrs = max(nthrs_p, nthrs_d); - - // printf("Smem: prefill %zu, decode %zu, total %zu\n", smem_size_p, smem_size_d, - // smem_size); printf("Blocks: prefill %d, decode %d, total %d\n", nblks_p, nblks_d, - // nblks); printf("Threads: prefill %d, decode %d, total %d\n", nthrs_p, nthrs_d, - // nthrs); - // ************************************************ / - - static int* tbAssign = nullptr; - if (tbAssign == nullptr) cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2)); - cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2)); - - // Setup kernel arguments - void* args[] = {(void*)&xsize, (void*)&prefill_params, (void*)&decode_params, - (void*)&tbAssign}; - FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - // Launch kernel - if (enable_pdl) { - cudaLaunchAttribute attribute[1]; - cudaLaunchConfig_t config; - attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attribute[0].val.programmaticStreamSerializationAllowed = 1; - config.attrs = attribute; - config.numAttrs = 1; - config.gridDim = nblks; - config.blockDim = nthrs; - config.dynamicSmemBytes = smem_size; - config.stream = stream; - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, xsize, prefill_params, - decode_params, tbAssign)); - } else { - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - } + // we expect each sm execute two threadblocks + // TODO(Zihao): fix the following computation + const int num_ctas_per_sm_p = + max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ_P) * 16) ? 2 : 1; + const int max_smem_per_threadblock_p = max_smem_per_sm / num_ctas_per_sm_p; + + constexpr uint32_t max_num_mma_kv_reg_p = + (HEAD_DIM_VO >= 128 && NUM_MMA_Q_P == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !USE_FP16_QK_REDUCTION) + ? 2 + : (8 / NUM_MMA_Q_P); + // TODO(Zihao): fix the following computation + const uint32_t max_num_mma_kv_smem_p = + (max_smem_per_threadblock_p / (16 * HEAD_DIM_QK * sizeof(DTypeQ_P)) - + NUM_MMA_Q_P * NUM_WARPS_Q_P) / + (2 * NUM_WARPS_KV_P); + + // control NUM_MMA_KV for maximum warp occupancy + DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_p, max_num_mma_kv_reg_p), NUM_MMA_KV_P, { + using KTraits_P = KernelTraits; + + if constexpr (KTraits_P::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_P + << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV_P << " NUM_WARPS_Q=" << NUM_WARPS_Q_P + << " NUM_WARPS_KV=" << NUM_WARPS_KV_P + << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } else { + // Decode stuff + // TODO: Is there a way to avoid this nested dispatch? + DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_d, max_num_mma_kv_reg_d), NUM_MMA_KV_D, { + using KTraits_D = + KernelTraits; + if constexpr (KTraits_D::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_D + << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV_D << " NUM_WARPS_Q=" << NUM_WARPS_Q_D + << " NUM_WARPS_KV=" << NUM_WARPS_KV_D + << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } else { + // End decode stuff + constexpr uint32_t num_threads_p = (NUM_WARPS_Q_P * NUM_WARPS_KV_P) * WARP_SIZE; + size_t smem_size_p = sizeof(typename KTraits_P::SharedStorage); + size_t smem_size_d = sizeof(typename KTraits_D::SharedStorage); + + auto kernel = + PODWithKVCacheTensorKernel; + // Prefill: decide num_splits for split-kv + int num_blocks_per_sm = 0; + int num_sm = 0; + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + // FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &num_blocks_per_sm, kernel, num_threads_p, smem_size_p)); + // Above function returns 0 for some reason, so we use a workaround + num_blocks_per_sm = std::max( + 1, std::min((int)(max_smem_per_sm / smem_size_p), (int)(256 / num_threads_p))); + + // Setup new prefill params if (not) split + auto o = prefill_params.o; + auto lse = prefill_params.lse; + if (prefill_params.partition_kv) { + // Use cooperative groups to increase occupancy + assert(tmp_v != nullptr); + prefill_params.o = tmp_v; + prefill_params.lse = tmp_s; + } + + // Setup new decode params if (not) split + if (prefill_params.partition_kv) { + assert(tmp_v != nullptr); + decode_params.o = tmp_v; + decode_params.lse = tmp_s; + } + + uint32_t padded_batch_size_p = prefill_params.padded_batch_size; + uint32_t padded_batch_size_d = decode_params.padded_batch_size; + printf("Debug: launching prefill with padded_batch_size_p %d, num_kv_heads %d\n", + padded_batch_size_p, num_kv_heads); + int nblks_p(padded_batch_size_p * num_kv_heads); + int nthrs_p(32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P); + printf("Debug: launching decode with padded_batch_size_d %d, num_kv_heads %d\n", + padded_batch_size_d, num_kv_heads); + int nblks_d(padded_batch_size_d * num_kv_heads); + int nthrs_d(32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D); + + // ******* Select final combined sizes here ******* / + size_t smem_size = max(smem_size_p, smem_size_d); + int nblks = nblks_p + nblks_d; + int nthrs = max(nthrs_p, nthrs_d); + + // printf("Smem: prefill %zu, decode %zu, total %zu\n", smem_size_p, smem_size_d, + // smem_size); printf("Blocks: prefill %d, decode %d, total %d\n", nblks_p, nblks_d, + // nblks); printf("Threads: prefill %d, decode %d, total %d\n", nthrs_p, nthrs_d, + // nthrs); + // ************************************************ / + + static int* tbAssign = nullptr; + if (tbAssign == nullptr) cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2)); + cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2)); + + // Setup kernel arguments + void* args[] = {(void*)&prefill_params, (void*)&decode_params, (void*)&tbAssign}; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + // Launch kernel + if (enable_pdl) { + cudaLaunchAttribute attribute[1]; + cudaLaunchConfig_t config; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = 1; + config.attrs = attribute; + config.numAttrs = 1; + config.gridDim = nblks; + config.blockDim = nthrs; + config.dynamicSmemBytes = smem_size; + config.stream = stream; + FLASHINFER_CUDA_CALL( + cudaLaunchKernelEx(&config, kernel, prefill_params, decode_params, tbAssign)); + } else { FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } - // Post-kernel stuff for split-kv prefill - if (!(num_chunks <= 1 || tmp_p == nullptr)) { - if constexpr (PrefillAttentionVariant::use_softmax) { - FLASHINFER_CUDA_CALL(MergeStates(tmp_p, tmp_lse, o_p, lse_p, num_chunks, qo_len, - num_qo_heads, HEAD_DIM_VO, stream)); - } else { - FLASHINFER_CUDA_CALL(AttentionSum(tmp_p, o_p, num_chunks, qo_len, num_qo_heads, - HEAD_DIM_VO, stream)); - } - } - // Post-kernel stuff for split-kv decode - if (tmp_v != nullptr) { - if constexpr (DecodeAttentionVariant::use_softmax) { - FLASHINFER_CUDA_CALL(VariableLengthMergeStates( - tmp_v, tmp_s, decode_params.merge_indptr, o_d, lse_d, - decode_params.max_total_num_rows, decode_params.total_num_rows, num_qo_heads, - HEAD_DIM_VO, stream)); - } else { - FLASHINFER_CUDA_CALL(VariableLengthAttentionSum( - tmp_v, decode_params.merge_indptr, o_d, decode_params.max_total_num_rows, - decode_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream)); - } + // Post-kernel stuff for split-kv + if (tmp_v != nullptr) { + if constexpr (DecodeAttentionVariant::use_softmax) { + FLASHINFER_CUDA_CALL(VariableLengthMergeStates( + tmp_v, tmp_s, decode_params.merge_indptr, o, lse, + decode_params.max_total_num_rows, decode_params.total_num_rows, num_qo_heads, + HEAD_DIM_VO, stream)); + } else { + FLASHINFER_CUDA_CALL(VariableLengthAttentionSum( + tmp_v, decode_params.merge_indptr, o, decode_params.max_total_num_rows, + decode_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream)); } } - }); - } - }); + } + }); + } }); return cudaSuccess; } - } // namespace flashinfer #endif // FLASHINFER_PREFILL_CUH_ diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 135cee8e3..26b2b0bef 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -2560,7 +2560,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Param // this won't happen in CUDAGraph mode because we fixed the padded_batch_size return cudaSuccess; } - + // bs = num_qo_tiles * num_kv_tiles * gqa_group_size dim3 nblks(padded_batch_size, 1, num_kv_heads); dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index aa57d34c6..809795095 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -58,7 +58,7 @@ inline void CopyToPageLockedBuffer(void* page_locked_int_buffer, int64_t offset, } /*! - * \brief Compute the maximum number of pages per batch and the new batch size + * \brief Compute the maximum number of pages per batch and the new batch size (grid dim x) * after we partition Paged KV-Cache into multiple chunks on KV sequence length * dimension. * \tparam IdType A template type indicates the index data type @@ -67,7 +67,7 @@ inline void CopyToPageLockedBuffer(void* page_locked_int_buffer, int64_t offset, * \param num_pages The number of pages per request in the batch * \param max_num_pages_per_batch_lb The pre-set lower bound of maximum number of * pages per batch, default to 1 - * \return (max_num_pages_per_batch, new_batch_size) The number of pages per batch and + * \return (max_num_pages_per_batch, real_batch_size) The number of pages per batch and * the new batch size after the partition. */ template @@ -78,24 +78,24 @@ inline auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( for (const IdType& elem : num_pages) { high = max(high, elem); } - uint32_t new_batch_size; + uint32_t real_batch_size; while (low < high) { uint32_t mid = (low + high) / 2; - new_batch_size = 0; + real_batch_size = 0; for (const IdType& elem : num_pages) { - new_batch_size += ceil_div(elem, mid); + real_batch_size += ceil_div(elem, mid); } - if (new_batch_size * gdy > max_grid_size) { + if (real_batch_size * gdy > max_grid_size) { low = mid + 1; } else { high = mid; } } - new_batch_size = 0; + real_batch_size = 0; for (const IdType& elem : num_pages) { - new_batch_size += ceil_div(std::max(elem, 1), low); + real_batch_size += ceil_div(std::max(elem, 1), low); } - return std::make_tuple(low, new_batch_size); + return std::make_tuple(low, real_batch_size); } inline auto PrefillBinarySearchKVChunkSize(const bool enable_cuda_graph, @@ -115,12 +115,12 @@ inline auto PrefillBinarySearchKVChunkSize(const bool enable_cuda_graph, constexpr int64_t min_kv_len = 1; while (low < high) { const int64_t mid = (low + high) / 2; - int64_t new_batch_size = 0; + int64_t real_batch_size = 0; for (uint32_t i = 0; i < batch_size; ++i) { - new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) * - ceil_div(std::max(kv_len_arr[i], min_kv_len), mid); + real_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) * + ceil_div(std::max(kv_len_arr[i], min_kv_len), mid); } - if (new_batch_size > max_batch_size_if_split) { + if (real_batch_size > max_batch_size_if_split) { low = mid + 1; } else { high = mid; @@ -138,7 +138,7 @@ inline auto PrefillBinarySearchKVChunkSize(const bool enable_cuda_graph, * \param split_kv Whether to split the KV cache into multiple chunks * \param max_grid_size The maximum grid size that can be used in a partiton-kv kernel * \param max_num_pages_per_batch The maximum number of pages per batch - * \param new_batch_size The new batch size after the partition + * \param real_batch_size The new batch size after the partition * \param paged_kv The paged kv cache data structure * \param num_qo_heads A integer indicates the number of heads of query and output * \param pos_encoding_mode The positional encoding mode @@ -149,7 +149,7 @@ template inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, - uint32_t& new_batch_size, uint32_t& gdy, uint32_t batch_size, + uint32_t& real_batch_size, uint32_t& gdy, uint32_t batch_size, typename Params::IdType* kv_indptr_h, const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) { using DTypeKV = typename Params::DTypeKV; @@ -165,7 +165,7 @@ inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; gdy = num_kv_heads; - const uint32_t smem_size = + const uint32_t smem_size = // kv + max + denominator 2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); @@ -187,17 +187,17 @@ inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( max_num_pages_per_batch = std::max( max_num_pages_per_batch, kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]); } - new_batch_size = batch_size; + real_batch_size = batch_size; } else { - // compute max_num_pages_per_batch and new_batch_size + // compute max_num_pages_per_batch and real_batch_size std::vector num_pages(batch_size); for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; } - std::tie(max_num_pages_per_batch, new_batch_size) = + std::tie(max_num_pages_per_batch, real_batch_size) = PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, gdy, num_pages, std::max(128 / page_size, 1U)); - if (new_batch_size == batch_size && !enable_cuda_graph) { + if (real_batch_size == batch_size && !enable_cuda_graph) { // do not use partition-kv kernel for short sequence, when not using CUDAGraph split_kv = false; } else { @@ -212,7 +212,7 @@ inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( template inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMLA( bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, - uint32_t& new_batch_size, uint32_t& gdy, uint32_t batch_size, + uint32_t& real_batch_size, uint32_t& gdy, uint32_t batch_size, typename Params::IdType* kv_indptr_h, const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) { using DTypeKV = typename Params::DTypeKV; @@ -253,17 +253,17 @@ inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMLA( max_num_pages_per_batch = std::max( max_num_pages_per_batch, kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]); } - new_batch_size = batch_size; + real_batch_size = batch_size; } else { - // compute max_num_pages_per_batch and new_batch_size + // compute max_num_pages_per_batch and real_batch_size std::vector num_pages(batch_size); for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; } - std::tie(max_num_pages_per_batch, new_batch_size) = + std::tie(max_num_pages_per_batch, real_batch_size) = PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, gdy, num_pages, std::max(128 / page_size, 1U)); - if (new_batch_size == batch_size && !enable_cuda_graph) { + if (real_batch_size == batch_size && !enable_cuda_graph) { // do not use partition-kv kernel for short sequence, when not using CUDAGraph split_kv = false; } else { @@ -280,7 +280,7 @@ template inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMlaCuteSM80( bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, - uint32_t& new_batch_size, uint32_t& gdy_, uint32_t batch_size, + uint32_t& real_batch_size, uint32_t& gdy_, uint32_t batch_size, typename Params::IdType* kv_indptr_h, const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) { using DTypeKV = typename Params::DTypeKV; @@ -313,17 +313,17 @@ inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMlaCuteSM8 max_num_pages_per_batch = std::max( max_num_pages_per_batch, kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]); } - new_batch_size = batch_size; + real_batch_size = batch_size; } else { - // compute max_num_pages_per_batch and new_batch_size + // compute max_num_pages_per_batch and real_batch_size std::vector num_pages(batch_size); for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; } - std::tie(max_num_pages_per_batch, new_batch_size) = + std::tie(max_num_pages_per_batch, real_batch_size) = PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, gdy, num_pages, std::max(128 / page_size, 1U)); - if (new_batch_size == batch_size && !enable_cuda_graph) { + if (real_batch_size == batch_size && !enable_cuda_graph) { // do not use partition-kv kernel for short sequence, when not using CUDAGraph split_kv = false; } else { @@ -432,16 +432,16 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in using DTypeO = typename Params::DTypeO; using IdType = typename Params::IdType; bool split_kv; - uint32_t max_grid_size, kv_chunk_size_in_pages, new_batch_size, gdy; + uint32_t max_grid_size, kv_chunk_size_in_pages, real_batch_size, gdy; FLASHINFER_CUDA_CALL(work_estimation_func(split_kv, max_grid_size, kv_chunk_size_in_pages, - new_batch_size, gdy, batch_size, indptr_h, num_qo_heads, - page_size, enable_cuda_graph, stream)); + real_batch_size, gdy, batch_size, indptr_h, + num_qo_heads, page_size, enable_cuda_graph, stream)); size_t padded_batch_size; plan_info.enable_cuda_graph = enable_cuda_graph; plan_info.split_kv = split_kv; padded_batch_size = - (enable_cuda_graph) ? (split_kv ? max_grid_size / gdy : batch_size) : new_batch_size; + (enable_cuda_graph) ? (split_kv ? max_grid_size / gdy : batch_size) : real_batch_size; plan_info.padded_batch_size = padded_batch_size; auto [request_indices_vec, kv_tile_indices_vec, o_indptr_vec] = @@ -481,7 +481,7 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in bool* block_valid_mask_h = GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.block_valid_mask_offset); for (uint32_t i = 0; i < padded_batch_size; ++i) { - block_valid_mask_h[i] = i < new_batch_size; + block_valid_mask_h[i] = i < real_batch_size; } } @@ -492,19 +492,8 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in return cudaSuccess; } -template -inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, - uint32_t total_num_rows, uint32_t batch_size, - uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, - uint32_t page_size, uint32_t max_batch_size_if_split, - bool enable_cuda_graph) { - std::vector request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr; - merge_indptr.push_back(0); - o_indptr.push_back(0); - - const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; - - // step 1: determine packed_qo_len_arr and verify qo_indptr contents. +inline auto get_qkv_len_arr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size, + uint32_t num_qo_heads, uint32_t gqa_group_size) { std::vector packed_qo_len_arr(batch_size), kv_len_arr(batch_size); for (uint32_t i = 0; i < batch_size; ++i) { packed_qo_len_arr[i] = int64_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * int64_t(gqa_group_size); @@ -522,44 +511,81 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, FLASHINFER_ERROR(err_msg.str()); } } + return std::make_tuple(packed_qo_len_arr, kv_len_arr); +} - // step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q - const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U); +inline auto get_q_tiles(std::vector& packed_qo_len_arr, uint32_t batch_size, + uint32_t head_dim, uint32_t page_size, uint32_t total_num_rows, + uint32_t gqa_group_size, bool enable_cuda_graph, bool is_decode = false) { uint32_t cta_tile_q; uint32_t total_num_tiles_q; if (enable_cuda_graph) { // When CUDA graphs are enabled, the lengths of sequences determined by // qo_indptr_h can vary. We assume that the dummy data based on which // the CUDA graph is created fixes the maximum number of tokens. - const uint64_t max_seq_len = total_num_rows - batch_size + 1; - uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size; - cta_tile_q = FA2DetermineCtaTileQ(max_qo_len, head_dim); - + if (is_decode) { + cta_tile_q = 16; + } else { + const uint64_t max_seq_len = total_num_rows - batch_size + 1; + uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size; + cta_tile_q = FA2DetermineCtaTileQ(max_qo_len, head_dim); + } // Find an upper bound for the number of tiles, derived from the total // number of rows and the batch size. The sum of qo lengths rounded // up to cta_tile_q will not exceed this number derived from the total // number of rows. total_num_tiles_q = ceil_div(total_num_rows * gqa_group_size, cta_tile_q) + batch_size - 1; } else { - int64_t sum_packed_qo_len = 0; - for (uint32_t i = 0; i < batch_size; ++i) { - sum_packed_qo_len += packed_qo_len_arr[i]; + if (is_decode) { + cta_tile_q = 16; + } else { + int64_t sum_packed_qo_len = 0; + for (uint32_t i = 0; i < batch_size; ++i) { + sum_packed_qo_len += packed_qo_len_arr[i]; + } + const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size; + cta_tile_q = FA2DetermineCtaTileQ(avg_packed_qo_len, head_dim); } - const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size; - cta_tile_q = FA2DetermineCtaTileQ(avg_packed_qo_len, head_dim); total_num_tiles_q = 0; for (uint32_t i = 0; i < batch_size; ++i) { total_num_tiles_q += ceil_div(packed_qo_len_arr[i], cta_tile_q); } } + return std::make_tuple(cta_tile_q, total_num_tiles_q); +} - auto [split_kv, kv_chunk_size] = - PrefillBinarySearchKVChunkSize(enable_cuda_graph, max_batch_size_if_split, packed_qo_len_arr, - kv_len_arr, cta_tile_q, min_kv_chunk_size); +inline auto get_qkv_tile_indices(std::vector& packed_qo_len_arr, + std::vector& kv_len_arr, uint32_t batch_size, + uint32_t cta_tile_q, uint32_t kv_chunk_size, + uint32_t gqa_group_size, + std::vector& request_indices = nullptr, + std::vector& qo_tile_indices = nullptr, + std::vector& kv_tile_indices = nullptr, + std::vector& merge_indptr = nullptr, + std::vector& o_indptr = nullptr) { + uint32_t start_req_idx = 0; // for global q,k,v,o indexing in POD Attention + if (request_indices == nullptr) { + request_indices = std::vector(); + } else { + start_req_idx = request_indices.back(); + } + if (qo_tile_indices == nullptr) { + qo_tile_indices = std::vector(); + } + if (kv_tile_indices == nullptr) { + kv_tile_indices = std::vector(); + } + if (merge_indptr == nullptr) { + merge_indptr = std::vector(); + merge_indptr.push_back(0); + } + if (o_indptr == nullptr) { + o_indptr = std::vector(); + o_indptr.push_back(0); + } - // step 3: split qo_indptr and kv_indptr - uint32_t new_batch_size = 0; + uint32_t real_batch_size = 0; for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { const int64_t packed_qo_len = packed_qo_len_arr[request_idx]; const int64_t kv_len = std::max(int(kv_len_arr[request_idx]), 1); @@ -568,8 +594,8 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, for (uint32_t q_tile_idx = 0; q_tile_idx < num_tiles_q; ++q_tile_idx) { for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_tiles_kv; ++kv_tile_idx) { - new_batch_size += 1; - request_indices.push_back(request_idx); + real_batch_size += 1; + request_indices.push_back(request_idx + start_req_idx); qo_tile_indices.push_back(q_tile_idx); kv_tile_indices.push_back(kv_tile_idx); } @@ -581,16 +607,49 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, } o_indptr.push_back(o_indptr.back() + qo_len * num_tiles_kv); } + return std::make_tuple(request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, + real_batch_size); +} + +template +inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, + uint32_t total_num_rows, uint32_t batch_size, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, + uint32_t page_size, uint32_t max_batch_size_if_split, + bool enable_cuda_graph) { + std::vector request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr; + merge_indptr.push_back(0); + o_indptr.push_back(0); + + const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; + const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U); + + // step 1: determine packed_qo_len_arr and verify qo_indptr contents. + auto [packed_qo_len_arr, kv_len_arr] = + get_qkv_len_arr(qo_indptr_h, kv_indptr_h, batch_size, num_qo_heads, gqa_group_size); + + // step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q + auto [cta_tile_q, total_num_tiles_q] = + get_q_tiles(packed_qo_len_arr, batch_size, head_dim, page_size, total_num_rows, + gqa_group_size, enable_cuda_graph); + + auto [split_kv, kv_chunk_size] = + PrefillBinarySearchKVChunkSize(enable_cuda_graph, max_batch_size_if_split, packed_qo_len_arr, + kv_len_arr, cta_tile_q, min_kv_chunk_size); + + auto [request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, + real_batch_size] = get_qkv_tile_indices(packed_qo_len_arr, kv_len_arr, batch_size, + cta_tile_q, kv_chunk_size, gqa_group_size); const size_t padded_batch_size = - enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) : new_batch_size; - FLASHINFER_CHECK(new_batch_size <= padded_batch_size, + enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) : real_batch_size; + FLASHINFER_CHECK(real_batch_size <= padded_batch_size, "new batch size should not exceed padded batch size"); // step 4: multiply kv_chunk_size by page_size kv_chunk_size *= page_size; - return std::make_tuple(split_kv, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, + return std::make_tuple(split_kv, real_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, std::move(request_indices), std::move(qo_tile_indices), std::move(kv_tile_indices), std::move(merge_indptr), std::move(o_indptr)); } @@ -699,11 +758,11 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads; // step 2: determine kv_chunk_size - auto [split_kv, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec, - qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] = - PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, batch_size, num_qo_heads, - num_kv_heads, head_dim_vo, page_size, max_batch_size_if_split, - enable_cuda_graph); + auto [split_kv, real_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, + request_indices_vec, qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, + o_indptr_vec] = PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, batch_size, + num_qo_heads, num_kv_heads, head_dim_vo, page_size, + max_batch_size_if_split, enable_cuda_graph); plan_info.cta_tile_q = cta_tile_q; plan_info.total_num_rows = total_num_rows; @@ -765,7 +824,298 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.block_valid_mask_offset); std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), merge_indptr_h); for (uint32_t i = 0; i < padded_batch_size; ++i) { - block_valid_mask_h[i] = i < new_batch_size; + block_valid_mask_h[i] = i < real_batch_size; + } + } + + size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, + cudaMemcpyHostToDevice, stream)); + + return cudaSuccess; +} + +/* +Modifed to support two tile sizes, and assign blocks proportional to +the number of tiles. +*/ +template +inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_t total_num_rows_p, + uint32_t batch_size_p, IdType* qo_indptr_d, IdType* kv_indptr_d, + uint32_t total_num_rows_d, uint32_t batch_size_d, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, + uint32_t page_size, uint32_t max_batch_size_if_split, + bool enable_cuda_graph) { + std::vector request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr; + merge_indptr.push_back(0); + o_indptr.push_back(0); + + const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; + const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U); + // step 1: determine packed_qo_len_arr and verify qo_indptr contents. + auto [packed_qo_len_arr_p, kv_len_arr_p] = + get_qkv_len_arr(qo_indptr_p, kv_indptr_p, batch_size_p, num_qo_heads, gqa_group_size); + auto [packed_qo_len_arr_d, kv_len_arr_d] = + get_qkv_len_arr(qo_indptr_d, kv_indptr_d, batch_size_d, num_qo_heads, gqa_group_size); + + // step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q + auto [cta_tile_q_p, num_tiles_q_p] = + get_q_tiles(packed_qo_len_arr_p, batch_size_p, head_dim, page_size, total_num_rows_p, + gqa_group_size, enable_cuda_graph); + auto [cta_tile_q_d, num_tiles_q_d] = + get_q_tiles(packed_qo_len_arr_d, batch_size_d, head_dim, page_size, total_num_rows_d, + gqa_group_size, enable_cuda_graph, /*is_decode=*/true); + + uint32_t total_num_tiles_q = num_tiles_q_p + num_tiles_q_d; + // Allocate CTAs proportional to the number of query tiles in prefill and decode + // TODO(Wenxuan): explore a more balanced cost function considering kv len. + // See discussion: https://github.com/flashinfer-ai/flashinfer/issues/1175 + uint32_t max_bs_p = max_batch_size_if_split * num_tiles_q_p / total_num_tiles_q; + uint32_t max_bs_d = max_batch_size_if_split - max_bs_p; + auto [split_kv_p, kv_chunk_size_p] = + PrefillBinarySearchKVChunkSize(enable_cuda_graph, max_bs_p, packed_qo_len_arr_p, kv_len_arr_p, + cta_tile_q_p, min_kv_chunk_size); + auto [split_kv_d, kv_chunk_size_d] = + PrefillBinarySearchKVChunkSize(enable_cuda_graph, max_bs_d, packed_qo_len_arr_d, kv_len_arr_d, + cta_tile_q_d, min_kv_chunk_size); + + // step 3: split qo_indptr and kv_indptr + // Use one set of qkv indices, merge_indptr and o_indptr to simply merging. + auto [request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, + new_batch_size_p] = get_qkv_tile_indices(packed_qo_len_arr_p, kv_len_arr_p, batch_size_p, + cta_tile_q_p, kv_chunk_size_p, gqa_group_size); + auto [request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, + new_batch_size_d] = + get_qkv_tile_indices(packed_qo_len_arr_d, kv_len_arr_d, batch_size_d, cta_tile_q_d, + kv_chunk_size_d, gqa_group_size, request_indices, qo_tile_indices, + kv_tile_indices, merge_indptr, o_indptr); + + bool split_kv = split_kv_p || split_kv_d; + uint32_t real_batch_size = new_batch_size_p + new_batch_size_d; + const size_t padded_batch_size_p = + enable_cuda_graph ? std::max(max_bs_p, num_tiles_q_p) : new_batch_size_p; + const size_t padded_batch_size_d = + enable_cuda_graph ? std::max(max_bs_d, num_tiles_q_d) : new_batch_size_d; + FLASHINFER_CHECK(real_batch_size <= padded_batch_size_p + padded_batch_size_d, + "new batch size should not exceed padded batch size"); + + // step 4: multiply kv_chunk_size by page_size + kv_chunk_size_p *= page_size; + kv_chunk_size_d *= page_size; + + return std::make_tuple(split_kv, real_batch_size, padded_batch_size_p, padded_batch_size_d, + cta_tile_q_p, cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, + std::move(request_indices), std::move(qo_tile_indices), + std::move(kv_tile_indices), std::move(merge_indptr), std::move(o_indptr)); +} + +struct PODPlanInfo { + int64_t padded_batch_size_p; + int64_t padded_batch_size_d; + int64_t total_num_rows; + int64_t total_num_rows_p; + int64_t total_num_rows_d; + int64_t total_num_rows_offset; + uint16_t cta_tile_q_p; + uint16_t cta_tile_q_d; + int64_t request_indices_offset; + int64_t qo_tile_indices_offset; + int64_t kv_tile_indices_offset; + int64_t merge_indptr_offset; + int64_t o_indptr_offset; + int64_t kv_chunk_size_ptr_offset_p; + int64_t kv_chunk_size_ptr_offset_d; + int64_t v_offset; + int64_t s_offset; + int64_t block_valid_mask_offset; + bool enable_cuda_graph; + bool split_kv; + + PODPlanInfo() + : padded_batch_size_p(0), + padded_batch_size_d(0), + total_num_rows(0), + total_num_rows_p(0), + total_num_rows_d(0), + total_num_rows_offset(0), + cta_tile_q_p(0), + cta_tile_q_d(0), + request_indices_offset(0), + qo_tile_indices_offset(0), + kv_tile_indices_offset(0), + merge_indptr_offset(0), + o_indptr_offset(0), + kv_chunk_size_ptr_offset_p(0), + kv_chunk_size_ptr_offset_d(0), + v_offset(0), + s_offset(0), + block_valid_mask_offset(0), + enable_cuda_graph(false), + split_kv(false) {} + + // convert PrefillPlanInfo to std::vector + std::vector ToVector() const { + return {padded_batch_size_p, + padded_batch_size_d, + total_num_rows, + total_num_rows_p, + total_num_rows_d, + total_num_rows_offset, + cta_tile_q_p, + cta_tile_q_d, + request_indices_offset, + qo_tile_indices_offset, + kv_tile_indices_offset, + merge_indptr_offset, + o_indptr_offset, + kv_chunk_size_ptr_offset_p, + kv_chunk_size_ptr_offset_d, + v_offset, + s_offset, + block_valid_mask_offset, + enable_cuda_graph, + split_kv}; + } + + // From std::vector to PodPlanInfo + void FromVector(const std::vector& vec) { + if (vec.size() != 20) { + std::ostringstream err_msg; + err_msg << "PodPlanInfo::FromVector: vec.size() should be 20, but got " << vec.size(); + FLASHINFER_ERROR(err_msg.str()); + } + padded_batch_size_p = vec[0]; + padded_batch_size_d = vec[1]; + total_num_rows = vec[2]; + total_num_rows_p = vec[3]; + total_num_rows_d = vec[4]; + total_num_rows_offset = vec[5]; + cta_tile_q_p = vec[6]; + cta_tile_q_d = vec[7]; + request_indices_offset = vec[8]; + qo_tile_indices_offset = vec[9]; + kv_tile_indices_offset = vec[10]; + merge_indptr_offset = vec[11]; + o_indptr_offset = vec[12]; + kv_chunk_size_ptr_offset_p = vec[13]; + kv_chunk_size_ptr_offset_d = vec[14]; + v_offset = vec[15]; + s_offset = vec[16]; + block_valid_mask_offset = vec[17]; + enable_cuda_graph = vec[18]; + split_kv = vec[19]; + } +}; + +template +inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_bytes, + void* int_buffer, void* page_locked_int_buffer, + size_t int_workspace_size_in_bytes, PODPlanInfo& plan_info, + IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_t total_num_rows_p, + uint32_t batch_size_p, IdType* qo_indptr_d, IdType* kv_indptr_d, + uint32_t total_num_rows_d, uint32_t batch_size_d, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t head_dim_qk, uint32_t head_dim_vo, + uint32_t page_size, bool enable_cuda_graph, uint32_t sizeof_dtype_o, + cudaStream_t stream) { + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " + << num_kv_heads; + FLASHINFER_ERROR(err_msg.str()); + } + + // step 0: get the number of SMs + int num_sm = 0; + int dev_id = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + int num_blocks_per_sm = 3; // TODO(Wenxuan): increase this to reduce wave quantization? + int max_grid_size = num_blocks_per_sm * num_sm; + uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads; + + // step 2: determine kv_chunk_size + auto [split_kv, real_batch_size, padded_batch_size_p, padded_batch_size_d, cta_tile_q_p, + cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, request_indices_vec, qo_tile_indices_vec, + kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] = + PODSplitQOKVIndptr(qo_indptr_p, kv_indptr_p, qo_indptr_d, kv_indptr_d, total_num_rows_p, + batch_size_p, total_num_rows_d, batch_size_d, num_qo_heads, num_kv_heads, + head_dim_vo, page_size, max_batch_size_if_split, enable_cuda_graph); + uint32_t padded_batch_size = padded_batch_size_p + padded_batch_size_d; + uint32_t batch_size = batch_size_p + batch_size_d; + uint32_t total_num_rows = total_num_rows_p + total_num_rows_d; + + plan_info.padded_batch_size_p = padded_batch_size_p; + plan_info.padded_batch_size_d = padded_batch_size_d; + plan_info.total_num_rows_p = total_num_rows_p; + plan_info.total_num_rows_d = total_num_rows_d; + plan_info.total_num_rows = total_num_rows; + plan_info.cta_tile_q_p = cta_tile_q_p; + plan_info.cta_tile_q_d = cta_tile_q_d; + plan_info.enable_cuda_graph = enable_cuda_graph; + plan_info.split_kv = split_kv; + + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + plan_info.request_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size, 16, "pod_prefill_request_indices"); + plan_info.qo_tile_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size, 16, "pod_prefill_qo_tile_indices"); + plan_info.kv_tile_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size, 16, "pod_prefill_kv_tile_indices"); + plan_info.o_indptr_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * (batch_size + 1), 16, "pod_o_indptr"); + plan_info.kv_chunk_size_ptr_offset_p = + int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "pod_prefill_kv_chunk_size_ptr"); + plan_info.kv_chunk_size_ptr_offset_d = + int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "pod_decode_kv_chunk_size_ptr"); + + if (plan_info.enable_cuda_graph) { + plan_info.total_num_rows_offset = + int_allocator.aligned_alloc_offset(sizeof(uint32_t), 16, "batch_prefill_total_num_rows"); + uint32_t* total_num_rows_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.total_num_rows_offset); + *total_num_rows_h = total_num_rows_p + total_num_rows_d; + } + + IdType* request_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.request_indices_offset); + IdType* qo_tile_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.qo_tile_indices_offset); + IdType* kv_tile_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_tile_indices_offset); + IdType* o_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.o_indptr_offset); + IdType* kv_chunk_size_ptr_p = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_chunk_size_ptr_offset_p); + IdType* kv_chunk_size_ptr_d = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_chunk_size_ptr_offset_d); + std::copy(request_indices_vec.begin(), request_indices_vec.end(), request_indices_h); + std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), qo_tile_indices_h); + std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), kv_tile_indices_h); + std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), o_indptr_h); + kv_chunk_size_ptr_p[0] = kv_chunk_size_p; + kv_chunk_size_ptr_d[0] = kv_chunk_size_d; + + if (split_kv) { + // TODO(Wenxuan): write through for non-split-kv requests + uint32_t num_outputs_p = num_qo_heads * padded_batch_size_p * cta_tile_q_p; + uint32_t num_outputs_d = num_qo_heads * padded_batch_size_d * cta_tile_q_d; + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + plan_info.v_offset = float_allocator.aligned_alloc_offset( + (num_outputs_p + num_outputs_d) * head_dim_vo * sizeof(float), 16, "pod_tmp_v"); + plan_info.s_offset = float_allocator.aligned_alloc_offset( + (num_outputs_p + num_outputs_d) * sizeof(float), 16, "pod_tmp_s"); + plan_info.merge_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * (plan_info.total_num_rows + 1), 16, "pod_merge_indptr"); + plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset( + sizeof(bool) * padded_batch_size, 16, "pod_block_valid_mask"); + + IdType* merge_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.merge_indptr_offset); + bool* block_valid_mask_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.block_valid_mask_offset); + std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), merge_indptr_h); + for (uint32_t i = 0; i < padded_batch_size; ++i) { + block_valid_mask_h[i] = i < real_batch_size; } } @@ -1169,7 +1519,7 @@ inline cudaError_t TwoStageHolisticPlan(void* float_buffer, size_t float_workspa // used for remapping the output offsets // layout [packed_qo_len x num_kv_tiels, num_kv_heads, head_dim] int partial_o_nnz = 0; - std::vector merge_indptr, merge_o_indices, num_expand_qo_len_vec; + std::vector merge_indptr, merge_o_indices, num_to_merge_qo_len_vec; merge_indptr.push_back(partial_o_nnz); for (uint32_t task = 0; task < NUM_TASKS; ++task) { int cluster_tile_q = CTA_TILE_Q_SIZES[task] * cluster_size; @@ -1314,7 +1664,7 @@ inline cudaError_t TwoStageHolisticPlan(void* float_buffer, size_t float_workspa } // update num_qo_len_vec - num_expand_qo_len_vec.push_back(merge_indptr.size() - 1); + num_to_merge_qo_len_vec.push_back(merge_indptr.size() - 1); // allocate buffer for state merge function plan_info.merge_indptr_offset = int_allocator.aligned_alloc_offset(sizeof(IdType) * max_packed_qo_lens, 16, "merge_indptr"); @@ -1326,7 +1676,7 @@ inline cudaError_t TwoStageHolisticPlan(void* float_buffer, size_t float_workspa CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.merge_indptr_offset, merge_indptr); CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.merge_o_indices_offset, merge_o_indices); CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.num_qo_len_offset, - num_expand_qo_len_vec); + num_to_merge_qo_len_vec); size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, diff --git a/tvm_binding/batch_prefill.cu b/tvm_binding/batch_prefill.cu index bf7161116..710764cf4 100644 --- a/tvm_binding/batch_prefill.cu +++ b/tvm_binding/batch_prefill.cu @@ -41,12 +41,14 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para using namespace flashinfer; -IntTuple BatchPrefillWithKVCachePlan( - DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, - DLTensor* page_locked_int_workspace_buffer, DLTensor* qo_indptr, DLTensor* kv_indptr, - IntTuple kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, - int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, - int64_t head_dim_vo, bool causal, TVMStreamHandle cuda_stream) { +IntTuple BatchPrefillWithKVCachePlan(DLTensor* float_workspace_buffer, + DLTensor* int_workspace_buffer, + DLTensor* page_locked_int_workspace_buffer, + DLTensor* qo_indptr, DLTensor* kv_indptr, IntTuple kv_len_arr, + int64_t total_num_rows, int64_t batch_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, + bool enable_cuda_graph, int64_t head_dim_qk, + int64_t head_dim_vo, TVMStreamHandle cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer->shape[0] * DataType(float_workspace_buffer->dtype).bytes(); size_t int_workspace_size_in_bytes =