diff --git a/slm/model_zoo/gpt-3/external_ops/token_dispatcher_utils/tokens_stable_unzip.cu b/slm/model_zoo/gpt-3/external_ops/token_dispatcher_utils/tokens_stable_unzip.cu index 67207d1bc160..9d63f0863d2a 100644 --- a/slm/model_zoo/gpt-3/external_ops/token_dispatcher_utils/tokens_stable_unzip.cu +++ b/slm/model_zoo/gpt-3/external_ops/token_dispatcher_utils/tokens_stable_unzip.cu @@ -17,16 +17,13 @@ #define CUMSUM_BLOCK_SIZE 48 // cumsum开销和并行度之间的tradeoff的结果,勿动 #define CUMSUM_INVALID_TAG -1 // 用于标记无效的cumsum,尝试过-114514但失败了 - +#ifndef MAX_NUM_EXPERTS +#define MAX_NUM_EXPERTS 32 +#endif // 多阶段算法,控制每block处理的行数来权衡额外开销 // 首先解析routemap来更新专家当前所收到的token数,然后check前一个block给的前缀和并更新给下一个block // 随后,目的行号的信息已获取,立即开始搬运工作,直至任务完全完成 -template +template __global__ void tokens_unzip_stable_kernel( const X_T *__restrict__ X, const routemap_T *__restrict__ routemap_topk, @@ -40,11 +37,13 @@ __global__ void tokens_unzip_stable_kernel( const int total_zipped_tokens_num, const int max_tokens_per_expert, const int token_length, - const int scale_length) { + const int scale_length, + const int num_experts, + const int topk) { const int block_row_base = blockIdx.x * CUMSUM_BLOCK_SIZE; - int cumsum_offset[num_experts]; - int expert_offset[num_experts]; - int local_cumsum[num_experts]; + int cumsum_offset[MAX_NUM_EXPERTS]; + int expert_offset[MAX_NUM_EXPERTS]; + int local_cumsum[MAX_NUM_EXPERTS]; #pragma unroll for (int i = 0; i < num_experts; i++) { cumsum_offset[i] = @@ -55,13 +54,13 @@ __global__ void tokens_unzip_stable_kernel( local_cumsum[i] = 0; } const int base_row_idx = blockIdx.x * CUMSUM_BLOCK_SIZE; - __shared__ int shared_expert_rowmap[CUMSUM_BLOCK_SIZE][num_experts]; - __shared__ probs_T shared_expert_probmap[CUMSUM_BLOCK_SIZE][num_experts]; + __shared__ int shared_expert_rowmap[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS]; + __shared__ probs_T shared_expert_probmap[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS]; // --------------------- thread0 单线程任务传递 ------------------------- if (threadIdx.x == 0) [[unlikely]] { - int local_expert_rowmap[CUMSUM_BLOCK_SIZE][num_experts]; - probs_T local_expert_probs[CUMSUM_BLOCK_SIZE][num_experts]; + int local_expert_rowmap[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS]; + probs_T local_expert_probs[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS]; #pragma unroll for (int i = 0; i < CUMSUM_BLOCK_SIZE; i++) { #pragma unroll @@ -171,35 +170,28 @@ void dispatch_tokens_unzip_stable( #define GET_DATA(tensor, type) tensor.data() // 分发处理不同的类型组合 -#define DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, TOPK, NUM_EXPERTS, HAS_SCALE) \ - auto kernel = tokens_unzip_stable_kernel; \ - kernel<<>>( \ - GET_DATA(X, TOKEN_T), \ - GET_DATA(expert_routemap_topk, INT_T), \ - GET_DATA(expert_prob_topk, PROB_T), \ - XScale ? XScale->data() : nullptr, \ - GET_DATA(X_unzipped, TOKEN_T), \ - GET_DATA(zipped_expertwise_rowmap, INT_T), \ - GET_DATA(token_prob_unzipped, PROB_T), \ - XScale_unzipped.data(), \ - global_expertwise_block_cumsum.data(), \ - total_zipped_tokens_num, \ - max_tokens_per_expert, \ - token_length, \ - scale_length); +#define DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE) \ + auto kernel = tokens_unzip_stable_kernel; \ + kernel<<>>( \ + GET_DATA(X, TOKEN_T), \ + GET_DATA(expert_routemap_topk, INT_T), \ + GET_DATA(expert_prob_topk, PROB_T), \ + XScale ? XScale->data() : nullptr, \ + GET_DATA(X_unzipped, TOKEN_T), \ + GET_DATA(zipped_expertwise_rowmap, INT_T), \ + GET_DATA(token_prob_unzipped, PROB_T), \ + XScale_unzipped.data(), \ + global_expertwise_block_cumsum.data(), \ + total_zipped_tokens_num, \ + max_tokens_per_expert, \ + token_length, \ + scale_length, \ + num_experts, \ + topk); // 可扩展:处理特定的topk和num_experts组合,可根据之后需求进行扩展 #define HANDLE_EXPERT_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE) \ - if (topk == 8 && num_experts == 4) { \ - DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, 8, 4, HAS_SCALE) \ - } else { \ - std::__throw_invalid_argument; \ - } + DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE) #define HANDLE_TOKEN_TYPE(PROB_T, INT_T) \ if (DTYPE_CASE(X.dtype(), BFLOAT16)) { \ diff --git a/slm/model_zoo/gpt-3/external_ops/token_dispatcher_utils/tokens_unzip_and_zip.cu b/slm/model_zoo/gpt-3/external_ops/token_dispatcher_utils/tokens_unzip_and_zip.cu index efdfa0e87668..4c29a65e1674 100644 --- a/slm/model_zoo/gpt-3/external_ops/token_dispatcher_utils/tokens_unzip_and_zip.cu +++ b/slm/model_zoo/gpt-3/external_ops/token_dispatcher_utils/tokens_unzip_and_zip.cu @@ -14,6 +14,9 @@ #include "utils.h" +#ifndef MAX_NUM_EXPERTS +#define MAX_NUM_EXPERTS 32 +#endif template +template __global__ void tokens_zip_kernel( const phi::bfloat16 *__restrict__ unzipped_tokens_in, const int *__restrict__ zipped_expertwise_rowmap, const int *__restrict__ expert_routemap_topk, - const float *__restrict__ unzipped_token_probs, + const phi::bfloat16 *__restrict__ unzipped_token_probs, phi::bfloat16 *__restrict__ zipped_tokens_out, - float *__restrict__ zipped_probs_topk, + phi::bfloat16 *__restrict__ zipped_probs_topk, const int total_zipped_tokens_num, - const int token_length) { + const int token_length, + const int num_experts, + const int topk) { const int this_row = blockIdx.x; if (this_row >= total_zipped_tokens_num) return; @@ -260,7 +265,7 @@ __global__ void tokens_zip_kernel( __nv_bfloat16 *zipped_tokens = reinterpret_cast<__nv_bfloat16 *>(zipped_tokens_out); - int local_row_fetchlist[num_experts]; + int local_row_fetchlist[MAX_NUM_EXPERTS]; // -------------------------初始化任务表 ------------------------ #pragma unroll @@ -365,19 +370,146 @@ __global__ void tokens_zip_kernel( } } } -template + +template +__global__ void tokens_zip_kernel( + const phi::bfloat16 *__restrict__ unzipped_tokens_in, + const int *__restrict__ zipped_expertwise_rowmap, + const int *__restrict__ expert_routemap_topk, + const float *__restrict__ unzipped_token_probs, + phi::bfloat16 *__restrict__ zipped_tokens_out, + float *__restrict__ zipped_probs_topk, + const int total_zipped_tokens_num, + const int token_length, + const int num_experts, + const int topk) { + const int this_row = blockIdx.x; + if (this_row >= total_zipped_tokens_num) return; + + const __nv_bfloat16 *unzipped_tokens = + reinterpret_cast(unzipped_tokens_in); + __nv_bfloat16 *zipped_tokens = + reinterpret_cast<__nv_bfloat16 *>(zipped_tokens_out); + + int local_row_fetchlist[MAX_NUM_EXPERTS]; + +// -------------------------初始化任务表 ------------------------ +#pragma unroll + for (int expert = 0; expert < num_experts; ++expert) { + const int fetch_row = + zipped_expertwise_rowmap[this_row * num_experts + expert]; + local_row_fetchlist[expert] = fetch_row; + } + +#pragma unroll + for (int k = 0; k < topk; ++k) { + const int expert_idx = expert_routemap_topk[this_row * topk + k]; + if (expert_idx < 0) [[likely]] + continue; + const int expert_fetch_row = local_row_fetchlist[expert_idx]; + zipped_probs_topk[this_row * topk + k] = + unzipped_token_probs[expert_fetch_row]; + } + + constexpr int vecSize = 2; // __nv_bfloat162 = 2 x bfloat16 + const int num_full_vec = token_length / vecSize; + const int remaining_elems = token_length % vecSize; + const int thread_stride = blockDim.x * vecSize; + + if constexpr (MP) { + // ------------------------ 手动混合精度 --------------------------------- + // 齐整区域向量化搬移 + for (int x_offset = threadIdx.x * vecSize; + x_offset < num_full_vec * vecSize; + x_offset += thread_stride) { + float2 sum = {0.0f, 0.0f}; + __nv_bfloat162 raw = {0, 0}; + int aggreg_cnt = 0; + __nv_bfloat162 *out_ptr = reinterpret_cast<__nv_bfloat162 *>( + &zipped_tokens[this_row * token_length + x_offset]); +#pragma unroll + for (int expert = 0; expert < num_experts; ++expert) { + const int fetch_row = local_row_fetchlist[expert]; + if (fetch_row < 0) continue; + aggreg_cnt++; + // 手动类型提升 + raw = *reinterpret_cast( + &unzipped_tokens[fetch_row * token_length + x_offset]); + float2 token_vec = __bfloat1622float2(raw); + sum.x = __fadd_rn(token_vec.x, sum.x); + sum.y = __fadd_rn(token_vec.y, sum.y); + } + // 选择性类型下降为原有精度 + *out_ptr = (aggreg_cnt > 1) ? __float22bfloat162_rn(sum) : raw; + } + + // 剩余元素处理 + for (int i = num_full_vec * vecSize + threadIdx.x; i < token_length; + i += blockDim.x) { + float sum = 0.0f; + __nv_bfloat16 raw = 0; + int aggreg_cnt = 0; +#pragma unroll + for (int expert = 0; expert < num_experts; ++expert) { + int fetch_row = local_row_fetchlist[expert]; + if (fetch_row < 0) continue; + aggreg_cnt++; + raw = unzipped_tokens[fetch_row * token_length + i]; + float token_val = __bfloat162float(raw); + sum = __fadd_rn(token_val, sum); + } + zipped_tokens[this_row * token_length + i] = + (aggreg_cnt > 1) ? __float2bfloat16_rn(sum) : raw; + } + } else { + // ------------------------ BF16 intrinsics 加权累加 ----------------------- + // 齐整区域向量化搬移 + for (int x_offset = threadIdx.x * vecSize; + x_offset < num_full_vec * vecSize; + x_offset += thread_stride) { + __nv_bfloat162 sum = {0, 0}; + __nv_bfloat162 *out_ptr = reinterpret_cast<__nv_bfloat162 *>( + &zipped_tokens[this_row * token_length + x_offset]); +#pragma unroll + for (int expert = 0; expert < num_experts; ++expert) { + const int fetch_row = local_row_fetchlist[expert]; + if (fetch_row < 0) continue; + __nv_bfloat162 token_vec = *reinterpret_cast( + &unzipped_tokens[fetch_row * token_length + x_offset]); + sum = __hadd2(sum, token_vec); + } + *out_ptr = sum; + } + + // 剩余元素处理 + for (int i = num_full_vec * vecSize + threadIdx.x; i < token_length; + i += blockDim.x) { + __nv_bfloat16 sum = (__nv_bfloat16)0; +#pragma unroll + for (int expert = 0; expert < num_experts; ++expert) { + int fetch_row = local_row_fetchlist[expert]; + if (fetch_row < 0) continue; + __nv_bfloat16 token_val = unzipped_tokens[fetch_row * token_length + i]; + sum = __hadd(sum, token_val); + } + zipped_tokens[this_row * token_length + i] = sum; + } + } +} __global__ void tokens_zip_kernel( - const float*__restrict__ unzipped_tokens, + const float *__restrict__ unzipped_tokens, const int *__restrict__ zipped_expertwise_rowmap, const int *__restrict__ expert_routemap_topk, const float *__restrict__ unzipped_token_probs, float *__restrict__ zipped_tokens, float *__restrict__ zipped_probs_topk, const int total_zipped_tokens_num, - const int token_length) { + const int token_length, + const int num_experts, + const int topk) { const int this_row = blockIdx.x; if (this_row >= total_zipped_tokens_num) return; - int local_row_fetchlist[num_experts]; + int local_row_fetchlist[MAX_NUM_EXPERTS]; // -------------------------初始化任务表 ------------------------ #pragma unroll @@ -402,7 +534,7 @@ __global__ void tokens_zip_kernel( // ------------------------ 手动混合精度 --------------------------------- // 齐整区域向量化搬移 for (int x_offset = threadIdx.x; x_offset < token_length; - x_offset += thread_stride) { + x_offset += thread_stride) { float sum = 0.0f; #pragma unroll for (int expert = 0; expert < num_experts; ++expert) { @@ -506,9 +638,9 @@ void dispatch_tokens_zip(const paddle::Tensor &unzipped_tokens, block.x = 256; // Map data types to C++ types - if (topk == 8 && num_experts == 4) { - if (unzipped_tokens.dtype() == paddle::DataType::BFLOAT16){ - tokens_zip_kernel<8, 4><<>>( + if (unzipped_tokens.dtype() == paddle::DataType::BFLOAT16) { + if(zipped_probs_topk.dtype() == paddle::DataType::FLOAT32){ + tokens_zip_kernel<<>>( unzipped_tokens.data(), zipped_expertwise_rowmap.data(), expert_routemap_topk.data(), @@ -516,18 +648,34 @@ void dispatch_tokens_zip(const paddle::Tensor &unzipped_tokens, zipped_tokens.data(), zipped_probs_topk.data(), total_zipped_tokens_num, - token_length); - }else if (unzipped_tokens.dtype() == paddle::DataType::FLOAT32){ - tokens_zip_kernel<8, 4><<>>( - unzipped_tokens.data(), + token_length, + num_experts, + topk); + }else if(zipped_probs_topk.dtype() == paddle::DataType::BFLOAT16){ + tokens_zip_kernel<<>>( + unzipped_tokens.data(), zipped_expertwise_rowmap.data(), expert_routemap_topk.data(), - unzipped_token_probs.data(), - zipped_tokens.data(), - zipped_probs_topk.data(), + unzipped_token_probs.data(), + zipped_tokens.data(), + zipped_probs_topk.data(), total_zipped_tokens_num, - token_length); - } + token_length, + num_experts, + topk); + } + } else if (unzipped_tokens.dtype() == paddle::DataType::FLOAT32) { + tokens_zip_kernel<<>>( + unzipped_tokens.data(), + zipped_expertwise_rowmap.data(), + expert_routemap_topk.data(), + unzipped_token_probs.data(), + zipped_tokens.data(), + zipped_probs_topk.data(), + total_zipped_tokens_num, + token_length, + num_experts, + topk); } } @@ -595,7 +743,8 @@ std::vector tokens_zip( const paddle::Tensor &unzipped_token_probs, const int &total_zipped_tokens_num, const int &num_experts) { - PD_CHECK(unzipped_tokens.dtype() == paddle::DataType::BFLOAT16 || unzipped_tokens.dtype() == paddle::DataType::FLOAT32); + PD_CHECK(unzipped_tokens.dtype() == paddle::DataType::BFLOAT16 || + unzipped_tokens.dtype() == paddle::DataType::FLOAT32); const int rows = unzipped_tokens.shape()[0]; // seqlen const int cols = unzipped_tokens.shape()[1]; // 一般为7168 const int topk = expert_routemap_topk.shape()[1]; // 一般为8 @@ -609,13 +758,21 @@ std::vector tokens_zip( unzipped_token_probs.dtype(), unzipped_token_probs.place()); // ----------------------- 0初始化 zipped_probs_topk ------------------ - void *zipped_probs_topk_ptr = - reinterpret_cast(zipped_probs_topk.data()); - cudaMemsetAsync(zipped_probs_topk_ptr, - 0, - sizeof(float) * total_zipped_tokens_num * topk, - unzipped_token_probs.stream()); - + if (unzipped_token_probs.dtype() == paddle::DataType::FLOAT32) { + void *zipped_probs_topk_ptr = + reinterpret_cast(zipped_probs_topk.data()); + cudaMemsetAsync(zipped_probs_topk_ptr, + 0, + sizeof(float) * total_zipped_tokens_num * topk, + unzipped_token_probs.stream()); + } else if (unzipped_token_probs.dtype() == paddle::DataType::BFLOAT16) { + void *zipped_probs_topk_ptr = + reinterpret_cast(zipped_probs_topk.data()); + cudaMemsetAsync(zipped_probs_topk_ptr, + 0, + sizeof(phi::bfloat16) * total_zipped_tokens_num * topk, + unzipped_token_probs.stream()); + } dispatch_tokens_zip(unzipped_tokens, zipped_expertwise_rowmap, expert_routemap_topk, diff --git a/tests/ops/test_unzip_zip.py b/tests/ops/test_unzip_zip.py new file mode 100644 index 000000000000..6318979115db --- /dev/null +++ b/tests/ops/test_unzip_zip.py @@ -0,0 +1,139 @@ +import numpy as np +import paddle +import TokenDispatcherUtils as TDU + + +def fabricate_dispatch_result( + seqlen, token_length, topk, num_experts, data_type="bfloat32", broadcast_ratio=0.5 +): + tokens = paddle.randn([seqlen, token_length], dtype=data_type) + + tokens_scale = paddle.empty([0]) + if data_type == "float8_e4m3fn": + scale_cols = (token_length + 127) // 128 + tokens_scale = paddle.randn([seqlen, scale_cols], dtype="float32") + + # 计算每个token选择的专家数量,更集中在期望值附近 + expected_experts = max(1, min(broadcast_ratio * num_experts, topk)) + + # 使用正态分布生成专家数量,确保集中在期望值附近 + std_dev = max(1, expected_experts / 6) # 标准差设为期望值的1/4,确保集中分布 + experts_count = paddle.normal(expected_experts, std_dev, [seqlen]) + # 四舍五入并裁剪到合理范围 + # experts_count = paddle.clip(paddle.round(experts_count), 1, min(topk, num_experts)) + experts_count = paddle.clip(paddle.round(experts_count), 1, min(topk, num_experts)) + experts_count = paddle.cast(experts_count, "int32") + + # 预分配结果数组 + dispatched_indices = paddle.full([seqlen, topk], -1, dtype="int32") + dispatched_probs = paddle.zeros([seqlen, topk], dtype="float32") + + # 批量生成专家索引和概率 + for i in range(seqlen): + count = experts_count[i].item() + + # 高效生成随机不重复专家索引 + indices = paddle.randperm(num_experts)[:count] + dispatched_indices[i, :count] = indices + + # 高效设置概率值 + prob_value = 1.0 / count + dispatched_probs[i, :count] = paddle.full([count], prob_value, dtype=data_type) + + # 高效计算每个专家的最大token数 + # 创建one-hot编码 + valid_indices = dispatched_indices.reshape([-1]) + valid_mask = valid_indices >= 0 + valid_experts = valid_indices[valid_mask] + + # 使用histogram统计每个专家的token数 + expert_counts = paddle.histogram( + valid_experts, bins=num_experts, min=0, max=num_experts - 1 + ) + expert_counts = paddle.cast(expert_counts, "int32") + print("expert counts: ", expert_counts.numpy()) + max_tokens_per_expert = expert_counts.max().item() + + return ( + tokens, + tokens_scale, + dispatched_indices, + dispatched_probs, + max_tokens_per_expert, + ) + + +def tensor_max_abs_rel_err(a, b, eps=1e-8): + max_abs_err = paddle.max(paddle.abs(a - b)) + denom = paddle.maximum(paddle.abs(a), paddle.abs(b)) + denom = paddle.maximum(denom, paddle.to_tensor(eps, dtype=denom.dtype)) + max_rel_err = paddle.max(paddle.abs(a - b) / denom) + return max_abs_err, max_rel_err + + +def test_unzip_zip(): + SEQLEN = 16384 + TOKEN_LEN = 7168 + for dt in ["bfloat16"]: + for expert_num in [2, 4, 8, 16, 32]: + for topk in [4, 8, 12]: + print("###################################") + print( + "testing with {} experts and topk {}, datatype is {}".format( + expert_num, topk, dt + ) + ) + ( + tokens, + tokens_scale, + dispatched_indices, + dispatched_probs, + max_tokens_per_expert, + ) = fabricate_dispatch_result( + SEQLEN, + TOKEN_LEN, + topk, + expert_num, + data_type=dt, + broadcast_ratio=0.5, + ) + if dt == "bfloat16": + tokens_scale = None + ( + unzipped_tokens, + zipped_expertwise_rowmap, + unzipped_probs, + unzipped_scales, + ) = TDU.tokens_unzip_stable( + tokens, + tokens_scale, + dispatched_indices, + dispatched_probs, + topk=topk, + num_experts=expert_num, + max_tokens_per_expert=max_tokens_per_expert, + ) + tokens_recovered, probs_recovered = TDU.tokens_zip( + (unzipped_tokens * unzipped_probs.unsqueeze(-1)).astype("bfloat16"), + zipped_expertwise_rowmap, + dispatched_indices, + unzipped_probs, + total_zipped_tokens=SEQLEN, + num_experts=expert_num, + ) + print( + "unzip-zip tokens 最大绝对误差:{}, 相对误差:{}".format( + *tensor_max_abs_rel_err(tokens, tokens_recovered) + ) + ) + print( + "unzip-zip probs 最大绝对误差:{}, 相对误差:{}".format( + *tensor_max_abs_rel_err(dispatched_probs, probs_recovered) + ) + ) + + +# core.nvprof_enable_record_event() + +if __name__ == "__main__": + test_unzip_zip()