Skip to content

[feat] support fa3 backend for pd disaggregated #2695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@ default_stages:
- pre-commit # Run locally
# - manual # Run in CI
repos:
# 格式化
- repo: https://github.com/google/yapf
rev: v0.43.0
hooks:
- id: yapf
args: [--in-place, --verbose]
# 代码检查
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7
Expand All @@ -29,15 +23,6 @@ repos:
rev: 6.0.1
hooks:
- id: isort
# # 格式化
# - repo: https://github.com/pre-commit/mirrors-clang-format
# rev: v20.1.3
# hooks:
# - id: clang-format
# # exclude: '.*'
# types_or: [c++, cuda]
# args: [--style=file, --verbose]

# markdown
- repo: https://github.com/jackdewinter/pymarkdown
rev: v0.9.29
Expand Down
117 changes: 58 additions & 59 deletions custom_ops/0001-DeepGEMM-95e81b3.patch
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ index 15b22ca..63e7fb7 100644
@@ -1,4 +1,4 @@
-import torch
+import paddle

from . import jit
from .jit_kernels import (
diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh
Expand All @@ -53,7 +53,7 @@ index c17d466..6fdc52f 100644
-from torch.utils.cpp_extension import CUDA_HOME
+from ..paddle_utils import CUDA_HOME
from typing import Tuple

from . import interleave_ffma
diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py
index fcb377e..db9d6f3 100644
Expand All @@ -65,8 +65,8 @@ index fcb377e..db9d6f3 100644
import subprocess
-from torch.utils.cpp_extension import CUDA_HOME
+from ..paddle_utils import CUDA_HOME


def run_cuobjdump(file_path):
diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py
index 66c370a..4761426 100644
Expand All @@ -78,7 +78,7 @@ index 66c370a..4761426 100644
-import torch
+import paddle
from typing import Optional

from .template import map_ctype
@@ -35,7 +35,7 @@ class Runtime:
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
Expand All @@ -100,8 +100,8 @@ index ead37f5..51b02c1 100644
-import torch
+import paddle
from typing import Any, Dict, Iterable, Tuple


# Name map for Python `eval`
typename_map: Dict[Any, str] = {
**{t: t.__name__ for t in (bool, int, float)},
Expand All @@ -116,15 +116,15 @@ index ead37f5..51b02c1 100644
+ paddle.float8_e4m3fn: 'paddle.float8_e4m3fn',
+ paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
}

# `ctype` map for Python casting
ctype_map: Dict[Any, Any] = {
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
- **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
+ **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)},
}


@@ -27,25 +27,25 @@ genc_map = {
bool: ('bool', 'bool'),
int: ('int', 'int'),
Expand All @@ -140,8 +140,8 @@ index ead37f5..51b02c1 100644
+ paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
+ paddle.device.cuda.Stream: ('void*', 'cudaStream_t'),
}


def map_ctype(value: Any) -> Any:
if hasattr(value, 'data_ptr'):
- if value.dtype == torch.int:
Expand Down Expand Up @@ -171,11 +171,11 @@ index cb438b7..44aa0ed 100644
+import paddle
from functools import lru_cache
from typing import Tuple

@@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config


-def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
- rhs: Tuple[torch.Tensor, torch.Tensor],
- out: torch.Tensor) -> None:
Expand All @@ -189,7 +189,7 @@ index cb438b7..44aa0ed 100644
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
- this function will do a transposing with a set of slow PyTorch operations.
+ this function will do a transposing with a set of slow paddle operations.

Arguments:
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
Expand All @@ -202,10 +202,10 @@ index cb438b7..44aa0ed 100644
@@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
n, k_ = rhs.shape
m_, n_ = out.shape

- assert n % 64 == 0 and k % 128 == 0
+ # assert n % 64 == 0 and k % 128 == 0

# Type and shape checks
- assert m == m_ and n == n_ and k == k_
- assert n > 0 and k > 0
Expand All @@ -223,13 +223,13 @@ index cb438b7..44aa0ed 100644
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
+ # assert out.dtype == paddle.bfloat16
+ # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()

# LHS scales must be transposed for TMA load, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
- assert rhs_scales.is_contiguous()
+ # assert rhs_scales.is_contiguous()

# Do nothing if `m` is zero
if m == 0:
@@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
Expand Down Expand Up @@ -264,12 +264,12 @@ index 3b518c9..ba776bd 100644
-import torch
+import paddle
from typing import Tuple

from .gemm import get_best_configs, get_block_n_padding_for_smem_d
@@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout,
"""


-def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
- rhs: Tuple[torch.Tensor, torch.Tensor],
- out: torch.Tensor, m_indices: torch.Tensor) -> None:
Expand All @@ -285,7 +285,7 @@ index 3b518c9..ba776bd 100644
+ this function will do a transposing with a set of slow Pypaddle operations.
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
`get_m_alignment_for_contiguous_layout()` (128).

Arguments:
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
Expand All @@ -301,7 +301,7 @@ index 3b518c9..ba776bd 100644
Values of `m_indices` in every-m-alignment-block must also be the same.
@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
m__ = m_indices.numel()

# Type and shape checks
- assert m == m_ == m__ and k == k_ and n == n_
- assert lhs_scales.shape == (m, (k + 127) // 128)
Expand All @@ -321,12 +321,12 @@ index 3b518c9..ba776bd 100644
+ # assert m_indices.dtype == paddle.int32
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
+ # assert out.is_contiguous() and m_indices.is_contiguous()

# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
- assert rhs_scales.is_contiguous()
+ # assert rhs_scales.is_contiguous()

# Do nothing if `m` is zero
if m == 0:
@@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
Expand Down Expand Up @@ -357,8 +357,8 @@ index 3b518c9..ba776bd 100644
)
@@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
runtime(*args)


-def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
- rhs: Tuple[torch.Tensor, torch.Tensor],
- out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
Expand All @@ -374,7 +374,7 @@ index 3b518c9..ba776bd 100644
+ this function will do a transposing with a set of slow paddle operations.
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
should be separately transposed.

Arguments:
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
Expand All @@ -386,7 +386,7 @@ index 3b518c9..ba776bd 100644
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
num_groups___ = masked_m.numel()

# Type and shape checks
- assert num_groups == num_groups_ == num_groups__ == num_groups___
- assert m == m_ and n == n_ and k == k_
Expand All @@ -410,16 +410,16 @@ index 3b518c9..ba776bd 100644
+ # assert masked_m.dtype == paddle.int32
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
+ # assert out.is_contiguous() and masked_m.is_contiguous()

# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
- assert rhs_scales.is_contiguous()
+ # assert rhs_scales.is_contiguous()

# Auto-tuning with compilation
global includes, template
@@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]

args = (lhs, lhs_scales, rhs, rhs_scales, out,
masked_m, m,
- torch.cuda.current_stream(), num_sms, smem_config[0])
Expand Down Expand Up @@ -454,11 +454,11 @@ index 6ed6749..9e1d70f 100644
-import torch
+import paddle
from typing import Any, Dict

from ..jit import build, cpp_format, generate, Runtime
@@ -51,10 +51,10 @@ class JITTuner:
continue

# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
Expand All @@ -478,39 +478,39 @@ index c6da56b..a17b1b1 100644
@@ -1,4 +1,4 @@
-import torch
+import paddle

_num_sms = None

@@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None:
num_sms: the desired maximum SM count for all GEMM kernels to use.
"""
global _num_sms
- assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
+ assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count
_num_sms = num_sms


@@ -25,7 +25,7 @@ def get_num_sms() -> int:
"""
global _num_sms
if _num_sms is None:
- _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
+ _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
return _num_sms


@@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
return ceil_div(x, alignment) * alignment


-def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
+def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor:
"""
- Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
+ Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along the M axis
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.

@@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
m, n = x.shape[-2], x.shape[-1]
aligned_m = get_tma_aligned_size(m, x.element_size())
Expand All @@ -519,14 +519,14 @@ index c6da56b..a17b1b1 100644
+ if x.strides[0] == 1 and x.strides[1] == aligned_m:
return x
x, remove_dim = x.unsqueeze(0), True

b = x.shape[0]

# The last kernel gives a column-major TMA aligned layout
- if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
+ if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m:
return x.squeeze(0) if remove_dim else x

# Normal layout requires transposing
- aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
+ aligned_x = paddle.transpose(
Expand Down Expand Up @@ -574,28 +574,28 @@ index d5cdd01..5237f09 100644
-import torch.distributed as dist
+import paddle
+import paddle.distributed as dist


def bench(fn, num_warmups: int = 5, num_tests: int = 10,
high_precision: bool = False):
# Flush L2 cache with 256 MB data
- torch.cuda.synchronize()
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
+ paddle.device.cuda.synchronize()
+ paddle.device.synchronize()
+ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
cache.zero_()

# Warmup
@@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,

# Add a large kernel to eliminate the CPU launch overhead
if high_precision:
- x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
- y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
+ x = paddle.randn((8192, 8192), dtype=paddle.float32)
+ y = paddle.randn((8192, 8192), dtype=paddle.float32)
x @ y

# Testing
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
Expand All @@ -607,9 +607,9 @@ index d5cdd01..5237f09 100644
end_event.record()
- torch.cuda.synchronize()
+ paddle.device.synchronize()

return start_event.elapsed_time(end_event) / num_tests

@@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
Expand All @@ -636,8 +636,7 @@ index d5cdd01..5237f09 100644
- torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
+ paddle.empty(flush_l2_size, dtype=paddle.int32).zero_()
fn()

if not using_nsys:
--
--
2.43.0

Loading