Skip to content

some question about Fp8 Performance profile #39

@JimpleM

Description

@JimpleM

I have use nsight compute to test the example code(tma_gemm.py as follow),but got the result is about 695.29GB/s

  • inst number of TMA: 19.97K (samed as the result)
  • L2 to Shared Memory: 2.70TB/s(not same)

my environment is following:

pytorch-triton           3.2.0+git0d4682f0
torch                    2.7.0.dev20250116+cu126
torchaudio               2.6.0.dev20250116+cu126
torchvision              0.22.0.dev20250116+cu126
#tma_gemm.py
import triton
import triton.language as tl
import numpy as np
import torch

@triton.jit
def gemm_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr,  #
                      prob_m, prob_n, prob_k, block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr):
    
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(prob_m, block_m)
    num_pid_k = tl.cdiv(prob_k, block_k)
    pid_m = pid % num_pid_m
    pid_n = pid // num_pid_m
    offs_am = pid_m * block_m
    offs_bn = pid_n * block_n
    offs_k = 0

    accumulator = tl.zeros((block_m, block_n), dtype=tl.float32)
    for kk in range(0, num_pid_k):

        a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [block_m, block_k], tl.float8e4nv)
        b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [block_n, block_k], tl.float8e4nv)
        
        accumulator = tl.dot(a, b.T, acc=accumulator, out_dtype=tl.float32)
        offs_k += block_k

    accumulator = accumulator.to(tl.float16)
    tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])


def matmul(a, b, config=None):

    m, _ = a.shape
    n, k = b.shape
    
    block_m = 64
    block_n = 64
    block_k = 128
    num_warps = 4
    num_stages = 4
    TMA_SIZE = 128

    if config:
        block_m = config["block_m"]
        block_n = config["block_n"]
        block_k = config["block_k"]
        num_warps = config["num_warps"]
        num_stages = config["num_stages"]
        TMA_SIZE = config["TMA_SIZE"]

    print(block_m,block_n,block_k,num_warps,num_stages,TMA_SIZE)

    desc_a = np.empty(TMA_SIZE, dtype=np.int8)
    desc_b = np.empty(TMA_SIZE, dtype=np.int8)
    desc_c = np.empty(TMA_SIZE, dtype=np.int8)

    c = torch.empty((m, n), dtype=torch.float16, device='cuda')
    triton.runtime.driver.active.utils.fill_2d_tma_descriptor(a.data_ptr(), m, k, block_m, block_k, a.element_size(),desc_a)
    triton.runtime.driver.active.utils.fill_2d_tma_descriptor(b.data_ptr(), n, k, block_n, block_k, b.element_size(),desc_b)
    triton.runtime.driver.active.utils.fill_2d_tma_descriptor(c.data_ptr(), m, n, block_m, block_n, c.element_size(),desc_c)
    desc_a = torch.tensor(desc_a, device='cuda')
    desc_b = torch.tensor(desc_b, device='cuda')
    desc_c = torch.tensor(desc_c, device='cuda')

    total_blocks_m = triton.cdiv(m, block_m)
    total_blocks_n = triton.cdiv(n, block_n)
    
    grid = (total_blocks_m * total_blocks_n, 1, 1)
    k = gemm_kernel_tma[grid](
        desc_a, desc_b, desc_c,
        m, n, k,
        block_m,
        block_n,
        block_k,
        num_warps=num_warps,
        num_stages=num_stages,
    )

    # with open('tma_fp8.ttgir', 'w') as f:
    #      print(k.asm['ttgir'], file=f)

    # with open('tma_fp8.ptx', 'w') as f:
    #      print(k.asm['ptx'], file=f)

    return c


if __name__ == '__main__':

    M = 128
    N = 4096
    K = 4096

    torch.manual_seed(0)
    a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn)
    b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn)
    b = b.T.contiguous()

    config = {
        "block_m":64,
        "block_n":64,
        "block_k":256,
        "num_warps":4,
        "num_stages":4,
        "TMA_SIZE":512
    }

    c = matmul(a, b,config=config)
    print(c)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions