Skip to content

update default sm num #10586

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
update default sm num
  • Loading branch information
phlrain committed May 13, 2025
commit 78b950883c259c99d419357ac2bac9a07e985040
2 changes: 1 addition & 1 deletion ops/csrc/fp8/deep_gemm/jit_kernels/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def auto_tuning_with_compilation(m, n, k, num_sms):
return runtime, num_sms, smem_size


def gemm_fp8_fp8_bf16_nt(lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, num_sms=112) -> None:
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, num_sms=132) -> None:
"""
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
Expand Down
4 changes: 2 additions & 2 deletions ops/csrc/fp8/deep_gemm/jit_kernels/m_grouped_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def auto_tuning_with_compilation_grouped_gemm_contiguous(m, n, k, num_groups, nu


def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, m_indices: Tensor, num_sms=112
lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, m_indices: Tensor, num_sms=132
) -> None:
"""
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
Expand Down Expand Up @@ -215,7 +215,7 @@ def auto_tuning_with_compilation_grouped_gemm_masked(m, expected_m, n, k, num_gr


def m_grouped_gemm_fp8_fp8_bf16_nt_masked(
lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, masked_m: Tensor, expected_m: int, num_sms=112
lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, masked_m: Tensor, expected_m: int, num_sms=132
) -> None:
"""
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
Expand Down
Loading