Skip to content

Adding transposed scale support for dtaq op #10582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: dsv3_dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ constexpr int64_t TILE_SIZE = 128; // 每个 block 处理 128x128 的元素块


#define BLOCK_SIZE 128
template <typename OutT, bool using_pow2_scaling, bool padding_last_dim_to_8x>
template <typename OutT,
bool using_pow2_scaling,
bool padding_last_dim_to_8x,
bool input_scale_transpose = true,
bool output_scale_transpose = true>
__global__ void FusedActDequantTransposeActQuant(
const phi::float8_e4m3fn *__restrict__ Xin,
const float *__restrict__ Xscale,
Expand All @@ -46,10 +50,18 @@ __global__ void FusedActDequantTransposeActQuant(
// 原始fp8 scale读入smem_max,用于后续dequant
// ------------------------------
if (threadIdx.y == 0) {
for (int i = threadIdx.x; i < BLOCK_SIZE; i += blockDim.x) {
smem_max[i] = Xscale[i];
for (int y_offset = threadIdx.x; y_offset < BLOCK_SIZE;
y_offset += blockDim.x) {
if constexpr (input_scale_transpose) {
smem_max[y_offset] =
Xscale[blockIdx.x * gridDim.y + g_block_y_offset + y_offset];
} else {
smem_max[y_offset] =
Xscale[(g_block_y_offset + y_offset) * gridDim.x + blockIdx.x];
}
}
}

__syncthreads(); // smem_tile中的Xscale数据已ready

// 阶段1:
Expand Down Expand Up @@ -101,6 +113,7 @@ __global__ void FusedActDequantTransposeActQuant(
if (threadIdx.x == 0)
smem_max[y_offset] = local_max; // x0 顺序写,复用,无conflict
}
__syncthreads();

// 阶段3:
// Output放缩强转 + Scale写回‌
Expand Down Expand Up @@ -129,8 +142,13 @@ __global__ void FusedActDequantTransposeActQuant(
g_output_x_offset < g_output_inner_stride) {
out[g_output_y_offset * g_output_inner_stride + g_output_x_offset] =
(g_output_x_offset < rows) ? output_scaled_fp8 : (OutT)0;
scales[g_output_y_offset * g_scale_inner_stride +
g_output_x_offset / 128] = scale_on_fp8_to_inputT;
if constexpr (output_scale_transpose) {
scales[g_output_x_offset / 128 * cols + g_output_y_offset] =
scale_on_fp8_to_inputT;
} else {
scales[g_output_y_offset * g_scale_inner_stride +
g_output_x_offset / 128] = scale_on_fp8_to_inputT;
}
}
}
}
Expand Down Expand Up @@ -179,8 +197,10 @@ std::vector<paddle::Tensor> fused_act_dequant_transpose_act_quant(
8; // 向上padding到8的倍数, 因为128为8的倍数,不影响scale shape
}
out = paddle::empty({cols, rows}, paddle::DataType::FLOAT8_E4M3FN, X.place());
// scale = paddle::empty(
// {cols, (rows + 127) / 128}, paddle::DataType::FLOAT32, X.place());
scale = paddle::empty(
{cols, (rows + 127) / 128}, paddle::DataType::FLOAT32, X.place());
{(rows + 127) / 128, cols}, paddle::DataType::FLOAT32, X.place());

dispatch_fused_act_dequant_transpose_act_quant<phi::float8_e4m3fn>(
X,
Expand Down
41 changes: 30 additions & 11 deletions tests/ops/test_fused_act_dequant_transpose_act_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import paddle.incubate.nn.functional as F
#import test_quant
import FusedQuantOps as FQO
from paddle.base import core
REP=1

'''
Swiglu Function:
Expand Down Expand Up @@ -51,8 +53,8 @@ def compare_tensors(a, b):
print("[最大相对差距]" f"位置: {max_rel_idx}")
print(f"a[{max_rel_idx}] = {a[max_rel_idx]:.6g}" + f"\t b[{max_rel_idx}] = {b[max_rel_idx]:.6g}" + f"\t 相对差值: {max_rel_val:.6g}\n")
print("周围元素比较-a:")
print(f"{a[max_rel_idx[0], (max_rel_idx[1] - 10):(max_rel_idx[1] + 10)]} ")
print("周围元素比较-b:")
print(f"{a[max_rel_idx[0], (max_rel_idx[1] - 10):(max_rel_idx[1] + 10)]} ")
print(f"{b[max_rel_idx[0], (max_rel_idx[1] - 10):(max_rel_idx[1] + 10)]} ")

# 返回结构化结果
Expand All @@ -79,30 +81,47 @@ def printany(te):
print("-"*20)

def verify_swiglu_quant_result():
for width in [130, 4098, 7168]:
for height in [256, 1026, 4098]:
print("#"*60 + f" Testing width:{width}, height:{height} " + "#"*60)
for width in [7168]:
for height in [32768]:
#print("#"*60 + f" Testing width:{width}, height:{height} " + "#"*60)
x= paddle.clip(paddle.randn([height, width]).astype("bfloat16"), min=-50, max=50)
for padding in [False]:
pad_tag = "Padded: True" if padding is not None else "Padded: False"
print("-" * 20 + f"Testing with {pad_tag}" + "-" * 20)
x_fp8, scale = FQO.fused_act_quant(x, transpose_output=True, padding_last_dim_to_8x=padding, using_pow2_scaling=False)
fused_res, fused_scales = FQO.fused_act_dequant_transpose_act_quant(x_fp8,scale,padding_last_dim_to_8x=padding,using_pow2_scaling=False)
x_fp8, scale = FQO.fused_act_quant(x, transpose_output=False, padding_last_dim_to_8x=padding, using_pow2_scaling=False)
scale_t = scale.T.contiguous()
core.nvprof_start()
for i in range(REP):
if i > 1: core.nvprof_nvtx_push("fused_dtaq")
fused_res, fused_scales = FQO.fused_act_dequant_transpose_act_quant(x_fp8,scale_t,padding_last_dim_to_8x=padding,using_pow2_scaling=False)
if i > 1: core.nvprof_nvtx_pop()
np_results=[]
golden_res = x
np_results.append(golden_res.astype("float").numpy())
for i in range(REP):
if i > 1: core.nvprof_nvtx_push("original")
golden_res, golden_scales = FQO.fused_act_quant(FQO.fused_act_dequant(x_fp8,scale).T.contiguous(),transpose_output=False, padding_last_dim_to_8x=padding, using_pow2_scaling=False)
if i > 1: core.nvprof_nvtx_pop()

golden_res = dequantize_fp8_to_bf16(golden_res, golden_scales)
#golden_res = x.T.contiguous()
np_results.append(golden_res.astype("float32").numpy())
if padding:
dequanted_sliced_result = dequantize_fp8_to_bf16(fused_res, fused_scales)
np_results.append(dequanted_sliced_result[:, :height].numpy())
np_results.append(dequanted_sliced_result[:, :height].astype("float32").numpy())
else:
np_results.append(dequantize_fp8_to_bf16(fused_res, fused_scales).numpy())
np_results.append(dequantize_fp8_to_bf16(fused_res, fused_scales.T.contiguous().astype("float32")).numpy())
nan_cnt_golden, nan_cnt_fused= np.sum(np.isnan(np_results[0])), np.sum(np.isnan(np_results[1]))
print(f"Nan count of Golden result: {nan_cnt_golden}; Nan count of Fused result: {nan_cnt_fused}")
try:
np.testing.assert_allclose(np_results[0], np_results[1], rtol=0.01, atol=1) #存在截断误差,atol=1,通常在1e-6
np.testing.assert_allclose(np_results[0], np_results[1], rtol=0.01, atol=1e-2) #存在截断误差,atol=1,通常在1e-6
print("+++++++ Passed ++++++++")
print(np_results[0])
print("-----------------------")
print(np_results[1])
except AssertionError as err:
print(err)
print(np_results[0])
print("-----------------------")
print(np_results[1])
compare_tensors(np_results[0], np_results[1])

def run():
Expand Down
Loading