Skip to content

feat: optimize merge_attn_states thread block dispatch #279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 133 additions & 132 deletions kernels/openai-triton/merge-attn-states/cuda_merge_attn_states.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ from_float(half& d, float s) { d = __float2half(s); }
static __forceinline__ __device__ void
from_float(__nv_bfloat16& d, float s) { d = __float2bfloat16(s); }

// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
template <typename scalar_t, bool kLoopOverHead>
__global__ void merge_attn_states_kernel(

template <typename scalar_t>
__device__ __forceinline__ void merge_attn_states_per_thread(
scalar_t* output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
float* output_lse, // [NUM_HEADS, NUM_TOKENS]
const scalar_t* __restrict__ prefix_output, // [NUM_TOKENS, NUM_HEADS,
Expand All @@ -33,140 +32,129 @@ __global__ void merge_attn_states_kernel(
const float* __restrict__ suffix_lse, // [NUM_HEADS, NUM_TOKENS]
const uint num_tokens, // NUM_TOKENS
const uint num_heads, // NUM QUERY HEADS
const uint head_size // HEAD_SIZE, 32,48,64,...,512,etc
const uint head_size, // HEAD_SIZE, 32,48,64,...,512,etc
const uint token_idx,
const uint head_idx,
const uint thr_idx
) {
// TODO(DefTruth): may need to support fp8?
if constexpr (kLoopOverHead) {
// May loop over num heads for large NUM_TOKENS
const uint token_idx = blockIdx.x;
const uint thread_idx = threadIdx.x;
using pack_128b_t = uint4; // float -> 4, half/bf16 -> 8
constexpr uint pack_size = 16 / sizeof(scalar_t);

#pragma unroll
for (uint head_idx = 0; head_idx < num_heads; ++head_idx) {
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
p_lse =
std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
s_lse =
std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
const uint thr_offset = thr_idx * pack_size; // (0~15)*8, etc.
const uint blk_offset =
token_idx * num_heads * head_size + head_idx * head_size;
const scalar_t* prefix_output_blk = prefix_output + blk_offset;
const scalar_t* suffix_output_blk = suffix_output + blk_offset;
scalar_t* output_blk = output + blk_offset;

const float max_lse = fmaxf(p_lse, s_lse);
p_lse = p_lse - max_lse;
s_lse = s_lse - max_lse;
const float p_se = expf(p_lse);
const float s_se = expf(s_lse);
const float out_se = p_se + s_se;
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;

if (output_lse != nullptr) {
float out_lse = logf(out_se) + max_lse;
output_lse[head_idx * num_tokens + token_idx] = out_lse;
}
const float max_lse = fmaxf(p_lse, s_lse);
p_lse = p_lse - max_lse;
s_lse = s_lse - max_lse;
const float p_se = expf(p_lse);
const float s_se = expf(s_lse);
const float out_se = p_se + s_se;
const float p_scale = p_se / out_se;
const float s_scale = s_se / out_se;

const uint blk_offset =
token_idx * num_heads * head_size + head_idx * head_size;
const scalar_t* prefix_output_blk = prefix_output + blk_offset;
const scalar_t* suffix_output_blk = suffix_output + blk_offset;
scalar_t* output_blk = output + blk_offset;
// We only need to write to output_lse once per head.
if (output_lse != nullptr && thr_idx == 0) {
float out_lse = logf(out_se) + max_lse;
output_lse[head_idx * num_tokens + token_idx] = out_lse;
}

// float -> 4, half/bf16 -> 8
using pack_128b_t = uint4;
constexpr uint pack_size = 16 / sizeof(scalar_t);
if (thr_offset < head_size) {
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
prefix_output_blk)[thr_offset / pack_size];
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
suffix_output_blk)[thr_offset / pack_size];
pack_128b_t o_out_pack;

const uint thr_offset = thread_idx * pack_size;
const float p_scale = p_se / out_se;
const float s_scale = s_se / out_se;
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
// Always use float for FMA to keep precision.
// half(uint16_t), bfloat16, float -> float.
const float p_out_f =
to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
const float s_out_f =
to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
// float -> half(uint16_t), bfloat16, float.
from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i],
o_out_f);
}

if (thr_offset < head_size) {
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
prefix_output_blk)[thr_offset / pack_size];
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
suffix_output_blk)[thr_offset / pack_size];
pack_128b_t o_out_pack;
// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_blk)[
thr_offset / pack_size] = o_out_pack;
}
}

#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
// Always use float for FMA to keep precision.
// half(uint16_t), bfloat16, float -> float.
const float p_out_f =
to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
const float s_out_f =
to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
// float -> half(uint16_t), bfloat16, float.
from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i],
o_out_f);
}
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
template <typename scalar_t, bool kLoopOverHead, bool kFlattenOverHead = false>
__global__ void merge_attn_states_kernel(
scalar_t* output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
float* output_lse, // [NUM_HEADS, NUM_TOKENS]
const scalar_t* __restrict__ prefix_output, // [NUM_TOKENS, NUM_HEADS,
// HEAD_SIZE]
const float* __restrict__ prefix_lse, // [NUM_HEADS, NUM_TOKENS]
const scalar_t* __restrict__ suffix_output, // [NUM_TOKENS, NUM_HEADS,
// HEAD_SIZE]
const float* __restrict__ suffix_lse, // [NUM_HEADS, NUM_TOKENS]
const uint num_tokens, // NUM_TOKENS
const uint num_heads, // NUM QUERY HEADS
const uint head_size // HEAD_SIZE, 32,48,64,...,512,etc
) {
if constexpr (kLoopOverHead) {
// May loop over num heads for large num_tokens
const uint token_idx = blockIdx.x;
const uint thread_idx = threadIdx.x;

if constexpr (kFlattenOverHead) {
// thread num = (num_heads * head_size) / pack_size
// = num_heads * (head_size / pack_size), 16 * (128 / 8)
// tid: 0~255, 0~15->head 0, 16~31->head 1, ..., etc.
constexpr uint pack_size = 16 / sizeof(scalar_t);
const uint head_idx = thread_idx / (head_size / pack_size);
const uint thr_idx = thread_idx % (head_size / pack_size);
merge_attn_states_per_thread<scalar_t>(
output, output_lse, prefix_output,
prefix_lse, suffix_output, suffix_lse,
num_tokens, num_heads, head_size,
token_idx, head_idx, thr_idx
);
} else {
const uint thr_idx = thread_idx;
#pragma unroll
for (uint head_idx = 0; head_idx < num_heads; ++head_idx) {
merge_attn_states_per_thread<scalar_t>(
output, output_lse, prefix_output,
prefix_lse, suffix_output, suffix_lse,
num_tokens, num_heads, head_size,
token_idx, head_idx, thr_idx
);
} // End loop over heads
} // End kFlattenOverHead

// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_blk)[
thr_offset / pack_size] = o_out_pack;
}
} // End loop over heads
} else {
const uint token_idx = blockIdx.x;
const uint head_idx = blockIdx.y;
const uint thread_idx = threadIdx.x;
const uint thr_idx = thread_idx;

float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;

const float max_lse = fmaxf(p_lse, s_lse);
p_lse = p_lse - max_lse;
s_lse = s_lse - max_lse;
const float p_se = expf(p_lse);
const float s_se = expf(s_lse);
const float out_se = p_se + s_se;

if (output_lse != nullptr) {
float out_lse = logf(out_se) + max_lse;
output_lse[head_idx * num_tokens + token_idx] = out_lse;
}

const uint blk_offset =
token_idx * num_heads * head_size + head_idx * head_size;
const scalar_t* prefix_output_blk = prefix_output + blk_offset;
const scalar_t* suffix_output_blk = suffix_output + blk_offset;
scalar_t* output_blk = output + blk_offset;

// float -> 4, half/bf16 -> 8
using pack_128b_t = uint4; // 16 bytes
constexpr uint pack_size = 16 / sizeof(scalar_t);

const uint thr_offset = thread_idx * pack_size;
const float p_scale = p_se / out_se;
const float s_scale = s_se / out_se;

if (thr_offset < head_size) {
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
prefix_output_blk)[thr_offset / pack_size];
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
suffix_output_blk)[thr_offset / pack_size];
pack_128b_t o_out_pack;

#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
// Always use float for FMA to keep precision.
// half(uint16_t), bfloat16, float -> float.
const float p_out_f =
to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
const float s_out_f =
to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
// float -> half(uint16_t), bfloat16, float.
from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i],
o_out_f);
}

// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_blk)[
thr_offset / pack_size] = o_out_pack;
}
merge_attn_states_per_thread<scalar_t>(
output, output_lse, prefix_output,
prefix_lse, suffix_output, suffix_lse,
num_tokens, num_heads, head_size,
token_idx, head_idx, thr_idx
);
}
}

Expand All @@ -183,9 +171,9 @@ __global__ void merge_attn_states_kernel(
} \
}

#define LAUNCH_MERGE_ATTN_STATES(SCALAR_T, kLoopOverHead) \
#define LAUNCH_MERGE_ATTN_STATES(SCALAR_T, kLoopOverHead, kFlattenOverHead) \
{ \
merge_attn_states_kernel<SCALAR_T, kLoopOverHead> \
merge_attn_states_kernel<SCALAR_T, kLoopOverHead, kFlattenOverHead> \
<<<grid, block>>>( \
reinterpret_cast<SCALAR_T*>(output.data_ptr()), output_lse_ptr, \
reinterpret_cast<SCALAR_T*>(prefix_output.data_ptr()), \
Expand Down Expand Up @@ -217,20 +205,33 @@ void merge_attn_states_launcher(
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr<float>();
}
// Keep threads num <= 512 per thread block.
const bool skip_flatten_over_head = (
(num_heads * head_size) / pack_size > 512);

const bool skip_loop_over_head = (
num_tokens <= 1024 || num_heads >= 64
|| disable_loop_over_head
disable_loop_over_head || num_tokens <= 1024 ||
(num_heads >= 64 && skip_flatten_over_head)
);

if (skip_loop_over_head) {
dim3 grid(num_tokens, num_heads);
dim3 block(head_size / pack_size);
LAUNCH_MERGE_ATTN_STATES(SCALAR_T, false);
LAUNCH_MERGE_ATTN_STATES(SCALAR_T, false, false);
} else {
// try loop over num heads for large NUM_TOKENS
dim3 grid(num_tokens);
dim3 block(head_size / pack_size);
LAUNCH_MERGE_ATTN_STATES(SCALAR_T, true);
// try loop over num heads for large num_tokens
if (skip_flatten_over_head) {
dim3 grid(num_tokens);
dim3 block(head_size / pack_size);
LAUNCH_MERGE_ATTN_STATES(SCALAR_T, true, false);
} else {
// cases:
// num_tokens 8192, num_heads 16, head_size 128
// num_tokens 4096, num_heads 16, head_size 128
dim3 grid(num_tokens);
dim3 block((num_heads * head_size) / pack_size);
LAUNCH_MERGE_ATTN_STATES(SCALAR_T, true, true);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math"
# "--use_fast_math"
],
extra_cflags=['-std=c++17'],
verbose=True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ def merge_attn_states_torch(
return output, output_lse


NUM_TOKENS = [256, 512, 613, 1024, 1536, 4096]
NUM_BATCH_TOKENS = [256, 512, 613, 1024, 1536, 4096]
NUM_QUERY_HEADS = [4, 8, 16, 32]
HEAD_SIZES = [64, 96, 128]
DTYPES = [torch.float32, torch.half, torch.bfloat16]

@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("output_dtype", DTYPES)
Expand Down