Skip to content

Support arbitrary num_experts and topk, with bfloat16 zip prob. #10583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@

#define CUMSUM_BLOCK_SIZE 48 // cumsum开销和并行度之间的tradeoff的结果,勿动
#define CUMSUM_INVALID_TAG -1 // 用于标记无效的cumsum,尝试过-114514但失败了

#ifndef MAX_NUM_EXPERTS
#define MAX_NUM_EXPERTS 32
#endif
// 多阶段算法,控制每block处理的行数来权衡额外开销
// 首先解析routemap来更新专家当前所收到的token数,然后check前一个block给的前缀和并更新给下一个block
// 随后,目的行号的信息已获取,立即开始搬运工作,直至任务完全完成
template <typename X_T,
typename routemap_T,
typename probs_T,
int topk,
int num_experts,
bool has_scale>
template <typename X_T, typename routemap_T, typename probs_T, bool has_scale>
__global__ void tokens_unzip_stable_kernel(
const X_T *__restrict__ X,
const routemap_T *__restrict__ routemap_topk,
Expand All @@ -40,11 +37,13 @@ __global__ void tokens_unzip_stable_kernel(
const int total_zipped_tokens_num,
const int max_tokens_per_expert,
const int token_length,
const int scale_length) {
const int scale_length,
const int num_experts,
const int topk) {
const int block_row_base = blockIdx.x * CUMSUM_BLOCK_SIZE;
int cumsum_offset[num_experts];
int expert_offset[num_experts];
int local_cumsum[num_experts];
int cumsum_offset[MAX_NUM_EXPERTS];
int expert_offset[MAX_NUM_EXPERTS];
int local_cumsum[MAX_NUM_EXPERTS];
#pragma unroll
for (int i = 0; i < num_experts; i++) {
cumsum_offset[i] =
Expand All @@ -55,13 +54,13 @@ __global__ void tokens_unzip_stable_kernel(
local_cumsum[i] = 0;
}
const int base_row_idx = blockIdx.x * CUMSUM_BLOCK_SIZE;
__shared__ int shared_expert_rowmap[CUMSUM_BLOCK_SIZE][num_experts];
__shared__ probs_T shared_expert_probmap[CUMSUM_BLOCK_SIZE][num_experts];
__shared__ int shared_expert_rowmap[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS];
__shared__ probs_T shared_expert_probmap[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS];

// --------------------- thread0 单线程任务传递 -------------------------
if (threadIdx.x == 0) [[unlikely]] {
int local_expert_rowmap[CUMSUM_BLOCK_SIZE][num_experts];
probs_T local_expert_probs[CUMSUM_BLOCK_SIZE][num_experts];
int local_expert_rowmap[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS];
probs_T local_expert_probs[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS];
#pragma unroll
for (int i = 0; i < CUMSUM_BLOCK_SIZE; i++) {
#pragma unroll
Expand Down Expand Up @@ -171,35 +170,28 @@ void dispatch_tokens_unzip_stable(
#define GET_DATA(tensor, type) tensor.data<type>()

// 分发处理不同的类型组合
#define DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, TOPK, NUM_EXPERTS, HAS_SCALE) \
auto kernel = tokens_unzip_stable_kernel<TOKEN_T, \
INT_T, \
PROB_T, \
TOPK, \
NUM_EXPERTS, \
HAS_SCALE>; \
kernel<<<grid, block, 0, X.stream()>>>( \
GET_DATA(X, TOKEN_T), \
GET_DATA(expert_routemap_topk, INT_T), \
GET_DATA(expert_prob_topk, PROB_T), \
XScale ? XScale->data<float>() : nullptr, \
GET_DATA(X_unzipped, TOKEN_T), \
GET_DATA(zipped_expertwise_rowmap, INT_T), \
GET_DATA(token_prob_unzipped, PROB_T), \
XScale_unzipped.data<float>(), \
global_expertwise_block_cumsum.data<int>(), \
total_zipped_tokens_num, \
max_tokens_per_expert, \
token_length, \
scale_length);
#define DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE) \
auto kernel = tokens_unzip_stable_kernel<TOKEN_T, INT_T, PROB_T, HAS_SCALE>; \
kernel<<<grid, block, 0, X.stream()>>>( \
GET_DATA(X, TOKEN_T), \
GET_DATA(expert_routemap_topk, INT_T), \
GET_DATA(expert_prob_topk, PROB_T), \
XScale ? XScale->data<float>() : nullptr, \
GET_DATA(X_unzipped, TOKEN_T), \
GET_DATA(zipped_expertwise_rowmap, INT_T), \
GET_DATA(token_prob_unzipped, PROB_T), \
XScale_unzipped.data<float>(), \
global_expertwise_block_cumsum.data<int>(), \
total_zipped_tokens_num, \
max_tokens_per_expert, \
token_length, \
scale_length, \
num_experts, \
topk);

// 可扩展:处理特定的topk和num_experts组合,可根据之后需求进行扩展
#define HANDLE_EXPERT_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE) \
if (topk == 8 && num_experts == 4) { \
DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, 8, 4, HAS_SCALE) \
} else { \
std::__throw_invalid_argument; \
}
DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE)

#define HANDLE_TOKEN_TYPE(PROB_T, INT_T) \
if (DTYPE_CASE(X.dtype(), BFLOAT16)) { \
Expand Down
Loading
Loading