Skip to content

Build mxfp4 kernel for sm120a #2285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

Conversation

gau-nernst
Copy link
Collaborator

Just making some quick changes here to see if I can build mxfp4 kernel on 5090 (sm120). Eventually this will be put under torchao._C_cutlass_120a?

Setting -DCUTLASS_DEBUG_TRACE_LEVEL=1 so I can see debug trace.

To build (using torch==2.8.0.dev20250530+cu128)

TORCH_CUDA_ARCH_LIST=12.0a uv pip install -e . -v --no-build-isolation

Running pytest test/prototype/mx_formats/test_mx_mm.py -v

/home/thien/code/ao/third_party/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h:244    workspace_bytes: 0
/home/thien/code/ao/third_party/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h:312  GemmUniversal::initialize() - workspace 0, stream: null
/home/thien/code/ao/third_party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp:201  to_underlying_arguments():
/home/thien/code/ao/third_party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp:214    WARNING: Arguments do not include a valid SM count.
  For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.
/home/thien/code/ao/third_party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp:218  to_underlying_arguments(): Setting persistent grid SM count to 170
/home/thien/code/ao/third_party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp:224    WARNING: Arguments do not include a valid max cluster count.
  For optimal performance, populate the arguments KernelHardwareInfo struct with the max_active_clusters.
/home/thien/code/ao/third_party/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h:336    Setting smem size to 101376
/home/thien/code/ao/third_party/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h:343    cudaFuncSetAttribute() returned error: invalid resource handle

cudaFuncSetAttribute() returned error: invalid resource handle means that the function is invalid? https://github.com/NVIDIA/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/gemm/device/gemm_universal_adapter.h#L338, which is quite strange...

For reference, I can build and run the example from Cutlass here https://github.com/NVIDIA/cutlass/blob/v3.9.2/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu. The changes in this PR has been taken from this example. When building with CUTLASS_DEBUG_TRACE_LEVEL=1, there are also warnings in sm90_gemm_tma_warpspecialized_cooperative.hpp, so that is probably not the issue.

@drisspg

cc @alexsamardzic in case you faced this error with Cutlass before

Copy link

pytorch-bot bot commented May 31, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2285

Note: Links to docs will display an error until the docs builds have been completed.

❌ 10 New Failures, 1 Unrelated Failure

As of commit 5abfe97 with merge base e51ffd9 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 31, 2025
@drisspg
Copy link
Contributor

drisspg commented May 31, 2025

The first thing that comes to mind is that example is doing NVfp4 where all our recipes are doing MXfp4, e.g. https://github.com/pytorch/ao/pull/2285/files#diff-e155558499c3b1fbab1b5d3b60f032bf1e636908a8ef50a1de33bff518107019R240-R241 needs to change as well. For inference we have MXFP8 and MXFP4 support I am planning to add an NVFP4 scaling recipe next, that being said I would imagine that MXFP4 is supported on 5090..

cc @syed-ahmed

@gau-nernst
Copy link
Collaborator Author

I noticed that as well

  • Changing the torchao kernel to nvfp4 results in the same error
  • Changing the cutlass example to mxfp4 still works

😭

@syed-ahmed
Copy link
Contributor

Per cutlass docs, I believe MXFP4 is supported in 5090: https://github.com/NVIDIA/cutlass/blob/9d165a3b8ef446a7ff3db198413f82bcb83f46fe/media/docs/cpp/blackwell_functionality.md#blackwell-sm120-gemms

However note the section that talks about the differences with sm100. So it's possible we need more changes to the kernel in torch ao. Also what CUDA version are you using? I'd assume you'd need a fairly recent CUDA version. I'll try to guide more next week.

@gau-nernst
Copy link
Collaborator Author

@syed-ahmed I'm using CUDA 12.9

The strange thing is that the cutlass example works, but the one in torchao doesn't. I carefully compared the two, and I don't spot any difference in the template arguments.

@syed-ahmed
Copy link
Contributor

How about the test? Are the inputs similar to the cutlass example?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants