Skip to content

Conversation

ilmarkov
Copy link
Contributor

@ilmarkov ilmarkov commented Jun 12, 2025

Additional performance optimizations after #18778

Tune CUTLASS configs for M >= 128.
For Llama 8B on B200, these tunings offer a GEMM improvement of:

  • 1.13 to 1.32x speedup for M= 128
  • up to 1.4x speedup for M=256
  • up to 1.2x speedup for M>=512

Kernel benchmarks using #17126 on B200.
python benchmarks/kernels/bench_fp8_gemm.py --model meta-llama/Llama-3.1-8B-Instruct --tp-sizes 1

meta-llama/Llama-3.1-8B-Instruct, N=6144 K=4096, BF16 vs FP8 GEMMs TFLOP/s

batch_size fp8-channel-w-token-a-noquant before fp8-channel-w-token-a-noquant after speedup channel fp8-tensor-w-tensor-a-noquant before fp8-tensor-w-tensor-a-noquant after speedup tensor
128 915.05 1045.85 1.14 905.81 1020.35 1.13
256 1478.03 1386.12 0.94 1464.20 1362.24 0.93
512 1795.06 1891.02 1.05 1752.24 1838.72 1.05
1024 2496.48 2805.51 1.12 2255.31 2341.43 1.04
2048 2780.10 2974.92 1.07 2303.57 2623.87 1.14
4096 2131.02 3096.40 1.45 2329.41 2464.44 1.06
8192 3024.36 3141.42 1.04 2297.25 2747.69 1.20
16384 2447.03 2526.07 1.03 2544.68 2547.96 1.00

meta-llama/Llama-3.1-8B-Instruct, N=4096 K=4096, BF16 vs FP8 GEMMs TFLOP/s

batch_size fp8-channel-w-token-a-noquant before fp8-channel-w-token-a-noquant after speedup channel fp8-tensor-w-tensor-a-noquant before fp8-tensor-w-tensor-a-noquant after speedup tensor
128 612.35 707.66 1.16 610.76 703.11 1.15
256 993.06 1256.10 1.27 992.84 1240.99 1.25
512 1906.16 1911.99 1.00 1832.05 1887.05 1.03
1024 2354.26 2366.93 1.01 2076.66 2373.50 1.14
2048 2353.71 2895.17 1.23 2459.18 2358.63 0.96
4096 2767.91 3074.91 1.11 2284.51 2684.30 1.17
8192 2785.40 2690.46 0.97 2452.45 2534.51 1.03
16384 2876.28 3106.67 1.08 2483.94 2746.92 1.11

meta-llama/Llama-3.1-8B-Instruct, N=28672 K=4096, BF16 vs FP8 GEMMs TFLOP/s

batch_size fp8-channel-w-token-a-noquant before fp8-channel-w-token-a-noquant after speedup channel fp8-tensor-w-tensor-a-noquant before fp8-tensor-w-tensor-a-noquant after speedup tensor
128 1041.23 1193.94 1.15 1025.66 1195.29 1.17
256 1803.80 1917.11 1.06 1816.42 1730.27 0.95
512 2458.94 2598.96 1.06 2065.77 2306.82 1.12
1024 2761.78 2750.88 1.00 2369.36 2528.56 1.07
2048 2658.87 3203.53 1.20 2318.98 2557.95 1.10
4096 2933.46 3143.50 1.07 2376.12 2758.24 1.16
8192 2892.25 3245.67 1.12 2375.57 2743.85 1.16
16384 2516.17 3128.93 1.24 2330.53 2461.43 1.06

meta-llama/Llama-3.1-8B-Instruct, N=4096 K=14336, BF16 vs FP8 GEMMs TFLOP/s

batch_size fp8-channel-w-token-a-noquant before fp8-channel-w-token-a-noquant after speedup channel fp8-tensor-w-tensor-a-noquant before fp8-tensor-w-tensor-a-noquant after speedup tensor
128 933.90 1235.35 1.32 931.74 1226.11 1.32
256 1456.98 2058.74 1.41 1455.69 1914.42 1.32
512 2655.40 2781.46 1.05 2257.84 2294.90 1.02
1024 2628.98 2842.13 1.08 2297.41 2513.93 1.09
2048 2350.75 2980.99 1.27 2573.86 2766.97 1.07
4096 2868.60 2962.87 1.03 2383.17 2532.65 1.06
8192 2873.71 3012.48 1.05 2463.67 2598.51 1.05
16384 3059.99 3047.74 1.00 2456.67 2771.49 1.13

Raw results:

# B200 original tunings
Final fp8 results

meta-llama/Llama-3.1-8B-Instruct, N=6144 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     5.916485               5.753715               5.330122                       8.427494                       8.492594
1         16.0    94.814923              95.310200              87.932563                     137.824707                     139.242251
2         64.0   402.507107             372.321682             343.232151                     533.888100                     542.573360
3        128.0   729.184544             649.331042             601.438116                     905.808861                     915.048684
4        256.0  1132.751952            1040.795575            1001.030304                    1464.199092                    1478.029512
5        512.0  1321.631173            1346.470214            1287.622689                    1752.236884                    1795.059333
6       1024.0  1455.638913            1889.658170            1827.840432                    2255.309570                    2496.482565
7       2048.0  1376.049153            1936.787426            1967.928953                    2303.567956                    2780.100566
8       4096.0  1478.027416            2157.047029            2101.412171                    2329.406419                    2131.023902
9       8192.0  1379.167675            1930.552235            2169.440446                    2297.250630                    3024.361306
10     16384.0  1512.843919            2225.828721            2248.795494                    2544.683145                    2447.031661
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     2.735701               3.960087               3.638411                       5.857841                       5.863624
1         16.0    72.857795              63.743366              58.346328                      93.357504                      93.512478
2         64.0   286.303008             252.376350             231.058220                     367.518648                     368.735161
3        128.0   561.330810             436.344018             403.591567                     610.759122                     612.352349
4        256.0   892.904450             722.014118             673.181730                     992.838121                     993.062243
5        512.0  1335.789318            1244.028423            1168.758716                    1832.046345                    1906.160681
6       1024.0  1362.674710            1595.984515            1507.854834                    2076.664305                    2354.263940
7       2048.0  1444.418219            1754.939659            1742.990726                    2459.176463                    2353.708208
8       4096.0  1401.484139            1880.406713            1855.367146                    2284.509454                    2767.905774
9       8192.0  1398.889877            1955.402314            1911.768342                    2452.454092                    2785.403049
10     16384.0  1460.068325            1966.350816            1920.191376                    2483.944616                    2876.278402
meta-llama/Llama-3.1-8B-Instruct, N=28672 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     5.874869               9.026715               8.794950                      10.324740                      10.312664
1         16.0    83.964112             142.523231             138.664943                     163.408862                     163.315558
2         64.0   324.667142             547.798568             536.520240                     625.708992                     626.914940
3        128.0   625.926013             931.562074             921.267795                    1025.664907                    1041.226725
4        256.0  1050.271848            1592.183316            1679.035360                    1816.422508                    1803.803735
5        512.0  1167.108021            1742.851040            2179.777668                    2065.771791                    2458.939473
6       1024.0  1183.067620            2015.266752            2015.761440                    2369.364333                    2761.776963
7       2048.0  1349.371659            2267.846249            2524.110341                    2318.975770                    2658.874489
8       4096.0  1373.148830            2493.818778            2769.848289                    2376.119108                    2933.456316
9       8192.0  1444.179074            2467.145220            2832.606598                    2375.569830                    2892.251298
10     16384.0  1640.933397            2179.746326            2779.491915                    2330.529514                    2516.174763
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=14336, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     4.791137               6.954432               6.315147                      10.151014                      10.180551
1         16.0    80.863465             110.846927             101.118255                     161.630774                     162.609635
2         64.0   304.159817             436.142282             399.738175                     638.139034                     641.458877
3        128.0   597.426272             686.429752             642.180653                     931.739553                     933.902498
4        256.0   937.115129            1068.134125            1032.962169                    1455.692971                    1456.978210
5        512.0  1213.062456            1562.788481            1620.324747                    2257.837265                    2655.399612
6       1024.0  1305.196388            1894.087803            1877.716799                    2297.411520                    2628.976703
7       2048.0  1561.465072            2085.494941            2025.323917                    2573.858329                    2350.745222
8       4096.0  1420.453839            1842.320062            1901.444339                    2383.170734                    2868.599415
9       8192.0  1448.519465            1907.495063            2139.841349                    2463.669148                    2873.705030
10     16384.0  1688.603781            1946.983184            2163.580978                    2456.667968                    3059.988651

# B200 new tunings
meta-llama/Llama-3.1-8B-Instruct, N=6144 K=4096, BF16 vs FP8 GEMMs TFLOP/s:                                                                                                                                                                                         11:02:40 [44/1840]
BF16 vs FP8 GEMMs:                                                                                                                                                                                                                                                                        batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     5.916331               5.752971               5.331652                       8.431181                       8.492365
1         16.0    94.832504              95.340008              87.978731                     137.806246                     139.224049
2         64.0   402.560830             372.396811             343.311274                     534.094811                     542.683472
3        128.0   729.298288             712.367899             649.664917                    1020.351690                    1045.847411
4        256.0  1100.851307            1012.834663             943.266136                    1362.239253                    1386.120373
5        512.0  1291.479695            1421.702930            1347.557716                    1838.718713                    1891.015141
6       1024.0  1379.674717            1944.752836            1926.004290                    2341.433773                    2805.510343
7       2048.0  1309.916283            1933.263790            2076.636611                    2623.874093                    2974.917372
8       4096.0  1360.217076            2272.197151            2188.802666                    2464.444326                    3096.404425
9       8192.0  1524.163460            2074.211280            2194.345177                    2747.688572                    3141.422329
10     16384.0  1417.419230            2204.637002            2467.752537                    2547.963118                    2526.073582
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs: 
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     2.735327               3.961801               3.639225                       5.857957                       5.864160
1         16.0    72.850743              63.756404              58.353337                      93.363137                      93.533756
2         64.0   286.291697             252.373232             231.051912                     367.701854                     368.926455
3        128.0   561.452649             488.928025             445.815221                     703.111253                     707.661004
4        256.0   877.802644             857.525203             785.913864                    1240.990405                    1256.104390
5        512.0  1328.234177            1252.490708            1162.430576                    1887.046609                    1911.990213
6       1024.0  1359.444243            1654.777509            1570.108659                    2373.502210                    2366.933668
7       2048.0  1340.725112            1945.152203            1815.402300                    2358.633534                    2895.170557
8       4096.0  1399.770130            2075.425013            1929.797405                    2684.297310                    3074.905032
9       8192.0  1380.581480            1990.289004            2022.863034                    2534.505820                    2690.458973
10     16384.0  1597.126436            1886.911935            2099.725676                    2746.915436                    3106.673213
meta-llama/Llama-3.1-8B-Instruct, N=28672 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     5.872733               9.056824               8.831987                      10.328445                      10.324831
1         16.0    83.907331             142.358127             138.957479                     163.077286                     162.950189
2         64.0   325.719953             547.661152             536.217431                     625.238793                     626.667364
3        128.0   627.114404            1036.629114            1026.800390                    1195.285892                    1193.937513
4        256.0  1063.593126            1578.888886            1680.148068                    1730.266538                    1917.106395
5        512.0  1211.245299            2137.538200            2221.104601                    2306.823109                    2598.962932
6       1024.0  1261.211925            2265.508692            2678.558117                    2528.556402                    2750.880922
7       2048.0  1479.313044            2519.701912            2487.436762                    2557.952353                    3203.530597
8       4096.0  1411.236934            2548.094918            3050.396496                    2758.244470                    3143.503067
9       8192.0  1518.417110            2485.848437            3054.438074                    2743.848234                    3245.666514
10     16384.0  1486.188091            2505.924104            2727.295027                    2461.434182                    3128.932457
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=14336, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     4.789876               7.018027               6.362752                      10.150914                      10.182799                                                                                                                                               
1         16.0    80.944360             110.808334             101.124611                     161.641379                     162.232945
2         64.0   304.287137             436.538206             399.630826                     637.981253                     641.965898
3        128.0   581.212977             830.911039             773.134201                    1226.113255                    1235.353949
4        256.0   923.419097            1318.384174            1252.268795                    1914.419550                    2058.737917
5        512.0  1192.748590            1760.185693            1672.037736                    2294.904028                    2781.462489
6       1024.0  1390.944196            1972.711955            1846.207296                    2513.925690                    2842.126215
7       2048.0  1526.332614            2086.987818            2131.239790                    2766.971373                    2980.985212
8       4096.0  1420.221997            1947.634285            2079.659206                    2532.649174                    2962.865672
9       8192.0  1423.135510            2158.495821            2122.273246                    2598.510939                    3012.476439
10     16384.0  1547.783987            2071.583046            2076.774398                    2771.493920                    3047.740912

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ilmarkov, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request implements additional performance tunings for the SM100 FP8 CUTLASS GEMM kernel. The changes involve adjusting kernel configurations based on the matrix M dimension (batch size) to improve TFLOP/s, particularly for larger batch sizes, as evidenced by the included benchmark data.

Highlights

  • Performance Tuning: Further optimized FP8 GEMM kernels for SM100 (Blackwell) GPUs, building on previous work.
  • Targeted Optimizations: Introduced specific CUTLASS configurations (TileShape, ClusterShape) tailored for different ranges of the M dimension (batch size), particularly focusing on M >= 128.
  • Benchmark Results: Provided detailed benchmark results on B200 hardware using Llama-3.1-8B-Instruct, demonstrating speedups ranging from 1.13x to 1.4x for various matrix dimensions and batch sizes, especially in the M=128 and M=256 ranges.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces performance optimizations for SM100 FP8 CUTLASS kernels. It includes adjustments to tile and cluster shapes, and a new configuration for M values where 128 < next_pow_2(M) <= 256. The changes are well-structured and the dispatch logic has been updated correctly.

template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_default {
// M in (128, inf)
// M in (256, inf)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment indicates that this configuration applies to M in the range (256, inf). Double check that this is the intended range, and that the upper bound is indeed unbounded.

template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M256 {
// M in (128, 256]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment indicates that this configuration applies to M in the range (128, 256]. Consider adding a unit test to verify that this config is used when M is 129, 192, and 256.

out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
// m in (256, inf)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment indicates that this configuration applies to M in the range (256, inf). Consider adding a unit test to verify that this config is used when M is 257, 512, and 1024.

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

@houseroad houseroad added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 13, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @ilmarkov !

@simon-mo simon-mo merged commit e13945f into vllm-project:main Jun 15, 2025
103 of 105 checks passed
@mgoin mgoin deleted the imarkov/fp8_cutlass_configs branch June 15, 2025 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants