Skip to content

[Feature] Support batch prefill for POD Attention #1231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
06ef37d
use FA2DetermineCtaTileQ for pod
Edenzzzz May 20, 2025
14e3f13
modify wrapper..
Edenzzzz Jun 6, 2025
22dc539
fix
Edenzzzz Jun 8, 2025
3257899
bench against persistent
Edenzzzz Jun 25, 2025
82f1550
rename xsize to num_qo_tiles
Edenzzzz Jun 27, 2025
06fee31
fix
Edenzzzz Jun 30, 2025
f47f73e
fix
Edenzzzz Jun 30, 2025
c51cecc
fix
Edenzzzz Jun 30, 2025
e73566c
add mixed scheduler
Edenzzzz Jul 3, 2025
1b2d4c0
rename to num_to_merge_qo_len
Edenzzzz Jul 3, 2025
78e1266
add params
Edenzzzz Jul 3, 2025
4979a2a
plan to use one reduction kernel for prefill and decode
Edenzzzz Jul 4, 2025
2102e22
fix
Edenzzzz Jul 4, 2025
fab82ae
use unifed qkv indptr
Edenzzzz Jul 6, 2025
13b6b19
fix
Edenzzzz Jul 6, 2025
7d29232
fix plan func upper call interface
Edenzzzz Jul 6, 2025
106bfdc
rename new_batch_size to real_batch_size
Edenzzzz Jul 7, 2025
5e3e896
concat request_indices
Edenzzzz Jul 7, 2025
ac07253
unifed indices in wrapper.plan
Edenzzzz Jul 8, 2025
eb8f719
fixes
Edenzzzz Jul 8, 2025
e8b266d
fix params
Edenzzzz Jul 8, 2025
560918b
fix some indices and params
Edenzzzz Jul 8, 2025
2105101
update PODWithKVCacheTensorRun args
Edenzzzz Jul 9, 2025
1a82b17
add paged kv params
Edenzzzz Jul 9, 2025
dd80a06
complete PODWithKVCacheTensorRun params
Edenzzzz Jul 9, 2025
0bb164b
share lse
Edenzzzz Jul 10, 2025
32d762b
templaterize CTA_TILE_Q_P
Edenzzzz Jul 10, 2025
870b0b2
update dispatch logic
Edenzzzz Jul 11, 2025
a7eb44b
add pod template inst and .cu gen
Edenzzzz Jul 11, 2025
f57fde0
fixes
Edenzzzz Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions aot_build_utils/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
generate_batch_paged_decode_inst,
generate_batch_paged_prefill_inst,
generate_batch_ragged_prefill_inst,
generate_pod_inst,
generate_single_decode_inst,
generate_single_prefill_inst,
)
Expand Down Expand Up @@ -250,11 +251,69 @@ def write_if_different(path: Path, content: str) -> None:
f"f16qk_{bool(use_fp16_qk_reduction)}"
)

# POD files
pod_uris = []
for (
head_dim,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode_p,
mask_mode_d,
idtype,
) in product(
head_dims,
pos_encoding_modes,
use_fp16_qk_reductions,
mask_modes,
mask_modes, # mask_mode_d can be different from mask_mode_p
idtypes,
):
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list(
product(prefill_dtypes, fp8_dtypes)
):
fname = f"pod_head_qk_{head_dim}_head_vo_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_maskp_{mask_mode_p}_maskd_{mask_mode_d}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
content = generate_pod_inst.get_cu_file_str(
head_dim, # head_dim_qk
head_dim, # head_dim_vo
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode_p,
mask_mode_d,
dtype_q, # dtype_q
dtype_kv, # dtype_kv
dtype_q, # dtype_out
idtype,
)
write_if_different(path / fname, content)

for sliding_window_p in [True, False]:
for sliding_window_d in [True, False]:
for logits_soft_cap_p in [True, False]:
for logits_soft_cap_d in [True, False]:
if (
mask_mode_p == 0 and mask_mode_d == 0
): # NOTE(Zihao): uri do not contain mask, avoid duplicate uris
pod_uris.append(
f"pod_with_kv_cache_dtype_q_{dtype_q}_"
f"dtype_kv_{dtype_kv}_"
f"dtype_o_{dtype_q}_"
f"dtype_idx_{idtype}_"
f"head_dim_qk_{head_dim}_"
f"head_dim_vo_{head_dim}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_p_{sliding_window_p}_"
f"use_swa_d_{sliding_window_d}_"
f"use_logits_cap_p_{logits_soft_cap_p}_"
f"use_logits_cap_d_{logits_soft_cap_d}_"
f"f16qk_{bool(use_fp16_qk_reduction)}"
)

return (
single_decode_uris
+ batch_decode_uris
+ single_prefill_uris
+ batch_prefill_uris
+ pod_uris
)


Expand Down
130 changes: 130 additions & 0 deletions aot_build_utils/generate_pod_inst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
Copyright (c) 2024 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import re
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)


def get_cu_file_str(
head_dim_qk,
head_dim_vo,
pos_encoding_mode,
use_fp16_qk_reduction,
mask_mode_p,
mask_mode_d,
dtype_q,
dtype_kv,
dtype_out,
idtype,
):
cta_tile_q_choice = [128, 64, 16]

def get_insts(attention_variant_p, attention_variant_d, dtype_out):
return "\n".join(
[
"""template cudaError_t PODWithKVCacheTensorDispatched<{head_dim_qk}, {head_dim_vo}, {pos_encoding_mode}, {use_fp16_qk_reduction}, {mask_mode_p}, {cta_tile_q_p}, {cta_tile_q_d}, {mask_mode_d}, {attention_variant_p}, {attention_variant_d}, PrefillParams, DecodeParams>(
PrefillParams prefill_params, DecodeParams decode_params,
{dtype_out}* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
""".format(
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
use_fp16_qk_reduction=use_fp16_qk_reduction,
mask_mode_p=mask_mode_literal[int(mask_mode_p)],
cta_tile_q_p=cta_tile_q_p,
cta_tile_q_d=cta_tile_q_d,
mask_mode_d=mask_mode_literal[int(mask_mode_d)],
attention_variant_p=attention_variant_p,
attention_variant_d=attention_variant_d,
dtype_out=dtype_out,
)
for cta_tile_q_p in cta_tile_q_choice
for cta_tile_q_d in cta_tile_q_choice
]
)

use_custom_mask_p = "true" if int(mask_mode_p) == 2 else "false"
use_custom_mask_d = "true" if int(mask_mode_d) == 2 else "false"
dtype_q = dtype_literal[dtype_q]
dtype_kv = dtype_literal[dtype_kv]
dtype_out = dtype_literal[dtype_out]
idtype = idtype_literal[idtype]

content = f"""#include <flashinfer/attention/default_prefill_params.cuh>
#include <flashinfer/attention/default_decode_params.cuh>
#include <flashinfer/attention/variants.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/attention/pod.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/page.cuh>

#include "pytorch_conversion_utils.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;

using PrefillParams = BatchPrefillPagedParams<{dtype_q}, {dtype_kv}, {dtype_out}>;
using DecodeParams = BatchPrefillPagedParams<{dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>;

constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;

using AttentionVariant1_P = DefaultAttention<{use_custom_mask_p}, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>;
using AttentionVariant1_D = DefaultAttention<{use_custom_mask_d}, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>;

{get_insts("AttentionVariant1_P", "AttentionVariant1_D", dtype_out)}

using AttentionVariant2_P = DefaultAttention<{use_custom_mask_p}, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>;
using AttentionVariant2_D = DefaultAttention<{use_custom_mask_d}, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>;

{get_insts("AttentionVariant2_P", "AttentionVariant2_D", dtype_out)}

using AttentionVariant3_P = DefaultAttention<{use_custom_mask_p}, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>;
using AttentionVariant3_D = DefaultAttention<{use_custom_mask_d}, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>;

{get_insts("AttentionVariant3_P", "AttentionVariant3_D", dtype_out)}

using AttentionVariant4_P = DefaultAttention<{use_custom_mask_p}, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>;
using AttentionVariant4_D = DefaultAttention<{use_custom_mask_d}, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>;

{get_insts("AttentionVariant4_P", "AttentionVariant4_D", dtype_out)}

}}"""
return content


if __name__ == "__main__":
pattern = (
r"pod_head_qk_([0-9]+)_head_vo_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_maskp_([0-9]+)_maskd_([0-9]+)_"
r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu"
)
compiled_pattern = re.compile(pattern)
path = Path(sys.argv[1])
fname = path.name
match = compiled_pattern.match(fname)

with open(path, "w") as f:
f.write(get_cu_file_str(*match.groups()))
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,67 @@ def run_bench(
device=0,
causal=True,
):
# POD Attention only supports page size = 1 due to use of single prefill kernel
page_block_size = 1
# if page size > 1, prefill kv len must be divisible by page size to ensure
# an identical workload as in BatchAttention
page_size = 1
seq_lens = torch.tensor(d_kv_lens + p_kv_lens, dtype=torch.int32)
q_lens = torch.tensor(d_qo_lens + p_qo_lens, dtype=torch.int32)

seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int()
d_seq_lens_blocks = (
torch.tensor(d_kv_lens, dtype=torch.int32) / page_block_size
).int()
seq_lens_blocks = torch.ceil(seq_lens / page_size).int()
p_seq_lens = torch.tensor(p_kv_lens, dtype=torch.int32) / page_size
d_seq_lens = (torch.tensor(d_kv_lens, dtype=torch.int32) / page_size).int()

q_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int()
# General params
qo_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int()
kv_indptr = torch.cat(
[torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0
).int()
d_q_indptr = torch.cat(
[torch.tensor([0]), torch.cumsum(torch.tensor(d_qo_lens), 0)], dim=0
).int()
d_kv_indptr = torch.cat(
[torch.tensor([0]), torch.cumsum(d_seq_lens_blocks, 0)], dim=0
).int()
num_blocks = kv_indptr[-1].item()

q = torch.rand(q_indptr[-1].item(), num_qo_heads, head_dim).to(
num_pages = kv_indptr[-1].item()
q = torch.rand(qo_indptr[-1].item(), num_qo_heads, head_dim).to(
device, dtype=torch.bfloat16
)
kv_data = torch.randn(num_blocks, 2, page_block_size, num_kv_heads, head_dim).to(
kv_data = torch.randn(num_pages, 2, page_size, num_kv_heads, head_dim).to(
device, dtype=torch.bfloat16
)

# Prefill params
seq_lens_blocks_p = torch.ceil(
torch.tensor(p_kv_lens, dtype=torch.int32) / page_size
).int()
qo_indptr_p = torch.cat(
[torch.tensor([0]), torch.cumsum(torch.tensor(p_qo_lens), 0)], dim=0
).int()
kv_indptr_p = torch.cat(
[torch.tensor([0]), torch.cumsum(p_seq_lens, 0)], dim=0
).int()
num_pages_p = seq_lens_blocks_p[-1].item()
kv_indices_p = torch.arange(num_pages_p, device=device, dtype=torch.int32)
last_page_len_p = (p_seq_lens - 1) % page_size + 1

# Decode params

qo_indptr_d = torch.cat(
[torch.tensor([0]), torch.cumsum(torch.tensor(d_qo_lens), 0)], dim=0
).int()
kv_indptr_d = torch.cat(
[torch.tensor([0]), torch.cumsum(d_seq_lens, 0)], dim=0
).int()
num_pages_d = kv_indptr_d[-1].item()
kv_indices_d = torch.arange(num_pages_d, device=device, dtype=torch.int32)
last_page_len_d = (d_seq_lens - 1) % page_size + 1

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
kv_layout = "NHD"

wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout=kv_layout,
backend="fa2",
)
last_page_len = (seq_lens - 1) % page_block_size + 1
last_page_len = (seq_lens - 1) % page_size + 1
wrapper_old.plan(
q_indptr.to(device),
qo_indptr.to(device),
kv_indptr.to(device),
torch.arange(num_blocks).int().to(device),
torch.arange(num_pages, dtype=torch.int32, device=device),
last_page_len,
num_qo_heads,
num_kv_heads,
Expand All @@ -72,28 +91,28 @@ def run_bench(
ms_old = do_bench(lambda: wrapper_old.run(q, kv_data))

if len(p_kv_lens) == 1:
q_d = q[: d_q_indptr[-1]]
kv_d = kv_data[: d_kv_indptr[-1]].unbind(1)
q_p = q[d_q_indptr[-1] :]
k_p, v_p = kv_data[d_kv_indptr[-1] :].unbind(1)
k_p, v_p = k_p.squeeze(1), v_p.squeeze(1)
kv_indices_d = torch.arange(
0, d_kv_indptr[-1], device=device, dtype=torch.int32
)
q_d = q[: qo_indptr_d[-1]]
q_p = q[qo_indptr_d[-1] :]
kv_indices_d = torch.arange(0, num_pages_d, device=device, dtype=torch.int32)

last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1
last_page_len_d = (d_seq_lens - 1) % page_size + 1
wrapper_pod = flashinfer.PODWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout=kv_layout,
)
wrapper_pod.plan(
d_kv_indptr.to(device),
qo_indptr_p.to(device),
kv_indptr_p.to(device),
kv_indices_p.to(device),
last_page_len_p,
qo_indptr_d.to(device),
kv_indptr_d.to(device),
kv_indices_d.to(device),
last_page_len=last_page_len_d,
last_page_len_d,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
page_size=page_block_size,
page_size=page_size,
q_data_type=torch.bfloat16,
kv_data_type=torch.bfloat16,
)
Expand All @@ -113,29 +132,47 @@ def run_bench(
ms_pod = do_bench(
lambda: wrapper_pod.run(
q_p,
k_p,
v_p,
q_d,
kv_d,
paged_kv_cache=kv_data,
causal_p=causal,
causal_d=causal,
)
)
# Persistent attention
wrapper = flashinfer.BatchAttention(kv_layout="NHD")
wrapper.plan(
qo_indptr.to(device),
kv_indptr.to(device),
torch.arange(num_pages, dtype=torch.int32, device=device),
seq_lens.to(device),
num_qo_heads,
num_kv_heads,
head_dim,
head_dim,
page_size,
causal=causal,
q_data_type=torch.bfloat16,
kv_data_type=torch.bfloat16,
)
ms_persistent = do_bench(lambda: wrapper.run(q, kv_data))

print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms")
if len(p_kv_lens) == 1:
print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms")
total_bytes = (
q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
)
print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms")
if len(p_kv_lens) == 1:
print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms")
print(f"Elapsed time (Persistent Attention): {ms_persistent:.2f} ms")

print(f"Loading memory size (MB): {total_bytes / (1024**2):.2f} MB")

bandwidth_old_gb_s = total_bytes / (ms_old * 1e-3) / (1024**3)

bandwidth_new_gb_s = total_bytes / (ms_persistent * 1e-3) / (1024**3)
print(f"Memory bandwidth (Batched Prefill): {bandwidth_old_gb_s:.2f} GB/s")
if len(p_kv_lens) == 1:
bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3)
print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s")
print(f"Memory bandwidth (Persistent Attention): {bandwidth_new_gb_s:.2f} GB/s")


if __name__ == "__main__":
Expand Down
Loading