Skip to content

Conversation

xyang16
Copy link
Contributor

@xyang16 xyang16 commented Aug 18, 2025

Purpose

This PR add routed_scaling_factor to MoE grouped topk.

  • In vllm/model_executor/layers/fused_moe/layer.py, add routed_scaling_factor parameter to grouped_topk().
  • In vllm/model_executor/layers/quantization/fp8.py, pass routed_scaling_factor to torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8() instead of hardcoded 1.0.

Note: rocm aiter grouped topk already has routed_scaling_factor, see here.

transformers reference: https://github.com/huggingface/transformers/blob/v4.55.3/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py#L145

(Optional) Documentation Update


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the rocm Related to AMD ROCm label Aug 18, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the routed_scaling_factor to the MoE grouped top-k logic, which is a valuable addition for controlling expert contributions. The changes are propagated through many layers, including various quantization methods.

However, I've noticed some inconsistencies in the application of routed_scaling_factor. While it's correctly applied in some paths (like grouped_topk and the standard top-k in SGLFusedMOE), it's missing in others. For example, in vllm/model_executor/layers/fused_moe/cpu_fused_moe.py, the routed_scaling_factor is added to IPEXFusedMOE.__call__ but is not passed to the underlying ipex_fusion function, making it an unused parameter. Similar omissions exist for non-grouped top-k paths in other files, which are outside the scope of the current diff but should be addressed for consistency.

For correctness and consistency, I recommend applying the routed_scaling_factor to all routing paths or clarifying if this omission is intentional. The issue with IPEXFusedMOE should also be addressed.

@mergify mergify bot added the deepseek Related to DeepSeek models label Aug 18, 2025
@xyang16 xyang16 force-pushed the moe branch 2 times, most recently from 2a6cdfb to a2d24d5 Compare August 20, 2025 16:32
@xyang16 xyang16 force-pushed the moe branch 2 times, most recently from f8820de to 9f16b11 Compare August 25, 2025 19:12
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, just would like some assertions or implementations added where the arg is ignored at the moment. I think this would be helped out if we made it an Optional[float] and defaulted to None. This would also allow us to skip the multiply in the noop case

@xyang16 xyang16 force-pushed the moe branch 4 times, most recently from 6ef2b16 to 03254df Compare August 26, 2025 01:39
@xyang16
Copy link
Contributor Author

xyang16 commented Aug 26, 2025

@mgoin Since the routed_scaling_factor doesn't default to None in model config, see here, I still keep routed_scaling_factor default to 1.0, to keep the type consistent.

But I have added the check to skip the multiply if routed_scaling_factor is 1.0:

        if routed_scaling_factor != 1.0:
            topk_weights = topk_weights * routed_scaling_factor

Please let me know if this works. Thanks!

@mgoin
Copy link
Member

mgoin commented Aug 29, 2025

Could you run an eval on deepseek? A bit worried of leaving something behind here since we don't have great quantized moe tests at the moment

@xyang16
Copy link
Contributor Author

xyang16 commented Aug 29, 2025

Thanks for your review! I have run the eval based on this PR and #23274 combined. Eval result posted in #23274.

Pasting result here as well:

lm_eval --model local-completions \
  --model_args model=deepseek-ai/DeepSeek-R1,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=8 \
  --tasks gsm8k

Baseline:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9538|±  |0.0058|
|     |       |strict-match    |     5|exact_match|↑  |0.9538|±  |0.0058|

This PR:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.956|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.956|±  |0.0056|

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 29, 2025
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) August 29, 2025 23:44
@vllm-bot vllm-bot merged commit 8fb85b7 into vllm-project:main Aug 30, 2025
45 of 50 checks passed
@xyang16 xyang16 deleted the moe branch August 31, 2025 01:54
@josephrocca
Copy link

I noticed some buggy/weird outputs when building from main and running DeepSeek w4a16, and eventually narrowed it down to this commit. I've tested on H100, H200, and B200, and the issue is the same. To reproduce:

vllm serve RedHatAI/DeepSeek-R1-0528-quantized.w4a16 --tensor-parallel-size 8 --max-model-len 8192

And send a simple request like "Respond with 'cat' and nothing more", and about 10% of the time it'll go haywire and output a whole essay about cats or something.

Then test building vLLM before this commit, and it works fine - i.e. responds with 'cat' 100% of the time.

@yewentao256
Copy link
Member

#24118
Found this is the commit that introduces this accuracy issue for R1

@xyang16
Copy link
Contributor Author

xyang16 commented Sep 4, 2025

@yewentao256 I can revert this PR, please let me know if it works. Thanks.

@yewentao256
Copy link
Member

yewentao256 commented Sep 4, 2025

@yewentao256 I can revert this PR, please let me know if it works. Thanks.

No worries, it is been fixed now @xyang16 #23123

eicherseiji pushed a commit to eicherseiji/vllm that referenced this pull request Sep 9, 2025
Signed-off-by: Xin Yang <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: Xin Yang <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants