Skip to content

Add fused_transpose_quant op #10601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dsv3_dev
Choose a base branch
from

Conversation

lshpku
Copy link

@lshpku lshpku commented May 16, 2025

PR types

Performance optimization

PR changes

APIs

Description

新增fused_transpose_quant算子,等价于以下代码:

def fused_transpose_quant(x):
    N, L, C = paddle.shape(x)
    x = x.reshape([N, L, C // 128, 128]).astype('float32')
    scale = ComputeScale(x.abs().max(axis=-1))     # scale.shape = [N, L, C//128]
    x = (x / scale.unsqueeze(-1)).astype('float8_e4m3fn')
    out = x.reshape([N, L, C]).transpose(0, 2, 1)  # out.shape = [N, C, L]
    return out, scale

实现亮点

  1. 使用128x128的分块进行transpose,实现高并行度
  2. 由于C维保证128对齐,因此总是使用4x向量化进行读取
  3. 由于H维无对齐保证,因此写回时区分了1x/2x/4x向量化实例

性能测试

在A100-40G上做了初步测试,由于A100不支持fp8,因此我在写回cast fp32 to fp8的时候用fp32 to int8进行了模拟,完整测试等H卡环境配好了再测

输入x.shape 用时(ns) 带宽(GBps) 带宽利用率 说明
[4, 4096, 7168] 338,641 1051 67.6% 4x向量化
[4, 4096 + 2, 7168] 383,741 928 59.7% 2x向量化
[4, 4096 + 1, 7168] 411,918 864 55.6% 无向量化

Pcard-85711

Copy link

paddle-bot bot commented May 16, 2025

Thanks for your contribution!

@lshpku lshpku force-pushed the fused-transpose-quant branch 4 times, most recently from 1eee742 to 5cde9d5 Compare May 16, 2025 08:41
@lshpku lshpku force-pushed the fused-transpose-quant branch from 5cde9d5 to a9fbd10 Compare May 16, 2025 09:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant