From 06ef37d7848cc4c8115dc4eb574e065b3bb68e73 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 20 May 2025 21:01:35 +0000 Subject: [PATCH 01/33] use FA2DetermineCtaTileQ for pod --- include/flashinfer/attention/pod.cuh | 40 +++++++++++++----------- include/flashinfer/attention/prefill.cuh | 2 +- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/include/flashinfer/attention/pod.cuh b/include/flashinfer/attention/pod.cuh index f705cd11e..fc350e9a7 100644 --- a/include/flashinfer/attention/pod.cuh +++ b/include/flashinfer/attention/pod.cuh @@ -207,25 +207,27 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; uint32_t cta_tile_q_p = 0; - int64_t unpacked_qo_len = qo_len * group_size; - if (unpacked_qo_len > 64 && HEAD_DIM_VO < 256) { - cta_tile_q_p = 128; - } else { - auto compute_capacity = GetCudaComputeCapability(); - if (compute_capacity.first >= 8) { - // Ampere or newer - if (unpacked_qo_len > 16) { - // avg_packed_qo_len <= 64 - cta_tile_q_p = 64; - } else { - // avg_packed_qo_len <= 16 - cta_tile_q_p = 16; - } - } else { - // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout - cta_tile_q_p = 64; - } - } + int64_t unpacked_qo_len = + qo_len * group_size; // TODO(@Wenxuan): Include batch size in calculation + // if (unpacked_qo_len > 64 && HEAD_DIM_VO < 256) { + // cta_tile_q_p = 128; + // } else { + // auto compute_capacity = GetCudaComputeCapability(); + // if (compute_capacity.first >= 8) { + // // Ampere or newer + // if (unpacked_qo_len > 16) { + // // avg_packed_qo_len <= 64 + // cta_tile_q_p = 64; + // } else { + // // avg_packed_qo_len <= 16 + // cta_tile_q_p = 16; + // } + // } else { + // // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout + // cta_tile_q_p = 64; + // } + // } + cta_tile_q_p = FA2DetermineCtaTileQ(unpacked_qo_len, HEAD_DIM_VO); // Decode vars setup using DTypeQ_D = typename DecodeParams::DTypeQ; diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 135cee8e3..26b2b0bef 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -2560,7 +2560,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Param // this won't happen in CUDAGraph mode because we fixed the padded_batch_size return cudaSuccess; } - + // bs = num_qo_tiles * num_kv_tiles * gqa_group_size dim3 nblks(padded_batch_size, 1, num_kv_heads); dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); From 14e3f13fc9072fe929ccf6286f59f65f97c01ad0 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 6 Jun 2025 02:42:08 +0000 Subject: [PATCH 02/33] modify wrapper.. --- benchmarks/bench_mixed_attention.py | 2 +- flashinfer/pod.py | 79 ++++++++++++++-------- include/flashinfer/attention/scheduler.cuh | 2 +- 3 files changed, 51 insertions(+), 32 deletions(-) diff --git a/benchmarks/bench_mixed_attention.py b/benchmarks/bench_mixed_attention.py index f581628b9..515e353d0 100644 --- a/benchmarks/bench_mixed_attention.py +++ b/benchmarks/bench_mixed_attention.py @@ -89,7 +89,7 @@ def run_bench( wrapper_pod.plan( d_kv_indptr.to(device), kv_indices_d.to(device), - last_page_len=last_page_len_d, + last_page_len_d=last_page_len_d, num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, diff --git a/flashinfer/pod.py b/flashinfer/pod.py index 49b2847a0..cdf8b0672 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -118,9 +118,13 @@ def __init__( float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, - paged_kv_indptr_buffer: Optional[torch.Tensor] = None, - paged_kv_indices_buffer: Optional[torch.Tensor] = None, - paged_kv_last_page_len_buffer: Optional[torch.Tensor] = None, + qo_indptr_buffer_p: Optional[torch.Tensor] = None, + paged_kv_indptr_buffer_p: Optional[torch.Tensor] = None, + paged_kv_indices_buffer_p: Optional[torch.Tensor] = None, + paged_kv_last_page_len_buffer_p: Optional[torch.Tensor] = None, + paged_kv_indptr_buffer_d: Optional[torch.Tensor] = None, + paged_kv_indices_buffer_d: Optional[torch.Tensor] = None, + paged_kv_last_page_len_buffer_d: Optional[torch.Tensor] = None, jit_args: Optional[List[Any]] = None, ) -> None: r"""Constructor of :class:`PODWithPagedKVCacheWrapper`. @@ -140,22 +144,36 @@ def __init__( auxiliary data structures will be stored as the provided buffers. The ``batch_size`` cannot change during the lifecycle of this wrapper when CUDAGraph is enabled. - indptr_buffer : Optional[torch.Tensor] - The user reserved buffer on GPU to store the indptr of the paged kv cache, the size + qo_indptr_buffer_p: Optional[torch.Tensor] + The user reserved buffer to store the ``qo_indptr`` array, the size of the buffer + should be ``[batch_size + 1]``. + This argument is only effective when ``use_cuda_graph`` is ``True``. + + paged_kv_indptr_buffer_p: Optional[torch.Tensor] + The user reserved buffer on GPU to store the indptr of the prefill paged kv cache, the size of the buffer should be ``[batch_size + 1]``. Only needed when ``use_cuda_graph`` is ``True``. - indices_buffer : Optional[torch.Tensor] - The user reserved buffer on GPU to store the page indices of the paged kv cache, + paged_kv_indices_buffer_p: Optional[torch.Tensor] + The user reserved buffer on GPU to store the page indices of the prefill paged kv cache, should be large enough to store the maximum number of page indices (``max_num_pages``) during the lifecycle of this wrapper. Only needed when ``use_cuda_graph`` is ``True``. - last_page_len_buffer : Optional[torch.Tensor] - The user reserved buffer on GPU to store the number of entries in the last page, the + paged_kv_last_page_len_buffer_p: Optional[torch.Tensor] + The user reserved buffer on GPU to store the number of entries in the last page for prefill, the size of the buffer should be ``[batch_size]``. Only needed when ``use_cuda_graph`` is ``True``. + paged_kv_indptr_buffer_d: Optional[torch.Tensor] + Same as ``paged_kv_indptr_buffer_p`` but for decode. + + paged_kv_indices_buffer_d: Optional[torch.Tensor] + Same as ``paged_kv_indices_buffer_p`` but for decode. + + paged_kv_last_page_len_buffer_d: Optional[torch.Tensor] + Same as ``paged_kv_last_page_len_buffer_p`` but for decode. + jit_args : Optional[List[Any]] If provided, the wrapper will use the provided arguments to create the JIT module, otherwise, the wrapper will use default attention implementation. @@ -255,9 +273,9 @@ def reset_workspace_buffer( def plan( self, - indptr: torch.Tensor, - indices: torch.Tensor, - last_page_len: torch.Tensor, + indptr_d: torch.Tensor, + indices_d: torch.Tensor, + last_page_len_d: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, @@ -276,11 +294,11 @@ def plan( Parameters ---------- - indptr : torch.Tensor - The indptr of the paged kv cache, shape: ``[batch_size + 1]`` - indices : torch.Tensor - The page indices of the paged kv cache, shape: ``[qo_indptr[-1]]`` - last_page_len : torch.Tensor + indptr_d : torch.Tensor + The indptr of the paged kv cache for decode, shape: ``[batch_size + 1]`` + indices_d : torch.Tensor + The page indices of the paged kv cache for decode, shape: ``[qo_indptr[-1]]`` + last_page_len_d : torch.Tensor The number of entries in the last page of each request in the paged kv cache, shape: ``[batch_size]`` num_qo_heads : int @@ -324,7 +342,7 @@ def plan( """ # Logits soft cap is not supported currently logits_soft_cap = False - batch_size = len(last_page_len) + batch_size = len(last_page_len_d) if logits_soft_cap is None: logits_soft_cap = 0.0 @@ -337,33 +355,34 @@ def plan( batch_size, self._fixed_batch_size ) ) - if len(indices) > len(self._paged_kv_indices_buf): + if len(indices_d) > len(self._paged_kv_indices_buf): raise ValueError( "The size of indices should be less than or equal to the allocated buffer" ) - self._paged_kv_indptr_buf.copy_(indptr, non_blocking=non_blocking) + self._paged_kv_indptr_buf.copy_(indptr_d, non_blocking=non_blocking) self._paged_kv_last_page_len_buf.copy_( - last_page_len, non_blocking=non_blocking + last_page_len_d, non_blocking=non_blocking ) - self._paged_kv_indices_buf[: len(indices)].copy_( - indices, non_blocking=(indices.device == self.device) and non_blocking + self._paged_kv_indices_buf[: len(indices_d)].copy_( + indices_d, + non_blocking=(indices_d.device == self.device) and non_blocking, ) else: - self._paged_kv_indptr_buf = indptr.to( + self._paged_kv_indptr_buf = indptr_d.to( self.device, non_blocking=non_blocking ) - self._paged_kv_indices_buf = indices.to( + self._paged_kv_indices_buf = indices_d.to( self.device, non_blocking=non_blocking ) - self._paged_kv_last_page_len_buf = last_page_len.to( + self._paged_kv_last_page_len_buf = last_page_len_d.to( self.device, non_blocking=non_blocking ) self._qo_indptr_buf = qo_indptr_host.to( self.device, non_blocking=non_blocking ) - indptr_host = indptr.to("cpu") - last_page_len_host = last_page_len.to("cpu") + indptr_host = indptr_d.to("cpu") + last_page_len_host = last_page_len_d.to("cpu") if data_type is not None: if q_data_type is None: @@ -387,7 +406,7 @@ def plan( q_data_type, kv_data_type, q_data_type, - indptr.dtype, + indptr_d.dtype, head_dim, # head_dim_qk head_dim, # head_dim_vo PosEncodingMode[pos_encoding_mode].value, @@ -413,7 +432,7 @@ def plan( False, # causal ) - self._indptr_type = indptr.dtype + self._indptr_type = indptr_d.dtype self._pos_encoding_mode = pos_encoding_mode self._window_left = window_left self._logits_soft_cap = logits_soft_cap diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index aa57d34c6..1b6250231 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -165,7 +165,7 @@ inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; gdy = num_kv_heads; - const uint32_t smem_size = + const uint32_t smem_size = // kv + max + denominator 2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); From 22dc539fd159f6b88ceaee8962ec820a7eaf508b Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sun, 8 Jun 2025 02:30:45 +0000 Subject: [PATCH 03/33] fix --- csrc/pod_config.inc | 2 +- flashinfer/pod.py | 54 ++++++++++++++++++++++++++++++++++++--------- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/csrc/pod_config.inc b/csrc/pod_config.inc index 16306651a..3c95f4286 100644 --- a/csrc/pod_config.inc +++ b/csrc/pod_config.inc @@ -30,7 +30,7 @@ using namespace flashinfer; return DISPATCH_BOOL(window_left_d > -1, USE_SLIDING_WINDOW_D, [&] { \ return DISPATCH_BOOL(false, USE_LOGITS_SOFT_CAP, [&] { \ using IdType = int32_t; \ - using PrefillParams = SinglePrefillParams;\ + using PrefillParams = BatchPrefillPagedParams;\ using DecodeParams = BatchPrefillPagedParams; \ __VA_ARGS__(); \ diff --git a/flashinfer/pod.py b/flashinfer/pod.py index cdf8b0672..82a64b786 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -125,6 +125,8 @@ def __init__( paged_kv_indptr_buffer_d: Optional[torch.Tensor] = None, paged_kv_indices_buffer_d: Optional[torch.Tensor] = None, paged_kv_last_page_len_buffer_d: Optional[torch.Tensor] = None, + custom_mask_buf_p: Optional[torch.Tensor] = None, + mask_indptr_buf_p: Optional[torch.Tensor] = None, jit_args: Optional[List[Any]] = None, ) -> None: r"""Constructor of :class:`PODWithPagedKVCacheWrapper`. @@ -194,14 +196,27 @@ def __init__( # Override options. Only tensor core version is performant. use_tensor_cores = True self._jit_module = None + assert ( + custom_mask_buf_p is None and mask_indptr_buf_p is None + ), "custom_mask_buf_p and mask_indptr_buf_p are not supported yet" self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device - self._int_workspace_buffer = torch.empty( + self._qo_indptr_buf_p = qo_indptr_buffer_p + self._int_workspace_buffer_p = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) - self._pin_memory_int_workspace_buffer = torch.empty( + self._pin_memory_int_workspace_buffer_p = torch.empty( + (8 * 1024 * 1024,), + dtype=torch.uint8, + pin_memory=True, + device="cpu", + ) + self._int_workspace_buffer_d = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) + self._pin_memory_int_workspace_buffer_d = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, pin_memory=True, @@ -209,29 +224,46 @@ def __init__( ) if use_cuda_graph: - if not torch.is_tensor(paged_kv_indptr_buffer): + if not torch.is_tensor(qo_indptr_buffer_p): + raise ValueError( + "qo_indptr_buffer_p should be a torch.Tensor in CUDA graph mode" + ) + if not torch.is_tensor(paged_kv_indptr_buffer_p) or not torch.is_tensor( + paged_kv_indptr_buffer_d + ): raise ValueError( "paged_kv_indptr_buffer should be a torch.Tensor in cudagraph mode" ) - if not torch.is_tensor(paged_kv_indices_buffer): + if not torch.is_tensor(paged_kv_indices_buffer_p) or not torch.is_tensor( + paged_kv_indices_buffer_d + ): raise ValueError( "paged_kv_indices_buffer should be a torch.Tensor in cudagraph mode" ) - if not torch.is_tensor(paged_kv_last_page_len_buffer): + if not torch.is_tensor( + paged_kv_last_page_len_buffer_p + ) or not torch.is_tensor(paged_kv_last_page_len_buffer_d): raise ValueError( "paged_kv_last_page_len_buffer should be a torch.Tensor in cudagraph mode" ) - self._fixed_batch_size = len(paged_kv_last_page_len_buffer) - if len(paged_kv_indptr_buffer) != self._fixed_batch_size + 1: + self._fixed_batch_size = len(paged_kv_last_page_len_buffer_p) + if len(paged_kv_indptr_buffer_p) != self._fixed_batch_size + 1: + raise ValueError( + "The length of paged_kv_indptr_buffer_p should be batch_size + 1" + ) + if len(paged_kv_last_page_len_buffer_p) != self._fixed_batch_size: raise ValueError( - "The size of paged_kv_indptr_buffer should be batch_size + 1" + "The length of paged_kv_last_page_len_buffer_p should be batch_size" ) else: self._fixed_batch_size = 0 - self._paged_kv_indptr_buf = paged_kv_indptr_buffer - self._paged_kv_indices_buf = paged_kv_indices_buffer - self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buffer + self._paged_kv_indptr_buf_p = paged_kv_indptr_buffer_p + self._paged_kv_indices_buf_p = paged_kv_indices_buffer_p + self._paged_kv_last_page_len_buf_p = paged_kv_last_page_len_buffer_p + self._paged_kv_indptr_buf_d = paged_kv_indptr_buffer_d + self._paged_kv_indices_buf_d = paged_kv_indices_buffer_d + self._paged_kv_last_page_len_buf_d = paged_kv_last_page_len_buffer_d self._use_tensor_cores = use_tensor_cores self._use_cuda_graph = use_cuda_graph From 3257899e686f003708d64e8c0e846900e7b45455 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 25 Jun 2025 04:26:01 +0000 Subject: [PATCH 04/33] bench against persistent --- benchmarks/bench_mixed_attention.py | 28 ++++++++++++++++++++++++---- include/flashinfer/attention/pod.cuh | 22 ++-------------------- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/benchmarks/bench_mixed_attention.py b/benchmarks/bench_mixed_attention.py index 515e353d0..4cc821441 100644 --- a/benchmarks/bench_mixed_attention.py +++ b/benchmarks/bench_mixed_attention.py @@ -121,21 +121,41 @@ def run_bench( causal_d=causal, ) ) + # Persistent attention + wrapper = flashinfer.BatchAttention(kv_layout="NHD") + wrapper.plan( + q_indptr.to(device), + kv_indptr.to(device), + torch.arange(num_blocks, dtype=torch.int32, device=device), + seq_lens.to(device), + num_qo_heads, + num_kv_heads, + head_dim, + head_dim, + page_block_size, + causal=causal, + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, + ) + ms_persistent = do_bench(lambda: wrapper.run(q, kv_data)) - print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms") - if len(p_kv_lens) == 1: - print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms") total_bytes = ( q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size() ) + print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms") + if len(p_kv_lens) == 1: + print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms") + print(f"Elapsed time (Persistent Attention): {ms_persistent:.2f} ms") + print(f"Loading memory size (MB): {total_bytes / (1024**2):.2f} MB") bandwidth_old_gb_s = total_bytes / (ms_old * 1e-3) / (1024**3) - + bandwidth_new_gb_s = total_bytes / (ms_persistent * 1e-3) / (1024**3) print(f"Memory bandwidth (Batched Prefill): {bandwidth_old_gb_s:.2f} GB/s") if len(p_kv_lens) == 1: bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3) print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s") + print(f"Memory bandwidth (Persistent Attention): {bandwidth_new_gb_s:.2f} GB/s") if __name__ == "__main__": diff --git a/include/flashinfer/attention/pod.cuh b/include/flashinfer/attention/pod.cuh index fc350e9a7..f261945bd 100644 --- a/include/flashinfer/attention/pod.cuh +++ b/include/flashinfer/attention/pod.cuh @@ -120,7 +120,7 @@ __global__ __launch_bounds__(std::max( linear_bid = ((int*)smem)[0]; op = ((int*)smem)[1]; // Sync to force all threads to wait - __syncthreads(); + // __syncthreads(); if (op == PREFILL) { const uint32_t linear_tid = threadIdx.x; @@ -209,24 +209,6 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, uint32_t cta_tile_q_p = 0; int64_t unpacked_qo_len = qo_len * group_size; // TODO(@Wenxuan): Include batch size in calculation - // if (unpacked_qo_len > 64 && HEAD_DIM_VO < 256) { - // cta_tile_q_p = 128; - // } else { - // auto compute_capacity = GetCudaComputeCapability(); - // if (compute_capacity.first >= 8) { - // // Ampere or newer - // if (unpacked_qo_len > 16) { - // // avg_packed_qo_len <= 64 - // cta_tile_q_p = 64; - // } else { - // // avg_packed_qo_len <= 16 - // cta_tile_q_p = 16; - // } - // } else { - // // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout - // cta_tile_q_p = 64; - // } - // } cta_tile_q_p = FA2DetermineCtaTileQ(unpacked_qo_len, HEAD_DIM_VO); // Decode vars setup @@ -413,7 +395,7 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, // ************************************************ / static int* tbAssign = nullptr; - if (tbAssign == nullptr) cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2)); + cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2)); cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2)); // Setup kernel arguments From 82f1550c9d952eb49b277f7c8ee574a8cf2b1938 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 27 Jun 2025 04:53:05 +0000 Subject: [PATCH 05/33] rename xsize to num_qo_tiles --- include/flashinfer/attention/pod.cuh | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/include/flashinfer/attention/pod.cuh b/include/flashinfer/attention/pod.cuh index f261945bd..d74c535f8 100644 --- a/include/flashinfer/attention/pod.cuh +++ b/include/flashinfer/attention/pod.cuh @@ -38,7 +38,7 @@ template __global__ __launch_bounds__(std::max( KTraits_P::NUM_THREADS, - KTraits_D::NUM_THREADS)) void PODWithKVCacheTensorKernel(const uint32_t xsize, + KTraits_D::NUM_THREADS)) void PODWithKVCacheTensorKernel(const uint32_t num_qo_tiles, const __grid_constant__ PrefillParams prefill_params, const __grid_constant__ DecodeParams @@ -55,7 +55,7 @@ __global__ __launch_bounds__(std::max( const uint32_t num_kv_heads_d = decode_params.paged_kv.num_heads; // THREADBLOCKS - const uint32_t prefill_blocks = num_kv_heads_p * xsize * (PartitionKV_P ? num_chunks : 1); + const uint32_t prefill_blocks = num_kv_heads_p * num_qo_tiles * (PartitionKV_P ? num_chunks : 1); const uint32_t decode_blocks = padded_bsize * num_kv_heads_d; int op; @@ -134,17 +134,17 @@ __global__ __launch_bounds__(std::max( // BlockID exceeds limit if (linear_bid >= prefill_blocks) return; - const uint32_t bx = linear_bid % xsize; + const uint32_t bx = linear_bid % num_qo_tiles; auto& smem_storage = reinterpret_cast(smem); // Not partition_kv if constexpr (!PartitionKV_P) { const uint32_t chunk_idx = 0; - const uint32_t kv_head_idx = linear_bid / xsize; + const uint32_t kv_head_idx = linear_bid / num_qo_tiles; SinglePrefillWithKVCacheDevice(prefill_params, smem_storage, tid, bx, chunk_idx, kv_head_idx, 1, num_kv_heads_p); } else { - const uint32_t chunk_idx = (linear_bid / xsize) % num_chunks; - const uint32_t kv_head_idx = linear_bid / (xsize * num_chunks); + const uint32_t chunk_idx = (linear_bid / num_qo_tiles) % num_chunks; + const uint32_t kv_head_idx = linear_bid / (num_qo_tiles * num_chunks); SinglePrefillWithKVCacheDevice(prefill_params, smem_storage, tid, bx, chunk_idx, kv_head_idx, num_chunks, num_kv_heads_p); } @@ -375,8 +375,9 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, decode_params.o = tmp_v; decode_params.lse = tmp_s; } - uint32_t xsize = ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q); - int nblks_p(xsize * (prefill_params.partition_kv ? prefill_params.partition_kv : 1) * + uint32_t num_qo_tiles = ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q); + int nblks_p(num_qo_tiles * + (prefill_params.partition_kv ? prefill_params.partition_kv : 1) * num_kv_heads); int nthrs_p(32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P); @@ -399,7 +400,7 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2)); // Setup kernel arguments - void* args[] = {(void*)&xsize, (void*)&prefill_params, (void*)&decode_params, + void* args[] = {(void*)&num_qo_tiles, (void*)&prefill_params, (void*)&decode_params, (void*)&tbAssign}; FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -416,7 +417,7 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, config.blockDim = nthrs; config.dynamicSmemBytes = smem_size; config.stream = stream; - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, xsize, prefill_params, + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, num_qo_tiles, prefill_params, decode_params, tbAssign)); } else { FLASHINFER_CUDA_CALL( From 06fee3183fcd712af3fea69bec2213d7d72139e3 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 30 Jun 2025 04:16:11 +0000 Subject: [PATCH 06/33] fix --- benchmarks/bench_mixed_attention.py | 15 +- flashinfer/pod.py | 72 +++----- include/flashinfer/attention/scheduler.cuh | 182 ++++++++++++++++++++- 3 files changed, 215 insertions(+), 54 deletions(-) diff --git a/benchmarks/bench_mixed_attention.py b/benchmarks/bench_mixed_attention.py index 4cc821441..a35ed968e 100644 --- a/benchmarks/bench_mixed_attention.py +++ b/benchmarks/bench_mixed_attention.py @@ -163,13 +163,18 @@ def run_bench( torch.random.manual_seed(42) # Irregular sequence lengths for prefill and decode - d_q_len_configs = [[1] * 122, [1] * 128, [1] * 242, [1] * 256] - d_kv_len_configs = [[600] * 122, [10000] * 128, [400] * 242, [8192] * 256] - p_q_configs = [[17] * 1, [10000], [17] * 1, []] - p_kv_configs = [[10000] * 1, [10000], [8192] * 1, []] + # d_q_len_configs = [[1] * 122, [1] * 128, [1] * 242, [1] * 256] + # d_kv_len_configs = [[600] * 122, [10000] * 128, [400] * 242, [8192] * 256] + # p_q_configs = [[17] * 1, [10000], [17] * 1, []] + # p_kv_configs = [[10000] * 1, [10000], [8192] * 1, []] + + p_q_configs = [] + p_kv_configs = [] + d_q_len_configs = [] + d_kv_len_configs = [] # construct random length testcases - for _ in range(1): + for _ in range(3): bsz = 256 stride = 16 sparsity = 0.05 diff --git a/flashinfer/pod.py b/flashinfer/pod.py index 82a64b786..1b0294431 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -118,13 +118,10 @@ def __init__( float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, - qo_indptr_buffer_p: Optional[torch.Tensor] = None, - paged_kv_indptr_buffer_p: Optional[torch.Tensor] = None, - paged_kv_indices_buffer_p: Optional[torch.Tensor] = None, - paged_kv_last_page_len_buffer_p: Optional[torch.Tensor] = None, - paged_kv_indptr_buffer_d: Optional[torch.Tensor] = None, - paged_kv_indices_buffer_d: Optional[torch.Tensor] = None, - paged_kv_last_page_len_buffer_d: Optional[torch.Tensor] = None, + qo_indptr_buffer: Optional[torch.Tensor] = None, + paged_kv_indptr_buffer: Optional[torch.Tensor] = None, + paged_kv_indices_buffer: Optional[torch.Tensor] = None, + paged_kv_last_page_len_buffer: Optional[torch.Tensor] = None, custom_mask_buf_p: Optional[torch.Tensor] = None, mask_indptr_buf_p: Optional[torch.Tensor] = None, jit_args: Optional[List[Any]] = None, @@ -146,36 +143,27 @@ def __init__( auxiliary data structures will be stored as the provided buffers. The ``batch_size`` cannot change during the lifecycle of this wrapper when CUDAGraph is enabled. - qo_indptr_buffer_p: Optional[torch.Tensor] + qo_indptr_buffer: Optional[torch.Tensor] The user reserved buffer to store the ``qo_indptr`` array, the size of the buffer should be ``[batch_size + 1]``. This argument is only effective when ``use_cuda_graph`` is ``True``. - paged_kv_indptr_buffer_p: Optional[torch.Tensor] + paged_kv_indptr_buffer: Optional[torch.Tensor] The user reserved buffer on GPU to store the indptr of the prefill paged kv cache, the size of the buffer should be ``[batch_size + 1]``. Only needed when ``use_cuda_graph`` is ``True``. - paged_kv_indices_buffer_p: Optional[torch.Tensor] + paged_kv_indices_buffer: Optional[torch.Tensor] The user reserved buffer on GPU to store the page indices of the prefill paged kv cache, should be large enough to store the maximum number of page indices (``max_num_pages``) during the lifecycle of this wrapper. Only needed when ``use_cuda_graph`` is ``True``. - paged_kv_last_page_len_buffer_p: Optional[torch.Tensor] + paged_kv_last_page_len_buffer: Optional[torch.Tensor] The user reserved buffer on GPU to store the number of entries in the last page for prefill, the size of the buffer should be ``[batch_size]``. Only needed when ``use_cuda_graph`` is ``True``. - paged_kv_indptr_buffer_d: Optional[torch.Tensor] - Same as ``paged_kv_indptr_buffer_p`` but for decode. - - paged_kv_indices_buffer_d: Optional[torch.Tensor] - Same as ``paged_kv_indices_buffer_p`` but for decode. - - paged_kv_last_page_len_buffer_d: Optional[torch.Tensor] - Same as ``paged_kv_last_page_len_buffer_p`` but for decode. - jit_args : Optional[List[Any]] If provided, the wrapper will use the provided arguments to create the JIT module, otherwise, the wrapper will use default attention implementation. @@ -203,20 +191,11 @@ def __init__( self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device - self._qo_indptr_buf_p = qo_indptr_buffer_p - self._int_workspace_buffer_p = torch.empty( + self._qo_indptr_buf = qo_indptr_buffer + self._int_workspace_buffer = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) - self._pin_memory_int_workspace_buffer_p = torch.empty( - (8 * 1024 * 1024,), - dtype=torch.uint8, - pin_memory=True, - device="cpu", - ) - self._int_workspace_buffer_d = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) - self._pin_memory_int_workspace_buffer_d = torch.empty( + self._pin_memory_int_workspace_buffer = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, pin_memory=True, @@ -224,46 +203,43 @@ def __init__( ) if use_cuda_graph: - if not torch.is_tensor(qo_indptr_buffer_p): + if not torch.is_tensor(qo_indptr_buffer): raise ValueError( "qo_indptr_buffer_p should be a torch.Tensor in CUDA graph mode" ) - if not torch.is_tensor(paged_kv_indptr_buffer_p) or not torch.is_tensor( - paged_kv_indptr_buffer_d + if not torch.is_tensor(paged_kv_indptr_buffer) or not torch.is_tensor( + paged_kv_indptr_buffer ): raise ValueError( "paged_kv_indptr_buffer should be a torch.Tensor in cudagraph mode" ) - if not torch.is_tensor(paged_kv_indices_buffer_p) or not torch.is_tensor( - paged_kv_indices_buffer_d + if not torch.is_tensor(paged_kv_indices_buffer) or not torch.is_tensor( + paged_kv_indices_buffer ): raise ValueError( "paged_kv_indices_buffer should be a torch.Tensor in cudagraph mode" ) if not torch.is_tensor( - paged_kv_last_page_len_buffer_p - ) or not torch.is_tensor(paged_kv_last_page_len_buffer_d): + paged_kv_last_page_len_buffer + ) or not torch.is_tensor(paged_kv_last_page_len_buffer): raise ValueError( "paged_kv_last_page_len_buffer should be a torch.Tensor in cudagraph mode" ) - self._fixed_batch_size = len(paged_kv_last_page_len_buffer_p) - if len(paged_kv_indptr_buffer_p) != self._fixed_batch_size + 1: + self._fixed_batch_size = len(paged_kv_last_page_len_buffer) + if len(paged_kv_indptr_buffer) != self._fixed_batch_size + 1: raise ValueError( "The length of paged_kv_indptr_buffer_p should be batch_size + 1" ) - if len(paged_kv_last_page_len_buffer_p) != self._fixed_batch_size: + if len(paged_kv_last_page_len_buffer) != self._fixed_batch_size: raise ValueError( "The length of paged_kv_last_page_len_buffer_p should be batch_size" ) else: self._fixed_batch_size = 0 - self._paged_kv_indptr_buf_p = paged_kv_indptr_buffer_p - self._paged_kv_indices_buf_p = paged_kv_indices_buffer_p - self._paged_kv_last_page_len_buf_p = paged_kv_last_page_len_buffer_p - self._paged_kv_indptr_buf_d = paged_kv_indptr_buffer_d - self._paged_kv_indices_buf_d = paged_kv_indices_buffer_d - self._paged_kv_last_page_len_buf_d = paged_kv_last_page_len_buffer_d + self._paged_kv_indptr_buf = paged_kv_indptr_buffer + self._paged_kv_indices_buf = paged_kv_indices_buffer + self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buffer self._use_tensor_cores = use_tensor_cores self._use_cuda_graph = use_cuda_graph diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 1b6250231..445efd1f2 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -677,7 +677,7 @@ template inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, void* page_locked_int_buffer, size_t int_workspace_size_in_bytes, PrefillPlanInfo& plan_info, - IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t total_num_rows, + IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_t total_num_rows, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim_qk, uint32_t head_dim_vo, uint32_t page_size, bool enable_cuda_graph, uint32_t sizeof_dtype_o, @@ -689,6 +689,186 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i FLASHINFER_ERROR(err_msg.str()); } + // step 0: get the number of SMs + int num_sm = 0; + int dev_id = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + int num_blocks_per_sm = 3; + int max_grid_size = num_blocks_per_sm * num_sm; + uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads; + + // step 2: determine kv_chunk_size + auto [split_kv, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec, + qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] = + PrefillSplitQOKVIndptr(qo_indptr_p, kv_indptr_p, total_num_rows, batch_size, num_qo_heads, + num_kv_heads, head_dim_vo, page_size, max_batch_size_if_split, + enable_cuda_graph); + + plan_info.cta_tile_q = cta_tile_q; + plan_info.total_num_rows = total_num_rows; + plan_info.enable_cuda_graph = enable_cuda_graph; + plan_info.padded_batch_size = padded_batch_size; + plan_info.split_kv = split_kv; + + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + plan_info.request_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size, 16, "batch_prefill_request_indices"); + plan_info.qo_tile_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size, 16, "batch_prefill_qo_tile_indices"); + plan_info.kv_tile_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size, 16, "batch_prefill_kv_tile_indices"); + plan_info.o_indptr_offset = int_allocator.aligned_alloc_offset(sizeof(IdType) * (batch_size + 1), + 16, "batch_prefill_o_indptr"); + plan_info.kv_chunk_size_ptr_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr"); + + if (plan_info.enable_cuda_graph) { + plan_info.total_num_rows_offset = + int_allocator.aligned_alloc_offset(sizeof(uint32_t), 16, "batch_prefill_total_num_rows"); + uint32_t* total_num_rows_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.total_num_rows_offset); + *total_num_rows_h = qo_indptr_h[batch_size]; + } + + IdType* request_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.request_indices_offset); + IdType* qo_tile_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.qo_tile_indices_offset); + IdType* kv_tile_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_tile_indices_offset); + IdType* o_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.o_indptr_offset); + IdType* kv_chunk_size_ptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_chunk_size_ptr_offset); + std::copy(request_indices_vec.begin(), request_indices_vec.end(), request_indices_h); + std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), qo_tile_indices_h); + std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), kv_tile_indices_h); + std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), o_indptr_h); + kv_chunk_size_ptr_h[0] = kv_chunk_size; + + if (split_kv) { + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + plan_info.v_offset = float_allocator.aligned_alloc_offset( + num_qo_heads * padded_batch_size * cta_tile_q * head_dim_vo * sizeof(float), 16, + "batch_prefill_tmp_v"); + plan_info.s_offset = float_allocator.aligned_alloc_offset( + num_qo_heads * padded_batch_size * cta_tile_q * sizeof(float), 16, "batch_prefill_tmp_s"); + plan_info.merge_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * (plan_info.total_num_rows + 1), 16, "batch_prefill_merge_indptr"); + plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset( + sizeof(bool) * padded_batch_size, 16, "batch_prefill_block_valid_mask"); + + IdType* merge_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.merge_indptr_offset); + bool* block_valid_mask_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.block_valid_mask_offset); + std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), merge_indptr_h); + for (uint32_t i = 0; i < padded_batch_size; ++i) { + block_valid_mask_h[i] = i < new_batch_size; + } + } + + size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, + cudaMemcpyHostToDevice, stream)); + + return cudaSuccess; +} + +struct PODPlanInfo { + int64_t padded_batch_size; + int64_t total_num_rows; + int64_t total_num_rows_offset; + int64_t cta_tile_q; + int64_t request_indices_offset; + int64_t qo_tile_indices_offset; + int64_t kv_tile_indices_offset; + int64_t merge_indptr_offset; + int64_t o_indptr_offset; + int64_t kv_chunk_size_ptr_offset; + int64_t v_offset; + int64_t s_offset; + int64_t block_valid_mask_offset; + bool enable_cuda_graph; + bool split_kv; + + PODPlanInfo() + : padded_batch_size(0), + total_num_rows(0), + total_num_rows_offset(0), + cta_tile_q(0), + request_indices_offset(0), + qo_tile_indices_offset(0), + kv_tile_indices_offset(0), + merge_indptr_offset(0), + o_indptr_offset(0), + kv_chunk_size_ptr_offset(0), + v_offset(0), + s_offset(0), + block_valid_mask_offset(0), + enable_cuda_graph(false), + split_kv(false) {} + + // convert PrefillPlanInfo to std::vector + std::vector ToVector() const { + return {padded_batch_size, + total_num_rows, + total_num_rows_offset, + cta_tile_q, + request_indices_offset, + qo_tile_indices_offset, + kv_tile_indices_offset, + merge_indptr_offset, + o_indptr_offset, + kv_chunk_size_ptr_offset, + v_offset, + s_offset, + block_valid_mask_offset, + enable_cuda_graph, + split_kv}; + } + + // From std::vector to PodPlanInfo + void FromVector(const std::vector& vec) { + if (vec.size() != 15) { + std::ostringstream err_msg; + err_msg << "PodPlanInfo::FromVector: vec.size() should be 15, but got " << vec.size(); + FLASHINFER_ERROR(err_msg.str()); + } + padded_batch_size = vec[0]; + total_num_rows = vec[1]; + total_num_rows_offset = vec[2]; + cta_tile_q = vec[3]; + request_indices_offset = vec[4]; + qo_tile_indices_offset = vec[5]; + kv_tile_indices_offset = vec[6]; + merge_indptr_offset = vec[7]; + o_indptr_offset = vec[8]; + kv_chunk_size_ptr_offset = vec[9]; + v_offset = vec[10]; + s_offset = vec[11]; + block_valid_mask_offset = vec[12]; + enable_cuda_graph = vec[13]; + split_kv = vec[14]; + } +}; + +template +inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_bytes, + void* int_buffer, void* page_locked_int_buffer, + size_t int_workspace_size_in_bytes, PODPlanInfo& plan_info, + IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t total_num_rows, + uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim_qk, uint32_t head_dim_vo, uint32_t page_size, + bool enable_cuda_graph, uint32_t sizeof_dtype_o, cudaStream_t stream) { + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " + << num_kv_heads; + FLASHINFER_ERROR(err_msg.str()); + } + // step 0: get the number of SMs int num_sm = 0; int dev_id = 0; From f47f73e3673d7cd862c9e9302a476e977f8789e7 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 30 Jun 2025 15:20:14 +0000 Subject: [PATCH 07/33] fix --- benchmarks/bench_mixed_attention.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/benchmarks/bench_mixed_attention.py b/benchmarks/bench_mixed_attention.py index a35ed968e..4cc821441 100644 --- a/benchmarks/bench_mixed_attention.py +++ b/benchmarks/bench_mixed_attention.py @@ -163,18 +163,13 @@ def run_bench( torch.random.manual_seed(42) # Irregular sequence lengths for prefill and decode - # d_q_len_configs = [[1] * 122, [1] * 128, [1] * 242, [1] * 256] - # d_kv_len_configs = [[600] * 122, [10000] * 128, [400] * 242, [8192] * 256] - # p_q_configs = [[17] * 1, [10000], [17] * 1, []] - # p_kv_configs = [[10000] * 1, [10000], [8192] * 1, []] - - p_q_configs = [] - p_kv_configs = [] - d_q_len_configs = [] - d_kv_len_configs = [] + d_q_len_configs = [[1] * 122, [1] * 128, [1] * 242, [1] * 256] + d_kv_len_configs = [[600] * 122, [10000] * 128, [400] * 242, [8192] * 256] + p_q_configs = [[17] * 1, [10000], [17] * 1, []] + p_kv_configs = [[10000] * 1, [10000], [8192] * 1, []] # construct random length testcases - for _ in range(3): + for _ in range(1): bsz = 256 stride = 16 sparsity = 0.05 From c51cecc7c7364c16984c71315b1b27c9b8984a9e Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 30 Jun 2025 23:15:53 +0000 Subject: [PATCH 08/33] fix --- csrc/pod.cu | 28 ++++++++++++++++++++++ include/flashinfer/attention/scheduler.cuh | 2 +- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/csrc/pod.cu b/csrc/pod.cu index fabde1be7..8168c03e5 100644 --- a/csrc/pod.cu +++ b/csrc/pod.cu @@ -271,3 +271,31 @@ void pod_with_kv_cache_tensor( //}); }); } + +at::Tensor PODWithKVCachePlan(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr_p, + at::Tensor kv_indptr_p, at::Tensor kv_len_arr, int64_t total_num_rows, + int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, + int64_t head_dim_vo, bool causal) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + + PrefillPlanInfo plan_info; + + const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device()); + const cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); + cudaError_t status = PrefillPlan( + float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr(), + kv_indptr.data_ptr(), total_num_rows, batch_size, num_qo_heads, num_kv_heads, + head_dim_qk, head_dim_vo, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); + + TORCH_CHECK(status == cudaSuccess, + "Failed to plan prefill with error: ", cudaGetErrorString(status)); + + return vec_to_tensor(plan_info.ToVector()); +} diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 445efd1f2..8f42e4bf0 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -694,7 +694,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i int dev_id = 0; FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - int num_blocks_per_sm = 3; + int num_blocks_per_sm = 2; int max_grid_size = num_blocks_per_sm * num_sm; uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads; From e73566cff643922424aca9a124ad971cb53d1bbd Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 3 Jul 2025 00:21:32 +0000 Subject: [PATCH 09/33] add mixed scheduler --- csrc/pod.cu | 67 ++++-- flashinfer/pod.py | 12 +- include/flashinfer/attention/scheduler.cuh | 248 ++++++++++++++++----- 3 files changed, 245 insertions(+), 82 deletions(-) diff --git a/csrc/pod.cu b/csrc/pod.cu index 8168c03e5..f990c1fbb 100644 --- a/csrc/pod.cu +++ b/csrc/pod.cu @@ -36,6 +36,36 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, using namespace flashinfer; +at::Tensor PODWithKVCachePlan(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr_p, + at::Tensor kv_indptr_p, at::Tensor kv_len_arr_p, + at::Tensor kv_indptr_d, at::Tensor qo_indptr_d, + int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, + int64_t head_dim_qk, int64_t head_dim_vo, bool causal) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + + PrefillPlanInfo plan_info; + + const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device()); + const cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); + cudaError_t status = PrefillPlan( + float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, qo_indptr_p.data_ptr(), + kv_indptr_p.data_ptr(), qo_indptr_d.data_ptr(), + kv_indptr_d.data_ptr(), total_num_rows, batch_size, num_qo_heads, num_kv_heads, + head_dim_qk, head_dim_vo, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); + + TORCH_CHECK(status == cudaSuccess, + "Failed to plan prefill with error: ", cudaGetErrorString(status)); + + return vec_to_tensor(plan_info.ToVector()); +} + void pod_with_kv_cache_tensor( // Prefill params at::Tensor q_p, at::Tensor k_p, at::Tensor v_p, at::Tensor tmp_p, at::Tensor o_p, @@ -200,30 +230,6 @@ void pod_with_kv_cache_tensor( params.q_stride_n = q_stride_n_d; params.q_stride_h = q_stride_h_d; params.window_left = window_left_d; - - params.request_indices = nullptr; - params.qo_tile_indices = nullptr; - params.kv_tile_indices = nullptr; - params.merge_indptr = nullptr; - params.o_indptr = nullptr; - params.kv_chunk_size_ptr = nullptr; - params.block_valid_mask = nullptr; - params.total_num_rows = nullptr; - params.max_total_num_rows = 0; - params.padded_batch_size = 0; - params.partition_kv = false; - - params.maybe_mask_indptr = maybe_mask_indptr_d - ? static_cast(maybe_mask_indptr_d->data_ptr()) - : nullptr; - params.maybe_alibi_slopes = maybe_alibi_slopes_d - ? static_cast(maybe_alibi_slopes_d->data_ptr()) - : nullptr; - params.logits_soft_cap = logits_soft_cap_d; - params.sm_scale = sm_scale_d; - params.rope_rcp_scale = rope_rcp_scale_d; - params.rope_rcp_theta = rope_rcp_theta_d; - params.request_indices = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.request_indices_offset); params.qo_tile_indices = @@ -245,6 +251,19 @@ void pod_with_kv_cache_tensor( } params.padded_batch_size = plan_info.padded_batch_size; params.max_total_num_rows = plan_info.total_num_rows; + + params.partition_kv = false; + params.maybe_mask_indptr = maybe_mask_indptr_d + ? static_cast(maybe_mask_indptr_d->data_ptr()) + : nullptr; + params.maybe_alibi_slopes = maybe_alibi_slopes_d + ? static_cast(maybe_alibi_slopes_d->data_ptr()) + : nullptr; + params.logits_soft_cap = logits_soft_cap_d; + params.sm_scale = sm_scale_d; + params.rope_rcp_scale = rope_rcp_scale_d; + params.rope_rcp_theta = rope_rcp_theta_d; + if (plan_info.enable_cuda_graph) { params.total_num_rows = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.total_num_rows_offset); diff --git a/flashinfer/pod.py b/flashinfer/pod.py index 1b0294431..ecb74c205 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -350,17 +350,17 @@ def plan( """ # Logits soft cap is not supported currently logits_soft_cap = False - batch_size = len(last_page_len_d) + batch_size_d = len(last_page_len_d) if logits_soft_cap is None: logits_soft_cap = 0.0 - qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") + qo_indptr_host = _get_range_buf(batch_size_d + 1, "cpu") if self.is_cuda_graph_enabled: - if batch_size != self._fixed_batch_size: + if batch_size_d != self._fixed_batch_size: raise ValueError( "The batch size should be fixed in cudagraph mode, the runtime batch size {} " " mismatches the batch size set during initialization {}".format( - batch_size, self._fixed_batch_size + batch_size_d, self._fixed_batch_size ) ) if len(indices_d) > len(self._paged_kv_indices_buf): @@ -429,8 +429,8 @@ def plan( qo_indptr_host, indptr_host, kv_lens_arr_host, - batch_size, # total_num_rows - batch_size, + batch_size_d, # total_num_rows + batch_size_d, num_qo_heads, num_kv_heads, page_size, diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 8f42e4bf0..200cf0080 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -492,19 +492,8 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in return cudaSuccess; } -template -inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, - uint32_t total_num_rows, uint32_t batch_size, - uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, - uint32_t page_size, uint32_t max_batch_size_if_split, - bool enable_cuda_graph) { - std::vector request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr; - merge_indptr.push_back(0); - o_indptr.push_back(0); - - const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; - - // step 1: determine packed_qo_len_arr and verify qo_indptr contents. +inline auto get_qkv_len_arr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size, + uint32_t num_qo_heads, uint32_t gqa_group_size) { std::vector packed_qo_len_arr(batch_size), kv_len_arr(batch_size); for (uint32_t i = 0; i < batch_size; ++i) { packed_qo_len_arr[i] = int64_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * int64_t(gqa_group_size); @@ -522,8 +511,12 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, FLASHINFER_ERROR(err_msg.str()); } } + return std::make_tuple(packed_qo_len_arr, kv_len_arr); +} - // step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q +inline auto get_q_tiles(std::vector& packed_qo_len_arr, uint32_t batch_size, + uint32_t head_dim, uint32_t page_size, uint32_t total_num_rows, + uint32_t gqa_group_size, bool enable_cuda_graph, uint32_t tile_size = -1) { const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U); uint32_t cta_tile_q; uint32_t total_num_tiles_q; @@ -533,7 +526,11 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, // the CUDA graph is created fixes the maximum number of tokens. const uint64_t max_seq_len = total_num_rows - batch_size + 1; uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size; - cta_tile_q = FA2DetermineCtaTileQ(max_qo_len, head_dim); + if (tile_size == -1) { + cta_tile_q = FA2DetermineCtaTileQ(max_qo_len, head_dim); + } else { + cta_tile_q = tile_size; + } // Find an upper bound for the number of tiles, derived from the total // number of rows and the batch size. The sum of qo lengths rounded @@ -546,19 +543,36 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, sum_packed_qo_len += packed_qo_len_arr[i]; } const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size; - cta_tile_q = FA2DetermineCtaTileQ(avg_packed_qo_len, head_dim); + if (tile_size == -1) { + cta_tile_q = FA2DetermineCtaTileQ(avg_packed_qo_len, head_dim); + } else { + cta_tile_q = tile_size; + } total_num_tiles_q = 0; for (uint32_t i = 0; i < batch_size; ++i) { total_num_tiles_q += ceil_div(packed_qo_len_arr[i], cta_tile_q); } } + return std::make_tuple(cta_tile_q, total_num_tiles_q); +} - auto [split_kv, kv_chunk_size] = - PrefillBinarySearchKVChunkSize(enable_cuda_graph, max_batch_size_if_split, packed_qo_len_arr, - kv_len_arr, cta_tile_q, min_kv_chunk_size); +inline auto get_qkv_tile_indices(std::vector& packed_qo_len_arr, + std::vector& kv_len_arr, uint32_t batch_size, + uint32_t cta_tile_q, uint32_t kv_chunk_size, + uint32_t gqa_group_size, + std::vector& merge_indptr = nullptr, + std::vector& o_indptr = nullptr) { + std::vector request_indices, qo_tile_indices, kv_tile_indices; + if (merge_indptr == nullptr) { + merge_indptr = std::vector(); + merge_indptr.push_back(0); + } + if (o_indptr == nullptr) { + o_indptr = std::vector(); + o_indptr.push_back(0); + } - // step 3: split qo_indptr and kv_indptr uint32_t new_batch_size = 0; for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { const int64_t packed_qo_len = packed_qo_len_arr[request_idx]; @@ -581,6 +595,38 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, } o_indptr.push_back(o_indptr.back() + qo_len * num_tiles_kv); } + return std::make_tuple(request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, + new_batch_size); +} + +template +inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, + uint32_t total_num_rows, uint32_t batch_size, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, + uint32_t page_size, uint32_t max_batch_size_if_split, + bool enable_cuda_graph) { + std::vector request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr; + merge_indptr.push_back(0); + o_indptr.push_back(0); + + const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; + + // step 1: determine packed_qo_len_arr and verify qo_indptr contents. + auto [packed_qo_len_arr, kv_len_arr] = + get_qkv_len_arr(qo_indptr_h, kv_indptr_h, batch_size, num_qo_heads, gqa_group_size); + + // step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q + auto [cta_tile_q, total_num_tiles_q] = + get_q_tiles(packed_qo_len_arr, batch_size, head_dim, page_size, total_num_rows, + gqa_group_size, enable_cuda_graph); + + auto [split_kv, kv_chunk_size] = + PrefillBinarySearchKVChunkSize(enable_cuda_graph, max_batch_size_if_split, packed_qo_len_arr, + kv_len_arr, cta_tile_q, min_kv_chunk_size); + + auto [request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, new_batch_size] = + get_qkv_tile_indices(packed_qo_len_arr, kv_len_arr, batch_size, cta_tile_q, kv_chunk_size, + gqa_group_size); const size_t padded_batch_size = enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) : new_batch_size; @@ -776,20 +822,99 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i return cudaSuccess; } +/* +Modifed from PrefillSplitQOKVIndptr to support two tile sizes, and assign blocks proportional to the +number of tiles. +*/ +template +inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_t total_num_rows_p, + uint32_t batch_size_p, IdType* qo_indptr_d, IdType* kv_indptr_d, + uint32_t total_num_rows_d, uint32_t batch_size_d, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, + uint32_t page_size, uint32_t max_batch_size_if_split, + bool enable_cuda_graph) { + std::vector request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr; + merge_indptr.push_back(0); + o_indptr.push_back(0); + + const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; + + // step 1: determine packed_qo_len_arr and verify qo_indptr contents. + auto [packed_qo_len_arr_p, kv_len_arr_p] = + get_qkv_len_arr(qo_indptr_p, kv_indptr_p, batch_size_p, num_qo_heads, gqa_group_size); + auto [packed_qo_len_arr_d, kv_len_arr_d] = + get_qkv_len_arr(qo_indptr_d, kv_indptr_d, batch_size_d, num_qo_heads, gqa_group_size); + + // step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q + auto [cta_tile_q_p, num_tiles_q_p] = + get_q_tiles(packed_qo_len_arr_p, batch_size_p, head_dim, page_size, total_num_rows_p, + gqa_group_size, enable_cuda_graph); + auto cta_tile_q_d = 16; // minimum for tensor core decode + auto [cta_tile_q_d, num_tiles_q_d] = + get_q_tiles(packed_qo_len_arr_d, batch_size_d, head_dim, page_size, total_num_rows_d, + gqa_group_size, enable_cuda_graph, cta_tile_q_d); + + uint32_t total_num_tiles_q = num_tiles_q_p + num_tiles_q_d; + // Assign CTAs proportional to the number of query tiles + // TODO(Wenxuan): explore a more balanced cost function considering kv len. + // See discussion: https://github.com/flashinfer-ai/flashinfer/issues/1175 + uint32_t max_bs_p = max_batch_size_if_split * num_tiles_q_p / total_num_tiles_q; + uint32_t max_bs_d = max_batch_size_if_split * num_tiles_q_d / total_num_tiles_q; + auto [split_kv_p, kv_chunk_size_p] = + PrefillBinarySearchKVChunkSize(enable_cuda_graph, max_bs_p, packed_qo_len_arr_p, kv_len_arr_p, + cta_tile_q_p, min_kv_chunk_size); + auto [split_kv_d, kv_chunk_size_d] = + PrefillBinarySearchKVChunkSize(enable_cuda_graph, max_bs_d, packed_qo_len_arr_d, kv_len_arr_d, + cta_tile_q_d, min_kv_chunk_size); + + // step 3: split qo_indptr and kv_indptr + auto [request_indices_p, qo_tile_indices_p, kv_tile_indices_p, merge_indptr, o_indptr, + new_batch_size_p] = get_qkv_tile_indices(packed_qo_len_arr_p, kv_len_arr_p, batch_size_p, + cta_tile_q_p, kv_chunk_size_p, gqa_group_size); + auto [request_indices_d, qo_tile_indices_d, kv_tile_indices_d, merge_indptr_d, o_indptr_d, + new_batch_size_d] = + get_qkv_tile_indices(packed_qo_len_arr_d, kv_len_arr_d, batch_size_d, cta_tile_q_d, + kv_chunk_size_d, gqa_group_size, merge_indptr, o_indptr); + + bool split_kv = split_kv_p || split_kv_d; + uint32_t new_batch_size = new_batch_size_p + new_batch_size_d; + const size_t padded_batch_size_p = + enable_cuda_graph ? std::max(max_bs_p, total_num_tiles_q_p) : new_batch_size_p; + const size_t padded_batch_size_d = + enable_cuda_graph ? std::max(max_bs_d, total_num_tiles_q_d) : new_batch_size_d; + FLASHINFER_CHECK(new_batch_size <= padded_batch_size_p + padded_batch_size_d, + "new batch size should not exceed padded batch size"); + + // step 4: multiply kv_chunk_size by page_size + kv_chunk_size_p *= page_size; + kv_chunk_size_d *= page_size; + + return std::make_tuple( + split_kv, new_batch_size, padded_batch_size_p, padded_batch_size_d, cta_tile_q_p, + cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, std::move(merge_indptr), std::move(o_indptr), + std::move(request_indices_p), std::move(qo_tile_indices_p), std::move(kv_tile_indices_p), + std::move(request_indices_d), std::move(qo_tile_indices_d), std::move(kv_tile_indices_d)); +} + struct PODPlanInfo { int64_t padded_batch_size; int64_t total_num_rows; int64_t total_num_rows_offset; int64_t cta_tile_q; - int64_t request_indices_offset; - int64_t qo_tile_indices_offset; - int64_t kv_tile_indices_offset; + int64_t request_indices_offset_p; + int64_t request_indices_offset_d; + int64_t qo_tile_indices_offset_p; + int64_t qo_tile_indices_offset_d; + int64_t kv_tile_indices_offset_p; + int64_t kv_tile_indices_offset_d; int64_t merge_indptr_offset; int64_t o_indptr_offset; int64_t kv_chunk_size_ptr_offset; - int64_t v_offset; - int64_t s_offset; - int64_t block_valid_mask_offset; + int64_t v_offset_p; + int64_t v_offset_d int64_t s_offset_p; + int64_t s_offset_d; + int64_t block_valid_mask_offset_p; + int64_t block_valid_mask_offset_d; bool enable_cuda_graph; bool split_kv; @@ -798,15 +923,21 @@ struct PODPlanInfo { total_num_rows(0), total_num_rows_offset(0), cta_tile_q(0), - request_indices_offset(0), - qo_tile_indices_offset(0), - kv_tile_indices_offset(0), + request_indices_offset_p(0), + request_indices_offset_d(0), + qo_tile_indices_offset_p(0), + qo_tile_indices_offset_d(0), + kv_tile_indices_offset_p(0), + kv_tile_indices_offset_d(0), merge_indptr_offset(0), o_indptr_offset(0), kv_chunk_size_ptr_offset(0), - v_offset(0), - s_offset(0), - block_valid_mask_offset(0), + v_offset_p(0), + v_offset_d(0), + s_offset_p(0), + s_offset_d(0), + block_valid_mask_offset_p(0), + block_valid_mask_offset_d(0), enable_cuda_graph(false), split_kv(false) {} @@ -840,15 +971,21 @@ struct PODPlanInfo { total_num_rows = vec[1]; total_num_rows_offset = vec[2]; cta_tile_q = vec[3]; - request_indices_offset = vec[4]; - qo_tile_indices_offset = vec[5]; - kv_tile_indices_offset = vec[6]; + request_indices_offset_p = vec[4]; + request_indices_offset_d = vec[5]; + qo_tile_indices_offset_p = vec[6]; + qo_tile_indices_offset_d = vec[7]; + kv_tile_indices_offset_p = vec[8]; + kv_tile_indices_offset_d = vec[9]; merge_indptr_offset = vec[7]; o_indptr_offset = vec[8]; kv_chunk_size_ptr_offset = vec[9]; - v_offset = vec[10]; - s_offset = vec[11]; - block_valid_mask_offset = vec[12]; + v_offset_p = vec[10]; + v_offset_d = vec[11]; + s_offset_p = vec[12]; + s_offset_d = vec[13]; + block_valid_mask_offset_p = vec[14]; + block_valid_mask_offset_d = vec[15]; enable_cuda_graph = vec[13]; split_kv = vec[14]; } @@ -858,10 +995,12 @@ template inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, void* page_locked_int_buffer, size_t int_workspace_size_in_bytes, PODPlanInfo& plan_info, - IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t total_num_rows, - uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t head_dim_qk, uint32_t head_dim_vo, uint32_t page_size, - bool enable_cuda_graph, uint32_t sizeof_dtype_o, cudaStream_t stream) { + IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_t total_num_rows_p, + uint32_t batch_size_p, IdType* qo_indptr_d, IdType* kv_indptr_d, + uint32_t total_num_rows_d, uint32_t batch_size_d, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t head_dim_qk, uint32_t head_dim_vo, + uint32_t page_size, bool enable_cuda_graph, uint32_t sizeof_dtype_o, + cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " @@ -874,21 +1013,26 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by int dev_id = 0; FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - int num_blocks_per_sm = 2; + int num_blocks_per_sm = 2; // TODO(Wenxuan): increase this to reduce wave quantization? int max_grid_size = num_blocks_per_sm * num_sm; uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads; // step 2: determine kv_chunk_size - auto [split_kv, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec, - qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] = - PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, batch_size, num_qo_heads, - num_kv_heads, head_dim_vo, page_size, max_batch_size_if_split, - enable_cuda_graph); - - plan_info.cta_tile_q = cta_tile_q; - plan_info.total_num_rows = total_num_rows; + auto [split_kv, new_batch_size, padded_batch_size_p, padded_batch_size_d, cta_tile_q_p, + cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, merge_indptr_vec, o_indptr_vec, + request_indices_vec_p, qo_tile_indices_vec_p, kv_tile_indices_vec_p, request_indices_vec_d, + qo_tile_indices_vec_d, kv_tile_indices_vec_d] = + PODSplitQOKVIndptr(qo_indptr_p, kv_indptr_p, qo_indptr_d, kv_indptr_d, total_num_rows_p, + batch_size_p, total_num_rows_d, batch_size_d, num_qo_heads, num_kv_heads, + head_dim_vo, page_size, max_batch_size_if_split, enable_cuda_graph); + + plan_info.cta_tile_q_p = cta_tile_q_p; + plan_info.cta_tile_q_d = cta_tile_q_d; + plan_info.total_num_rows_p = total_num_rows_p; + plan_info.total_num_rows_d = total_num_rows_d; plan_info.enable_cuda_graph = enable_cuda_graph; - plan_info.padded_batch_size = padded_batch_size; + plan_info.padded_batch_size_p = padded_batch_size_p; + plan_info.padded_batch_size_d = padded_batch_size_d; plan_info.split_kv = split_kv; AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); From 1b2d4c092bf124bfb899324c1dcba0b76c101ad7 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 3 Jul 2025 05:01:14 +0000 Subject: [PATCH 10/33] rename to num_to_merge_qo_len --- csrc/batch_attention.cu | 2 +- csrc/batch_attention_customize_config.jinja | 2 +- include/flashinfer/attention/persistent.cuh | 6 +++--- include/flashinfer/attention/persistent_template.cuh | 4 ++-- include/flashinfer/attention/scheduler.cuh | 6 +++--- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/csrc/batch_attention.cu b/csrc/batch_attention.cu index 86a7fd287..3d6aac8c3 100644 --- a/csrc/batch_attention.cu +++ b/csrc/batch_attention.cu @@ -154,7 +154,7 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_indptr_offset); params[i].merge_o_indices = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_o_indices_offset); - params[i].num_packed_qo_len = + params[i].num_to_merge_qo_len = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.num_qo_len_offset); params[i].num_kv_heads = num_kv_heads; diff --git a/csrc/batch_attention_customize_config.jinja b/csrc/batch_attention_customize_config.jinja index 6bc85b067..116bfc3d8 100644 --- a/csrc/batch_attention_customize_config.jinja +++ b/csrc/batch_attention_customize_config.jinja @@ -80,7 +80,7 @@ struct PersistentParams { // for state reduction IdType* merge_indptr; IdType* merge_o_indices; - IdType* num_packed_qo_len; + IdType* num_to_merge_qo_len; uint32_t num_kv_heads; uint_fastdiv gqa_group_size; diff --git a/include/flashinfer/attention/persistent.cuh b/include/flashinfer/attention/persistent.cuh index 8c00aa579..3256bc14a 100644 --- a/include/flashinfer/attention/persistent.cuh +++ b/include/flashinfer/attention/persistent.cuh @@ -437,7 +437,7 @@ struct BlockBatchReductionPersistent { static __device__ __forceinline__ void Run( typename KTraits::DTypeIn* __restrict__ V, typename KTraits::DTypeO* __restrict__ v_merged, float* __restrict__ S, float* __restrict__ s_merged, - const typename KTraits::IdType num_packed_qo_len, const uint_fastdiv gqa_group_size, + const typename KTraits::IdType num_to_merge_qo_len, const uint_fastdiv gqa_group_size, const uint32_t num_kv_heads, const typename KTraits::IdType* indptr, const typename KTraits::IdType* o_indices, uint8_t* smem PROFILER_CLOSURE_FUNC_PARAMS) { using DTypeIn = typename KTraits::DTypeIn; @@ -465,10 +465,10 @@ struct BlockBatchReductionPersistent { float* s_smem = (float*)(smem + num_warps * num_smem_stages * bdy * head_dim * sizeof(DTypeIn) + warp_idx * 32 * sizeof(float)); - // V: [num_packed_qo_len x num_kv_tiles, num_kv_heads, head_dim] + // V: [num_to_merge_qo_len x num_kv_tiles, num_kv_heads, head_dim] // v_merged: [qo_len, num_kv_heads, gqa_group_size, head_dim] #pragma unroll 1 - for (uint32_t i = worker_id; i < num_packed_qo_len * num_kv_heads; i += num_workers) { + for (uint32_t i = worker_id; i < num_to_merge_qo_len * num_kv_heads; i += num_workers) { PROFILER_EVENT_START(profiler_closure, PersistentProfileEventType::kReduction); // remap workload diff --git a/include/flashinfer/attention/persistent_template.cuh b/include/flashinfer/attention/persistent_template.cuh index 3bd2331b3..3c367cc8f 100644 --- a/include/flashinfer/attention/persistent_template.cuh +++ b/include/flashinfer/attention/persistent_template.cuh @@ -81,7 +81,7 @@ __global__ __launch_bounds__( grid.sync(); BlockReductionRunner::Run(params_1.partial_o, params_1.final_o, params_1.partial_lse, - params_1.final_lse, *(params_1.num_packed_qo_len), + params_1.final_lse, *(params_1.num_to_merge_qo_len), params_1.gqa_group_size, params_1.num_kv_heads, params_1.merge_indptr, params_1.merge_o_indices, smem); #else @@ -90,7 +90,7 @@ __global__ __launch_bounds__( grid.sync(); BlockReductionRunner::Run(params_1.partial_o, params_1.final_o, params_1.partial_lse, - params_1.final_lse, *(params_1.num_packed_qo_len), + params_1.final_lse, *(params_1.num_to_merge_qo_len), params_1.gqa_group_size, params_1.num_kv_heads, params_1.merge_indptr, params_1.merge_o_indices, smem, profiler_closure); #endif diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 200cf0080..25646d23d 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -1493,7 +1493,7 @@ inline cudaError_t TwoStageHolisticPlan(void* float_buffer, size_t float_workspa // used for remapping the output offsets // layout [packed_qo_len x num_kv_tiels, num_kv_heads, head_dim] int partial_o_nnz = 0; - std::vector merge_indptr, merge_o_indices, num_expand_qo_len_vec; + std::vector merge_indptr, merge_o_indices, num_to_merge_qo_len_vec; merge_indptr.push_back(partial_o_nnz); for (uint32_t task = 0; task < NUM_TASKS; ++task) { int cluster_tile_q = CTA_TILE_Q_SIZES[task] * cluster_size; @@ -1638,7 +1638,7 @@ inline cudaError_t TwoStageHolisticPlan(void* float_buffer, size_t float_workspa } // update num_qo_len_vec - num_expand_qo_len_vec.push_back(merge_indptr.size() - 1); + num_to_merge_qo_len_vec.push_back(merge_indptr.size() - 1); // allocate buffer for state merge function plan_info.merge_indptr_offset = int_allocator.aligned_alloc_offset(sizeof(IdType) * max_packed_qo_lens, 16, "merge_indptr"); @@ -1650,7 +1650,7 @@ inline cudaError_t TwoStageHolisticPlan(void* float_buffer, size_t float_workspa CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.merge_indptr_offset, merge_indptr); CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.merge_o_indices_offset, merge_o_indices); CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.num_qo_len_offset, - num_expand_qo_len_vec); + num_to_merge_qo_len_vec); size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, From 78e1266aaca88fda1319139c9db8c860dbd518bc Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 3 Jul 2025 05:11:17 +0000 Subject: [PATCH 11/33] add params --- include/flashinfer/attention/scheduler.cuh | 79 ++++++++++++---------- 1 file changed, 44 insertions(+), 35 deletions(-) diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 25646d23d..3e4ad8513 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -561,9 +561,14 @@ inline auto get_qkv_tile_indices(std::vector& packed_qo_len_arr, std::vector& kv_len_arr, uint32_t batch_size, uint32_t cta_tile_q, uint32_t kv_chunk_size, uint32_t gqa_group_size, + std::vector& request_indices = nullptr, std::vector& merge_indptr = nullptr, std::vector& o_indptr = nullptr) { - std::vector request_indices, qo_tile_indices, kv_tile_indices; + std::vector qo_tile_indices, kv_tile_indices; + if (request_indices == nullptr) { + request_indices = std::vector(); + request_indices.push_back(0); + } if (merge_indptr == nullptr) { merge_indptr = std::vector(); merge_indptr.push_back(0); @@ -868,13 +873,13 @@ inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_ cta_tile_q_d, min_kv_chunk_size); // step 3: split qo_indptr and kv_indptr - auto [request_indices_p, qo_tile_indices_p, kv_tile_indices_p, merge_indptr, o_indptr, + auto [request_indices, qo_tile_indices_p, kv_tile_indices_p, merge_indptr, o_indptr, new_batch_size_p] = get_qkv_tile_indices(packed_qo_len_arr_p, kv_len_arr_p, batch_size_p, cta_tile_q_p, kv_chunk_size_p, gqa_group_size); - auto [request_indices_d, qo_tile_indices_d, kv_tile_indices_d, merge_indptr_d, o_indptr_d, - new_batch_size_d] = - get_qkv_tile_indices(packed_qo_len_arr_d, kv_len_arr_d, batch_size_d, cta_tile_q_d, - kv_chunk_size_d, gqa_group_size, merge_indptr, o_indptr); + auto [request_indices, qo_tile_indices_d, kv_tile_indices_d, merge_indptr_d, o_indptr_d, + new_batch_size_d] = get_qkv_tile_indices(packed_qo_len_arr_d, kv_len_arr_d, batch_size_d, + cta_tile_q_d, kv_chunk_size_d, gqa_group_size, + request_indices, merge_indptr, o_indptr); bool split_kv = split_kv_p || split_kv_d; uint32_t new_batch_size = new_batch_size_p + new_batch_size_d; @@ -889,11 +894,11 @@ inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_ kv_chunk_size_p *= page_size; kv_chunk_size_d *= page_size; - return std::make_tuple( - split_kv, new_batch_size, padded_batch_size_p, padded_batch_size_d, cta_tile_q_p, - cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, std::move(merge_indptr), std::move(o_indptr), - std::move(request_indices_p), std::move(qo_tile_indices_p), std::move(kv_tile_indices_p), - std::move(request_indices_d), std::move(qo_tile_indices_d), std::move(kv_tile_indices_d)); + return std::make_tuple(split_kv, new_batch_size, padded_batch_size_p, padded_batch_size_d, + cta_tile_q_p, cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, + std::move(merge_indptr), std::move(o_indptr), std::move(request_indices), + std::move(qo_tile_indices_p), std::move(kv_tile_indices_p), + std::move(qo_tile_indices_d), std::move(kv_tile_indices_d)); } struct PODPlanInfo { @@ -909,7 +914,8 @@ struct PODPlanInfo { int64_t kv_tile_indices_offset_d; int64_t merge_indptr_offset; int64_t o_indptr_offset; - int64_t kv_chunk_size_ptr_offset; + int64_t kv_chunk_size_ptr_offset_p; + int64_t kv_chunk_size_ptr_offset_d; int64_t v_offset_p; int64_t v_offset_d int64_t s_offset_p; int64_t s_offset_d; @@ -931,13 +937,13 @@ struct PODPlanInfo { kv_tile_indices_offset_d(0), merge_indptr_offset(0), o_indptr_offset(0), - kv_chunk_size_ptr_offset(0), + kv_chunk_size_ptr_offset_p(0), + kv_chunk_size_ptr_offset_d(0), v_offset_p(0), v_offset_d(0), s_offset_p(0), s_offset_d(0), - block_valid_mask_offset_p(0), - block_valid_mask_offset_d(0), + block_valid_mask_offset(0), enable_cuda_graph(false), split_kv(false) {} @@ -952,7 +958,8 @@ struct PODPlanInfo { kv_tile_indices_offset, merge_indptr_offset, o_indptr_offset, - kv_chunk_size_ptr_offset, + kv_chunk_size_ptr_offset_p, + kv_chunk_size_ptr_offset_d, v_offset, s_offset, block_valid_mask_offset, @@ -979,13 +986,13 @@ struct PODPlanInfo { kv_tile_indices_offset_d = vec[9]; merge_indptr_offset = vec[7]; o_indptr_offset = vec[8]; - kv_chunk_size_ptr_offset = vec[9]; - v_offset_p = vec[10]; - v_offset_d = vec[11]; - s_offset_p = vec[12]; + kv_chunk_size_ptr_offset_p = vec[9]; + kv_chunk_size_ptr_offset_d = vec[10]; + v_offset_p = vec[11]; + v_offset_d = vec[12]; + s_offset_p = vec[13]; s_offset_d = vec[13]; - block_valid_mask_offset_p = vec[14]; - block_valid_mask_offset_d = vec[15]; + block_valid_mask_offset = vec[14]; enable_cuda_graph = vec[13]; split_kv = vec[14]; } @@ -1020,8 +1027,8 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by // step 2: determine kv_chunk_size auto [split_kv, new_batch_size, padded_batch_size_p, padded_batch_size_d, cta_tile_q_p, cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, merge_indptr_vec, o_indptr_vec, - request_indices_vec_p, qo_tile_indices_vec_p, kv_tile_indices_vec_p, request_indices_vec_d, - qo_tile_indices_vec_d, kv_tile_indices_vec_d] = + request_indices_vec, qo_tile_indices_vec_p, kv_tile_indices_vec_p, qo_tile_indices_vec_d, + kv_tile_indices_vec_d] = PODSplitQOKVIndptr(qo_indptr_p, kv_indptr_p, qo_indptr_d, kv_indptr_d, total_num_rows_p, batch_size_p, total_num_rows_d, batch_size_d, num_qo_heads, num_kv_heads, head_dim_vo, page_size, max_batch_size_if_split, enable_cuda_graph); @@ -1037,15 +1044,17 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); plan_info.request_indices_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * padded_batch_size, 16, "batch_prefill_request_indices"); + sizeof(IdType) * padded_batch_size, 16, "pod_prefill_request_indices"); plan_info.qo_tile_indices_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * padded_batch_size, 16, "batch_prefill_qo_tile_indices"); + sizeof(IdType) * padded_batch_size, 16, "pod_prefill_qo_tile_indices"); plan_info.kv_tile_indices_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * padded_batch_size, 16, "batch_prefill_kv_tile_indices"); - plan_info.o_indptr_offset = int_allocator.aligned_alloc_offset(sizeof(IdType) * (batch_size + 1), - 16, "batch_prefill_o_indptr"); - plan_info.kv_chunk_size_ptr_offset = - int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr"); + sizeof(IdType) * padded_batch_size, 16, "pod_prefill_kv_tile_indices"); + plan_info.o_indptr_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * (batch_size + 1), 16, "pod_o_indptr"); + plan_info.kv_chunk_size_ptr_offset_p = + int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "pod_prefill_kv_chunk_size_ptr"); + plan_info.kv_chunk_size_ptr_offset_d = + int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "pod_prefill_kv_chunk_size_ptr"); if (plan_info.enable_cuda_graph) { plan_info.total_num_rows_offset = @@ -1075,13 +1084,13 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); plan_info.v_offset = float_allocator.aligned_alloc_offset( num_qo_heads * padded_batch_size * cta_tile_q * head_dim_vo * sizeof(float), 16, - "batch_prefill_tmp_v"); + "pod_tmp_v"); plan_info.s_offset = float_allocator.aligned_alloc_offset( - num_qo_heads * padded_batch_size * cta_tile_q * sizeof(float), 16, "batch_prefill_tmp_s"); + num_qo_heads * padded_batch_size * cta_tile_q * sizeof(float), 16, "pod_tmp_s"); plan_info.merge_indptr_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * (plan_info.total_num_rows + 1), 16, "batch_prefill_merge_indptr"); + sizeof(IdType) * (plan_info.total_num_rows + 1), 16, "pod_merge_indptr"); plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset( - sizeof(bool) * padded_batch_size, 16, "batch_prefill_block_valid_mask"); + sizeof(bool) * padded_batch_size, 16, "pod_block_valid_mask"); IdType* merge_indptr_h = GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.merge_indptr_offset); From 4979a2a4f666bc7d7bbd326e512c2343669d575d Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 4 Jul 2025 05:07:34 +0000 Subject: [PATCH 12/33] plan to use one reduction kernel for prefill and decode --- csrc/flashinfer_ops.cu | 4 ++-- csrc/pod.cu | 2 +- csrc/pod_jit_pybind.cu | 4 ++-- flashinfer/pod.py | 2 +- include/flashinfer/attention/pod.cuh | 16 +++------------- include/flashinfer/attention/scheduler.cuh | 3 ++- 6 files changed, 11 insertions(+), 20 deletions(-) diff --git a/csrc/flashinfer_ops.cu b/csrc/flashinfer_ops.cu index 526ad969a..37b776e3c 100644 --- a/csrc/flashinfer_ops.cu +++ b/csrc/flashinfer_ops.cu @@ -123,7 +123,7 @@ void BatchPrefillWithPagedKVCacheRun( int64_t window_left BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS); //========== pod-attention ========= -void pod_with_kv_cache_tensor( +void PODWithKVCacheTensorRun( // Prefill params at::Tensor q_p, at::Tensor k_p, at::Tensor v_p, at::Tensor tmp_p, at::Tensor o_p, std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, @@ -280,7 +280,7 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { // pod-attention // Temporarily disabled because we don't generate the implementation yet. - // m.def("pod_with_kv_cache_tensor", pod_with_kv_cache_tensor); + // m.def("PODWithKVCacheTensor", PODWithKVCacheTensorRun); // quantization // GPU packbits operator diff --git a/csrc/pod.cu b/csrc/pod.cu index f990c1fbb..577e83463 100644 --- a/csrc/pod.cu +++ b/csrc/pod.cu @@ -66,7 +66,7 @@ at::Tensor PODWithKVCachePlan(at::Tensor float_workspace_buffer, at::Tensor int_ return vec_to_tensor(plan_info.ToVector()); } -void pod_with_kv_cache_tensor( +void PODWithKVCacheTensorRun( // Prefill params at::Tensor q_p, at::Tensor k_p, at::Tensor v_p, at::Tensor tmp_p, at::Tensor o_p, std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, diff --git a/csrc/pod_jit_pybind.cu b/csrc/pod_jit_pybind.cu index 2e8d47bf2..66561a5af 100644 --- a/csrc/pod_jit_pybind.cu +++ b/csrc/pod_jit_pybind.cu @@ -16,7 +16,7 @@ #include "pod_config.inc" #include "pytorch_extension_utils.h" -void pod_with_kv_cache_tensor( +void PODWithKVCacheTensorRun( // Prefill params at::Tensor q_p, at::Tensor k_p, at::Tensor v_p, at::Tensor tmp_p, at::Tensor o_p, std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, @@ -36,5 +36,5 @@ void pod_with_kv_cache_tensor( TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { // Batch-request prefill attention with KV-Cache operator - m.def("pod_with_kv_cache_tensor", pod_with_kv_cache_tensor); + m.def("PODWithKVCacheTensor", PODWithKVCacheTensorRun); } diff --git a/flashinfer/pod.py b/flashinfer/pod.py index ecb74c205..d2625b608 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -44,7 +44,7 @@ @functools.cache def get_pod_module(*args): module = gen_pod_module(*args).build_and_load() - return SimpleNamespace(run_tensor=module.pod_with_kv_cache_tensor.default) + return SimpleNamespace(run_tensor=module.PODWithKVCacheTensor.default) class PODWithPagedKVCacheWrapper: diff --git a/include/flashinfer/attention/pod.cuh b/include/flashinfer/attention/pod.cuh index d74c535f8..05d854781 100644 --- a/include/flashinfer/attention/pod.cuh +++ b/include/flashinfer/attention/pod.cuh @@ -426,26 +426,16 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - // Post-kernel stuff for split-kv prefill - if (!(num_chunks <= 1 || tmp_p == nullptr)) { - if constexpr (PrefillAttentionVariant::use_softmax) { - FLASHINFER_CUDA_CALL(MergeStates(tmp_p, tmp_lse, o_p, lse_p, num_chunks, qo_len, - num_qo_heads, HEAD_DIM_VO, stream)); - } else { - FLASHINFER_CUDA_CALL(AttentionSum(tmp_p, o_p, num_chunks, qo_len, num_qo_heads, - HEAD_DIM_VO, stream)); - } - } - // Post-kernel stuff for split-kv decode + // Post-kernel stuff for split-kv if (tmp_v != nullptr) { if constexpr (DecodeAttentionVariant::use_softmax) { FLASHINFER_CUDA_CALL(VariableLengthMergeStates( - tmp_v, tmp_s, decode_params.merge_indptr, o_d, lse_d, + tmp_v, tmp_s, decode_params.merge_indptr, o, lse, decode_params.max_total_num_rows, decode_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream)); } else { FLASHINFER_CUDA_CALL(VariableLengthAttentionSum( - tmp_v, decode_params.merge_indptr, o_d, decode_params.max_total_num_rows, + tmp_v, decode_params.merge_indptr, o, decode_params.max_total_num_rows, decode_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream)); } } diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 3e4ad8513..ab025054a 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -873,6 +873,7 @@ inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_ cta_tile_q_d, min_kv_chunk_size); // step 3: split qo_indptr and kv_indptr + // use one merge_indptr and o_indptr to simply merging auto [request_indices, qo_tile_indices_p, kv_tile_indices_p, merge_indptr, o_indptr, new_batch_size_p] = get_qkv_tile_indices(packed_qo_len_arr_p, kv_len_arr_p, batch_size_p, cta_tile_q_p, kv_chunk_size_p, gqa_group_size); @@ -1054,7 +1055,7 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by plan_info.kv_chunk_size_ptr_offset_p = int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "pod_prefill_kv_chunk_size_ptr"); plan_info.kv_chunk_size_ptr_offset_d = - int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "pod_prefill_kv_chunk_size_ptr"); + int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "pod_decode_kv_chunk_size_ptr"); if (plan_info.enable_cuda_graph) { plan_info.total_num_rows_offset = From 2102e220066fb49780a9895ddd040c7d1f667bfc Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 4 Jul 2025 05:15:01 +0000 Subject: [PATCH 13/33] fix --- include/flashinfer/attention/scheduler.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index ab025054a..b3a3acb17 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -877,7 +877,7 @@ inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_ auto [request_indices, qo_tile_indices_p, kv_tile_indices_p, merge_indptr, o_indptr, new_batch_size_p] = get_qkv_tile_indices(packed_qo_len_arr_p, kv_len_arr_p, batch_size_p, cta_tile_q_p, kv_chunk_size_p, gqa_group_size); - auto [request_indices, qo_tile_indices_d, kv_tile_indices_d, merge_indptr_d, o_indptr_d, + auto [request_indices, qo_tile_indices_d, kv_tile_indices_d, merge_indptr_d, o_indptr, new_batch_size_d] = get_qkv_tile_indices(packed_qo_len_arr_d, kv_len_arr_d, batch_size_d, cta_tile_q_d, kv_chunk_size_d, gqa_group_size, request_indices, merge_indptr, o_indptr); From fab82aeb7faecd1e99b18c6e0dacec4ba93adc12 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sun, 6 Jul 2025 04:12:01 +0000 Subject: [PATCH 14/33] use unifed qkv indptr --- include/flashinfer/attention/cascade.cuh | 57 +++++++++++----------- include/flashinfer/attention/pod.cuh | 2 +- include/flashinfer/attention/scheduler.cuh | 39 +++++++++------ 3 files changed, 54 insertions(+), 44 deletions(-) diff --git a/include/flashinfer/attention/cascade.cuh b/include/flashinfer/attention/cascade.cuh index 8fb5e6b91..1ef2f3a34 100644 --- a/include/flashinfer/attention/cascade.cuh +++ b/include/flashinfer/attention/cascade.cuh @@ -354,11 +354,11 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa * \tparam DTypeO The data type of v_merged. * \param V The partial v of index sets. (nnz, h, d) * \param S The logsumexp value of index sets. (nnz, h) - * \param indptr The start offsets of each position in the variable length array. + * \param merge_indptr The start offsets of each position in the variable length array. * \param v_merged The merged v of index sets union. (n, h, d) * \param s_merged The merged logsumexp value of index sets union. (n, h) * \param max_seq_len The maximum sequence length supported by the kernel. - * \param seq_len_ptr The current sequence length (number of positions populated in indptr). + * \param seq_len_ptr The current sequence length (number of positions populated in merge_indptr). * \param num_heads The number of heads of v. * \param head_dim The dimension of each head. * \note s are logsumexp values with base 2. @@ -366,9 +366,9 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa template __global__ void PersistentVariableLengthMergeStatesKernel( - DTypeIn* __restrict__ V, float* __restrict__ S, IdType* indptr, DTypeO* __restrict__ v_merged, - float* __restrict__ s_merged, uint32_t max_seq_len, uint32_t* __restrict__ seq_len_ptr, - uint32_t num_heads) { + DTypeIn* __restrict__ V, float* __restrict__ S, IdType* merge_indptr, + DTypeO* __restrict__ v_merged, float* __restrict__ s_merged, uint32_t max_seq_len, + uint32_t* __restrict__ seq_len_ptr, uint32_t num_heads) { uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t cta_id = blockIdx.x; uint32_t num_ctas = gridDim.x; @@ -389,7 +389,7 @@ __global__ void PersistentVariableLengthMergeStatesKernel( uint32_t pos = i / num_heads; uint32_t head_idx = i % num_heads; state_t st; - const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos]; + const uint32_t num_index_sets = merge_indptr[pos + 1] - merge_indptr[pos]; if (num_index_sets == 0) { vec_t v; @@ -403,10 +403,10 @@ __global__ void PersistentVariableLengthMergeStatesKernel( if (num_index_sets == 1) { vec_t v; - v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); + v.cast_load(V + (merge_indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); if (s_merged != nullptr) { - s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx]; + s_merged[pos * num_heads + head_idx] = S[merge_indptr[pos] * num_heads + head_idx]; } continue; } @@ -415,7 +415,8 @@ __global__ void PersistentVariableLengthMergeStatesKernel( for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { cp_async::pred_load( v_smem + (iter * bdy + ty) * head_dim + tx * vec_size, - V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size, + V + ((merge_indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + + tx * vec_size, (iter * bdy + ty) < num_index_sets); cp_async::commit_group(); } @@ -424,7 +425,7 @@ __global__ void PersistentVariableLengthMergeStatesKernel( if (iter % bdx == 0) { s_smem[ty * bdx + tx] = iter * bdy + (ty * bdx + tx) < num_index_sets - ? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx] + ? S[(merge_indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx] : 0.f; __syncthreads(); } @@ -440,7 +441,7 @@ __global__ void PersistentVariableLengthMergeStatesKernel( cp_async::pred_load( v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size, V + - ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) * + ((merge_indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size, (iter + num_smem_stages) * bdy + ty < num_index_sets); @@ -465,11 +466,9 @@ __global__ void PersistentVariableLengthMergeStatesKernel( template -__global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ V, IdType* indptr, - DTypeO* __restrict__ v_sum, - uint32_t max_seq_len, - uint32_t* __restrict__ seq_len_ptr, - uint32_t num_heads) { +__global__ void PersistentVariableLengthAttentionSumKernel( + DTypeIn* __restrict__ V, IdType* merge_indptr, DTypeO* __restrict__ v_sum, uint32_t max_seq_len, + uint32_t* __restrict__ seq_len_ptr, uint32_t num_heads) { uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t cta_id = blockIdx.x; uint32_t num_ctas = gridDim.x; @@ -489,7 +488,7 @@ __global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ for (uint32_t i = cta_id; i < seq_len * num_heads; i += num_ctas) { uint32_t pos = i / num_heads; uint32_t head_idx = i % num_heads; - const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos]; + const uint32_t num_index_sets = merge_indptr[pos + 1] - merge_indptr[pos]; if (num_index_sets == 0) { vec_t v; @@ -500,7 +499,7 @@ __global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ if (num_index_sets == 1) { vec_t v; - v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); + v.cast_load(V + (merge_indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); v.store(v_sum + (pos * num_heads + head_idx) * head_dim + tx * vec_size); continue; } @@ -509,7 +508,8 @@ __global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { cp_async::pred_load( v_smem + (iter * bdy + ty) * head_dim + tx * vec_size, - V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size, + V + ((merge_indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + + tx * vec_size, (iter * bdy + ty) < num_index_sets); cp_async::commit_group(); } @@ -529,7 +529,7 @@ __global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ cp_async::pred_load( v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size, V + - ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) * + ((merge_indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size, (iter + num_smem_stages) * bdy + ty < num_index_sets); @@ -679,7 +679,7 @@ cudaError_t AttentionSum(DTypeIn* v, DTypeO* v_sum, uint32_t num_index_sets, uin } template -cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTypeO* v_merged, +cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* merge_indptr, DTypeO* v_merged, float* s_merged, uint32_t max_seq_len, uint32_t* seq_len, uint32_t num_heads, uint32_t head_dim, bool enable_pdl, cudaStream_t stream = nullptr) { @@ -705,7 +705,8 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp dim3 nblks(num_sms * num_blocks_per_sm); dim3 nthrs(bdx, bdy); - void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &max_seq_len, &seq_len, &num_heads}; + void* args[] = {&v, &s, &merge_indptr, &v_merged, + &s_merged, &max_seq_len, &seq_len, &num_heads}; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -721,8 +722,8 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp config.blockDim = nthrs; config.dynamicSmemBytes = smem_size; config.stream = stream; - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, v, s, indptr, v_merged, s_merged, - max_seq_len, seq_len, num_heads)); + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, v, s, merge_indptr, v_merged, + s_merged, max_seq_len, seq_len, num_heads)); } else { FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } @@ -731,7 +732,7 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp } template -cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum, +cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* merge_indptr, DTypeO* v_sum, uint32_t max_seq_len, uint32_t* seq_len, uint32_t num_heads, uint32_t head_dim, bool enable_pdl, cudaStream_t stream = nullptr) { @@ -756,7 +757,7 @@ cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum dim3 nblks(num_sms * num_blocks_per_sm); dim3 nthrs(bdx, bdy); - void* args[] = {&v, &indptr, &v_sum, &max_seq_len, &seq_len, &num_heads}; + void* args[] = {&v, &merge_indptr, &v_sum, &max_seq_len, &seq_len, &num_heads}; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -772,8 +773,8 @@ cudaError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum config.blockDim = nthrs; config.dynamicSmemBytes = smem_size; config.stream = stream; - FLASHINFER_CUDA_CALL( - cudaLaunchKernelEx(&config, kernel, v, indptr, v_sum, max_seq_len, seq_len, num_heads)); + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, v, merge_indptr, v_sum, max_seq_len, + seq_len, num_heads)); } else { FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } diff --git a/include/flashinfer/attention/pod.cuh b/include/flashinfer/attention/pod.cuh index 05d854781..d1a8f570c 100644 --- a/include/flashinfer/attention/pod.cuh +++ b/include/flashinfer/attention/pod.cuh @@ -114,7 +114,7 @@ __global__ __launch_bounds__(std::max( ((int*)smem)[0] = linear_bid; ((int*)smem)[1] = op; } - // Sync to wait for dynamic scheduler to finish + // Sync to wait for dynamic scheduler to write to smem __syncthreads(); // Fetch from shared memory the assigned blockId and operation. linear_bid = ((int*)smem)[0]; diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index b3a3acb17..9f27246e3 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -562,13 +562,22 @@ inline auto get_qkv_tile_indices(std::vector& packed_qo_len_arr, uint32_t cta_tile_q, uint32_t kv_chunk_size, uint32_t gqa_group_size, std::vector& request_indices = nullptr, + std::vector& qo_tile_indices = nullptr, + std::vector& kv_tile_indices = nullptr, std::vector& merge_indptr = nullptr, std::vector& o_indptr = nullptr) { - std::vector qo_tile_indices, kv_tile_indices; if (request_indices == nullptr) { request_indices = std::vector(); request_indices.push_back(0); } + if (qo_tile_indices == nullptr) { + qo_tile_indices = std::vector(); + qo_tile_indices.push_back(0); + } + if (kv_tile_indices == nullptr) { + kv_tile_indices = std::vector(); + kv_tile_indices.push_back(0); + } if (merge_indptr == nullptr) { merge_indptr = std::vector(); merge_indptr.push_back(0); @@ -828,8 +837,8 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i } /* -Modifed from PrefillSplitQOKVIndptr to support two tile sizes, and assign blocks proportional to the -number of tiles. +Modifed to support two tile sizes, and assign blocks proportional to +the number of tiles. */ template inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_t total_num_rows_p, @@ -873,14 +882,15 @@ inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_ cta_tile_q_d, min_kv_chunk_size); // step 3: split qo_indptr and kv_indptr - // use one merge_indptr and o_indptr to simply merging - auto [request_indices, qo_tile_indices_p, kv_tile_indices_p, merge_indptr, o_indptr, + // Use one set of qkv indices, merge_indptr and o_indptr to simply merging. + auto [request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, new_batch_size_p] = get_qkv_tile_indices(packed_qo_len_arr_p, kv_len_arr_p, batch_size_p, cta_tile_q_p, kv_chunk_size_p, gqa_group_size); - auto [request_indices, qo_tile_indices_d, kv_tile_indices_d, merge_indptr_d, o_indptr, - new_batch_size_d] = get_qkv_tile_indices(packed_qo_len_arr_d, kv_len_arr_d, batch_size_d, - cta_tile_q_d, kv_chunk_size_d, gqa_group_size, - request_indices, merge_indptr, o_indptr); + auto [request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, + new_batch_size_d] = + get_qkv_tile_indices(packed_qo_len_arr_d, kv_len_arr_d, batch_size_d, cta_tile_q_d, + kv_chunk_size_d, gqa_group_size, request_indices, qo_tile_indices, + kv_tile_indices, merge_indptr, o_indptr); bool split_kv = split_kv_p || split_kv_d; uint32_t new_batch_size = new_batch_size_p + new_batch_size_d; @@ -897,9 +907,9 @@ inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_ return std::make_tuple(split_kv, new_batch_size, padded_batch_size_p, padded_batch_size_d, cta_tile_q_p, cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, - std::move(merge_indptr), std::move(o_indptr), std::move(request_indices), - std::move(qo_tile_indices_p), std::move(kv_tile_indices_p), - std::move(qo_tile_indices_d), std::move(kv_tile_indices_d)); + std::move(request_indices), std::move(qo_tile_indices), + std::move(kv_tile_indices)) std::move(merge_indptr), + std::move(o_indptr); } struct PODPlanInfo { @@ -1027,9 +1037,8 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by // step 2: determine kv_chunk_size auto [split_kv, new_batch_size, padded_batch_size_p, padded_batch_size_d, cta_tile_q_p, - cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, merge_indptr_vec, o_indptr_vec, - request_indices_vec, qo_tile_indices_vec_p, kv_tile_indices_vec_p, qo_tile_indices_vec_d, - kv_tile_indices_vec_d] = + cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, request_indices, qo_tile_indices, + kv_tile_indices, merge_indptr, o_indptr] = PODSplitQOKVIndptr(qo_indptr_p, kv_indptr_p, qo_indptr_d, kv_indptr_d, total_num_rows_p, batch_size_p, total_num_rows_d, batch_size_d, num_qo_heads, num_kv_heads, head_dim_vo, page_size, max_batch_size_if_split, enable_cuda_graph); From 13b6b19139ca02bd8c1cccca9c2b7d792aa7accd Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sun, 6 Jul 2025 05:22:00 +0000 Subject: [PATCH 15/33] fix --- include/flashinfer/attention/scheduler.cuh | 3 --- 1 file changed, 3 deletions(-) diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 9f27246e3..ac9721151 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -568,15 +568,12 @@ inline auto get_qkv_tile_indices(std::vector& packed_qo_len_arr, std::vector& o_indptr = nullptr) { if (request_indices == nullptr) { request_indices = std::vector(); - request_indices.push_back(0); } if (qo_tile_indices == nullptr) { qo_tile_indices = std::vector(); - qo_tile_indices.push_back(0); } if (kv_tile_indices == nullptr) { kv_tile_indices = std::vector(); - kv_tile_indices.push_back(0); } if (merge_indptr == nullptr) { merge_indptr = std::vector(); From 7d29232881a5662604936f206dc6ff6d412ad6fa Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sun, 6 Jul 2025 23:43:36 +0000 Subject: [PATCH 16/33] fix plan func upper call interface --- csrc/batch_prefill.cu | 2 +- csrc/pod.cu | 32 ++++++------ flashinfer/decode.py | 1 - flashinfer/pod.py | 2 - flashinfer/prefill.py | 2 - include/flashinfer/attention/scheduler.cuh | 58 +++++++++------------- tvm_binding/batch_prefill.cu | 14 +++--- 7 files changed, 49 insertions(+), 62 deletions(-) diff --git a/csrc/batch_prefill.cu b/csrc/batch_prefill.cu index a51fc7f56..94588b393 100644 --- a/csrc/batch_prefill.cu +++ b/csrc/batch_prefill.cu @@ -47,7 +47,7 @@ at::Tensor BatchPrefillWithKVCachePlan( at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, - int64_t head_dim_vo, bool causal) { + int64_t head_dim_vo) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = diff --git a/csrc/pod.cu b/csrc/pod.cu index 577e83463..1995d9287 100644 --- a/csrc/pod.cu +++ b/csrc/pod.cu @@ -39,26 +39,28 @@ using namespace flashinfer; at::Tensor PODWithKVCachePlan(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr_p, at::Tensor kv_indptr_p, at::Tensor kv_len_arr_p, - at::Tensor kv_indptr_d, at::Tensor qo_indptr_d, - int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, - int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, - int64_t head_dim_qk, int64_t head_dim_vo, bool causal) { + uint32_t total_num_rows_p, uint32_t batch_size_p, + at::Tensor qo_indptr_d, at::Tensor kv_indptr_d, + uint32_t total_num_rows_d, uint32_t batch_size_d, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim_qk, + uint32_t head_dim_vo, uint32_t page_size, bool enable_cuda_graph) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); - PrefillPlanInfo plan_info; + PODPlanInfo plan_info; const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device()); const cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); - cudaError_t status = PrefillPlan( - float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, - int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), - int_workspace_size_in_bytes, plan_info, qo_indptr_p.data_ptr(), - kv_indptr_p.data_ptr(), qo_indptr_d.data_ptr(), - kv_indptr_d.data_ptr(), total_num_rows, batch_size, num_qo_heads, num_kv_heads, - head_dim_qk, head_dim_vo, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); + cudaError_t status = + PODPlan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, qo_indptr_p.data_ptr(), + kv_indptr_p.data_ptr(), total_num_rows_p, batch_size_p, + qo_indptr_d.data_ptr(), kv_indptr_d.data_ptr(), + total_num_rows_d, batch_size_d, num_qo_heads, num_kv_heads, head_dim_qk, + head_dim_vo, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); TORCH_CHECK(status == cudaSuccess, "Failed to plan prefill with error: ", cudaGetErrorString(status)); @@ -118,7 +120,7 @@ void PODWithKVCacheTensorRun( auto kv_scalar_type = k_p.scalar_type(); // Decode setup (Tensor decode = batched prefill) - PrefillPlanInfo plan_info; + PODPlanInfo plan_info; plan_info.FromVector(tensor_to_vec(plan_info_vec)); QKVLayout kv_layout_d = static_cast(layout_d); auto device = q_d.device(); @@ -296,13 +298,13 @@ at::Tensor PODWithKVCachePlan(at::Tensor float_workspace_buffer, at::Tensor int_ at::Tensor kv_indptr_p, at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, - int64_t head_dim_vo, bool causal) { + int64_t head_dim_vo) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); - PrefillPlanInfo plan_info; + PODPlanInfo plan_info; const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device()); const cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 1cb685bcd..d9af24d74 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -948,7 +948,6 @@ def plan( self.is_cuda_graph_enabled, head_dim, head_dim, - False, # causal ) else: if self._jit_module is not None: diff --git a/flashinfer/pod.py b/flashinfer/pod.py index d2625b608..eacf7dc13 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -242,7 +242,6 @@ def __init__( self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buffer self._use_tensor_cores = use_tensor_cores self._use_cuda_graph = use_cuda_graph - if use_cuda_graph: # NOTE(Zihao): if once created, no need to update it in plan/run self._qo_indptr_buf = torch.arange( @@ -437,7 +436,6 @@ def plan( self.is_cuda_graph_enabled, head_dim, head_dim, - False, # causal ) self._indptr_type = indptr_d.dtype diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index f35ab2cc3..74a192ae2 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1576,7 +1576,6 @@ def plan( self.is_cuda_graph_enabled, head_dim_qk, head_dim_vo, - causal, ) self._causal = causal @@ -2350,7 +2349,6 @@ def plan( self.is_cuda_graph_enabled, head_dim_qk, head_dim_vo, - causal, ) self._causal = causal diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index ac9721151..c47c564b2 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -524,9 +524,9 @@ inline auto get_q_tiles(std::vector& packed_qo_len_arr, uint32_t batch_ // When CUDA graphs are enabled, the lengths of sequences determined by // qo_indptr_h can vary. We assume that the dummy data based on which // the CUDA graph is created fixes the maximum number of tokens. - const uint64_t max_seq_len = total_num_rows - batch_size + 1; - uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size; if (tile_size == -1) { + const uint64_t max_seq_len = total_num_rows - batch_size + 1; + uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size; cta_tile_q = FA2DetermineCtaTileQ(max_qo_len, head_dim); } else { cta_tile_q = tile_size; @@ -538,12 +538,12 @@ inline auto get_q_tiles(std::vector& packed_qo_len_arr, uint32_t batch_ // number of rows. total_num_tiles_q = ceil_div(total_num_rows * gqa_group_size, cta_tile_q) + batch_size - 1; } else { - int64_t sum_packed_qo_len = 0; - for (uint32_t i = 0; i < batch_size; ++i) { - sum_packed_qo_len += packed_qo_len_arr[i]; - } - const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size; if (tile_size == -1) { + int64_t sum_packed_qo_len = 0; + for (uint32_t i = 0; i < batch_size; ++i) { + sum_packed_qo_len += packed_qo_len_arr[i]; + } + const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size; cta_tile_q = FA2DetermineCtaTileQ(avg_packed_qo_len, head_dim); } else { cta_tile_q = tile_size; @@ -866,11 +866,11 @@ inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_ gqa_group_size, enable_cuda_graph, cta_tile_q_d); uint32_t total_num_tiles_q = num_tiles_q_p + num_tiles_q_d; - // Assign CTAs proportional to the number of query tiles + // Allocate CTAs proportional to the number of query tiles in prefill and decode // TODO(Wenxuan): explore a more balanced cost function considering kv len. // See discussion: https://github.com/flashinfer-ai/flashinfer/issues/1175 uint32_t max_bs_p = max_batch_size_if_split * num_tiles_q_p / total_num_tiles_q; - uint32_t max_bs_d = max_batch_size_if_split * num_tiles_q_d / total_num_tiles_q; + uint32_t max_bs_d = max_batch_size_if_split - max_bs_p; auto [split_kv_p, kv_chunk_size_p] = PrefillBinarySearchKVChunkSize(enable_cuda_graph, max_bs_p, packed_qo_len_arr_p, kv_len_arr_p, cta_tile_q_p, min_kv_chunk_size); @@ -905,8 +905,7 @@ inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_ return std::make_tuple(split_kv, new_batch_size, padded_batch_size_p, padded_batch_size_d, cta_tile_q_p, cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, std::move(request_indices), std::move(qo_tile_indices), - std::move(kv_tile_indices)) std::move(merge_indptr), - std::move(o_indptr); + std::move(kv_tile_indices), std::move(merge_indptr), std::move(o_indptr)); } struct PODPlanInfo { @@ -914,21 +913,16 @@ struct PODPlanInfo { int64_t total_num_rows; int64_t total_num_rows_offset; int64_t cta_tile_q; - int64_t request_indices_offset_p; - int64_t request_indices_offset_d; - int64_t qo_tile_indices_offset_p; - int64_t qo_tile_indices_offset_d; - int64_t kv_tile_indices_offset_p; - int64_t kv_tile_indices_offset_d; + int64_t request_indices_offset; + int64_t qo_tile_indices_offset; + int64_t kv_tile_indices_offset; int64_t merge_indptr_offset; int64_t o_indptr_offset; int64_t kv_chunk_size_ptr_offset_p; int64_t kv_chunk_size_ptr_offset_d; - int64_t v_offset_p; - int64_t v_offset_d int64_t s_offset_p; - int64_t s_offset_d; - int64_t block_valid_mask_offset_p; - int64_t block_valid_mask_offset_d; + int64_t v_offset; + int64_t s_offset; + int64_t block_valid_mask_offset; bool enable_cuda_graph; bool split_kv; @@ -937,20 +931,14 @@ struct PODPlanInfo { total_num_rows(0), total_num_rows_offset(0), cta_tile_q(0), - request_indices_offset_p(0), - request_indices_offset_d(0), - qo_tile_indices_offset_p(0), - qo_tile_indices_offset_d(0), - kv_tile_indices_offset_p(0), - kv_tile_indices_offset_d(0), + request_indices_offset(0), + qo_tile_indices_offset(0), + kv_tile_indices_offset(0), merge_indptr_offset(0), o_indptr_offset(0), - kv_chunk_size_ptr_offset_p(0), - kv_chunk_size_ptr_offset_d(0), - v_offset_p(0), - v_offset_d(0), - s_offset_p(0), - s_offset_d(0), + kv_chunk_size_ptr_offset(0), + v_offset(0), + s_offset(0), block_valid_mask_offset(0), enable_cuda_graph(false), split_kv(false) {} @@ -1028,7 +1016,7 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by int dev_id = 0; FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - int num_blocks_per_sm = 2; // TODO(Wenxuan): increase this to reduce wave quantization? + int num_blocks_per_sm = 3; // TODO(Wenxuan): increase this to reduce wave quantization? int max_grid_size = num_blocks_per_sm * num_sm; uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads; diff --git a/tvm_binding/batch_prefill.cu b/tvm_binding/batch_prefill.cu index bf7161116..710764cf4 100644 --- a/tvm_binding/batch_prefill.cu +++ b/tvm_binding/batch_prefill.cu @@ -41,12 +41,14 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para using namespace flashinfer; -IntTuple BatchPrefillWithKVCachePlan( - DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, - DLTensor* page_locked_int_workspace_buffer, DLTensor* qo_indptr, DLTensor* kv_indptr, - IntTuple kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, - int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, - int64_t head_dim_vo, bool causal, TVMStreamHandle cuda_stream) { +IntTuple BatchPrefillWithKVCachePlan(DLTensor* float_workspace_buffer, + DLTensor* int_workspace_buffer, + DLTensor* page_locked_int_workspace_buffer, + DLTensor* qo_indptr, DLTensor* kv_indptr, IntTuple kv_len_arr, + int64_t total_num_rows, int64_t batch_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, + bool enable_cuda_graph, int64_t head_dim_qk, + int64_t head_dim_vo, TVMStreamHandle cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer->shape[0] * DataType(float_workspace_buffer->dtype).bytes(); size_t int_workspace_size_in_bytes = From 106bfdc31b8bb299a84c225229aabb67956ed14d Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 7 Jul 2025 04:14:46 +0000 Subject: [PATCH 17/33] rename new_batch_size to real_batch_size --- include/flashinfer/attention/scheduler.cuh | 106 ++++++++++----------- 1 file changed, 53 insertions(+), 53 deletions(-) diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index c47c564b2..1eefd2f9c 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -67,7 +67,7 @@ inline void CopyToPageLockedBuffer(void* page_locked_int_buffer, int64_t offset, * \param num_pages The number of pages per request in the batch * \param max_num_pages_per_batch_lb The pre-set lower bound of maximum number of * pages per batch, default to 1 - * \return (max_num_pages_per_batch, new_batch_size) The number of pages per batch and + * \return (max_num_pages_per_batch, real_batch_size) The number of pages per batch and * the new batch size after the partition. */ template @@ -78,24 +78,24 @@ inline auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( for (const IdType& elem : num_pages) { high = max(high, elem); } - uint32_t new_batch_size; + uint32_t real_batch_size; while (low < high) { uint32_t mid = (low + high) / 2; - new_batch_size = 0; + real_batch_size = 0; for (const IdType& elem : num_pages) { - new_batch_size += ceil_div(elem, mid); + real_batch_size += ceil_div(elem, mid); } - if (new_batch_size * gdy > max_grid_size) { + if (real_batch_size * gdy > max_grid_size) { low = mid + 1; } else { high = mid; } } - new_batch_size = 0; + real_batch_size = 0; for (const IdType& elem : num_pages) { - new_batch_size += ceil_div(std::max(elem, 1), low); + real_batch_size += ceil_div(std::max(elem, 1), low); } - return std::make_tuple(low, new_batch_size); + return std::make_tuple(low, real_batch_size); } inline auto PrefillBinarySearchKVChunkSize(const bool enable_cuda_graph, @@ -115,12 +115,12 @@ inline auto PrefillBinarySearchKVChunkSize(const bool enable_cuda_graph, constexpr int64_t min_kv_len = 1; while (low < high) { const int64_t mid = (low + high) / 2; - int64_t new_batch_size = 0; + int64_t real_batch_size = 0; for (uint32_t i = 0; i < batch_size; ++i) { - new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) * - ceil_div(std::max(kv_len_arr[i], min_kv_len), mid); + real_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) * + ceil_div(std::max(kv_len_arr[i], min_kv_len), mid); } - if (new_batch_size > max_batch_size_if_split) { + if (real_batch_size > max_batch_size_if_split) { low = mid + 1; } else { high = mid; @@ -138,7 +138,7 @@ inline auto PrefillBinarySearchKVChunkSize(const bool enable_cuda_graph, * \param split_kv Whether to split the KV cache into multiple chunks * \param max_grid_size The maximum grid size that can be used in a partiton-kv kernel * \param max_num_pages_per_batch The maximum number of pages per batch - * \param new_batch_size The new batch size after the partition + * \param real_batch_size The new batch size after the partition * \param paged_kv The paged kv cache data structure * \param num_qo_heads A integer indicates the number of heads of query and output * \param pos_encoding_mode The positional encoding mode @@ -149,7 +149,7 @@ template inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, - uint32_t& new_batch_size, uint32_t& gdy, uint32_t batch_size, + uint32_t& real_batch_size, uint32_t& gdy, uint32_t batch_size, typename Params::IdType* kv_indptr_h, const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) { using DTypeKV = typename Params::DTypeKV; @@ -187,17 +187,17 @@ inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( max_num_pages_per_batch = std::max( max_num_pages_per_batch, kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]); } - new_batch_size = batch_size; + real_batch_size = batch_size; } else { - // compute max_num_pages_per_batch and new_batch_size + // compute max_num_pages_per_batch and real_batch_size std::vector num_pages(batch_size); for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; } - std::tie(max_num_pages_per_batch, new_batch_size) = + std::tie(max_num_pages_per_batch, real_batch_size) = PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, gdy, num_pages, std::max(128 / page_size, 1U)); - if (new_batch_size == batch_size && !enable_cuda_graph) { + if (real_batch_size == batch_size && !enable_cuda_graph) { // do not use partition-kv kernel for short sequence, when not using CUDAGraph split_kv = false; } else { @@ -212,7 +212,7 @@ inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( template inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMLA( bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, - uint32_t& new_batch_size, uint32_t& gdy, uint32_t batch_size, + uint32_t& real_batch_size, uint32_t& gdy, uint32_t batch_size, typename Params::IdType* kv_indptr_h, const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) { using DTypeKV = typename Params::DTypeKV; @@ -253,17 +253,17 @@ inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMLA( max_num_pages_per_batch = std::max( max_num_pages_per_batch, kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]); } - new_batch_size = batch_size; + real_batch_size = batch_size; } else { - // compute max_num_pages_per_batch and new_batch_size + // compute max_num_pages_per_batch and real_batch_size std::vector num_pages(batch_size); for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; } - std::tie(max_num_pages_per_batch, new_batch_size) = + std::tie(max_num_pages_per_batch, real_batch_size) = PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, gdy, num_pages, std::max(128 / page_size, 1U)); - if (new_batch_size == batch_size && !enable_cuda_graph) { + if (real_batch_size == batch_size && !enable_cuda_graph) { // do not use partition-kv kernel for short sequence, when not using CUDAGraph split_kv = false; } else { @@ -280,7 +280,7 @@ template inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMlaCuteSM80( bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, - uint32_t& new_batch_size, uint32_t& gdy_, uint32_t batch_size, + uint32_t& real_batch_size, uint32_t& gdy_, uint32_t batch_size, typename Params::IdType* kv_indptr_h, const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) { using DTypeKV = typename Params::DTypeKV; @@ -313,17 +313,17 @@ inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMlaCuteSM8 max_num_pages_per_batch = std::max( max_num_pages_per_batch, kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]); } - new_batch_size = batch_size; + real_batch_size = batch_size; } else { - // compute max_num_pages_per_batch and new_batch_size + // compute max_num_pages_per_batch and real_batch_size std::vector num_pages(batch_size); for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; } - std::tie(max_num_pages_per_batch, new_batch_size) = + std::tie(max_num_pages_per_batch, real_batch_size) = PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, gdy, num_pages, std::max(128 / page_size, 1U)); - if (new_batch_size == batch_size && !enable_cuda_graph) { + if (real_batch_size == batch_size && !enable_cuda_graph) { // do not use partition-kv kernel for short sequence, when not using CUDAGraph split_kv = false; } else { @@ -432,16 +432,16 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in using DTypeO = typename Params::DTypeO; using IdType = typename Params::IdType; bool split_kv; - uint32_t max_grid_size, kv_chunk_size_in_pages, new_batch_size, gdy; + uint32_t max_grid_size, kv_chunk_size_in_pages, real_batch_size, gdy; FLASHINFER_CUDA_CALL(work_estimation_func(split_kv, max_grid_size, kv_chunk_size_in_pages, - new_batch_size, gdy, batch_size, indptr_h, num_qo_heads, - page_size, enable_cuda_graph, stream)); + real_batch_size, gdy, batch_size, indptr_h, + num_qo_heads, page_size, enable_cuda_graph, stream)); size_t padded_batch_size; plan_info.enable_cuda_graph = enable_cuda_graph; plan_info.split_kv = split_kv; padded_batch_size = - (enable_cuda_graph) ? (split_kv ? max_grid_size / gdy : batch_size) : new_batch_size; + (enable_cuda_graph) ? (split_kv ? max_grid_size / gdy : batch_size) : real_batch_size; plan_info.padded_batch_size = padded_batch_size; auto [request_indices_vec, kv_tile_indices_vec, o_indptr_vec] = @@ -481,7 +481,7 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in bool* block_valid_mask_h = GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.block_valid_mask_offset); for (uint32_t i = 0; i < padded_batch_size; ++i) { - block_valid_mask_h[i] = i < new_batch_size; + block_valid_mask_h[i] = i < real_batch_size; } } @@ -584,7 +584,7 @@ inline auto get_qkv_tile_indices(std::vector& packed_qo_len_arr, o_indptr.push_back(0); } - uint32_t new_batch_size = 0; + uint32_t real_batch_size = 0; for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { const int64_t packed_qo_len = packed_qo_len_arr[request_idx]; const int64_t kv_len = std::max(int(kv_len_arr[request_idx]), 1); @@ -593,7 +593,7 @@ inline auto get_qkv_tile_indices(std::vector& packed_qo_len_arr, for (uint32_t q_tile_idx = 0; q_tile_idx < num_tiles_q; ++q_tile_idx) { for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_tiles_kv; ++kv_tile_idx) { - new_batch_size += 1; + real_batch_size += 1; request_indices.push_back(request_idx); qo_tile_indices.push_back(q_tile_idx); kv_tile_indices.push_back(kv_tile_idx); @@ -607,7 +607,7 @@ inline auto get_qkv_tile_indices(std::vector& packed_qo_len_arr, o_indptr.push_back(o_indptr.back() + qo_len * num_tiles_kv); } return std::make_tuple(request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, - new_batch_size); + real_batch_size); } template @@ -635,19 +635,19 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, PrefillBinarySearchKVChunkSize(enable_cuda_graph, max_batch_size_if_split, packed_qo_len_arr, kv_len_arr, cta_tile_q, min_kv_chunk_size); - auto [request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, new_batch_size] = - get_qkv_tile_indices(packed_qo_len_arr, kv_len_arr, batch_size, cta_tile_q, kv_chunk_size, - gqa_group_size); + auto [request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, + real_batch_size] = get_qkv_tile_indices(packed_qo_len_arr, kv_len_arr, batch_size, + cta_tile_q, kv_chunk_size, gqa_group_size); const size_t padded_batch_size = - enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) : new_batch_size; - FLASHINFER_CHECK(new_batch_size <= padded_batch_size, + enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) : real_batch_size; + FLASHINFER_CHECK(real_batch_size <= padded_batch_size, "new batch size should not exceed padded batch size"); // step 4: multiply kv_chunk_size by page_size kv_chunk_size *= page_size; - return std::make_tuple(split_kv, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, + return std::make_tuple(split_kv, real_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, std::move(request_indices), std::move(qo_tile_indices), std::move(kv_tile_indices), std::move(merge_indptr), std::move(o_indptr)); } @@ -756,11 +756,11 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads; // step 2: determine kv_chunk_size - auto [split_kv, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec, - qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] = - PrefillSplitQOKVIndptr(qo_indptr_p, kv_indptr_p, total_num_rows, batch_size, num_qo_heads, - num_kv_heads, head_dim_vo, page_size, max_batch_size_if_split, - enable_cuda_graph); + auto [split_kv, real_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, + request_indices_vec, qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, + o_indptr_vec] = PrefillSplitQOKVIndptr(qo_indptr_p, kv_indptr_p, total_num_rows, batch_size, + num_qo_heads, num_kv_heads, head_dim_vo, page_size, + max_batch_size_if_split, enable_cuda_graph); plan_info.cta_tile_q = cta_tile_q; plan_info.total_num_rows = total_num_rows; @@ -822,7 +822,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.block_valid_mask_offset); std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), merge_indptr_h); for (uint32_t i = 0; i < padded_batch_size; ++i) { - block_valid_mask_h[i] = i < new_batch_size; + block_valid_mask_h[i] = i < real_batch_size; } } @@ -890,19 +890,19 @@ inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_ kv_tile_indices, merge_indptr, o_indptr); bool split_kv = split_kv_p || split_kv_d; - uint32_t new_batch_size = new_batch_size_p + new_batch_size_d; + uint32_t real_batch_size = new_batch_size_p + new_batch_size_d; const size_t padded_batch_size_p = enable_cuda_graph ? std::max(max_bs_p, total_num_tiles_q_p) : new_batch_size_p; const size_t padded_batch_size_d = enable_cuda_graph ? std::max(max_bs_d, total_num_tiles_q_d) : new_batch_size_d; - FLASHINFER_CHECK(new_batch_size <= padded_batch_size_p + padded_batch_size_d, + FLASHINFER_CHECK(real_batch_size <= padded_batch_size_p + padded_batch_size_d, "new batch size should not exceed padded batch size"); // step 4: multiply kv_chunk_size by page_size kv_chunk_size_p *= page_size; kv_chunk_size_d *= page_size; - return std::make_tuple(split_kv, new_batch_size, padded_batch_size_p, padded_batch_size_d, + return std::make_tuple(split_kv, real_batch_size, padded_batch_size_p, padded_batch_size_d, cta_tile_q_p, cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, std::move(request_indices), std::move(qo_tile_indices), std::move(kv_tile_indices), std::move(merge_indptr), std::move(o_indptr)); @@ -1021,7 +1021,7 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads; // step 2: determine kv_chunk_size - auto [split_kv, new_batch_size, padded_batch_size_p, padded_batch_size_d, cta_tile_q_p, + auto [split_kv, real_batch_size, padded_batch_size_p, padded_batch_size_d, cta_tile_q_p, cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr] = PODSplitQOKVIndptr(qo_indptr_p, kv_indptr_p, qo_indptr_d, kv_indptr_d, total_num_rows_p, @@ -1093,7 +1093,7 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.block_valid_mask_offset); std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), merge_indptr_h); for (uint32_t i = 0; i < padded_batch_size; ++i) { - block_valid_mask_h[i] = i < new_batch_size; + block_valid_mask_h[i] = i < real_batch_size; } } From 5e3e896e3464b28c6a22c4a73711f69cfb1e4a9f Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 7 Jul 2025 05:23:41 +0000 Subject: [PATCH 18/33] concat request_indices --- include/flashinfer/attention/scheduler.cuh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 1eefd2f9c..e07c68038 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -566,8 +566,11 @@ inline auto get_qkv_tile_indices(std::vector& packed_qo_len_arr, std::vector& kv_tile_indices = nullptr, std::vector& merge_indptr = nullptr, std::vector& o_indptr = nullptr) { + uint32_t start_req_idx = 0; // for global q,k,v,o indexing in POD Attention if (request_indices == nullptr) { request_indices = std::vector(); + } else { + start_req_idx = request_indices.back(); } if (qo_tile_indices == nullptr) { qo_tile_indices = std::vector(); @@ -594,7 +597,7 @@ inline auto get_qkv_tile_indices(std::vector& packed_qo_len_arr, for (uint32_t q_tile_idx = 0; q_tile_idx < num_tiles_q; ++q_tile_idx) { for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_tiles_kv; ++kv_tile_idx) { real_batch_size += 1; - request_indices.push_back(request_idx); + request_indices.push_back(request_idx + start_req_idx); qo_tile_indices.push_back(q_tile_idx); kv_tile_indices.push_back(kv_tile_idx); } From ac072530e3d68b74f050921c2ea073959ad7f2f9 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 8 Jul 2025 05:40:25 +0000 Subject: [PATCH 19/33] unifed indices in wrapper.plan --- flashinfer/pod.py | 115 ++++++++++++++------- flashinfer/prefill.py | 9 +- include/flashinfer/attention/scheduler.cuh | 9 +- 3 files changed, 92 insertions(+), 41 deletions(-) diff --git a/flashinfer/pod.py b/flashinfer/pod.py index eacf7dc13..46ef4d01e 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -205,7 +205,7 @@ def __init__( if use_cuda_graph: if not torch.is_tensor(qo_indptr_buffer): raise ValueError( - "qo_indptr_buffer_p should be a torch.Tensor in CUDA graph mode" + "qo_indptr_buffer should be a torch.Tensor in CUDA graph mode" ) if not torch.is_tensor(paged_kv_indptr_buffer) or not torch.is_tensor( paged_kv_indptr_buffer @@ -280,8 +280,12 @@ def reset_workspace_buffer( def plan( self, - indptr_d: torch.Tensor, - indices_d: torch.Tensor, + qo_indptr_p: torch.Tensor, + kv_indptr_p: torch.Tensor, + kv_indices_p: torch.Tensor, + last_page_len_p: torch.Tensor, + kv_indptr_d: torch.Tensor, + kv_indices_d: torch.Tensor, last_page_len_d: torch.Tensor, num_qo_heads: int, num_kv_heads: int, @@ -301,12 +305,18 @@ def plan( Parameters ---------- - indptr_d : torch.Tensor + qo_indptr_p: torch.Tensor + The indptr of the query/output tensor for prefill, shape: ``[batch_size + 1]``. + kv_indptr_p: torch.Tensor + The indptr of the paged kv cache for prefill, shape: ``[batch_size + 1]``. + kv_indices_p: torch.Tensor + The page indices of the paged kv cache for prefill, shape: ``[qo_indptr[-1]]``. + kv_indptr_d : torch.Tensor The indptr of the paged kv cache for decode, shape: ``[batch_size + 1]`` - indices_d : torch.Tensor + kv_indices_d : torch.Tensor The page indices of the paged kv cache for decode, shape: ``[qo_indptr[-1]]`` last_page_len_d : torch.Tensor - The number of entries in the last page of each request in the paged kv + The number of entries in the last page of each request in the kv cache, shape: ``[batch_size]`` num_qo_heads : int The number of query/output heads @@ -349,47 +359,71 @@ def plan( """ # Logits soft cap is not supported currently logits_soft_cap = False + batch_size_p = len(last_page_len_p) batch_size_d = len(last_page_len_d) + batch_size = batch_size_p + batch_size_d if logits_soft_cap is None: logits_soft_cap = 0.0 - qo_indptr_host = _get_range_buf(batch_size_d + 1, "cpu") + qo_indptr_host_p = qo_indptr_p.to("cpu", non_blocking=True) + qo_indptr_host_d = _get_range_buf(batch_size_d + 1, "cpu") if self.is_cuda_graph_enabled: - if batch_size_d != self._fixed_batch_size: + if batch_size != self._fixed_batch_size: raise ValueError( "The batch size should be fixed in cudagraph mode, the runtime batch size {} " " mismatches the batch size set during initialization {}".format( batch_size_d, self._fixed_batch_size ) ) - if len(indices_d) > len(self._paged_kv_indices_buf): + if len(kv_indices_d) + len(kv_indices_p) > len(self._paged_kv_indices_buf): raise ValueError( "The size of indices should be less than or equal to the allocated buffer" ) - self._paged_kv_indptr_buf.copy_(indptr_d, non_blocking=non_blocking) - self._paged_kv_last_page_len_buf.copy_( - last_page_len_d, non_blocking=non_blocking + self._paged_kv_indptr_buf[:batch_size_p].copy_( + kv_indptr_p, non_blocking=non_blocking ) - self._paged_kv_indices_buf[: len(indices_d)].copy_( - indices_d, - non_blocking=(indices_d.device == self.device) and non_blocking, + self._paged_kv_indptr_buf[batch_size_p : batch_size_p + batch_size_d].copy_( + kv_indptr_d, + non_blocking=(kv_indptr_d.device == self.device) and non_blocking, + ) + self._paged_kv_last_page_len_buf[:batch_size_p].copy_( + last_page_len_p, non_blocking=non_blocking + ) + self._paged_kv_last_page_len_buf[ + batch_size_p : batch_size_p + batch_size_d + ].copy_( + last_page_len_d, + non_blocking=(last_page_len_d.device == self.device) and non_blocking, + ) + self._paged_kv_indices_buf[:batch_size_p].copy_( + kv_indices_d, + non_blocking=(kv_indices_d.device == self.device) and non_blocking, + ) + self._paged_kv_indices_buf[ + batch_size_p : batch_size_p + batch_size_d + ].copy_( + kv_indices_d, + non_blocking=(kv_indices_d.device == self.device) and non_blocking, ) else: - self._paged_kv_indptr_buf = indptr_d.to( - self.device, non_blocking=non_blocking + to_device = lambda x: x.to(self.device, non_blocking=non_blocking) + self._qo_indptr_buf = torch.cat( + [to_device(qo_indptr_p), to_device(qo_indptr_host_d)] ) - self._paged_kv_indices_buf = indices_d.to( - self.device, non_blocking=non_blocking + self._paged_kv_indptr_buf = torch.cat( + [to_device(kv_indptr_p), to_device(kv_indptr_d)] ) - self._paged_kv_last_page_len_buf = last_page_len_d.to( - self.device, non_blocking=non_blocking + self._paged_kv_indices_buf = torch.cat( + [to_device(kv_indices_p), to_device(kv_indices_d)] ) - self._qo_indptr_buf = qo_indptr_host.to( - self.device, non_blocking=non_blocking + self._paged_kv_last_page_len_buf = torch.cat( + [to_device(last_page_len_p), to_device(last_page_len_d)] ) - indptr_host = indptr_d.to("cpu") - last_page_len_host = last_page_len_d.to("cpu") + kv_indptr_host_p = kv_indptr_p.to("cpu", non_blocking=True) + kv_indptr_host_d = kv_indptr_d.to("cpu", non_blocking=True) + last_page_len_host_p = last_page_len_p.to("cpu", non_blocking=True) + last_page_len_host_d = last_page_len_d.to("cpu", non_blocking=True) if data_type is not None: if q_data_type is None: @@ -404,7 +438,13 @@ def plan( self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type - kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size) + torch.cuda.synchronize() + kv_lens_arr_host_p = get_seq_lens( + kv_indptr_host_p, last_page_len_host_p, page_size + ) + kv_lens_arr_host_d = get_seq_lens( + kv_indptr_host_d, last_page_len_host_d, page_size + ) if self._jit_module is not None: self._cached_module = self._jit_module else: @@ -413,7 +453,7 @@ def plan( q_data_type, kv_data_type, q_data_type, - indptr_d.dtype, + kv_indptr_d.dtype, head_dim, # head_dim_qk head_dim, # head_dim_vo PosEncodingMode[pos_encoding_mode].value, @@ -425,20 +465,25 @@ def plan( self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, - qo_indptr_host, - indptr_host, - kv_lens_arr_host, - batch_size_d, # total_num_rows + qo_indptr_host_p, + kv_indptr_host_p, + kv_lens_arr_host_p, + batch_size_p, + batch_size_p, # total_num_rows_p + qo_indptr_host_d, + kv_indptr_host_d, + kv_lens_arr_host_d, batch_size_d, + batch_size_d, # total_num_rows_d num_qo_heads, num_kv_heads, - page_size, - self.is_cuda_graph_enabled, head_dim, head_dim, + page_size, + self.is_cuda_graph_enabled, ) - self._indptr_type = indptr_d.dtype + self._indptr_type = kv_indptr_d.dtype self._pos_encoding_mode = pos_encoding_mode self._window_left = window_left self._logits_soft_cap = logits_soft_cap @@ -606,7 +651,7 @@ def run( q_d, k_cache_d, v_cache_d, - self._qo_indptr_buf, + self._qo_indptr_buf_d, self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 74a192ae2..985bc2737 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1428,9 +1428,11 @@ def plan( self._max_item_len_ptr = max_item_len_ptr # NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors - qo_indptr_host = qo_indptr.to("cpu") - paged_kv_indptr_host = paged_kv_indptr.to("cpu") - paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu") + qo_indptr_host = qo_indptr.to("cpu", non_blocking=True) + paged_kv_indptr_host = paged_kv_indptr.to("cpu", non_blocking=True) + paged_kv_last_page_len_host = paged_kv_last_page_len.to( + "cpu", non_blocking=True + ) kv_lens_arr_host = get_seq_lens( paged_kv_indptr_host, paged_kv_last_page_len_host, page_size ) @@ -1438,6 +1440,7 @@ def plan( kv_lens_arr_host, non_blocking=non_blocking ) + torch.cuda.synchronize() total_num_rows = qo_indptr_host[-1] if self.is_cuda_graph_enabled: diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index e07c68038..73eb4f97c 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -1070,13 +1070,16 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_tile_indices_offset); IdType* o_indptr_h = GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.o_indptr_offset); - IdType* kv_chunk_size_ptr_h = - GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_chunk_size_ptr_offset); + IdType* kv_chunk_size_ptr_p = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_chunk_size_ptr_offset_p); + IdType* kv_chunk_size_ptr_d = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_chunk_size_ptr_offset_d); std::copy(request_indices_vec.begin(), request_indices_vec.end(), request_indices_h); std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), qo_tile_indices_h); std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), kv_tile_indices_h); std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), o_indptr_h); - kv_chunk_size_ptr_h[0] = kv_chunk_size; + kv_chunk_size_ptr_p[0] = kv_chunk_size_p; + kv_chunk_size_ptr_d[0] = kv_chunk_size_d; if (split_kv) { AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); From eb8f7196db93b9698ab3b4d127855b9521b8b791 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 8 Jul 2025 21:21:14 +0000 Subject: [PATCH 20/33] fixes --- csrc/pod.cu | 28 ---------------------------- flashinfer/pod.py | 12 +++++++----- include/flashinfer/attention/pod.cuh | 2 +- 3 files changed, 8 insertions(+), 34 deletions(-) diff --git a/csrc/pod.cu b/csrc/pod.cu index 1995d9287..c87b090d8 100644 --- a/csrc/pod.cu +++ b/csrc/pod.cu @@ -292,31 +292,3 @@ void PODWithKVCacheTensorRun( //}); }); } - -at::Tensor PODWithKVCachePlan(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr_p, - at::Tensor kv_indptr_p, at::Tensor kv_len_arr, int64_t total_num_rows, - int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, - int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, - int64_t head_dim_vo) { - size_t float_workspace_size_in_bytes = - float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); - size_t int_workspace_size_in_bytes = - int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); - - PODPlanInfo plan_info; - - const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device()); - const cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); - cudaError_t status = PrefillPlan( - float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, - int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), - int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr(), - kv_indptr.data_ptr(), total_num_rows, batch_size, num_qo_heads, num_kv_heads, - head_dim_qk, head_dim_vo, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); - - TORCH_CHECK(status == cudaSuccess, - "Failed to plan prefill with error: ", cudaGetErrorString(status)); - - return vec_to_tensor(plan_info.ToVector()); -} diff --git a/flashinfer/pod.py b/flashinfer/pod.py index 46ef4d01e..42ffab93e 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -379,10 +379,12 @@ def plan( raise ValueError( "The size of indices should be less than or equal to the allocated buffer" ) - self._paged_kv_indptr_buf[:batch_size_p].copy_( + self._paged_kv_indptr_buf[: batch_size_p + 1].copy_( kv_indptr_p, non_blocking=non_blocking ) - self._paged_kv_indptr_buf[batch_size_p : batch_size_p + batch_size_d].copy_( + self._paged_kv_indptr_buf[ + batch_size_p + 1 : batch_size_p + batch_size_d + 2 + ].copy_( kv_indptr_d, non_blocking=(kv_indptr_d.device == self.device) and non_blocking, ) @@ -395,12 +397,12 @@ def plan( last_page_len_d, non_blocking=(last_page_len_d.device == self.device) and non_blocking, ) - self._paged_kv_indices_buf[:batch_size_p].copy_( + self._paged_kv_indices_buf[: batch_size_p + 1].copy_( kv_indices_d, non_blocking=(kv_indices_d.device == self.device) and non_blocking, ) self._paged_kv_indices_buf[ - batch_size_p : batch_size_p + batch_size_d + batch_size_p + 1 : batch_size_p + batch_size_d + 2 ].copy_( kv_indices_d, non_blocking=(kv_indices_d.device == self.device) and non_blocking, @@ -651,7 +653,7 @@ def run( q_d, k_cache_d, v_cache_d, - self._qo_indptr_buf_d, + self._qo_indptr_buf, self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, diff --git a/include/flashinfer/attention/pod.cuh b/include/flashinfer/attention/pod.cuh index d1a8f570c..681cfaec5 100644 --- a/include/flashinfer/attention/pod.cuh +++ b/include/flashinfer/attention/pod.cuh @@ -396,7 +396,7 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, // ************************************************ / static int* tbAssign = nullptr; - cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2)); + if (tbAssign == nullptr) cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2)); cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2)); // Setup kernel arguments From e8b266db2f7e6edb3f7bb5abd3f825d5a6408208 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 8 Jul 2025 22:19:16 +0000 Subject: [PATCH 21/33] fix params --- include/flashinfer/attention/scheduler.cuh | 51 +++++++++++----------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 73eb4f97c..4e2baa301 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -915,7 +915,8 @@ struct PODPlanInfo { int64_t padded_batch_size; int64_t total_num_rows; int64_t total_num_rows_offset; - int64_t cta_tile_q; + uint16_t cta_tile_q_p; + uint16_t cta_tile_q_d; int64_t request_indices_offset; int64_t qo_tile_indices_offset; int64_t kv_tile_indices_offset; @@ -933,7 +934,8 @@ struct PODPlanInfo { : padded_batch_size(0), total_num_rows(0), total_num_rows_offset(0), - cta_tile_q(0), + cta_tile_q_p(0), + cta_tile_q_d(0), request_indices_offset(0), qo_tile_indices_offset(0), kv_tile_indices_offset(0), @@ -951,7 +953,8 @@ struct PODPlanInfo { return {padded_batch_size, total_num_rows, total_num_rows_offset, - cta_tile_q, + cta_tile_q_p, + cta_tile_q_d, request_indices_offset, qo_tile_indices_offset, kv_tile_indices_offset, @@ -976,24 +979,20 @@ struct PODPlanInfo { padded_batch_size = vec[0]; total_num_rows = vec[1]; total_num_rows_offset = vec[2]; - cta_tile_q = vec[3]; - request_indices_offset_p = vec[4]; - request_indices_offset_d = vec[5]; - qo_tile_indices_offset_p = vec[6]; - qo_tile_indices_offset_d = vec[7]; - kv_tile_indices_offset_p = vec[8]; - kv_tile_indices_offset_d = vec[9]; - merge_indptr_offset = vec[7]; - o_indptr_offset = vec[8]; - kv_chunk_size_ptr_offset_p = vec[9]; - kv_chunk_size_ptr_offset_d = vec[10]; - v_offset_p = vec[11]; - v_offset_d = vec[12]; - s_offset_p = vec[13]; - s_offset_d = vec[13]; + cta_tile_q_p = vec[3]; + cta_tile_q_d = vec[4]; + request_indices_offset = vec[5]; + qo_tile_indices_offset = vec[6]; + kv_tile_indices_offset = vec[7]; + merge_indptr_offset = vec[8]; + o_indptr_offset = vec[9]; + kv_chunk_size_ptr_offset_p = vec[10]; + kv_chunk_size_ptr_offset_d = vec[11]; + v_offset = vec[12]; + s_offset = vec[13]; block_valid_mask_offset = vec[14]; - enable_cuda_graph = vec[13]; - split_kv = vec[14]; + enable_cuda_graph = vec[15]; + split_kv = vec[16]; } }; @@ -1025,11 +1024,12 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by // step 2: determine kv_chunk_size auto [split_kv, real_batch_size, padded_batch_size_p, padded_batch_size_d, cta_tile_q_p, - cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, request_indices, qo_tile_indices, - kv_tile_indices, merge_indptr, o_indptr] = + cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, request_indices_vec, qo_tile_indices_vec, + kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] = PODSplitQOKVIndptr(qo_indptr_p, kv_indptr_p, qo_indptr_d, kv_indptr_d, total_num_rows_p, batch_size_p, total_num_rows_d, batch_size_d, num_qo_heads, num_kv_heads, head_dim_vo, page_size, max_batch_size_if_split, enable_cuda_graph); + uint32_t padded_batch_size = padded_batch_size_p + padded_batch_size_d; plan_info.cta_tile_q_p = cta_tile_q_p; plan_info.cta_tile_q_d = cta_tile_q_d; @@ -1082,12 +1082,13 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by kv_chunk_size_ptr_d[0] = kv_chunk_size_d; if (split_kv) { + uint32_t num_outputs_p = num_qo_heads * padded_batch_size_p * cta_tile_q_p * head_dim_vo; + uint32_t num_outputs_d = num_qo_heads * padded_batch_size_d * cta_tile_q_d * head_dim_vo; AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); plan_info.v_offset = float_allocator.aligned_alloc_offset( - num_qo_heads * padded_batch_size * cta_tile_q * head_dim_vo * sizeof(float), 16, - "pod_tmp_v"); + (num_outputs_p + num_outputs_d) * sizeof(float), 16, "pod_tmp_v"); plan_info.s_offset = float_allocator.aligned_alloc_offset( - num_qo_heads * padded_batch_size * cta_tile_q * sizeof(float), 16, "pod_tmp_s"); + (num_outputs_p + num_outputs_d) * sizeof(float), 16, "pod_tmp_s"); plan_info.merge_indptr_offset = int_allocator.aligned_alloc_offset( sizeof(IdType) * (plan_info.total_num_rows + 1), 16, "pod_merge_indptr"); plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset( From 560918b25931fae6254494c80b4d21ef9e1072ba Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 8 Jul 2025 23:44:55 +0000 Subject: [PATCH 22/33] fix some indices and params --- csrc/pod.cu | 4 +- flashinfer/pod.py | 12 ++-- include/flashinfer/attention/pod.cuh | 10 +-- include/flashinfer/attention/scheduler.cuh | 73 ++++++++++++++-------- 4 files changed, 59 insertions(+), 40 deletions(-) diff --git a/csrc/pod.cu b/csrc/pod.cu index c87b090d8..950c01c31 100644 --- a/csrc/pod.cu +++ b/csrc/pod.cu @@ -196,7 +196,7 @@ void PODWithKVCacheTensorRun( params.window_left = window_left_p; params.partition_kv = false; - + params.padded_batch_size = plan_info.padded_batch_size_p; params.maybe_custom_mask = maybe_custom_mask_p ? static_cast(maybe_custom_mask_p->data_ptr()) : nullptr; @@ -251,7 +251,7 @@ void PODWithKVCacheTensorRun( GetPtrFromBaseOffset(int_buffer_ptr, plan_info.block_valid_mask_offset); } } - params.padded_batch_size = plan_info.padded_batch_size; + params.padded_batch_size = plan_info.padded_batch_size_d; params.max_total_num_rows = plan_info.total_num_rows; params.partition_kv = false; diff --git a/flashinfer/pod.py b/flashinfer/pod.py index 42ffab93e..68b17caa7 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -397,12 +397,12 @@ def plan( last_page_len_d, non_blocking=(last_page_len_d.device == self.device) and non_blocking, ) - self._paged_kv_indices_buf[: batch_size_p + 1].copy_( - kv_indices_d, - non_blocking=(kv_indices_d.device == self.device) and non_blocking, + self._paged_kv_indices_buf[: len(kv_indices_p)].copy_( + kv_indices_p, + non_blocking=(kv_indices_p.device == self.device) and non_blocking, ) self._paged_kv_indices_buf[ - batch_size_p + 1 : batch_size_p + batch_size_d + 2 + len(kv_indices_p) : len(kv_indices_p) + len(kv_indices_d) ].copy_( kv_indices_d, non_blocking=(kv_indices_d.device == self.device) and non_blocking, @@ -471,12 +471,12 @@ def plan( kv_indptr_host_p, kv_lens_arr_host_p, batch_size_p, - batch_size_p, # total_num_rows_p + qo_indptr_host_p[-1], # total_num_rows_p qo_indptr_host_d, kv_indptr_host_d, kv_lens_arr_host_d, batch_size_d, - batch_size_d, # total_num_rows_d + qo_indptr_host_d[-1], # total_num_rows_d num_qo_heads, num_kv_heads, head_dim, diff --git a/include/flashinfer/attention/pod.cuh b/include/flashinfer/attention/pod.cuh index 681cfaec5..9826b7bfc 100644 --- a/include/flashinfer/attention/pod.cuh +++ b/include/flashinfer/attention/pod.cuh @@ -375,13 +375,13 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, decode_params.o = tmp_v; decode_params.lse = tmp_s; } - uint32_t num_qo_tiles = ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q); - int nblks_p(num_qo_tiles * - (prefill_params.partition_kv ? prefill_params.partition_kv : 1) * - num_kv_heads); + // uint32_t num_qo_tiles = ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q); + uint32_t padded_batch_size_p = prefill_params.padded_batch_size; + uint32_t padded_batch_size_d = decode_params.padded_batch_size; + int nblks_p(padded_batch_size_p * num_kv_heads); int nthrs_p(32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P); - int nblks_d(padded_batch_size_d * 1 * num_kv_heads); + int nblks_d(padded_batch_size_d * num_kv_heads); int nthrs_d(32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D); // ******* Select final combined sizes here ******* / diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 4e2baa301..380fa3fa9 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -58,7 +58,7 @@ inline void CopyToPageLockedBuffer(void* page_locked_int_buffer, int64_t offset, } /*! - * \brief Compute the maximum number of pages per batch and the new batch size + * \brief Compute the maximum number of pages per batch and the new batch size (grid dim x) * after we partition Paged KV-Cache into multiple chunks on KV sequence length * dimension. * \tparam IdType A template type indicates the index data type @@ -737,7 +737,7 @@ template inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, void* page_locked_int_buffer, size_t int_workspace_size_in_bytes, PrefillPlanInfo& plan_info, - IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_t total_num_rows, + IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t total_num_rows, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim_qk, uint32_t head_dim_vo, uint32_t page_size, bool enable_cuda_graph, uint32_t sizeof_dtype_o, @@ -761,7 +761,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i // step 2: determine kv_chunk_size auto [split_kv, real_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec, qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, - o_indptr_vec] = PrefillSplitQOKVIndptr(qo_indptr_p, kv_indptr_p, total_num_rows, batch_size, + o_indptr_vec] = PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, batch_size, num_qo_heads, num_kv_heads, head_dim_vo, page_size, max_batch_size_if_split, enable_cuda_graph); @@ -912,8 +912,11 @@ inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_ } struct PODPlanInfo { - int64_t padded_batch_size; + int64_t padded_batch_size_p; + int64_t padded_batch_size_d; int64_t total_num_rows; + int64_t total_num_rows_p; + int64_t total_num_rows_d; int64_t total_num_rows_offset; uint16_t cta_tile_q_p; uint16_t cta_tile_q_d; @@ -931,8 +934,11 @@ struct PODPlanInfo { bool split_kv; PODPlanInfo() - : padded_batch_size(0), + : padded_batch_size_p(0), + padded_batch_size_d(0), total_num_rows(0), + total_num_rows_p(0), + total_num_rows_d(0), total_num_rows_offset(0), cta_tile_q_p(0), cta_tile_q_d(0), @@ -941,7 +947,8 @@ struct PODPlanInfo { kv_tile_indices_offset(0), merge_indptr_offset(0), o_indptr_offset(0), - kv_chunk_size_ptr_offset(0), + kv_chunk_size_ptr_offset_p(0), + kv_chunk_size_ptr_offset_d(0), v_offset(0), s_offset(0), block_valid_mask_offset(0), @@ -950,8 +957,11 @@ struct PODPlanInfo { // convert PrefillPlanInfo to std::vector std::vector ToVector() const { - return {padded_batch_size, + return {padded_batch_size_p, + padded_batch_size_d, total_num_rows, + total_num_rows_p, + total_num_rows_d, total_num_rows_offset, cta_tile_q_p, cta_tile_q_d, @@ -971,28 +981,31 @@ struct PODPlanInfo { // From std::vector to PodPlanInfo void FromVector(const std::vector& vec) { - if (vec.size() != 15) { + if (vec.size() != 19) { std::ostringstream err_msg; - err_msg << "PodPlanInfo::FromVector: vec.size() should be 15, but got " << vec.size(); + err_msg << "PodPlanInfo::FromVector: vec.size() should be 19, but got " << vec.size(); FLASHINFER_ERROR(err_msg.str()); } - padded_batch_size = vec[0]; - total_num_rows = vec[1]; - total_num_rows_offset = vec[2]; - cta_tile_q_p = vec[3]; - cta_tile_q_d = vec[4]; - request_indices_offset = vec[5]; - qo_tile_indices_offset = vec[6]; - kv_tile_indices_offset = vec[7]; - merge_indptr_offset = vec[8]; - o_indptr_offset = vec[9]; - kv_chunk_size_ptr_offset_p = vec[10]; - kv_chunk_size_ptr_offset_d = vec[11]; - v_offset = vec[12]; - s_offset = vec[13]; - block_valid_mask_offset = vec[14]; - enable_cuda_graph = vec[15]; - split_kv = vec[16]; + padded_batch_size_p = vec[0]; + padded_batch_size_d = vec[1]; + total_num_rows = vec[2]; + total_num_rows_p = vec[3]; + total_num_rows_d = vec[4]; + total_num_rows_offset = vec[4]; + cta_tile_q_p = vec[5]; + cta_tile_q_d = vec[6]; + request_indices_offset = vec[7]; + qo_tile_indices_offset = vec[8]; + kv_tile_indices_offset = vec[9]; + merge_indptr_offset = vec[10]; + o_indptr_offset = vec[11]; + kv_chunk_size_ptr_offset_p = vec[12]; + kv_chunk_size_ptr_offset_d = vec[13]; + v_offset = vec[14]; + s_offset = vec[15]; + block_valid_mask_offset = vec[16]; + enable_cuda_graph = vec[17]; + split_kv = vec[18]; } }; @@ -1030,7 +1043,12 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by batch_size_p, total_num_rows_d, batch_size_d, num_qo_heads, num_kv_heads, head_dim_vo, page_size, max_batch_size_if_split, enable_cuda_graph); uint32_t padded_batch_size = padded_batch_size_p + padded_batch_size_d; + uint32_t batch_size = batch_size_p + batch_size_d; + uint32_t total_num_rows = total_num_rows_p + total_num_rows_d; + plan_info.padded_batch_size_p = padded_batch_size_p; + plan_info.padded_batch_size_d = padded_batch_size_d; + plan_info.total_num_rows = total_num_rows; plan_info.cta_tile_q_p = cta_tile_q_p; plan_info.cta_tile_q_d = cta_tile_q_d; plan_info.total_num_rows_p = total_num_rows_p; @@ -1059,7 +1077,7 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by int_allocator.aligned_alloc_offset(sizeof(uint32_t), 16, "batch_prefill_total_num_rows"); uint32_t* total_num_rows_h = GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.total_num_rows_offset); - *total_num_rows_h = qo_indptr_h[batch_size]; + *total_num_rows_h = total_num_rows_p + total_num_rows_d; } IdType* request_indices_h = @@ -1082,6 +1100,7 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by kv_chunk_size_ptr_d[0] = kv_chunk_size_d; if (split_kv) { + // TODO(Wenxuan): write through for non-split-kv requests uint32_t num_outputs_p = num_qo_heads * padded_batch_size_p * cta_tile_q_p * head_dim_vo; uint32_t num_outputs_d = num_qo_heads * padded_batch_size_d * cta_tile_q_d * head_dim_vo; AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); From 2105101395f430f368c70b01a0c64b762bda0c33 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 9 Jul 2025 05:23:23 +0000 Subject: [PATCH 23/33] update PODWithKVCacheTensorRun args --- csrc/flashinfer_ops.cu | 22 ++++---- csrc/pod.cu | 52 +++++++++--------- csrc/pod_customize_config.jinja | 2 +- csrc/pod_jit_pybind.cu | 22 ++++---- flashinfer/pod.py | 43 ++++++++------- flashinfer/prefill.py | 2 +- include/flashinfer/attention/scheduler.cuh | 62 ++++++++++------------ 7 files changed, 107 insertions(+), 98 deletions(-) diff --git a/csrc/flashinfer_ops.cu b/csrc/flashinfer_ops.cu index 37b776e3c..e305bfade 100644 --- a/csrc/flashinfer_ops.cu +++ b/csrc/flashinfer_ops.cu @@ -124,22 +124,24 @@ void BatchPrefillWithPagedKVCacheRun( //========== pod-attention ========= void PODWithKVCacheTensorRun( + // Shared params + at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, + at::Tensor plan_info_vec, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, // Prefill params - at::Tensor q_p, at::Tensor k_p, at::Tensor v_p, at::Tensor tmp_p, at::Tensor o_p, + at::Tensor q_p, at::Tensor paged_k_p, at::Tensor paged_v_p, std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, int64_t window_left_p, std::optional maybe_custom_mask_p, std::optional maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, // Decode params - at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, - at::Tensor plan_info_vec, at::Tensor q_d, at::Tensor paged_k_cache_d, - at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, at::Tensor paged_kv_indptr_d, - at::Tensor paged_kv_indices_d, at::Tensor paged_kv_last_page_len_d, at::Tensor o_d, - std::optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, - int64_t window_left, std::optional maybe_custom_mask_d, - std::optional maybe_mask_indptr_d, std::optional maybe_alibi_slopes_d, - double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, - bool enable_pdl); + at::Tensor q_d, at::Tensor paged_k_cache_d, at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, + at::Tensor paged_kv_indptr_d, at::Tensor paged_kv_indices_d, + at::Tensor paged_kv_last_page_len_d, std::optional maybe_lse_d, + int64_t mask_mode_code_d, int64_t layout_d, int64_t window_left_d, + std::optional maybe_custom_mask_d, std::optional maybe_mask_indptr_d, + std::optional maybe_alibi_slopes_d, double logits_soft_cap_d, double sm_scale_d, + double rope_rcp_scale_d, double rope_rcp_theta_d, bool enable_pdl); //========== quantization ========== void packbits(at::Tensor x, const std::string& bitorder, at::Tensor y); diff --git a/csrc/pod.cu b/csrc/pod.cu index 950c01c31..cb6926f3a 100644 --- a/csrc/pod.cu +++ b/csrc/pod.cu @@ -69,22 +69,24 @@ at::Tensor PODWithKVCachePlan(at::Tensor float_workspace_buffer, at::Tensor int_ } void PODWithKVCacheTensorRun( + // Shared params + at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, + at::Tensor plan_info_vec, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, // Prefill params - at::Tensor q_p, at::Tensor k_p, at::Tensor v_p, at::Tensor tmp_p, at::Tensor o_p, + at::Tensor q_p, at::Tensor paged_k_p, at::Tensor paged_v_p, std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, int64_t window_left_p, std::optional maybe_custom_mask_p, std::optional maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, // Decode params - at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, - at::Tensor plan_info_vec, at::Tensor q_d, at::Tensor paged_k_cache_d, - at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, at::Tensor paged_kv_indptr_d, - at::Tensor paged_kv_indices_d, at::Tensor paged_kv_last_page_len_d, at::Tensor o_d, - std::optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, - int64_t window_left_d, std::optional maybe_custom_mask_d, - std::optional maybe_mask_indptr_d, std::optional maybe_alibi_slopes_d, - double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, - bool enable_pdl) { + at::Tensor q_d, at::Tensor paged_k_cache_d, at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, + at::Tensor paged_kv_indptr_d, at::Tensor paged_kv_indices_d, + at::Tensor paged_kv_last_page_len_d, std::optional maybe_lse_d, + int64_t mask_mode_code_d, int64_t layout_d, int64_t window_left_d, + std::optional maybe_custom_mask_d, std::optional maybe_mask_indptr_d, + std::optional maybe_alibi_slopes_d, double logits_soft_cap_d, double sm_scale_d, + double rope_rcp_scale_d, double rope_rcp_theta_d, bool enable_pdl) { // Prefill setup unsigned int head_dim_qk = q_p.size(2); unsigned int kv_len_p, qo_len_p, num_kv_heads, num_qo_heads; @@ -94,19 +96,19 @@ void PODWithKVCacheTensorRun( uint32_t q_stride_n_p = q_p.stride(0), q_stride_h_p = q_p.stride(1), k_stride_n_p, k_stride_h_p, v_stride_n_p, v_stride_h_p; if (kv_layout_p == QKVLayout::kNHD) { - kv_len_p = k_p.size(0); - num_kv_heads = k_p.size(1); - k_stride_n_p = k_p.stride(0); - k_stride_h_p = k_p.stride(1); - v_stride_n_p = v_p.stride(0); - v_stride_h_p = v_p.stride(1); + kv_len_p = paged_k_p.size(0); + num_kv_heads = paged_k_p.size(1); + k_stride_n_p = paged_k_p.stride(0); + k_stride_h_p = paged_k_p.stride(1); + v_stride_n_p = paged_v_p.stride(0); + v_stride_h_p = paged_v_p.stride(1); } else { - kv_len_p = k_p.size(1); - num_kv_heads = k_p.size(0); - k_stride_h_p = k_p.stride(0); - k_stride_n_p = k_p.stride(1); - v_stride_h_p = v_p.stride(0); - v_stride_n_p = v_p.stride(1); + kv_len_p = paged_k_p.size(1); + num_kv_heads = paged_k_p.size(0); + k_stride_h_p = paged_k_p.stride(0); + k_stride_n_p = paged_k_p.stride(1); + v_stride_h_p = paged_v_p.stride(0); + v_stride_n_p = paged_v_p.stride(1); } if (maybe_lse_p) { const auto& lse = *maybe_lse_p; @@ -117,7 +119,7 @@ void PODWithKVCacheTensorRun( const MaskMode mask_mode_p = static_cast(mask_mode_code_p); auto q_scalar_type = q_p.scalar_type(); - auto kv_scalar_type = k_p.scalar_type(); + auto kv_scalar_type = paged_k_p.scalar_type(); // Decode setup (Tensor decode = batched prefill) PODPlanInfo plan_info; @@ -178,8 +180,8 @@ void PODWithKVCacheTensorRun( // Make params a reference to prefill_params to set values PrefillParams& params = prefill_params; params.q = static_cast(q_p.data_ptr()); - params.k = static_cast(k_p.data_ptr()); - params.v = static_cast(v_p.data_ptr()); + params.k = static_cast(paged_k_p.data_ptr()); + params.v = static_cast(paged_v_p.data_ptr()); params.o = static_cast(o_p.data_ptr()); params.lse = maybe_lse_p ? static_cast(maybe_lse_p->data_ptr()) : nullptr; params.num_qo_heads = num_qo_heads; diff --git a/csrc/pod_customize_config.jinja b/csrc/pod_customize_config.jinja index b4c56a0e8..ed66afb4b 100644 --- a/csrc/pod_customize_config.jinja +++ b/csrc/pod_customize_config.jinja @@ -30,7 +30,7 @@ constexpr auto USE_SLIDING_WINDOW_D = {{ use_sliding_window_d }}; constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; constexpr bool USE_LOGITS_SOFT_CAP = false; -using PrefillParams = SinglePrefillParams; +using PrefillParams = BatchPrefillPagedParams; using DecodeParams = BatchPrefillPagedParams; #define DISPATCH_context(MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, \ diff --git a/csrc/pod_jit_pybind.cu b/csrc/pod_jit_pybind.cu index 66561a5af..d9d71c6a5 100644 --- a/csrc/pod_jit_pybind.cu +++ b/csrc/pod_jit_pybind.cu @@ -17,22 +17,24 @@ #include "pytorch_extension_utils.h" void PODWithKVCacheTensorRun( + // Shared params + at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, + at::Tensor plan_info_vec, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, // Prefill params - at::Tensor q_p, at::Tensor k_p, at::Tensor v_p, at::Tensor tmp_p, at::Tensor o_p, + at::Tensor q_p, at::Tensor paged_k_p, at::Tensor paged_v_p, std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, int64_t window_left_p, std::optional maybe_custom_mask_p, std::optional maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, // Decode params - at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, - at::Tensor plan_info_vec, at::Tensor q_d, at::Tensor paged_k_cache_d, - at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, at::Tensor paged_kv_indptr_d, - at::Tensor paged_kv_indices_d, at::Tensor paged_kv_last_page_len_d, at::Tensor o_d, - std::optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, - int64_t window_left_d, std::optional maybe_custom_mask_d, - std::optional maybe_mask_indptr_d, std::optional maybe_alibi_slopes_d, - double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, - bool enable_pdl); + at::Tensor q_d, at::Tensor paged_k_cache_d, at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, + at::Tensor paged_kv_indptr_d, at::Tensor paged_kv_indices_d, + at::Tensor paged_kv_last_page_len_d, std::optional maybe_lse_d, + int64_t mask_mode_code_d, int64_t layout_d, int64_t window_left_d, + std::optional maybe_custom_mask_d, std::optional maybe_mask_indptr_d, + std::optional maybe_alibi_slopes_d, double logits_soft_cap_d, double sm_scale_d, + double rope_rcp_scale_d, double rope_rcp_theta_d, bool enable_pdl); TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { // Batch-request prefill attention with KV-Cache operator diff --git a/flashinfer/pod.py b/flashinfer/pod.py index 68b17caa7..a395d1970 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -310,11 +310,14 @@ def plan( kv_indptr_p: torch.Tensor The indptr of the paged kv cache for prefill, shape: ``[batch_size + 1]``. kv_indices_p: torch.Tensor - The page indices of the paged kv cache for prefill, shape: ``[qo_indptr[-1]]``. + The page indices of the paged kv cache for prefill, shape: ``[kv_indptr[-1]]``. + last_page_len_p : torch.Tensor + The number of entries in the last page of each request in the kv + cache, shape: ``[batch_size]`` kv_indptr_d : torch.Tensor The indptr of the paged kv cache for decode, shape: ``[batch_size + 1]`` kv_indices_d : torch.Tensor - The page indices of the paged kv cache for decode, shape: ``[qo_indptr[-1]]`` + The page indices of the paged kv cache for decode, shape: ``[kv_indptr[-1]]`` last_page_len_d : torch.Tensor The number of entries in the last page of each request in the kv cache, shape: ``[batch_size]`` @@ -542,7 +545,6 @@ def run( # Prefill setup _check_pos_encoding_mode(pos_encoding_mode_p) _check_kv_layout(kv_layout_p) - tmp_p = _get_cache_buf("pod_with_kv_cache_tmp", 32 * 1024 * 1024, q_p.device) if logits_soft_cap_p is None: logits_soft_cap_p = 0.0 if sm_scale_p is None: @@ -570,8 +572,15 @@ def run( lse_p = torch.empty( (q_p.size(0), q_p.size(1)), dtype=torch.float32, device=q_p.device ) - - out_p = torch.empty_like(q_p) + qo_len_p, num_qo_heads, head_dim = q_p.shape + qo_len_d, _, _ = q_d.shape + out = torch.empty( + qo_len_p + qo_len_d, + num_qo_heads, + head_dim, + device=q_p.device, + dtype=q_p.dtype, + ) # Decode setup k_cache_d, v_cache_d = _unpack_paged_kv_cache(paged_kv_cache_d, self._kv_layout) @@ -606,7 +615,6 @@ def run( lse_d = torch.empty( (q_d.size(0), q_d.size(1)), dtype=torch.float32, device=q_d.device ) - out_d = torch.empty_like(q_d) module_getter = get_pod_module( # Prefill params @@ -630,12 +638,19 @@ def run( logits_soft_cap_d > 0, # use_logits_soft_cap ) module_getter.run_tensor( + # Shared params + self._float_workspace_buffer, + self._int_workspace_buffer, + self._plan_info, + self._qo_indptr_buf, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, + self._paged_kv_last_page_len_buf, + out, # Prefill params q_p, k_p, v_p, - tmp_p, - out_p, lse_p, mask_mode_p, TensorLayout[kv_layout_p].value, @@ -647,17 +662,9 @@ def run( 1.0 / rope_scale_p, 1.0 / rope_theta_p, # Decode params - self._float_workspace_buffer, - self._int_workspace_buffer, - self._plan_info, q_d, k_cache_d, v_cache_d, - self._qo_indptr_buf, - self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, - self._paged_kv_last_page_len_buf, - out_d, lse_d, MaskMode.NON_CAUSAL.value, TensorLayout[self._kv_layout].value, @@ -673,9 +680,9 @@ def run( ) if v_scale is not None: - out_d *= v_scale + out *= v_scale - return (out_p, out_d) + return out[:qo_len_p], out[qo_len_p:] def end_forward(self) -> None: r"""Warning: this function is deprecated and has no effect.""" diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 985bc2737..6c5068716 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1305,7 +1305,7 @@ def plan( paged_kv_indptr : torch.Tensor The indptr of the paged kv-cache, shape: ``[batch_size + 1]``. paged_kv_indices : torch.Tensor - The page indices of the paged kv-cache, shape: ``[qo_indptr[-1]]``. + The page indices of the paged kv-cache, shape: ``[kv_indptr[-1]]``. paged_kv_last_page_len : torch.Tensor The number of entries in the last page of each request in the paged kv-cache, shape: ``[batch_size]``. diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 380fa3fa9..02bb82025 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -516,37 +516,35 @@ inline auto get_qkv_len_arr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t b inline auto get_q_tiles(std::vector& packed_qo_len_arr, uint32_t batch_size, uint32_t head_dim, uint32_t page_size, uint32_t total_num_rows, - uint32_t gqa_group_size, bool enable_cuda_graph, uint32_t tile_size = -1) { - const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U); + uint32_t gqa_group_size, bool enable_cuda_graph, bool is_decode = false) { uint32_t cta_tile_q; uint32_t total_num_tiles_q; if (enable_cuda_graph) { // When CUDA graphs are enabled, the lengths of sequences determined by // qo_indptr_h can vary. We assume that the dummy data based on which // the CUDA graph is created fixes the maximum number of tokens. - if (tile_size == -1) { + if (is_decode) { + cta_tile_q = 16; + } else { const uint64_t max_seq_len = total_num_rows - batch_size + 1; uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size; cta_tile_q = FA2DetermineCtaTileQ(max_qo_len, head_dim); - } else { - cta_tile_q = tile_size; } - // Find an upper bound for the number of tiles, derived from the total // number of rows and the batch size. The sum of qo lengths rounded // up to cta_tile_q will not exceed this number derived from the total // number of rows. total_num_tiles_q = ceil_div(total_num_rows * gqa_group_size, cta_tile_q) + batch_size - 1; } else { - if (tile_size == -1) { + if (is_decode) { + cta_tile_q = 16; + } else { int64_t sum_packed_qo_len = 0; for (uint32_t i = 0; i < batch_size; ++i) { sum_packed_qo_len += packed_qo_len_arr[i]; } const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size; cta_tile_q = FA2DetermineCtaTileQ(avg_packed_qo_len, head_dim); - } else { - cta_tile_q = tile_size; } total_num_tiles_q = 0; @@ -624,6 +622,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, o_indptr.push_back(0); const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; + const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U); // step 1: determine packed_qo_len_arr and verify qo_indptr contents. auto [packed_qo_len_arr, kv_len_arr] = @@ -852,7 +851,7 @@ inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_ o_indptr.push_back(0); const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; - + const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U); // step 1: determine packed_qo_len_arr and verify qo_indptr contents. auto [packed_qo_len_arr_p, kv_len_arr_p] = get_qkv_len_arr(qo_indptr_p, kv_indptr_p, batch_size_p, num_qo_heads, gqa_group_size); @@ -863,10 +862,9 @@ inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_ auto [cta_tile_q_p, num_tiles_q_p] = get_q_tiles(packed_qo_len_arr_p, batch_size_p, head_dim, page_size, total_num_rows_p, gqa_group_size, enable_cuda_graph); - auto cta_tile_q_d = 16; // minimum for tensor core decode auto [cta_tile_q_d, num_tiles_q_d] = get_q_tiles(packed_qo_len_arr_d, batch_size_d, head_dim, page_size, total_num_rows_d, - gqa_group_size, enable_cuda_graph, cta_tile_q_d); + gqa_group_size, enable_cuda_graph, /*is_decode=*/true); uint32_t total_num_tiles_q = num_tiles_q_p + num_tiles_q_d; // Allocate CTAs proportional to the number of query tiles in prefill and decode @@ -895,9 +893,9 @@ inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_ bool split_kv = split_kv_p || split_kv_d; uint32_t real_batch_size = new_batch_size_p + new_batch_size_d; const size_t padded_batch_size_p = - enable_cuda_graph ? std::max(max_bs_p, total_num_tiles_q_p) : new_batch_size_p; + enable_cuda_graph ? std::max(max_bs_p, num_tiles_q_p) : new_batch_size_p; const size_t padded_batch_size_d = - enable_cuda_graph ? std::max(max_bs_d, total_num_tiles_q_d) : new_batch_size_d; + enable_cuda_graph ? std::max(max_bs_d, num_tiles_q_d) : new_batch_size_d; FLASHINFER_CHECK(real_batch_size <= padded_batch_size_p + padded_batch_size_d, "new batch size should not exceed padded batch size"); @@ -981,9 +979,9 @@ struct PODPlanInfo { // From std::vector to PodPlanInfo void FromVector(const std::vector& vec) { - if (vec.size() != 19) { + if (vec.size() != 20) { std::ostringstream err_msg; - err_msg << "PodPlanInfo::FromVector: vec.size() should be 19, but got " << vec.size(); + err_msg << "PodPlanInfo::FromVector: vec.size() should be 20, but got " << vec.size(); FLASHINFER_ERROR(err_msg.str()); } padded_batch_size_p = vec[0]; @@ -991,21 +989,21 @@ struct PODPlanInfo { total_num_rows = vec[2]; total_num_rows_p = vec[3]; total_num_rows_d = vec[4]; - total_num_rows_offset = vec[4]; - cta_tile_q_p = vec[5]; - cta_tile_q_d = vec[6]; - request_indices_offset = vec[7]; - qo_tile_indices_offset = vec[8]; - kv_tile_indices_offset = vec[9]; - merge_indptr_offset = vec[10]; - o_indptr_offset = vec[11]; - kv_chunk_size_ptr_offset_p = vec[12]; - kv_chunk_size_ptr_offset_d = vec[13]; - v_offset = vec[14]; - s_offset = vec[15]; - block_valid_mask_offset = vec[16]; - enable_cuda_graph = vec[17]; - split_kv = vec[18]; + total_num_rows_offset = vec[5]; + cta_tile_q_p = vec[6]; + cta_tile_q_d = vec[7]; + request_indices_offset = vec[8]; + qo_tile_indices_offset = vec[9]; + kv_tile_indices_offset = vec[10]; + merge_indptr_offset = vec[11]; + o_indptr_offset = vec[12]; + kv_chunk_size_ptr_offset_p = vec[13]; + kv_chunk_size_ptr_offset_d = vec[14]; + v_offset = vec[15]; + s_offset = vec[16]; + block_valid_mask_offset = vec[17]; + enable_cuda_graph = vec[18]; + split_kv = vec[19]; } }; @@ -1054,8 +1052,6 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by plan_info.total_num_rows_p = total_num_rows_p; plan_info.total_num_rows_d = total_num_rows_d; plan_info.enable_cuda_graph = enable_cuda_graph; - plan_info.padded_batch_size_p = padded_batch_size_p; - plan_info.padded_batch_size_d = padded_batch_size_d; plan_info.split_kv = split_kv; AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); From 1a82b17674a5a19fe9f7ee709127eef622794ff8 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 9 Jul 2025 16:01:19 +0000 Subject: [PATCH 24/33] add paged kv params --- csrc/pod.cu | 97 ++++++++++++++++++++++++----------------------- flashinfer/pod.py | 10 ++--- 2 files changed, 53 insertions(+), 54 deletions(-) diff --git a/csrc/pod.cu b/csrc/pod.cu index cb6926f3a..5163114ef 100644 --- a/csrc/pod.cu +++ b/csrc/pod.cu @@ -42,7 +42,7 @@ at::Tensor PODWithKVCachePlan(at::Tensor float_workspace_buffer, at::Tensor int_ uint32_t total_num_rows_p, uint32_t batch_size_p, at::Tensor qo_indptr_d, at::Tensor kv_indptr_d, uint32_t total_num_rows_d, uint32_t batch_size_d, - uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim_qk, + uint32_t num_qo_heads_p, uint32_t num_kv_heads, uint32_t head_dim_qk, uint32_t head_dim_vo, uint32_t page_size, bool enable_cuda_graph) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); @@ -59,7 +59,7 @@ at::Tensor PODWithKVCachePlan(at::Tensor float_workspace_buffer, at::Tensor int_ int_workspace_size_in_bytes, plan_info, qo_indptr_p.data_ptr(), kv_indptr_p.data_ptr(), total_num_rows_p, batch_size_p, qo_indptr_d.data_ptr(), kv_indptr_d.data_ptr(), - total_num_rows_d, batch_size_d, num_qo_heads, num_kv_heads, head_dim_qk, + total_num_rows_d, batch_size_d, num_qo_heads_p, num_kv_heads, head_dim_qk, head_dim_vo, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); TORCH_CHECK(status == cudaSuccess, @@ -72,27 +72,37 @@ void PODWithKVCacheTensorRun( // Shared params at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, at::Tensor plan_info_vec, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, - at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, int64_t layout, // Prefill params at::Tensor q_p, at::Tensor paged_k_p, at::Tensor paged_v_p, - std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, - int64_t window_left_p, std::optional maybe_custom_mask_p, - std::optional maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p, - double rope_rcp_scale_p, double rope_rcp_theta_p, + std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t window_left_p, + std::optional maybe_custom_mask_p, std::optional maybe_alibi_slopes_p, + double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, // Decode params at::Tensor q_d, at::Tensor paged_k_cache_d, at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, - at::Tensor paged_kv_indptr_d, at::Tensor paged_kv_indices_d, - at::Tensor paged_kv_last_page_len_d, std::optional maybe_lse_d, - int64_t mask_mode_code_d, int64_t layout_d, int64_t window_left_d, + std::optional maybe_lse_d, int64_t mask_mode_code_d, int64_t window_left_d, std::optional maybe_custom_mask_d, std::optional maybe_mask_indptr_d, std::optional maybe_alibi_slopes_d, double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, bool enable_pdl) { + PODPlanInfo plan_info; + plan_info.FromVector(tensor_to_vec(plan_info_vec)); + auto device = q_d.device(); + uint32_t batch_size = paged_kv_indptr.size(0) - 1; + void* float_buffer_ptr = static_cast(float_workspace_buffer_d.data_ptr()); + void* int_buffer_ptr = static_cast(int_workspace_buffer_d.data_ptr()); + // get kv_cache_strides + const int64_t* kv_cache_strides = nullptr; + auto k_strides = paged_k_cache.strides(); + auto v_strides = paged_v_cache.strides(); + TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical"); + kv_cache_strides = k_strides.data(); + // Prefill setup - unsigned int head_dim_qk = q_p.size(2); - unsigned int kv_len_p, qo_len_p, num_kv_heads, num_qo_heads; + uint32_t head_dim_qk = q_p.size(2); + uint32_t kv_len_p, qo_len_p, num_kv_heads, num_qo_heads_p; QKVLayout kv_layout_p = static_cast(layout_p); qo_len_p = q_p.size(0); - num_qo_heads = q_p.size(1); + num_qo_heads_p = q_p.size(1); uint32_t q_stride_n_p = q_p.stride(0), q_stride_h_p = q_p.stride(1), k_stride_n_p, k_stride_h_p, v_stride_n_p, v_stride_h_p; if (kv_layout_p == QKVLayout::kNHD) { @@ -113,7 +123,7 @@ void PODWithKVCacheTensorRun( if (maybe_lse_p) { const auto& lse = *maybe_lse_p; TORCH_CHECK(lse.size(0) == qo_len_p, lse.size(0), q_p.size(0)); - TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q_p.size(1)); + TORCH_CHECK(lse.size(1) == num_qo_heads_p, lse.size(1), q_p.size(1)); } const MaskMode mask_mode_p = static_cast(mask_mode_code_p); @@ -122,24 +132,20 @@ void PODWithKVCacheTensorRun( auto kv_scalar_type = paged_k_p.scalar_type(); // Decode setup (Tensor decode = batched prefill) - PODPlanInfo plan_info; - plan_info.FromVector(tensor_to_vec(plan_info_vec)); - QKVLayout kv_layout_d = static_cast(layout_d); - auto device = q_d.device(); - int64_t batch_size = paged_kv_indptr_d.size(0) - 1; - int64_t num_qo_heads_d = q_d.size(1); + QKVLayout kv_layout = static_cast(layout); + uint32_t num_qo_heads_d = q_d.size(1); - TORCH_CHECK(num_qo_heads == num_qo_heads_d, + TORCH_CHECK(num_qo_heads_p == num_qo_heads_d, "POD currently requires same # Query heads for prefill and decode"); - int64_t num_kv_heads_d, page_size_d; - uint32_t head_dim_qk_d = q_d.size(2); - if (kv_layout_d == QKVLayout::kHND) { - num_kv_heads_d = paged_k_cache_d.size(1); - page_size_d = paged_k_cache_d.size(2); + uint32_t num_kv_heads, page_size; + uint32_t head_dim_qk = q_d.size(2); + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); } else { - page_size_d = paged_k_cache_d.size(1); - num_kv_heads_d = paged_k_cache_d.size(2); + page_size = paged_k_cache.size(1); + num_kv_heads = paged_k_cache.size(2); } TORCH_CHECK(num_kv_heads == num_kv_heads_d, "POD currently requires same # KV heads for prefill and decode; Prefill: ", @@ -151,9 +157,6 @@ void PODWithKVCacheTensorRun( TORCH_CHECK(lse.size(1) == q_d.size(1), lse.size(1), q_d.size(1)); } - void* float_buffer_ptr = static_cast(float_workspace_buffer_d.data_ptr()); - void* int_buffer_ptr = static_cast(int_workspace_buffer_d.data_ptr()); - const MaskMode mask_mode_d = static_cast(mask_mode_code_d); auto q_scalar_type_d = q_d.scalar_type(); auto kv_scalar_type_d = paged_k_cache_d.scalar_type(); @@ -163,11 +166,11 @@ void PODWithKVCacheTensorRun( const auto q_stride_h_d = q_d.stride(1); // get kv_cache_strides - const int64_t* kv_cache_strides_d = nullptr; - auto k_strides_d = paged_k_cache_d.strides(); - auto v_strides_d = paged_v_cache_d.strides(); - TORCH_CHECK(k_strides_d == v_strides_d, "k/v strides must be identical"); - kv_cache_strides_d = k_strides_d.data(); + const int64_t* kv_cache_strides = nullptr; + auto k_strides = paged_k_cache_d.strides(); + auto v_strides = paged_v_cache_d.strides(); + TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical"); + kv_cache_strides = k_strides.data(); const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer_d.device()); const cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); @@ -175,6 +178,13 @@ void PODWithKVCacheTensorRun( DISPATCH_context( MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, USE_SLIDING_WINDOW_P, USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, HEAD_DIM_VO, batch_size, kv_layout, + static_cast(paged_k_cache_d.data_ptr()), + static_cast(paged_v_cache_d.data_ptr()), kv_cache_strides, + static_cast(paged_kv_indices_d.data_ptr()), + static_cast(paged_kv_indptr_d.data_ptr()), + static_cast(paged_kv_last_page_len_d.data_ptr())); PrefillParams prefill_params; { // Make params a reference to prefill_params to set values @@ -184,9 +194,9 @@ void PODWithKVCacheTensorRun( params.v = static_cast(paged_v_p.data_ptr()); params.o = static_cast(o_p.data_ptr()); params.lse = maybe_lse_p ? static_cast(maybe_lse_p->data_ptr()) : nullptr; - params.num_qo_heads = num_qo_heads; + params.num_qo_heads_p = num_qo_heads_p; params.num_kv_heads = num_kv_heads; - params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); + params.group_size = uint_fastdiv(num_qo_heads_p / num_kv_heads); params.qo_len = qo_len_p; params.kv_len = kv_len_p; params.q_stride_n = q_stride_n_p; @@ -217,20 +227,13 @@ void PODWithKVCacheTensorRun( { DecodeParams& params = decode_params; params.q = static_cast(q_d.data_ptr()); - paged_kv_t paged_kv( - num_kv_heads, page_size_d, HEAD_DIM_VO, batch_size, kv_layout_d, - static_cast(paged_k_cache_d.data_ptr()), - static_cast(paged_v_cache_d.data_ptr()), kv_cache_strides_d, - static_cast(paged_kv_indices_d.data_ptr()), - static_cast(paged_kv_indptr_d.data_ptr()), - static_cast(paged_kv_last_page_len_d.data_ptr())); params.paged_kv = paged_kv; params.q_indptr = static_cast(qo_indptr_d.data_ptr()); params.o = static_cast(o_d.data_ptr()); params.lse = maybe_lse_d ? static_cast(maybe_lse_d->data_ptr()) : nullptr; - params.num_qo_heads = num_qo_heads; - params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); + params.num_qo_heads_p = num_qo_heads_p; + params.group_size = uint_fastdiv(num_qo_heads_p / paged_kv.num_heads); params.q_stride_n = q_stride_n_d; params.q_stride_h = q_stride_h_d; params.window_left = window_left_d; diff --git a/flashinfer/pod.py b/flashinfer/pod.py index a395d1970..adb4ce2e5 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -65,7 +65,7 @@ class PODWithPagedKVCacheWrapper: >>> page_size = 16 >>> # allocate 128MB workspace buffer >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") - >>> decode_wrapper = flashinfer.PODWithPagedKVCacheWrapper( + >>> wrapper = flashinfer.PODWithPagedKVCacheWrapper( ... workspace_buffer, "NHD" ... ) >>> batch_size = 7 @@ -83,7 +83,7 @@ class PODWithPagedKVCacheWrapper: ... ) for _ in range(num_layers) ... ] >>> # create auxiliary data structures for batch decode attention - >>> decode_wrapper.plan( + >>> wrapper.plan( ... kv_page_indptr, ... kv_page_indices, ... kv_last_page_len, @@ -510,7 +510,6 @@ def run( custom_mask_p: Optional[torch.Tensor] = None, packed_custom_mask_p: Optional[torch.Tensor] = None, causal_p: bool = False, - kv_layout_p: str = "NHD", pos_encoding_mode_p: str = "NONE", sm_scale_p: Optional[float] = None, window_left_p: int = -1, @@ -521,7 +520,6 @@ def run( custom_mask_d: Optional[torch.Tensor] = None, packed_custom_mask_d: Optional[torch.Tensor] = None, causal_d: bool = False, - kv_layout_d: str = "NHD", pos_encoding_mode_d: str = "NONE", sm_scale_d: Optional[float] = None, window_left_d: int = -1, @@ -544,7 +542,6 @@ def run( logits_soft_cap_d = None # Prefill setup _check_pos_encoding_mode(pos_encoding_mode_p) - _check_kv_layout(kv_layout_p) if logits_soft_cap_p is None: logits_soft_cap_p = 0.0 if sm_scale_p is None: @@ -647,13 +644,13 @@ def run( self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, out, + TensorLayout[self._kv_layout].value, # Prefill params q_p, k_p, v_p, lse_p, mask_mode_p, - TensorLayout[kv_layout_p].value, window_left_p, packed_custom_mask_p, _get_cache_alibi_slopes_buf(q_p.shape[1], q_p.device), @@ -667,7 +664,6 @@ def run( v_cache_d, lse_d, MaskMode.NON_CAUSAL.value, - TensorLayout[self._kv_layout].value, window_left_d, None, # packed_custom_mask None, # mask_indptr_buf From dd80a062dfa2365acaa2b48416439691083ebb84 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Wed, 9 Jul 2025 23:32:59 +0000 Subject: [PATCH 25/33] complete PODWithKVCacheTensorRun params --- csrc/pod.cu | 84 +++++++++++++++------------- include/flashinfer/attention/pod.cuh | 6 +- 2 files changed, 49 insertions(+), 41 deletions(-) diff --git a/csrc/pod.cu b/csrc/pod.cu index 5163114ef..6c1ad8905 100644 --- a/csrc/pod.cu +++ b/csrc/pod.cu @@ -133,9 +133,9 @@ void PODWithKVCacheTensorRun( // Decode setup (Tensor decode = batched prefill) QKVLayout kv_layout = static_cast(layout); - uint32_t num_qo_heads_d = q_d.size(1); + uint32_t num_qo_heads = q_d.size(1); - TORCH_CHECK(num_qo_heads_p == num_qo_heads_d, + TORCH_CHECK(num_qo_heads_p == num_qo_heads, "POD currently requires same # Query heads for prefill and decode"); uint32_t num_kv_heads, page_size; @@ -190,24 +190,33 @@ void PODWithKVCacheTensorRun( // Make params a reference to prefill_params to set values PrefillParams& params = prefill_params; params.q = static_cast(q_p.data_ptr()); - params.k = static_cast(paged_k_p.data_ptr()); - params.v = static_cast(paged_v_p.data_ptr()); + params.paged_kv = paged_kv; + params.q_indptr = static_cast(qo_indptr_p.data_ptr()); params.o = static_cast(o_p.data_ptr()); params.lse = maybe_lse_p ? static_cast(maybe_lse_p->data_ptr()) : nullptr; - params.num_qo_heads_p = num_qo_heads_p; - params.num_kv_heads = num_kv_heads; - params.group_size = uint_fastdiv(num_qo_heads_p / num_kv_heads); - params.qo_len = qo_len_p; - params.kv_len = kv_len_p; + params.num_qo_heads = num_qo_heads_p; + params.group_size = uint_fastdiv(num_qo_heads_p / paged_kv.num_heads); params.q_stride_n = q_stride_n_p; params.q_stride_h = q_stride_h_p; - params.k_stride_n = k_stride_n_p; - params.k_stride_h = k_stride_h_p; - params.v_stride_n = v_stride_n_p; - params.v_stride_h = v_stride_h_p; - params.window_left = window_left_p; - params.partition_kv = false; + + params.request_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.request_indices_offset); + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.kv_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_tile_indices_offset); + params.o_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.o_indptr_offset); + if (plan_info.split_kv) { + params.merge_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_indptr_offset); + if (plan_info.enable_cuda_graph) { + params.block_valid_mask = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.block_valid_mask_offset); + } + } + params.kv_chunk_size_ptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_chunk_size_ptr_offset_p); params.padded_batch_size = plan_info.padded_batch_size_p; params.maybe_custom_mask = maybe_custom_mask_p ? static_cast(maybe_custom_mask_p->data_ptr()) @@ -219,6 +228,11 @@ void PODWithKVCacheTensorRun( params.sm_scale = sm_scale_p; params.rope_rcp_scale = rope_rcp_scale_p; params.rope_rcp_theta = rope_rcp_theta_p; + params.max_total_num_rows = plan_info.total_num_rows; + if (plan_info.enable_cuda_graph) { + params.total_num_rows = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.total_num_rows_offset); + } } DecodeParams decode_params; @@ -232,28 +246,22 @@ void PODWithKVCacheTensorRun( params.o = static_cast(o_d.data_ptr()); params.lse = maybe_lse_d ? static_cast(maybe_lse_d->data_ptr()) : nullptr; - params.num_qo_heads_p = num_qo_heads_p; - params.group_size = uint_fastdiv(num_qo_heads_p / paged_kv.num_heads); + params.num_qo_heads = num_qo_heads; + params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); params.q_stride_n = q_stride_n_d; params.q_stride_h = q_stride_h_d; params.window_left = window_left_d; - params.request_indices = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.request_indices_offset); - params.qo_tile_indices = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); - params.kv_tile_indices = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_tile_indices_offset); - params.o_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.o_indptr_offset); - params.kv_chunk_size_ptr = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_chunk_size_ptr_offset); + params.request_indices = prefill_params.request_indices; + params.qo_tile_indices = prefill_params.qo_tile_indices; + params.kv_tile_indices = prefill_params.kv_tile_indices; + params.o_indptr = prefill_params.o_indptr; + params.kv_chunk_size_ptr = prefill_params.kv_chunk_size_ptr; if (plan_info.split_kv) { - params.merge_indptr = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_indptr_offset); - tmp_v = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.v_offset); - tmp_s = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.s_offset); + params.merge_indptr = prefill_params.merge_indptr; + tmp_v = prefill_params.v; + tmp_s = prefill_params.s; if (plan_info.enable_cuda_graph) { - params.block_valid_mask = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.block_valid_mask_offset); + params.block_valid_mask = prefill_params.block_valid_mask; } } params.padded_batch_size = plan_info.padded_batch_size_d; @@ -272,8 +280,7 @@ void PODWithKVCacheTensorRun( params.rope_rcp_theta = rope_rcp_theta_d; if (plan_info.enable_cuda_graph) { - params.total_num_rows = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.total_num_rows_offset); + params.total_num_rows = prefill_params.total_num_rows; } } @@ -286,12 +293,13 @@ void PODWithKVCacheTensorRun( DefaultAttention; // DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { - constexpr size_t CTA_TILE_Q = 16; + constexpr size_t CTA_TILE_Q_P = plan_info.cta_tile_q_p; + constexpr size_t CTA_TILE_Q_D = plan_info.cta_tile_q_d; cudaError_t status = flashinfer::PODWithKVCacheTensorDispatched< HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, MASK_MODE_P, - CTA_TILE_Q, MASK_MODE_D, PrefillAttentionVariant, DecodeAttentionVariant>( - prefill_params, static_cast(tmp_p.data_ptr()), decode_params, tmp_v, tmp_s, - enable_pdl, stream); + CTA_TILE_Q_P, CTA_TILE_Q_D, MASK_MODE_D, PrefillAttentionVariant, + DecodeAttentionVariant>(prefill_params, static_cast(tmp_p.data_ptr()), + decode_params, tmp_v, tmp_s, enable_pdl, stream); TORCH_CHECK(status == cudaSuccess, "PODWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); //}); diff --git a/include/flashinfer/attention/pod.cuh b/include/flashinfer/attention/pod.cuh index 9826b7bfc..ae6454d35 100644 --- a/include/flashinfer/attention/pod.cuh +++ b/include/flashinfer/attention/pod.cuh @@ -170,9 +170,9 @@ __global__ __launch_bounds__(std::max( } template + bool USE_FP16_QK_REDUCTION, MaskMode MASK_MODE_P, uint32_t CTA_TILE_Q_P, + uint32_t CTA_TILE_Q_D, MaskMode MASK_MODE_D, typename PrefillAttentionVariant, + typename DecodeAttentionVariant, typename PrefillParams, typename DecodeParams> cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, typename PrefillParams::DTypeO* tmp_p, DecodeParams decode_params, From 0bb164be63f51a9617f42acb39d92bff35283c55 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 10 Jul 2025 05:09:30 +0000 Subject: [PATCH 26/33] share lse --- csrc/pod.cu | 41 +++++++++++----------- flashinfer/pod.py | 22 +++++------- include/flashinfer/attention/pod.cuh | 11 +++--- include/flashinfer/attention/scheduler.cuh | 10 +++--- 4 files changed, 38 insertions(+), 46 deletions(-) diff --git a/csrc/pod.cu b/csrc/pod.cu index 6c1ad8905..3842bbaac 100644 --- a/csrc/pod.cu +++ b/csrc/pod.cu @@ -72,7 +72,8 @@ void PODWithKVCacheTensorRun( // Shared params at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, at::Tensor plan_info_vec, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, - at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, int64_t layout, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, + std::optional maybe_lse, int64_t layout, // Prefill params at::Tensor q_p, at::Tensor paged_k_p, at::Tensor paged_v_p, std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t window_left_p, @@ -99,9 +100,9 @@ void PODWithKVCacheTensorRun( // Prefill setup uint32_t head_dim_qk = q_p.size(2); - uint32_t kv_len_p, qo_len_p, num_kv_heads, num_qo_heads_p; + uint32_t kv_len_p, qo_len, num_kv_heads, num_qo_heads_p; QKVLayout kv_layout_p = static_cast(layout_p); - qo_len_p = q_p.size(0); + qo_len = q_p.size(0) + q_d.size(0); num_qo_heads_p = q_p.size(1); uint32_t q_stride_n_p = q_p.stride(0), q_stride_h_p = q_p.stride(1), k_stride_n_p, k_stride_h_p, v_stride_n_p, v_stride_h_p; @@ -120,9 +121,9 @@ void PODWithKVCacheTensorRun( v_stride_h_p = paged_v_p.stride(0); v_stride_n_p = paged_v_p.stride(1); } - if (maybe_lse_p) { - const auto& lse = *maybe_lse_p; - TORCH_CHECK(lse.size(0) == qo_len_p, lse.size(0), q_p.size(0)); + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == qo_len, lse.size(0), qo_len); TORCH_CHECK(lse.size(1) == num_qo_heads_p, lse.size(1), q_p.size(1)); } @@ -151,15 +152,7 @@ void PODWithKVCacheTensorRun( "POD currently requires same # KV heads for prefill and decode; Prefill: ", num_kv_heads, ", Decode: ", num_kv_heads_d); - if (maybe_lse_d) { - const auto& lse = *maybe_lse_d; - TORCH_CHECK(lse.size(0) == q_d.size(0), lse.size(0), q_d.size(0)); - TORCH_CHECK(lse.size(1) == q_d.size(1), lse.size(1), q_d.size(1)); - } - const MaskMode mask_mode_d = static_cast(mask_mode_code_d); - auto q_scalar_type_d = q_d.scalar_type(); - auto kv_scalar_type_d = paged_k_cache_d.scalar_type(); // get q_stride_n and q_stride_h const auto q_stride_n_d = q_d.stride(0); @@ -193,7 +186,7 @@ void PODWithKVCacheTensorRun( params.paged_kv = paged_kv; params.q_indptr = static_cast(qo_indptr_p.data_ptr()); params.o = static_cast(o_p.data_ptr()); - params.lse = maybe_lse_p ? static_cast(maybe_lse_p->data_ptr()) : nullptr; + params.lse = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; params.num_qo_heads = num_qo_heads_p; params.group_size = uint_fastdiv(num_qo_heads_p / paged_kv.num_heads); params.q_stride_n = q_stride_n_p; @@ -233,6 +226,12 @@ void PODWithKVCacheTensorRun( params.total_num_rows = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.total_num_rows_offset); } + if (plan_info.split_kv) { + if (plan_info.enable_cuda_graph) { + params.block_valid_mask = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.block_valid_mask_offset); + } + } } DecodeParams decode_params; @@ -245,7 +244,7 @@ void PODWithKVCacheTensorRun( params.q_indptr = static_cast(qo_indptr_d.data_ptr()); params.o = static_cast(o_d.data_ptr()); - params.lse = maybe_lse_d ? static_cast(maybe_lse_d->data_ptr()) : nullptr; + params.lse = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; params.num_qo_heads = num_qo_heads; params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); params.q_stride_n = q_stride_n_d; @@ -256,10 +255,12 @@ void PODWithKVCacheTensorRun( params.kv_tile_indices = prefill_params.kv_tile_indices; params.o_indptr = prefill_params.o_indptr; params.kv_chunk_size_ptr = prefill_params.kv_chunk_size_ptr; + if (plan_info.split_kv) { params.merge_indptr = prefill_params.merge_indptr; - tmp_v = prefill_params.v; - tmp_s = prefill_params.s; + // These should be assigned from plan info, not from prefill_params + tmp_v = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.v_offset); + tmp_s = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.s_offset); if (plan_info.enable_cuda_graph) { params.block_valid_mask = prefill_params.block_valid_mask; } @@ -298,8 +299,8 @@ void PODWithKVCacheTensorRun( cudaError_t status = flashinfer::PODWithKVCacheTensorDispatched< HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, MASK_MODE_P, CTA_TILE_Q_P, CTA_TILE_Q_D, MASK_MODE_D, PrefillAttentionVariant, - DecodeAttentionVariant>(prefill_params, static_cast(tmp_p.data_ptr()), - decode_params, tmp_v, tmp_s, enable_pdl, stream); + DecodeAttentionVariant>(prefill_params, decode_params, tmp_v, tmp_s, enable_pdl, + stream); TORCH_CHECK(status == cudaSuccess, "PODWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); //}); diff --git a/flashinfer/pod.py b/flashinfer/pod.py index adb4ce2e5..a751807c3 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -515,7 +515,6 @@ def run( window_left_p: int = -1, rope_scale_p: Optional[float] = None, rope_theta_p: Optional[float] = None, - return_lse_p: bool = False, # Decode options custom_mask_d: Optional[torch.Tensor] = None, packed_custom_mask_d: Optional[torch.Tensor] = None, @@ -528,7 +527,7 @@ def run( q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, - return_lse_d: bool = False, + return_lse: bool = False, use_fp16_qk_reduction: bool = False, enable_pdl: Optional[bool] = None, *args, @@ -564,10 +563,12 @@ def run( else: mask_mode_p = MaskMode.NON_CAUSAL.value - lse_p = None - if return_lse_p: - lse_p = torch.empty( - (q_p.size(0), q_p.size(1)), dtype=torch.float32, device=q_p.device + lse = None + if return_lse: + lse = torch.empty( + (q_p.size(0) + q_d.size(0), q_p.size(1)), + dtype=torch.float32, + device=q_p.device, ) qo_len_p, num_qo_heads, head_dim = q_p.shape qo_len_d, _, _ = q_d.shape @@ -607,12 +608,6 @@ def run( if rope_theta_d is None: rope_theta_d = 1e4 - lse_d = None - if return_lse_d: - lse_d = torch.empty( - (q_d.size(0), q_d.size(1)), dtype=torch.float32, device=q_d.device - ) - module_getter = get_pod_module( # Prefill params q_p.dtype, @@ -644,12 +639,12 @@ def run( self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, out, + lse, TensorLayout[self._kv_layout].value, # Prefill params q_p, k_p, v_p, - lse_p, mask_mode_p, window_left_p, packed_custom_mask_p, @@ -662,7 +657,6 @@ def run( q_d, k_cache_d, v_cache_d, - lse_d, MaskMode.NON_CAUSAL.value, window_left_d, None, # packed_custom_mask diff --git a/include/flashinfer/attention/pod.cuh b/include/flashinfer/attention/pod.cuh index ae6454d35..58ba4a706 100644 --- a/include/flashinfer/attention/pod.cuh +++ b/include/flashinfer/attention/pod.cuh @@ -173,9 +173,7 @@ template -cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, - typename PrefillParams::DTypeO* tmp_p, - DecodeParams decode_params, +cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, DecodeParams decode_params, typename DecodeParams::DTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream) { static_assert(std::is_same::value); @@ -349,8 +347,7 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, // Setup new prefill params if (not) split auto o_p = prefill_params.o; auto lse_p = prefill_params.lse; - float* tmp_lse = (float*)(tmp_p + num_chunks * qo_len * num_qo_heads * HEAD_DIM_VO); - if (num_chunks <= 1 || tmp_p == nullptr) { + if (num_chunks <= 1 || tmp_v == nullptr) { // Enough parallelism, do not split-kv prefill_params.partition_kv = 0; kernel = PODWithKVCacheTensorKernel; } diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 02bb82025..809795095 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -1046,11 +1046,11 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by plan_info.padded_batch_size_p = padded_batch_size_p; plan_info.padded_batch_size_d = padded_batch_size_d; + plan_info.total_num_rows_p = total_num_rows_p; + plan_info.total_num_rows_d = total_num_rows_d; plan_info.total_num_rows = total_num_rows; plan_info.cta_tile_q_p = cta_tile_q_p; plan_info.cta_tile_q_d = cta_tile_q_d; - plan_info.total_num_rows_p = total_num_rows_p; - plan_info.total_num_rows_d = total_num_rows_d; plan_info.enable_cuda_graph = enable_cuda_graph; plan_info.split_kv = split_kv; @@ -1097,11 +1097,11 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by if (split_kv) { // TODO(Wenxuan): write through for non-split-kv requests - uint32_t num_outputs_p = num_qo_heads * padded_batch_size_p * cta_tile_q_p * head_dim_vo; - uint32_t num_outputs_d = num_qo_heads * padded_batch_size_d * cta_tile_q_d * head_dim_vo; + uint32_t num_outputs_p = num_qo_heads * padded_batch_size_p * cta_tile_q_p; + uint32_t num_outputs_d = num_qo_heads * padded_batch_size_d * cta_tile_q_d; AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); plan_info.v_offset = float_allocator.aligned_alloc_offset( - (num_outputs_p + num_outputs_d) * sizeof(float), 16, "pod_tmp_v"); + (num_outputs_p + num_outputs_d) * head_dim_vo * sizeof(float), 16, "pod_tmp_v"); plan_info.s_offset = float_allocator.aligned_alloc_offset( (num_outputs_p + num_outputs_d) * sizeof(float), 16, "pod_tmp_s"); plan_info.merge_indptr_offset = int_allocator.aligned_alloc_offset( From 32d762b669bab2f223baee184752819d53e2b1b1 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 10 Jul 2025 05:16:17 +0000 Subject: [PATCH 27/33] templaterize CTA_TILE_Q_P --- csrc/pod.cu | 6 +- include/flashinfer/attention/pod.cuh | 366 +++++++++++++-------------- 2 files changed, 183 insertions(+), 189 deletions(-) diff --git a/csrc/pod.cu b/csrc/pod.cu index 3842bbaac..ac33efe66 100644 --- a/csrc/pod.cu +++ b/csrc/pod.cu @@ -23,9 +23,9 @@ namespace flashinfer { template + bool USE_FP16_QK_REDUCTION, MaskMode MASK_MODE_P, uint32_t CTA_TILE_Q_P, + uint32_t CTA_TILE_Q_D, MaskMode MASK_MODE_D, typename PrefillAttentionVariant, + typename DecodeAttentionVariant, typename PrefillParams, typename DecodeParams> cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, typename PrefillParams::DTypeO* tmp, DecodeParams decode_params, diff --git a/include/flashinfer/attention/pod.cuh b/include/flashinfer/attention/pod.cuh index 58ba4a706..21b014d6b 100644 --- a/include/flashinfer/attention/pod.cuh +++ b/include/flashinfer/attention/pod.cuh @@ -204,10 +204,7 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, DecodeP constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; - uint32_t cta_tile_q_p = 0; - int64_t unpacked_qo_len = - qo_len * group_size; // TODO(@Wenxuan): Include batch size in calculation - cta_tile_q_p = FA2DetermineCtaTileQ(unpacked_qo_len, HEAD_DIM_VO); + int64_t unpacked_qo_len = qo_len * group_size; // Decode vars setup using DTypeQ_D = typename DecodeParams::DTypeQ; @@ -251,197 +248,194 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, DecodeP NUM_MMA_Q_D * NUM_WARPS_Q_D) / (2 * NUM_WARPS_KV_D); - DISPATCH_CTA_TILE_Q(cta_tile_q_p, CTA_TILE_Q_P, { - constexpr uint32_t NUM_WARPS_Q_P = get_num_warps_q(CTA_TILE_Q_P); - constexpr uint32_t NUM_WARPS_KV_P = get_num_warps_kv(CTA_TILE_Q_P); - constexpr uint32_t NUM_MMA_Q_P = get_num_mma_q(CTA_TILE_Q_P); - - using DTypeQKAccum_P = - typename std::conditional, half, - float>::type; - - // we expect each sm execute two threadblocks - // TODO(Zihao): fix the following computation - const int num_ctas_per_sm_p = - max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ_P) * 16) ? 2 : 1; - const int max_smem_per_threadblock_p = max_smem_per_sm / num_ctas_per_sm_p; - - constexpr uint32_t max_num_mma_kv_reg_p = - (HEAD_DIM_VO >= 128 && NUM_MMA_Q_P == 2 && - POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !USE_FP16_QK_REDUCTION) - ? 2 - : (8 / NUM_MMA_Q_P); - // TODO(Zihao): fix the following computation - const uint32_t max_num_mma_kv_smem_p = - (max_smem_per_threadblock_p / (16 * HEAD_DIM_QK * sizeof(DTypeQ_P)) - - NUM_MMA_Q_P * NUM_WARPS_Q_P) / - (2 * NUM_WARPS_KV_P); - - // control NUM_MMA_KV for maximum warp occupancy - DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_p, max_num_mma_kv_reg_p), NUM_MMA_KV_P, { - using KTraits_P = - KernelTraits; - - if constexpr (KTraits_P::IsInvalid()) { - // Invalid configuration, skip - std::ostringstream err_msg; - err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_P - << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO - << " NUM_MMA_KV=" << NUM_MMA_KV_P << " NUM_WARPS_Q=" << NUM_WARPS_Q_P - << " NUM_WARPS_KV=" << NUM_WARPS_KV_P - << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" - " and report the issue to the developers."; - FLASHINFER_ERROR(err_msg.str()); - } else { - // Decode stuff - // TODO: Is there a way to avoid this nested dispatch? - DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_d, max_num_mma_kv_reg_d), NUM_MMA_KV_D, { - using KTraits_D = - KernelTraits; - if constexpr (KTraits_D::IsInvalid()) { - // Invalid configuration, skip - std::ostringstream err_msg; - err_msg - << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_D - << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO - << " NUM_MMA_KV=" << NUM_MMA_KV_D << " NUM_WARPS_Q=" << NUM_WARPS_Q_D - << " NUM_WARPS_KV=" << NUM_WARPS_KV_D - << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" - " and report the issue to the developers."; - FLASHINFER_ERROR(err_msg.str()); + constexpr uint32_t NUM_WARPS_Q_P = get_num_warps_q(CTA_TILE_Q_P); + constexpr uint32_t NUM_WARPS_KV_P = get_num_warps_kv(CTA_TILE_Q_P); + constexpr uint32_t NUM_MMA_Q_P = get_num_mma_q(CTA_TILE_Q_P); + + using DTypeQKAccum_P = + typename std::conditional, half, + float>::type; + + // we expect each sm execute two threadblocks + // TODO(Zihao): fix the following computation + const int num_ctas_per_sm_p = + max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ_P) * 16) ? 2 : 1; + const int max_smem_per_threadblock_p = max_smem_per_sm / num_ctas_per_sm_p; + + constexpr uint32_t max_num_mma_kv_reg_p = + (HEAD_DIM_VO >= 128 && NUM_MMA_Q_P == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !USE_FP16_QK_REDUCTION) + ? 2 + : (8 / NUM_MMA_Q_P); + // TODO(Zihao): fix the following computation + const uint32_t max_num_mma_kv_smem_p = + (max_smem_per_threadblock_p / (16 * HEAD_DIM_QK * sizeof(DTypeQ_P)) - + NUM_MMA_Q_P * NUM_WARPS_Q_P) / + (2 * NUM_WARPS_KV_P); + + // control NUM_MMA_KV for maximum warp occupancy + DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_p, max_num_mma_kv_reg_p), NUM_MMA_KV_P, { + using KTraits_P = KernelTraits; + + if constexpr (KTraits_P::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_P + << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV_P << " NUM_WARPS_Q=" << NUM_WARPS_Q_P + << " NUM_WARPS_KV=" << NUM_WARPS_KV_P + << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } else { + // Decode stuff + // TODO: Is there a way to avoid this nested dispatch? + DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_d, max_num_mma_kv_reg_d), NUM_MMA_KV_D, { + using KTraits_D = + KernelTraits; + if constexpr (KTraits_D::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_D + << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV_D << " NUM_WARPS_Q=" << NUM_WARPS_Q_D + << " NUM_WARPS_KV=" << NUM_WARPS_KV_D + << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } else { + // End decode stuff + constexpr uint32_t num_threads_p = (NUM_WARPS_Q_P * NUM_WARPS_KV_P) * WARP_SIZE; + size_t smem_size_p = sizeof(typename KTraits_P::SharedStorage); + size_t smem_size_d = sizeof(typename KTraits_D::SharedStorage); + + auto kernel = + PODWithKVCacheTensorKernel; + // Prefill: decide num_splits for split-kv + int num_blocks_per_sm = 0; + int num_sm = 0; + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + // FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &num_blocks_per_sm, kernel, num_threads_p, smem_size_p)); + // Above function returns 0 for some reason, so we use a workaround + num_blocks_per_sm = std::max( + 1, std::min((int)(max_smem_per_sm / smem_size_p), (int)(256 / num_threads_p))); + uint32_t max_num_kv_chunks = + (num_blocks_per_sm * num_sm) / + (num_kv_heads * ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q)); + uint32_t num_chunks; + if (max_num_kv_chunks > 0) { + uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256); + num_chunks = ceil_div(kv_len, chunk_size); } else { - // End decode stuff - constexpr uint32_t num_threads_p = (NUM_WARPS_Q_P * NUM_WARPS_KV_P) * WARP_SIZE; - size_t smem_size_p = sizeof(typename KTraits_P::SharedStorage); - size_t smem_size_d = sizeof(typename KTraits_D::SharedStorage); + num_chunks = 0; + } - auto kernel = + // Setup new prefill params if (not) split + auto o_p = prefill_params.o; + auto lse_p = prefill_params.lse; + if (num_chunks <= 1 || tmp_v == nullptr) { + // Enough parallelism, do not split-kv + prefill_params.partition_kv = 0; + kernel = PODWithKVCacheTensorKernel; + } else { + // Use cooperative groups to increase occupancy + prefill_params.partition_kv = num_chunks; + prefill_params.o = tmp_v; + prefill_params.lse = tmp_s; + kernel = PODWithKVCacheTensorKernel; - // Prefill: decide num_splits for split-kv - int num_blocks_per_sm = 0; - int num_sm = 0; - FLASHINFER_CUDA_CALL( - cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - // FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &num_blocks_per_sm, kernel, num_threads_p, smem_size_p)); - // Above function returns 0 for some reason, so we use a workaround - num_blocks_per_sm = std::max( - 1, std::min((int)(max_smem_per_sm / smem_size_p), (int)(256 / num_threads_p))); - uint32_t max_num_kv_chunks = - (num_blocks_per_sm * num_sm) / - (num_kv_heads * ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q)); - uint32_t num_chunks; - if (max_num_kv_chunks > 0) { - uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256); - num_chunks = ceil_div(kv_len, chunk_size); - } else { - num_chunks = 0; - } - - // Setup new prefill params if (not) split - auto o_p = prefill_params.o; - auto lse_p = prefill_params.lse; - if (num_chunks <= 1 || tmp_v == nullptr) { - // Enough parallelism, do not split-kv - prefill_params.partition_kv = 0; - kernel = PODWithKVCacheTensorKernel; - } else { - // Use cooperative groups to increase occupancy - prefill_params.partition_kv = num_chunks; - prefill_params.o = tmp_v; - prefill_params.lse = tmp_s; - kernel = PODWithKVCacheTensorKernel; - } + } - // Setup new decode params if (not) split - auto o_d = decode_params.o; - auto lse_d = decode_params.lse; - if (tmp_v == nullptr) { - // do not partition kv - decode_params.partition_kv = false; - } else { - decode_params.partition_kv = true; - decode_params.o = tmp_v; - decode_params.lse = tmp_s; - } - // uint32_t num_qo_tiles = ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q); - uint32_t padded_batch_size_p = prefill_params.padded_batch_size; - uint32_t padded_batch_size_d = decode_params.padded_batch_size; - int nblks_p(padded_batch_size_p * num_kv_heads); - int nthrs_p(32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P); - - int nblks_d(padded_batch_size_d * num_kv_heads); - int nthrs_d(32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D); - - // ******* Select final combined sizes here ******* / - size_t smem_size = max(smem_size_p, smem_size_d); - int nblks = nblks_p + nblks_d; - int nthrs = max(nthrs_p, nthrs_d); - - // printf("Smem: prefill %zu, decode %zu, total %zu\n", smem_size_p, smem_size_d, - // smem_size); printf("Blocks: prefill %d, decode %d, total %d\n", nblks_p, nblks_d, - // nblks); printf("Threads: prefill %d, decode %d, total %d\n", nthrs_p, nthrs_d, - // nthrs); - // ************************************************ / - - static int* tbAssign = nullptr; - if (tbAssign == nullptr) cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2)); - cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2)); - - // Setup kernel arguments - void* args[] = {(void*)&num_qo_tiles, (void*)&prefill_params, (void*)&decode_params, - (void*)&tbAssign}; - FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - // Launch kernel - if (enable_pdl) { - cudaLaunchAttribute attribute[1]; - cudaLaunchConfig_t config; - attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attribute[0].val.programmaticStreamSerializationAllowed = 1; - config.attrs = attribute; - config.numAttrs = 1; - config.gridDim = nblks; - config.blockDim = nthrs; - config.dynamicSmemBytes = smem_size; - config.stream = stream; - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, num_qo_tiles, prefill_params, - decode_params, tbAssign)); - } else { - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - } + // Setup new decode params if (not) split + auto o_d = decode_params.o; + auto lse_d = decode_params.lse; + if (tmp_v == nullptr) { + // do not partition kv + decode_params.partition_kv = false; + } else { + decode_params.partition_kv = true; + decode_params.o = tmp_v; + decode_params.lse = tmp_s; + } + // uint32_t num_qo_tiles = ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q); + uint32_t padded_batch_size_p = prefill_params.padded_batch_size; + uint32_t padded_batch_size_d = decode_params.padded_batch_size; + int nblks_p(padded_batch_size_p * num_kv_heads); + int nthrs_p(32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P); + + int nblks_d(padded_batch_size_d * num_kv_heads); + int nthrs_d(32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D); + + // ******* Select final combined sizes here ******* / + size_t smem_size = max(smem_size_p, smem_size_d); + int nblks = nblks_p + nblks_d; + int nthrs = max(nthrs_p, nthrs_d); + + // printf("Smem: prefill %zu, decode %zu, total %zu\n", smem_size_p, smem_size_d, + // smem_size); printf("Blocks: prefill %d, decode %d, total %d\n", nblks_p, nblks_d, + // nblks); printf("Threads: prefill %d, decode %d, total %d\n", nthrs_p, nthrs_d, + // nthrs); + // ************************************************ / + + static int* tbAssign = nullptr; + if (tbAssign == nullptr) cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2)); + cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2)); + + // Setup kernel arguments + void* args[] = {(void*)&num_qo_tiles, (void*)&prefill_params, (void*)&decode_params, + (void*)&tbAssign}; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + // Launch kernel + if (enable_pdl) { + cudaLaunchAttribute attribute[1]; + cudaLaunchConfig_t config; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = 1; + config.attrs = attribute; + config.numAttrs = 1; + config.gridDim = nblks; + config.blockDim = nthrs; + config.dynamicSmemBytes = smem_size; + config.stream = stream; + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, num_qo_tiles, prefill_params, + decode_params, tbAssign)); + } else { FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - - // Post-kernel stuff for split-kv - if (tmp_v != nullptr) { - if constexpr (DecodeAttentionVariant::use_softmax) { - FLASHINFER_CUDA_CALL(VariableLengthMergeStates( - tmp_v, tmp_s, decode_params.merge_indptr, o, lse, - decode_params.max_total_num_rows, decode_params.total_num_rows, num_qo_heads, - HEAD_DIM_VO, stream)); - } else { - FLASHINFER_CUDA_CALL(VariableLengthAttentionSum( - tmp_v, decode_params.merge_indptr, o, decode_params.max_total_num_rows, - decode_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream)); - } + } + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + + // Post-kernel stuff for split-kv + if (tmp_v != nullptr) { + if constexpr (DecodeAttentionVariant::use_softmax) { + FLASHINFER_CUDA_CALL(VariableLengthMergeStates( + tmp_v, tmp_s, decode_params.merge_indptr, o, lse, + decode_params.max_total_num_rows, decode_params.total_num_rows, num_qo_heads, + HEAD_DIM_VO, stream)); + } else { + FLASHINFER_CUDA_CALL(VariableLengthAttentionSum( + tmp_v, decode_params.merge_indptr, o, decode_params.max_total_num_rows, + decode_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream)); } } - }); - } - }); + } + }); + } }); - return cudaSuccess; +}); +return cudaSuccess; } } // namespace flashinfer From 870b0b237ed1eb139d3599e714061fc40166f7d3 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 11 Jul 2025 01:44:10 +0000 Subject: [PATCH 28/33] update dispatch logic --- csrc/pod.cu | 87 +++++++----------- flashinfer/pod.py | 4 +- include/flashinfer/attention/pod.cuh | 130 ++++++++------------------- 3 files changed, 72 insertions(+), 149 deletions(-) diff --git a/csrc/pod.cu b/csrc/pod.cu index ac33efe66..b655647fa 100644 --- a/csrc/pod.cu +++ b/csrc/pod.cu @@ -26,9 +26,7 @@ template -cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, - typename PrefillParams::DTypeO* tmp, - DecodeParams decode_params, +cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, DecodeParams decode_params, typename DecodeParams::DTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream); @@ -71,17 +69,16 @@ at::Tensor PODWithKVCachePlan(at::Tensor float_workspace_buffer, at::Tensor int_ void PODWithKVCacheTensorRun( // Shared params at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, - at::Tensor plan_info_vec, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, - at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, - std::optional maybe_lse, int64_t layout, + at::Tensor plan_info_vec, at::Tensor paged_k_cache, at::Tensor paged_v_cache, + at::Tensor qo_indptr, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, + at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional maybe_lse, + int64_t layout, // Prefill params - at::Tensor q_p, at::Tensor paged_k_p, at::Tensor paged_v_p, - std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t window_left_p, + at::Tensor q_p, int64_t mask_mode_code_p, int64_t window_left_p, std::optional maybe_custom_mask_p, std::optional maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, // Decode params - at::Tensor q_d, at::Tensor paged_k_cache_d, at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, - std::optional maybe_lse_d, int64_t mask_mode_code_d, int64_t window_left_d, + at::Tensor q_d, int64_t mask_mode_code_d, int64_t window_left_d, std::optional maybe_custom_mask_d, std::optional maybe_mask_indptr_d, std::optional maybe_alibi_slopes_d, double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, bool enable_pdl) { @@ -100,27 +97,11 @@ void PODWithKVCacheTensorRun( // Prefill setup uint32_t head_dim_qk = q_p.size(2); - uint32_t kv_len_p, qo_len, num_kv_heads, num_qo_heads_p; - QKVLayout kv_layout_p = static_cast(layout_p); + uint32_t qo_len, num_qo_heads_p; + QKVLayout kv_layout = static_cast(layout); qo_len = q_p.size(0) + q_d.size(0); num_qo_heads_p = q_p.size(1); - uint32_t q_stride_n_p = q_p.stride(0), q_stride_h_p = q_p.stride(1), k_stride_n_p, k_stride_h_p, - v_stride_n_p, v_stride_h_p; - if (kv_layout_p == QKVLayout::kNHD) { - kv_len_p = paged_k_p.size(0); - num_kv_heads = paged_k_p.size(1); - k_stride_n_p = paged_k_p.stride(0); - k_stride_h_p = paged_k_p.stride(1); - v_stride_n_p = paged_v_p.stride(0); - v_stride_h_p = paged_v_p.stride(1); - } else { - kv_len_p = paged_k_p.size(1); - num_kv_heads = paged_k_p.size(0); - k_stride_h_p = paged_k_p.stride(0); - k_stride_n_p = paged_k_p.stride(1); - v_stride_h_p = paged_v_p.stride(0); - v_stride_n_p = paged_v_p.stride(1); - } + uint32_t q_stride_n_p = q_p.stride(0), q_stride_h_p = q_p.stride(1); if (maybe_lse) { const auto& lse = *maybe_lse; TORCH_CHECK(lse.size(0) == qo_len, lse.size(0), qo_len); @@ -130,23 +111,21 @@ void PODWithKVCacheTensorRun( const MaskMode mask_mode_p = static_cast(mask_mode_code_p); auto q_scalar_type = q_p.scalar_type(); - auto kv_scalar_type = paged_k_p.scalar_type(); // Decode setup (Tensor decode = batched prefill) - QKVLayout kv_layout = static_cast(layout); uint32_t num_qo_heads = q_d.size(1); - TORCH_CHECK(num_qo_heads_p == num_qo_heads, "POD currently requires same # Query heads for prefill and decode"); - uint32_t num_kv_heads, page_size; - uint32_t head_dim_qk = q_d.size(2); + uint32_t num_kv_heads_d, num_kv_heads, page_size; if (kv_layout == QKVLayout::kHND) { num_kv_heads = paged_k_cache.size(1); + num_kv_heads_d = paged_k_cache.size(1); page_size = paged_k_cache.size(2); } else { - page_size = paged_k_cache.size(1); num_kv_heads = paged_k_cache.size(2); + num_kv_heads_d = paged_k_cache.size(2); + page_size = paged_k_cache.size(1); } TORCH_CHECK(num_kv_heads == num_kv_heads_d, "POD currently requires same # KV heads for prefill and decode; Prefill: ", @@ -158,13 +137,6 @@ void PODWithKVCacheTensorRun( const auto q_stride_n_d = q_d.stride(0); const auto q_stride_h_d = q_d.stride(1); - // get kv_cache_strides - const int64_t* kv_cache_strides = nullptr; - auto k_strides = paged_k_cache_d.strides(); - auto v_strides = paged_v_cache_d.strides(); - TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical"); - kv_cache_strides = k_strides.data(); - const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer_d.device()); const cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); @@ -173,25 +145,26 @@ void PODWithKVCacheTensorRun( USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, [&] { paged_kv_t paged_kv( num_kv_heads, page_size, HEAD_DIM_VO, batch_size, kv_layout, - static_cast(paged_k_cache_d.data_ptr()), - static_cast(paged_v_cache_d.data_ptr()), kv_cache_strides, - static_cast(paged_kv_indices_d.data_ptr()), - static_cast(paged_kv_indptr_d.data_ptr()), - static_cast(paged_kv_last_page_len_d.data_ptr())); + static_cast(paged_k_cache.data_ptr()), + static_cast(paged_v_cache.data_ptr()), kv_cache_strides, + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); PrefillParams prefill_params; { // Make params a reference to prefill_params to set values PrefillParams& params = prefill_params; params.q = static_cast(q_p.data_ptr()); params.paged_kv = paged_kv; - params.q_indptr = static_cast(qo_indptr_p.data_ptr()); - params.o = static_cast(o_p.data_ptr()); + params.q_indptr = static_cast(qo_indptr.data_ptr()); + params.o = static_cast(o.data_ptr()); params.lse = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; - params.num_qo_heads = num_qo_heads_p; - params.group_size = uint_fastdiv(num_qo_heads_p / paged_kv.num_heads); + params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); params.q_stride_n = q_stride_n_p; params.q_stride_h = q_stride_h_p; params.window_left = window_left_p; + params.num_kv_heads = num_kv_heads; + params.num_qo_heads = num_qo_heads; params.request_indices = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.request_indices_offset); @@ -226,6 +199,7 @@ void PODWithKVCacheTensorRun( params.total_num_rows = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.total_num_rows_offset); } + params.partition_kv = plan_info.split_kv; if (plan_info.split_kv) { if (plan_info.enable_cuda_graph) { params.block_valid_mask = @@ -241,21 +215,23 @@ void PODWithKVCacheTensorRun( DecodeParams& params = decode_params; params.q = static_cast(q_d.data_ptr()); params.paged_kv = paged_kv; - params.q_indptr = static_cast(qo_indptr_d.data_ptr()); - params.o = static_cast(o_d.data_ptr()); - + params.q_indptr = static_cast(qo_indptr.data_ptr()); + params.o = static_cast(o.data_ptr()); params.lse = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; - params.num_qo_heads = num_qo_heads; params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); params.q_stride_n = q_stride_n_d; params.q_stride_h = q_stride_h_d; params.window_left = window_left_d; + params.num_kv_heads = num_kv_heads; + params.num_qo_heads = num_qo_heads; + params.request_indices = prefill_params.request_indices; params.qo_tile_indices = prefill_params.qo_tile_indices; params.kv_tile_indices = prefill_params.kv_tile_indices; params.o_indptr = prefill_params.o_indptr; params.kv_chunk_size_ptr = prefill_params.kv_chunk_size_ptr; + params.partition_kv = plan_info.split_kv; if (plan_info.split_kv) { params.merge_indptr = prefill_params.merge_indptr; // These should be assigned from plan info, not from prefill_params @@ -268,7 +244,6 @@ void PODWithKVCacheTensorRun( params.padded_batch_size = plan_info.padded_batch_size_d; params.max_total_num_rows = plan_info.total_num_rows; - params.partition_kv = false; params.maybe_mask_indptr = maybe_mask_indptr_d ? static_cast(maybe_mask_indptr_d->data_ptr()) : nullptr; diff --git a/flashinfer/pod.py b/flashinfer/pod.py index a751807c3..bf6d7f6af 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -638,6 +638,8 @@ def run( self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, + k_cache_d, + v_cache_d, out, lse, TensorLayout[self._kv_layout].value, @@ -655,8 +657,6 @@ def run( 1.0 / rope_theta_p, # Decode params q_d, - k_cache_d, - v_cache_d, MaskMode.NON_CAUSAL.value, window_left_d, None, # packed_custom_mask diff --git a/include/flashinfer/attention/pod.cuh b/include/flashinfer/attention/pod.cuh index 21b014d6b..d7cb66a27 100644 --- a/include/flashinfer/attention/pod.cuh +++ b/include/flashinfer/attention/pod.cuh @@ -34,29 +34,25 @@ enum Operation { DECODE = 1, }; -template +template __global__ __launch_bounds__(std::max( KTraits_P::NUM_THREADS, - KTraits_D::NUM_THREADS)) void PODWithKVCacheTensorKernel(const uint32_t num_qo_tiles, - const __grid_constant__ PrefillParams + KTraits_D::NUM_THREADS)) void PODWithKVCacheTensorKernel(const __grid_constant__ PrefillParams prefill_params, const __grid_constant__ DecodeParams decode_params, int* tbAssign) { extern __shared__ uint8_t smem[]; + const uint32_t num_kv_heads = prefill_params.num_kv_heads; // PREFILL VARS - const uint32_t num_kv_heads_p = prefill_params.num_kv_heads; - const uint32_t num_chunks = prefill_params.partition_kv; - const uint32_t qo_len = prefill_params.qo_len; + const uint32_t padded_bsize_p = prefill_params.padded_batch_size; // DECODE VARS - const uint32_t padded_bsize = decode_params.padded_batch_size; - const uint32_t num_kv_heads_d = decode_params.paged_kv.num_heads; + const uint32_t padded_bsize_d = decode_params.padded_batch_size; // THREADBLOCKS - const uint32_t prefill_blocks = num_kv_heads_p * num_qo_tiles * (PartitionKV_P ? num_chunks : 1); - const uint32_t decode_blocks = padded_bsize * num_kv_heads_d; + const uint32_t prefill_blocks = padded_bsize_p * num_kv_heads; + const uint32_t decode_blocks = padded_bsize_d * num_kv_heads; int op; int linear_bid; @@ -110,7 +106,7 @@ __global__ __launch_bounds__(std::max( op = !op; linear_bid = atomicAdd(&tbAssign[num_SMs + 0], 1); } - // Write the blockId and operation to shared memory + // Write the global blockId and operation to shared memory ((int*)smem)[0] = linear_bid; ((int*)smem)[1] = op; } @@ -126,46 +122,35 @@ __global__ __launch_bounds__(std::max( const uint32_t linear_tid = threadIdx.x; // Return if threadId exceeds number of threads for this op if (linear_tid >= 32 * KTraits_P::NUM_WARPS_Q * KTraits_P::NUM_WARPS_KV) return; + if (linear_bid >= prefill_blocks) return; const dim3 tid = dim3(linear_tid % 32, (linear_tid / 32) % KTraits_P::NUM_WARPS_Q, (linear_tid / 32) / KTraits_P::NUM_WARPS_Q); - // dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), 1, num_kv_heads); - // dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), num_chunks, num_kv_heads); - // BlockID exceeds limit - if (linear_bid >= prefill_blocks) return; - - const uint32_t bx = linear_bid % num_qo_tiles; auto& smem_storage = reinterpret_cast(smem); - // Not partition_kv - if constexpr (!PartitionKV_P) { - const uint32_t chunk_idx = 0; - const uint32_t kv_head_idx = linear_bid / num_qo_tiles; - SinglePrefillWithKVCacheDevice(prefill_params, smem_storage, tid, bx, chunk_idx, - kv_head_idx, 1, num_kv_heads_p); - } else { - const uint32_t chunk_idx = (linear_bid / num_qo_tiles) % num_chunks; - const uint32_t kv_head_idx = linear_bid / (num_qo_tiles * num_chunks); - SinglePrefillWithKVCacheDevice(prefill_params, smem_storage, tid, bx, chunk_idx, - kv_head_idx, num_chunks, num_kv_heads_p); - } - } else /* OP == DECODE */ { - auto& smem_storage = reinterpret_cast(smem); - // dim3 nblks_d(padded_batch_size_d, 1, num_kv_heads); - if (linear_bid >= decode_blocks) return; + const uint32_t bx = linear_bid % padded_bsize_p; + const uint32_t kv_head_idx = linear_bid / padded_bsize_p; - const uint32_t bx = linear_bid % padded_bsize; - const uint32_t kv_head_idx = linear_bid / padded_bsize; + BatchPrefillWithPagedKVCacheDevice(prefill_params, smem_storage, tid, bx, + kv_head_idx, num_kv_heads); - // dim3 nthrs_d(32, NUM_WARPS_Q_D, NUM_WARPS_KV_D); + } else /* OP == DECODE */ { const uint32_t linear_tid = threadIdx.x; // Return if threadId exceeds number of threads for this op if (linear_tid >= 32 * KTraits_D::NUM_WARPS_Q * KTraits_D::NUM_WARPS_KV) return; + if (linear_bid >= decode_blocks) return; const dim3 tid = dim3(linear_tid % 32, (linear_tid / 32) % KTraits_D::NUM_WARPS_Q, (linear_tid / 32) / KTraits_D::NUM_WARPS_Q); + auto& smem_storage = reinterpret_cast(smem); + // dim3 nblks_d(padded_batch_size_d, 1, num_kv_heads); + const uint32_t bx = linear_bid % padded_bsize_d; + const uint32_t kv_head_idx = linear_bid / padded_bsize_d; + + // dim3 nthrs_d(32, NUM_WARPS_Q_D, NUM_WARPS_KV_D); + // Decode is faster with tensor cores, which are usually not saturated by prefill BatchPrefillWithPagedKVCacheDevice(decode_params, smem_storage, tid, bx, kv_head_idx, - num_kv_heads_d); + num_kv_heads); } } @@ -189,23 +174,10 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, DecodeP using DTypeO_P = typename PrefillParams::DTypeO; const uint32_t num_qo_heads = prefill_params.num_qo_heads; const uint32_t num_kv_heads = prefill_params.num_kv_heads; - const uint32_t qo_len = prefill_params.qo_len; - const uint32_t kv_len = prefill_params.kv_len; - if (kv_len < qo_len && MASK_MODE_P == MaskMode::kCausal) { - std::ostringstream err_msg; - err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be greater than or equal " - "to qo_len, got kv_len" - << kv_len << " and qo_len " << qo_len; - FLASHINFER_ERROR(err_msg.str()); - } - const uint32_t group_size = num_qo_heads / num_kv_heads; - const uint_fastdiv group_size_fastdiv(group_size); constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; - int64_t unpacked_qo_len = qo_len * group_size; - // Decode vars setup using DTypeQ_D = typename DecodeParams::DTypeQ; using DTypeKV_D = typename DecodeParams::DTypeKV; @@ -316,7 +288,7 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, DecodeP size_t smem_size_d = sizeof(typename KTraits_D::SharedStorage); auto kernel = - PODWithKVCacheTensorKernel; + PODWithKVCacheTensorKernel; // Prefill: decide num_splits for split-kv int num_blocks_per_sm = 0; int num_sm = 0; @@ -327,51 +299,32 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, DecodeP // Above function returns 0 for some reason, so we use a workaround num_blocks_per_sm = std::max( 1, std::min((int)(max_smem_per_sm / smem_size_p), (int)(256 / num_threads_p))); - uint32_t max_num_kv_chunks = - (num_blocks_per_sm * num_sm) / - (num_kv_heads * ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q)); - uint32_t num_chunks; - if (max_num_kv_chunks > 0) { - uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256); - num_chunks = ceil_div(kv_len, chunk_size); - } else { - num_chunks = 0; - } // Setup new prefill params if (not) split - auto o_p = prefill_params.o; - auto lse_p = prefill_params.lse; - if (num_chunks <= 1 || tmp_v == nullptr) { - // Enough parallelism, do not split-kv - prefill_params.partition_kv = 0; - kernel = PODWithKVCacheTensorKernel; - } else { + auto o = prefill_params.o; + auto lse = prefill_params.lse; + if (prefill_params.partition_kv) { // Use cooperative groups to increase occupancy - prefill_params.partition_kv = num_chunks; + assert(tmp_v != nullptr); prefill_params.o = tmp_v; prefill_params.lse = tmp_s; - kernel = - PODWithKVCacheTensorKernel; } // Setup new decode params if (not) split - auto o_d = decode_params.o; - auto lse_d = decode_params.lse; - if (tmp_v == nullptr) { - // do not partition kv - decode_params.partition_kv = false; - } else { - decode_params.partition_kv = true; + if (prefill_params.partition_kv) { + assert(tmp_v != nullptr); decode_params.o = tmp_v; decode_params.lse = tmp_s; } - // uint32_t num_qo_tiles = ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q); + uint32_t padded_batch_size_p = prefill_params.padded_batch_size; uint32_t padded_batch_size_d = decode_params.padded_batch_size; + printf("Debug: launching prefill with padded_batch_size_p %d, num_kv_heads %d\n", + padded_batch_size_p, num_kv_heads); int nblks_p(padded_batch_size_p * num_kv_heads); int nthrs_p(32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P); - + printf("Debug: launching decode with padded_batch_size_d %d, num_kv_heads %d\n", + padded_batch_size_d, num_kv_heads); int nblks_d(padded_batch_size_d * num_kv_heads); int nthrs_d(32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D); @@ -391,8 +344,7 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, DecodeP cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2)); // Setup kernel arguments - void* args[] = {(void*)&num_qo_tiles, (void*)&prefill_params, (void*)&decode_params, - (void*)&tbAssign}; + void* args[] = {(void*)&prefill_params, (void*)&decode_params, (void*)&tbAssign}; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -408,14 +360,12 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, DecodeP config.blockDim = nthrs; config.dynamicSmemBytes = smem_size; config.stream = stream; - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, num_qo_tiles, prefill_params, - decode_params, tbAssign)); + FLASHINFER_CUDA_CALL( + cudaLaunchKernelEx(&config, kernel, prefill_params, decode_params, tbAssign)); } else { FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); // Post-kernel stuff for split-kv if (tmp_v != nullptr) { @@ -434,10 +384,8 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, DecodeP }); } }); -}); -return cudaSuccess; + return cudaSuccess; } - } // namespace flashinfer #endif // FLASHINFER_PREFILL_CUH_ From a7eb44b84117559ee059e37040d60af40604e40f Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 11 Jul 2025 05:09:06 +0000 Subject: [PATCH 29/33] add pod template inst and .cu gen --- aot_build_utils/generate.py | 59 ++++++++ aot_build_utils/generate_pod_inst.py | 130 ++++++++++++++++++ ...ed_attention.py => bench_pod_attention.py} | 86 +++++++----- csrc/pod_jit_pybind.cu | 19 ++- csrc/pod_kernel_inst.jinja | 29 ++-- flashinfer/pod.py | 18 +-- 6 files changed, 276 insertions(+), 65 deletions(-) create mode 100644 aot_build_utils/generate_pod_inst.py rename benchmarks/{bench_mixed_attention.py => bench_pod_attention.py} (74%) diff --git a/aot_build_utils/generate.py b/aot_build_utils/generate.py index d4c21a8b6..5daa47c08 100644 --- a/aot_build_utils/generate.py +++ b/aot_build_utils/generate.py @@ -24,6 +24,7 @@ generate_batch_paged_decode_inst, generate_batch_paged_prefill_inst, generate_batch_ragged_prefill_inst, + generate_pod_inst, generate_single_decode_inst, generate_single_prefill_inst, ) @@ -250,11 +251,69 @@ def write_if_different(path: Path, content: str) -> None: f"f16qk_{bool(use_fp16_qk_reduction)}" ) + # POD files + pod_uris = [] + for ( + head_dim, + pos_encoding_mode, + use_fp16_qk_reduction, + mask_mode_p, + mask_mode_d, + idtype, + ) in product( + head_dims, + pos_encoding_modes, + use_fp16_qk_reductions, + mask_modes, + mask_modes, # mask_mode_d can be different from mask_mode_p + idtypes, + ): + for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list( + product(prefill_dtypes, fp8_dtypes) + ): + fname = f"pod_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_maskp_{mask_mode_p}_maskd_{mask_mode_d}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu" + content = generate_pod_inst.get_cu_file_str( + head_dim, # head_dim_qk + head_dim, # head_dim_vo + pos_encoding_mode, + use_fp16_qk_reduction, + mask_mode_p, + mask_mode_d, + dtype_q, # dtype_q + dtype_kv, # dtype_kv + dtype_q, # dtype_out + idtype, + ) + write_if_different(path / fname, content) + + for sliding_window_p in [True, False]: + for sliding_window_d in [True, False]: + for logits_soft_cap_p in [True, False]: + for logits_soft_cap_d in [True, False]: + if ( + mask_mode_p == 0 and mask_mode_d == 0 + ): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris + pod_uris.append( + f"pod_with_kv_cache_dtype_q_{dtype_q}_" + f"dtype_kv_{dtype_kv}_" + f"dtype_o_{dtype_q}_" + f"dtype_idx_{idtype}_" + f"head_dim_qk_{head_dim}_" + f"head_dim_vo_{head_dim}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_p_{sliding_window_p}_" + f"use_swa_d_{sliding_window_d}_" + f"use_logits_cap_p_{logits_soft_cap_p}_" + f"use_logits_cap_d_{logits_soft_cap_d}_" + f"f16qk_{bool(use_fp16_qk_reduction)}" + ) + return ( single_decode_uris + batch_decode_uris + single_prefill_uris + batch_prefill_uris + + pod_uris ) diff --git a/aot_build_utils/generate_pod_inst.py b/aot_build_utils/generate_pod_inst.py new file mode 100644 index 000000000..cb32c698c --- /dev/null +++ b/aot_build_utils/generate_pod_inst.py @@ -0,0 +1,130 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import re +import sys +from pathlib import Path + +from .literal_map import ( + dtype_literal, + idtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, +) + + +def get_cu_file_str( + head_dim_qk, + head_dim_vo, + pos_encoding_mode, + use_fp16_qk_reduction, + mask_mode_p, + mask_mode_d, + dtype_q, + dtype_kv, + dtype_out, + idtype, +): + cta_tile_q_choice = [128, 64, 16] + + def get_insts(attention_variant_p, attention_variant_d, dtype_out): + return "\n".join( + [ + """template cudaError_t PODWithKVCacheTensorDispatched<{head_dim_qk}, {head_dim_vo}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode_p}, {cta_tile_q_p}, {cta_tile_q_d}, {mask_mode_d}, {attention_variant_p}, {attention_variant_d}, PrefillParams, DecodeParams>( + PrefillParams prefill_params, DecodeParams decode_params, + {dtype_out}* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream); + """.format( + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, + pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], + use_fp16_qk_reduction=use_fp16_qk_reduction, + mask_mode_p=mask_mode_literal[int(mask_mode_p)], + cta_tile_q_p=cta_tile_q_p, + cta_tile_q_d=cta_tile_q_d, + mask_mode_d=mask_mode_literal[int(mask_mode_d)], + attention_variant_p=attention_variant_p, + attention_variant_d=attention_variant_d, + dtype_out=dtype_out, + ) + for cta_tile_q_p in cta_tile_q_choice + for cta_tile_q_d in cta_tile_q_choice + ] + ) + + use_custom_mask_p = "true" if int(mask_mode_p) == 2 else "false" + use_custom_mask_d = "true" if int(mask_mode_d) == 2 else "false" + dtype_q = dtype_literal[dtype_q] + dtype_kv = dtype_literal[dtype_kv] + dtype_out = dtype_literal[dtype_out] + idtype = idtype_literal[idtype] + + content = f"""#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "pytorch_conversion_utils.h" +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +using PrefillParams = BatchPrefillPagedParams<{dtype_q}, {dtype_kv}, {dtype_out}>; +using DecodeParams = BatchPrefillPagedParams<{dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>; + +constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; + +using AttentionVariant1_P = DefaultAttention<{use_custom_mask_p}, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>; +using AttentionVariant1_D = DefaultAttention<{use_custom_mask_d}, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>; + +{get_insts("AttentionVariant1_P", "AttentionVariant1_D", dtype_out)} + +using AttentionVariant2_P = DefaultAttention<{use_custom_mask_p}, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>; +using AttentionVariant2_D = DefaultAttention<{use_custom_mask_d}, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>; + +{get_insts("AttentionVariant2_P", "AttentionVariant2_D", dtype_out)} + +using AttentionVariant3_P = DefaultAttention<{use_custom_mask_p}, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>; +using AttentionVariant3_D = DefaultAttention<{use_custom_mask_d}, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>; + +{get_insts("AttentionVariant3_P", "AttentionVariant3_D", dtype_out)} + +using AttentionVariant4_P = DefaultAttention<{use_custom_mask_p}, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>; +using AttentionVariant4_D = DefaultAttention<{use_custom_mask_d}, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>; + +{get_insts("AttentionVariant4_P", "AttentionVariant4_D", dtype_out)} + +}}""" + return content + + +if __name__ == "__main__": + pattern = ( + r"pod_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_" + r"fp16qkred_([a-z]+)_maskp_([0-9]+)_maskd_([0-9]+)_" + r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" + ) + compiled_pattern = re.compile(pattern) + path = Path(sys.argv[1]) + fname = path.name + match = compiled_pattern.match(fname) + + with open(path, "w") as f: + f.write(get_cu_file_str(*match.groups())) diff --git a/benchmarks/bench_mixed_attention.py b/benchmarks/bench_pod_attention.py similarity index 74% rename from benchmarks/bench_mixed_attention.py rename to benchmarks/bench_pod_attention.py index 4cc821441..105698e94 100644 --- a/benchmarks/bench_mixed_attention.py +++ b/benchmarks/bench_pod_attention.py @@ -17,35 +17,54 @@ def run_bench( device=0, causal=True, ): - # POD Attention only supports page size = 1 due to use of single prefill kernel - page_block_size = 1 + # if page size > 1, prefill kv len must be divisible by page size to ensure + # an identical workload as in BatchAttention + page_size = 1 seq_lens = torch.tensor(d_kv_lens + p_kv_lens, dtype=torch.int32) q_lens = torch.tensor(d_qo_lens + p_qo_lens, dtype=torch.int32) - seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int() - d_seq_lens_blocks = ( - torch.tensor(d_kv_lens, dtype=torch.int32) / page_block_size - ).int() + seq_lens_blocks = torch.ceil(seq_lens / page_size).int() + p_seq_lens = torch.tensor(p_kv_lens, dtype=torch.int32) / page_size + d_seq_lens = (torch.tensor(d_kv_lens, dtype=torch.int32) / page_size).int() - q_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int() + # General params + qo_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int() kv_indptr = torch.cat( [torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0 ).int() - d_q_indptr = torch.cat( - [torch.tensor([0]), torch.cumsum(torch.tensor(d_qo_lens), 0)], dim=0 - ).int() - d_kv_indptr = torch.cat( - [torch.tensor([0]), torch.cumsum(d_seq_lens_blocks, 0)], dim=0 - ).int() - num_blocks = kv_indptr[-1].item() - - q = torch.rand(q_indptr[-1].item(), num_qo_heads, head_dim).to( + num_pages = kv_indptr[-1].item() + q = torch.rand(qo_indptr[-1].item(), num_qo_heads, head_dim).to( device, dtype=torch.bfloat16 ) - kv_data = torch.randn(num_blocks, 2, page_block_size, num_kv_heads, head_dim).to( + kv_data = torch.randn(num_pages, 2, page_size, num_kv_heads, head_dim).to( device, dtype=torch.bfloat16 ) + # Prefill params + num_pages_p = torch.ceil( + torch.tensor(p_kv_lens, dtype=torch.int32) / page_size + ).int() + qo_indptr_p = torch.cat( + [torch.tensor([0]), torch.cumsum(torch.tensor(p_qo_lens), 0)], dim=0 + ).int() + kv_indptr_p = torch.cat( + [torch.tensor([0]), torch.cumsum(p_seq_lens, 0)], dim=0 + ).int() + kv_indices_p = torch.arange(num_pages_p, device=device, dtype=torch.int32) + kv_data_p = kv_data[:num_pages_p] + last_page_len_p = (p_seq_lens - 1) % page_size + 1 + + # Decode params + qo_indptr_d = torch.cat( + [torch.tensor([0]), torch.cumsum(torch.tensor(d_qo_lens), 0)], dim=0 + ).int() + kv_indptr_d = torch.cat( + [torch.tensor([0]), torch.cumsum(d_seq_lens, 0)], dim=0 + ).int() + kv_indices_d = torch.arange(num_pages_p, device=device, dtype=torch.int32) + kv_data_d = kv_data[num_pages_p:] + last_page_len_d = (d_seq_lens - 1) % page_size + 1 + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) kv_layout = "NHD" @@ -54,11 +73,11 @@ def run_bench( kv_layout=kv_layout, backend="fa2", ) - last_page_len = (seq_lens - 1) % page_block_size + 1 + last_page_len = (seq_lens - 1) % page_size + 1 wrapper_old.plan( q_indptr.to(device), kv_indptr.to(device), - torch.arange(num_blocks).int().to(device), + torch.arange(num_blocks, dtype=torch.int32, device=device), last_page_len, num_qo_heads, num_kv_heads, @@ -72,28 +91,33 @@ def run_bench( ms_old = do_bench(lambda: wrapper_old.run(q, kv_data)) if len(p_kv_lens) == 1: - q_d = q[: d_q_indptr[-1]] - kv_d = kv_data[: d_kv_indptr[-1]].unbind(1) - q_p = q[d_q_indptr[-1] :] - k_p, v_p = kv_data[d_kv_indptr[-1] :].unbind(1) + q_d = q[: qo_indptr_d[-1]] + kv_d = kv_data[: kv_indptr_d[-1]].unbind(1) + q_p = q[qo_indptr_d[-1] :] + k_p, v_p = kv_data[kv_indptr_d[-1] :].unbind(1) k_p, v_p = k_p.squeeze(1), v_p.squeeze(1) kv_indices_d = torch.arange( - 0, d_kv_indptr[-1], device=device, dtype=torch.int32 + 0, kv_indptr_d[-1], device=device, dtype=torch.int32 ) - last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1 + last_page_len_d = (d_seq_lens - 1) % page_size + 1 wrapper_pod = flashinfer.PODWithPagedKVCacheWrapper( workspace_buffer, kv_layout=kv_layout, ) wrapper_pod.plan( - d_kv_indptr.to(device), + qo_indptr_p.to(device), + kv_indptr_p.to(device), + kv_indices_p.to(device), + last_page_len_p, + qo_indptr_d.to(device), + kv_indptr_d.to(device), kv_indices_d.to(device), - last_page_len_d=last_page_len_d, + last_page_len_d, num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, - page_size=page_block_size, + page_size=page_size, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, ) @@ -124,15 +148,15 @@ def run_bench( # Persistent attention wrapper = flashinfer.BatchAttention(kv_layout="NHD") wrapper.plan( - q_indptr.to(device), + qo_indptr.to(device), kv_indptr.to(device), - torch.arange(num_blocks, dtype=torch.int32, device=device), + torch.arange(num_pages, dtype=torch.int32, device=device), seq_lens.to(device), num_qo_heads, num_kv_heads, head_dim, head_dim, - page_block_size, + page_size, causal=causal, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, diff --git a/csrc/pod_jit_pybind.cu b/csrc/pod_jit_pybind.cu index d9d71c6a5..db5e44d2c 100644 --- a/csrc/pod_jit_pybind.cu +++ b/csrc/pod_jit_pybind.cu @@ -19,19 +19,16 @@ void PODWithKVCacheTensorRun( // Shared params at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, - at::Tensor plan_info_vec, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, - at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, + at::Tensor plan_info_vec, at::Tensor paged_k_cache, at::Tensor paged_v_cache, + at::Tensor qo_indptr, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, + at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional maybe_lse, + int64_t layout, // Prefill params - at::Tensor q_p, at::Tensor paged_k_p, at::Tensor paged_v_p, - std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, - int64_t window_left_p, std::optional maybe_custom_mask_p, - std::optional maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p, - double rope_rcp_scale_p, double rope_rcp_theta_p, + at::Tensor q_p, int64_t mask_mode_code_p, int64_t window_left_p, + std::optional maybe_custom_mask_p, std::optional maybe_alibi_slopes_p, + double logits_soft_cap_p, double sm_scale_p, double rope_rcp_scale_p, double rope_rcp_theta_p, // Decode params - at::Tensor q_d, at::Tensor paged_k_cache_d, at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, - at::Tensor paged_kv_indptr_d, at::Tensor paged_kv_indices_d, - at::Tensor paged_kv_last_page_len_d, std::optional maybe_lse_d, - int64_t mask_mode_code_d, int64_t layout_d, int64_t window_left_d, + at::Tensor q_d, int64_t mask_mode_code_d, int64_t window_left_d, std::optional maybe_custom_mask_d, std::optional maybe_mask_indptr_d, std::optional maybe_alibi_slopes_d, double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, bool enable_pdl); diff --git a/csrc/pod_kernel_inst.jinja b/csrc/pod_kernel_inst.jinja index e6938d58b..19b86efbd 100644 --- a/csrc/pod_kernel_inst.jinja +++ b/csrc/pod_kernel_inst.jinja @@ -13,20 +13,25 @@ #include "pod_config.inc" -using namespace flashinfer; - namespace flashinfer { + +using PrefillParams = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}>; +using DecodeParams = BatchPrefillPagedParams<{{ dtype_q }}, {{ dtype_kv }}, {{ dtype_o }}, {{ idtype }}>; + +constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; + constexpr auto use_custom_mask_p = {{ mask_mode_p }} == MaskMode::kCustom; constexpr auto use_custom_mask_d = {{ mask_mode_d }} == MaskMode::kCustom; -// Not sure about the below declaration -constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; +using PrefillAttentionVariant = DefaultAttention; +using DecodeAttentionVariant = DefaultAttention; + +{% for cta_tile_q_p in [16, 64, 128] %} template cudaError_t PODWithKVCacheTensorDispatched< - {{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE, - {{ use_fp16_qk_reduction }}, {{ mask_mode_p }}, 16, - {{ mask_mode_d }}, {{ variant_name_p }}, - {{ variant_name_d }}, PrefillParams, DecodeParams>( - PrefillParams prefill_params, {{ dtype_o }}* tmp, - DecodeParams decode_params, {{ dtype_o }}* tmp_v, - float *tmp_s, bool enable_pdl, cudaStream_t stream); -}; + {{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE, {{ use_fp16_qk_reduction }}, {{ mask_mode_p }}, {{ cta_tile_q_p }}, 16, {{ mask_mode_d }}, + PrefillAttentionVariant, DecodeAttentionVariant, PrefillParams, DecodeParams>( + PrefillParams prefill_params, DecodeParams decode_params, + {{ dtype_o }}* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream); +{% endfor %} + +}; // namespace flashinfer diff --git a/flashinfer/pod.py b/flashinfer/pod.py index bf6d7f6af..ebb331b73 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -502,10 +502,8 @@ def run( self, # Main params (prefill and decode) q_p: torch.Tensor, - k_p: torch.Tensor, - v_p: torch.Tensor, q_d: torch.Tensor, - paged_kv_cache_d: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], # Prefill options custom_mask_p: Optional[torch.Tensor] = None, packed_custom_mask_p: Optional[torch.Tensor] = None, @@ -581,9 +579,9 @@ def run( ) # Decode setup - k_cache_d, v_cache_d = _unpack_paged_kv_cache(paged_kv_cache_d, self._kv_layout) + k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout) _check_cached_qkv_data_type( - q_d, k_cache_d, self._cached_q_data_type, self._cached_kv_data_type + q_d, k_cache, self._cached_q_data_type, self._cached_kv_data_type ) # TODO_AK: Where are these coming from? pos_encoding_mode_d = self._pos_encoding_mode @@ -611,7 +609,7 @@ def run( module_getter = get_pod_module( # Prefill params q_p.dtype, - k_p.dtype, + k_cache.dtype, q_p.dtype, q_p.shape[-1], PosEncodingMode[pos_encoding_mode_p].value, @@ -634,19 +632,17 @@ def run( self._float_workspace_buffer, self._int_workspace_buffer, self._plan_info, - self._qo_indptr_buf, + k_cache, + v_cache, + self._qo_indptr_buf, # contains both prefill and decode indptr self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, - k_cache_d, - v_cache_d, out, lse, TensorLayout[self._kv_layout].value, # Prefill params q_p, - k_p, - v_p, mask_mode_p, window_left_p, packed_custom_mask_p, From f57fde0fc0090569d9f1b7e844bdcbc6be706b73 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 11 Jul 2025 05:27:50 +0000 Subject: [PATCH 30/33] fixes --- benchmarks/bench_pod_attention.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/benchmarks/bench_pod_attention.py b/benchmarks/bench_pod_attention.py index 105698e94..6312ce812 100644 --- a/benchmarks/bench_pod_attention.py +++ b/benchmarks/bench_pod_attention.py @@ -41,7 +41,7 @@ def run_bench( ) # Prefill params - num_pages_p = torch.ceil( + seq_lens_blocks_p = torch.ceil( torch.tensor(p_kv_lens, dtype=torch.int32) / page_size ).int() qo_indptr_p = torch.cat( @@ -50,24 +50,24 @@ def run_bench( kv_indptr_p = torch.cat( [torch.tensor([0]), torch.cumsum(p_seq_lens, 0)], dim=0 ).int() + num_pages_p = seq_lens_blocks_p[-1].item() kv_indices_p = torch.arange(num_pages_p, device=device, dtype=torch.int32) - kv_data_p = kv_data[:num_pages_p] last_page_len_p = (p_seq_lens - 1) % page_size + 1 # Decode params + qo_indptr_d = torch.cat( [torch.tensor([0]), torch.cumsum(torch.tensor(d_qo_lens), 0)], dim=0 ).int() kv_indptr_d = torch.cat( [torch.tensor([0]), torch.cumsum(d_seq_lens, 0)], dim=0 ).int() - kv_indices_d = torch.arange(num_pages_p, device=device, dtype=torch.int32) - kv_data_d = kv_data[num_pages_p:] + num_pages_d = kv_indptr_d[-1].item() + kv_indices_d = torch.arange(num_pages_d, device=device, dtype=torch.int32) last_page_len_d = (d_seq_lens - 1) % page_size + 1 workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) kv_layout = "NHD" - wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout=kv_layout, @@ -75,9 +75,9 @@ def run_bench( ) last_page_len = (seq_lens - 1) % page_size + 1 wrapper_old.plan( - q_indptr.to(device), + qo_indptr.to(device), kv_indptr.to(device), - torch.arange(num_blocks, dtype=torch.int32, device=device), + torch.arange(num_pages, dtype=torch.int32, device=device), last_page_len, num_qo_heads, num_kv_heads, @@ -92,13 +92,8 @@ def run_bench( if len(p_kv_lens) == 1: q_d = q[: qo_indptr_d[-1]] - kv_d = kv_data[: kv_indptr_d[-1]].unbind(1) q_p = q[qo_indptr_d[-1] :] - k_p, v_p = kv_data[kv_indptr_d[-1] :].unbind(1) - k_p, v_p = k_p.squeeze(1), v_p.squeeze(1) - kv_indices_d = torch.arange( - 0, kv_indptr_d[-1], device=device, dtype=torch.int32 - ) + kv_indices_d = torch.arange(0, num_pages_d, device=device, dtype=torch.int32) last_page_len_d = (d_seq_lens - 1) % page_size + 1 wrapper_pod = flashinfer.PODWithPagedKVCacheWrapper( @@ -137,10 +132,8 @@ def run_bench( ms_pod = do_bench( lambda: wrapper_pod.run( q_p, - k_p, - v_p, q_d, - kv_d, + paged_kv_cache=kv_data, causal_p=causal, causal_d=causal, ) From 0832bacd9383f6e796e732211595154c01387206 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 11 Jul 2025 17:20:53 +0000 Subject: [PATCH 31/33] trying to fix template --- aot_build_utils/generate_pod_inst.py | 37 +++++++++++----------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/aot_build_utils/generate_pod_inst.py b/aot_build_utils/generate_pod_inst.py index cb32c698c..53e061c09 100644 --- a/aot_build_utils/generate_pod_inst.py +++ b/aot_build_utils/generate_pod_inst.py @@ -39,6 +39,7 @@ def get_cu_file_str( idtype, ): cta_tile_q_choice = [128, 64, 16] + cta_tile_q_d = 16 def get_insts(attention_variant_p, attention_variant_d, dtype_out): return "\n".join( @@ -60,7 +61,6 @@ def get_insts(attention_variant_p, attention_variant_d, dtype_out): dtype_out=dtype_out, ) for cta_tile_q_p in cta_tile_q_choice - for cta_tile_q_d in cta_tile_q_choice ] ) @@ -71,43 +71,34 @@ def get_insts(attention_variant_p, attention_variant_d, dtype_out): dtype_out = dtype_literal[dtype_out] idtype = idtype_literal[idtype] - content = f"""#include -#include -#include -#include -#include -#include -#include -#include -#include + content = f"""#include +#include "pod_config.inc" -#include "pytorch_conversion_utils.h" -#include "pytorch_extension_utils.h" +namespace flashinfer {{ -using namespace flashinfer; +constexpr auto use_custom_mask_p = MaskMode::kNone == MaskMode::kCustom; +constexpr auto use_custom_mask_d = MaskMode::kNone == MaskMode::kCustom; using PrefillParams = BatchPrefillPagedParams<{dtype_q}, {dtype_kv}, {dtype_out}>; using DecodeParams = BatchPrefillPagedParams<{dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>; -constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; - -using AttentionVariant1_P = DefaultAttention<{use_custom_mask_p}, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>; -using AttentionVariant1_D = DefaultAttention<{use_custom_mask_d}, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>; +using AttentionVariant1_P = DefaultAttention; +using AttentionVariant1_D = DefaultAttention; {get_insts("AttentionVariant1_P", "AttentionVariant1_D", dtype_out)} -using AttentionVariant2_P = DefaultAttention<{use_custom_mask_p}, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>; -using AttentionVariant2_D = DefaultAttention<{use_custom_mask_d}, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>; +using AttentionVariant2_P = DefaultAttention; +using AttentionVariant2_D = DefaultAttention; {get_insts("AttentionVariant2_P", "AttentionVariant2_D", dtype_out)} -using AttentionVariant3_P = DefaultAttention<{use_custom_mask_p}, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>; -using AttentionVariant3_D = DefaultAttention<{use_custom_mask_d}, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>; +using AttentionVariant3_P = DefaultAttention; +using AttentionVariant3_D = DefaultAttention; {get_insts("AttentionVariant3_P", "AttentionVariant3_D", dtype_out)} -using AttentionVariant4_P = DefaultAttention<{use_custom_mask_p}, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>; -using AttentionVariant4_D = DefaultAttention<{use_custom_mask_d}, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>; +using AttentionVariant4_P = DefaultAttention; +using AttentionVariant4_D = DefaultAttention; {get_insts("AttentionVariant4_P", "AttentionVariant4_D", dtype_out)} From a570bff88d7e8945b0648f0cd7f9e9c5c9cff85f Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sun, 13 Jul 2025 22:50:29 +0000 Subject: [PATCH 32/33] fix PODSplitQOKVIndptr param --- aot_build_utils/generate.py | 2 +- csrc/pod_kernel_inst.jinja | 4 +--- include/flashinfer/attention/scheduler.cuh | 4 ++-- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/aot_build_utils/generate.py b/aot_build_utils/generate.py index 5daa47c08..5711f52e8 100644 --- a/aot_build_utils/generate.py +++ b/aot_build_utils/generate.py @@ -313,7 +313,7 @@ def write_if_different(path: Path, content: str) -> None: + batch_decode_uris + single_prefill_uris + batch_prefill_uris - + pod_uris + # + pod_uris ) diff --git a/csrc/pod_kernel_inst.jinja b/csrc/pod_kernel_inst.jinja index 19b86efbd..646288336 100644 --- a/csrc/pod_kernel_inst.jinja +++ b/csrc/pod_kernel_inst.jinja @@ -26,12 +26,10 @@ constexpr auto use_custom_mask_d = {{ mask_mode_d }} == MaskMode::kCustom; using PrefillAttentionVariant = DefaultAttention; using DecodeAttentionVariant = DefaultAttention; -{% for cta_tile_q_p in [16, 64, 128] %} template cudaError_t PODWithKVCacheTensorDispatched< - {{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE, {{ use_fp16_qk_reduction }}, {{ mask_mode_p }}, {{ cta_tile_q_p }}, 16, {{ mask_mode_d }}, + {{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE, {{ use_fp16_qk_reduction }}, {{ mask_mode_p }}, {{ 128 }}, 16, {{ mask_mode_d }}, PrefillAttentionVariant, DecodeAttentionVariant, PrefillParams, DecodeParams>( PrefillParams prefill_params, DecodeParams decode_params, {{ dtype_o }}* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream); -{% endfor %} }; // namespace flashinfer diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 809795095..075c012d6 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -1037,8 +1037,8 @@ inline cudaError_t PODPlan(void* float_buffer, size_t float_workspace_size_in_by auto [split_kv, real_batch_size, padded_batch_size_p, padded_batch_size_d, cta_tile_q_p, cta_tile_q_d, kv_chunk_size_p, kv_chunk_size_d, request_indices_vec, qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] = - PODSplitQOKVIndptr(qo_indptr_p, kv_indptr_p, qo_indptr_d, kv_indptr_d, total_num_rows_p, - batch_size_p, total_num_rows_d, batch_size_d, num_qo_heads, num_kv_heads, + PODSplitQOKVIndptr(qo_indptr_p, kv_indptr_p, total_num_rows_p, batch_size_p, qo_indptr_d, + kv_indptr_d, total_num_rows_d, batch_size_d, num_qo_heads, num_kv_heads, head_dim_vo, page_size, max_batch_size_if_split, enable_cuda_graph); uint32_t padded_batch_size = padded_batch_size_p + padded_batch_size_d; uint32_t batch_size = batch_size_p + batch_size_d; From dbc499a41826946288fe3799cb58a919f9783b22 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 15 Jul 2025 02:27:16 +0000 Subject: [PATCH 33/33] fix template and vec type errors --- csrc/batch_prefill.cu | 2 +- flashinfer/prefill.py | 2 + include/flashinfer/attention/scheduler.cuh | 136 +++++++++------------ tvm_binding/batch_prefill.cu | 14 +-- 4 files changed, 69 insertions(+), 85 deletions(-) diff --git a/csrc/batch_prefill.cu b/csrc/batch_prefill.cu index 94588b393..a51fc7f56 100644 --- a/csrc/batch_prefill.cu +++ b/csrc/batch_prefill.cu @@ -47,7 +47,7 @@ at::Tensor BatchPrefillWithKVCachePlan( at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, - int64_t head_dim_vo) { + int64_t head_dim_vo, bool causal) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 6c5068716..afa8bf2f6 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1579,6 +1579,7 @@ def plan( self.is_cuda_graph_enabled, head_dim_qk, head_dim_vo, + causal, ) self._causal = causal @@ -2352,6 +2353,7 @@ def plan( self.is_cuda_graph_enabled, head_dim_qk, head_dim_vo, + causal, ) self._causal = causal diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 075c012d6..bf05242d6 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -100,22 +100,22 @@ inline auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( inline auto PrefillBinarySearchKVChunkSize(const bool enable_cuda_graph, const uint32_t max_batch_size_if_split, - const std::vector& packed_qo_len_arr, - const std::vector& kv_len_arr, + const std::vector& packed_qo_len_arr, + const std::vector& kv_len_arr, const uint32_t qo_chunk_size, const uint32_t min_kv_chunk_size = 1) { - const int64_t batch_size = packed_qo_len_arr.size(); - int64_t max_kv_len = 1; - for (const int64_t& kv_len : kv_len_arr) { + const uint32_t batch_size = packed_qo_len_arr.size(); + uint32_t max_kv_len = 1; + for (const uint32_t& kv_len : kv_len_arr) { max_kv_len = std::max(max_kv_len, kv_len); } - int64_t low = min_kv_chunk_size; - int64_t high = max_kv_len; - constexpr int64_t min_kv_len = 1; + uint32_t low = min_kv_chunk_size; + uint32_t high = max_kv_len; + constexpr uint32_t min_kv_len = 1; while (low < high) { - const int64_t mid = (low + high) / 2; - int64_t real_batch_size = 0; + const uint32_t mid = (low + high) / 2; + uint32_t real_batch_size = 0; for (uint32_t i = 0; i < batch_size; ++i) { real_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) * ceil_div(std::max(kv_len_arr[i], min_kv_len), mid); @@ -492,18 +492,19 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in return cudaSuccess; } +template inline auto get_qkv_len_arr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size, uint32_t num_qo_heads, uint32_t gqa_group_size) { - std::vector packed_qo_len_arr(batch_size), kv_len_arr(batch_size); + std::vector packed_qo_len_arr(batch_size), kv_len_arr(batch_size); for (uint32_t i = 0; i < batch_size; ++i) { - packed_qo_len_arr[i] = int64_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * int64_t(gqa_group_size); + packed_qo_len_arr[i] = uint32_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * uint32_t(gqa_group_size); if (packed_qo_len_arr[i] < 0) { std::ostringstream err_msg; err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] << " - qo_indptr[" << i << "]" << qo_indptr_h[i] << " should be non-negative"; FLASHINFER_ERROR(err_msg.str()); } - kv_len_arr[i] = int64_t(kv_indptr_h[i + 1] - kv_indptr_h[i]); + kv_len_arr[i] = uint32_t(kv_indptr_h[i + 1] - kv_indptr_h[i]); if (kv_len_arr[i] < 0) { std::ostringstream err_msg; err_msg << "kv_indptr[" << i + 1 << "]" << kv_indptr_h[i + 1] << " - kv_indptr[" << i << "]" @@ -514,7 +515,7 @@ inline auto get_qkv_len_arr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t b return std::make_tuple(packed_qo_len_arr, kv_len_arr); } -inline auto get_q_tiles(std::vector& packed_qo_len_arr, uint32_t batch_size, +inline auto get_q_tiles(std::vector packed_qo_len_arr, uint32_t batch_size, uint32_t head_dim, uint32_t page_size, uint32_t total_num_rows, uint32_t gqa_group_size, bool enable_cuda_graph, bool is_decode = false) { uint32_t cta_tile_q; @@ -555,60 +556,53 @@ inline auto get_q_tiles(std::vector& packed_qo_len_arr, uint32_t batch_ return std::make_tuple(cta_tile_q, total_num_tiles_q); } -inline auto get_qkv_tile_indices(std::vector& packed_qo_len_arr, - std::vector& kv_len_arr, uint32_t batch_size, - uint32_t cta_tile_q, uint32_t kv_chunk_size, - uint32_t gqa_group_size, - std::vector& request_indices = nullptr, - std::vector& qo_tile_indices = nullptr, - std::vector& kv_tile_indices = nullptr, - std::vector& merge_indptr = nullptr, - std::vector& o_indptr = nullptr) { - uint32_t start_req_idx = 0; // for global q,k,v,o indexing in POD Attention - if (request_indices == nullptr) { - request_indices = std::vector(); - } else { - start_req_idx = request_indices.back(); - } - if (qo_tile_indices == nullptr) { - qo_tile_indices = std::vector(); - } - if (kv_tile_indices == nullptr) { - kv_tile_indices = std::vector(); - } - if (merge_indptr == nullptr) { - merge_indptr = std::vector(); - merge_indptr.push_back(0); - } - if (o_indptr == nullptr) { - o_indptr = std::vector(); - o_indptr.push_back(0); +template +inline auto get_qkv_tile_indices( + const std::vector& packed_qo_len_arr, const std::vector& kv_len_arr, + uint32_t batch_size, uint32_t cta_tile_q, uint32_t kv_chunk_size, uint32_t gqa_group_size, + std::vector* request_indices = nullptr, std::vector* qo_tile_indices = nullptr, + std::vector* kv_tile_indices = nullptr, std::vector* merge_indptr = nullptr, + std::vector* o_indptr = nullptr) { + std::vector local_req; + std::vector local_qo; + std::vector local_kv; + std::vector local_merge{0}; + std::vector local_o{0}; + + auto* out_req = request_indices ? request_indices : &local_req; + auto* out_qo = qo_tile_indices ? qo_tile_indices : &local_qo; + auto* out_kv = kv_tile_indices ? kv_tile_indices : &local_kv; + auto* out_merge = merge_indptr ? merge_indptr : &local_merge; + auto* out_o = o_indptr ? o_indptr : &local_o; + uint32_t start_req_idx = 0; // for global q,k,v,o indexing + if (request_indices && !request_indices->empty()) { + start_req_idx = request_indices->back(); } uint32_t real_batch_size = 0; for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { - const int64_t packed_qo_len = packed_qo_len_arr[request_idx]; - const int64_t kv_len = std::max(int(kv_len_arr[request_idx]), 1); - const int64_t num_tiles_q = ceil_div(packed_qo_len, cta_tile_q); - const int64_t num_tiles_kv = ceil_div(kv_len, kv_chunk_size); + const uint32_t packed_qo_len = packed_qo_len_arr[request_idx]; + const uint32_t kv_len = std::max(uint32_t(kv_len_arr[request_idx]), uint32_t(1)); + const uint32_t num_tiles_q = ceil_div(packed_qo_len, cta_tile_q); + const uint32_t num_tiles_kv = ceil_div(kv_len, kv_chunk_size); for (uint32_t q_tile_idx = 0; q_tile_idx < num_tiles_q; ++q_tile_idx) { for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_tiles_kv; ++kv_tile_idx) { real_batch_size += 1; - request_indices.push_back(request_idx + start_req_idx); - qo_tile_indices.push_back(q_tile_idx); - kv_tile_indices.push_back(kv_tile_idx); + request_indices->push_back(request_idx + start_req_idx); + qo_tile_indices->push_back(q_tile_idx); + kv_tile_indices->push_back(kv_tile_idx); } } int64_t qo_len = packed_qo_len / gqa_group_size; for (uint32_t row = 0; row < qo_len; ++row) { - merge_indptr.push_back(merge_indptr.back() + num_tiles_kv); + merge_indptr->push_back(merge_indptr->back() + num_tiles_kv); } - o_indptr.push_back(o_indptr.back() + qo_len * num_tiles_kv); + o_indptr->push_back(o_indptr->back() + qo_len * num_tiles_kv); } - return std::make_tuple(request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, - real_batch_size); + return std::make_tuple(std::move(local_req), std::move(local_qo), std::move(local_kv), + std::move(local_merge), std::move(local_o), real_batch_size); } template @@ -617,10 +611,6 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, uint32_t max_batch_size_if_split, bool enable_cuda_graph) { - std::vector request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr; - merge_indptr.push_back(0); - o_indptr.push_back(0); - const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U); @@ -638,8 +628,8 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, kv_len_arr, cta_tile_q, min_kv_chunk_size); auto [request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, - real_batch_size] = get_qkv_tile_indices(packed_qo_len_arr, kv_len_arr, batch_size, - cta_tile_q, kv_chunk_size, gqa_group_size); + real_batch_size] = get_qkv_tile_indices(packed_qo_len_arr, kv_len_arr, batch_size, + cta_tile_q, kv_chunk_size, gqa_group_size); const size_t padded_batch_size = enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) : real_batch_size; @@ -650,8 +640,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, kv_chunk_size *= page_size; return std::make_tuple(split_kv, real_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, - std::move(request_indices), std::move(qo_tile_indices), - std::move(kv_tile_indices), std::move(merge_indptr), std::move(o_indptr)); + request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr); } struct PrefillPlanInfo { @@ -846,10 +835,6 @@ inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_ uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, uint32_t max_batch_size_if_split, bool enable_cuda_graph) { - std::vector request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr; - merge_indptr.push_back(0); - o_indptr.push_back(0); - const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U); // step 1: determine packed_qo_len_arr and verify qo_indptr contents. @@ -881,21 +866,20 @@ inline auto PODSplitQOKVIndptr(IdType* qo_indptr_p, IdType* kv_indptr_p, uint32_ // step 3: split qo_indptr and kv_indptr // Use one set of qkv indices, merge_indptr and o_indptr to simply merging. - auto [request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, - new_batch_size_p] = get_qkv_tile_indices(packed_qo_len_arr_p, kv_len_arr_p, batch_size_p, - cta_tile_q_p, kv_chunk_size_p, gqa_group_size); - auto [request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, - new_batch_size_d] = - get_qkv_tile_indices(packed_qo_len_arr_d, kv_len_arr_d, batch_size_d, cta_tile_q_d, - kv_chunk_size_d, gqa_group_size, request_indices, qo_tile_indices, - kv_tile_indices, merge_indptr, o_indptr); + auto [request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr, real_bs_p] = + get_qkv_tile_indices(packed_qo_len_arr_p, kv_len_arr_p, batch_size_p, cta_tile_q_p, + kv_chunk_size_p, gqa_group_size); + auto [_, __, _____, _______, _________, real_bs_d] = + get_qkv_tile_indices(packed_qo_len_arr_d, kv_len_arr_d, batch_size_d, cta_tile_q_d, + kv_chunk_size_d, gqa_group_size, &request_indices, + &qo_tile_indices, &kv_tile_indices, &merge_indptr, &o_indptr); bool split_kv = split_kv_p || split_kv_d; - uint32_t real_batch_size = new_batch_size_p + new_batch_size_d; + uint32_t real_batch_size = real_bs_p + real_bs_d; const size_t padded_batch_size_p = - enable_cuda_graph ? std::max(max_bs_p, num_tiles_q_p) : new_batch_size_p; + enable_cuda_graph ? std::max(max_bs_p, num_tiles_q_p) : real_bs_p; const size_t padded_batch_size_d = - enable_cuda_graph ? std::max(max_bs_d, num_tiles_q_d) : new_batch_size_d; + enable_cuda_graph ? std::max(max_bs_d, num_tiles_q_d) : real_bs_d; FLASHINFER_CHECK(real_batch_size <= padded_batch_size_p + padded_batch_size_d, "new batch size should not exceed padded batch size"); diff --git a/tvm_binding/batch_prefill.cu b/tvm_binding/batch_prefill.cu index 710764cf4..bf7161116 100644 --- a/tvm_binding/batch_prefill.cu +++ b/tvm_binding/batch_prefill.cu @@ -41,14 +41,12 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para using namespace flashinfer; -IntTuple BatchPrefillWithKVCachePlan(DLTensor* float_workspace_buffer, - DLTensor* int_workspace_buffer, - DLTensor* page_locked_int_workspace_buffer, - DLTensor* qo_indptr, DLTensor* kv_indptr, IntTuple kv_len_arr, - int64_t total_num_rows, int64_t batch_size, - int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, - bool enable_cuda_graph, int64_t head_dim_qk, - int64_t head_dim_vo, TVMStreamHandle cuda_stream) { +IntTuple BatchPrefillWithKVCachePlan( + DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, + DLTensor* page_locked_int_workspace_buffer, DLTensor* qo_indptr, DLTensor* kv_indptr, + IntTuple kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, + int64_t head_dim_vo, bool causal, TVMStreamHandle cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer->shape[0] * DataType(float_workspace_buffer->dtype).bytes(); size_t int_workspace_size_in_bytes =