Skip to content

Conversation

jeejeelee
Copy link
Collaborator

@jeejeelee jeejeelee commented Sep 2, 2025

Purpose

When deploying the qwen3-coder model using the following scirpt(H20*8):

vllm serve Qwen3-Coder/Qwen3-Coder-480B-A35B-Instruct-FP8 \
  --max-model-len 32000 \
  --enable-expert-parallel \
  --tensor-parallel-size 8 \
  --enable-auto-tool-choice \
  --enforce-eager \
  --tool-call-parser qwen3_coder

the following error is raised.

^[[1;36m(VllmWorker TP5 pid=42616)^[[0;0m INFO 09-01 15:06:46 [gpu_model_runner.py:1986] Model loading took 56.3173 GiB and 15.050946 seconds
^[[1;36m(VllmWorker TP6 pid=42617)^[[0;0m INFO 09-01 15:06:46 [gpu_model_runner.py:1986] Model loading took 56.3173 GiB and 15.214502 seconds
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m INFO 09-01 15:06:47 [fused_moe.py:709] Using configuration from /root/Code/vllm_dev/vllm/vllm/model_executor/layers/fused_moe/configs/triton_3_4_0/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json for MoE layer. 
^[[1;36m(VllmWorker TP2 pid=42613)^[[0;0m INFO 09-01 15:06:47 [fused_moe.py:709] Using configuration from /root/Code/vllm_dev/vllm/vllm/model_executor/layers/fused_moe/configs/triton_3_4_0/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json for MoE layer. 
^[[1;36m(VllmWorker TP4 pid=42615)^[[0;0m INFO 09-01 15:06:47 [fused_moe.py:709] Using configuration from /root/Code/vllm_dev/vllm/vllm/model_executor/layers/fused_moe/configs/triton_3_4_0/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json for MoE layer. 
^[[1;36m(VllmWorker TP7 pid=42618)^[[0;0m INFO 09-01 15:06:47 [fused_moe.py:709] Using configuration from /root/Code/vllm_dev/vllm/vllm/model_executor/layers/fused_moe/configs/triton_3_4_0/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json for MoE layer. 
^[[1;36m(VllmWorker TP3 pid=42614)^[[0;0m INFO 09-01 15:06:47 [fused_moe.py:709] Using configuration from /root/Code/vllm_dev/vllm/vllm/model_executor/layers/fused_moe/configs/triton_3_4_0/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json for MoE layer. 
^[[1;36m(VllmWorker TP5 pid=42616)^[[0;0m INFO 09-01 15:06:47 [fused_moe.py:709] Using configuration from /root/Code/vllm_dev/vllm/vllm/model_executor/layers/fused_moe/configs/triton_3_4_0/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json for MoE layer. 
^[[1;36m(VllmWorker TP6 pid=42617)^[[0;0m INFO 09-01 15:06:47 [fused_moe.py:709] Using configuration from /root/Code/vllm_dev/vllm/vllm/model_executor/layers/fused_moe/configs/triton_3_4_0/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json for MoE layer. 
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611] WorkerProc hit an exception.
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611] Traceback (most recent call last):
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/v1/executor/multiproc_executor.py", line 606, in worker_busy_loop
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     output = func(*args, **kwargs)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]              ^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Soft/miniconda3/envs/py311_vllm_dev/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return func(*args, **kwargs)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]            ^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/v1/worker/gpu_worker.py", line 244, in determine_available_memory
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     self.model_runner.profile_run()
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/v1/worker/gpu_model_runner.py", line 2601, in profile_run
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     = self._dummy_run(self.max_num_tokens, is_profile=True)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Soft/miniconda3/envs/py311_vllm_dev/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return func(*args, **kwargs)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]            ^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/v1/worker/gpu_model_runner.py", line 2378, in _dummy_run
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     outputs = self.model(
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]               ^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Soft/miniconda3/envs/py311_vllm_dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Soft/miniconda3/envs/py311_vllm_dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/model_executor/models/qwen3_moe.py", line 682, in forward
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     hidden_states = self.model(input_ids, positions, intermediate_tensors,
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/compilation/decorators.py", line 223, in __call__
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return self.forward(*args, **kwargs)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/model_executor/models/qwen3_moe.py", line 429, in forward
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     hidden_states, residual = layer(positions, hidden_states, residual)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Soft/miniconda3/envs/py311_vllm_dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Soft/miniconda3/envs/py311_vllm_dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/model_executor/models/qwen3_moe.py", line 368, in forward
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     hidden_states = self.mlp(hidden_states)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]                     ^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Soft/miniconda3/envs/py311_vllm_dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Soft/miniconda3/envs/py311_vllm_dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/model_executor/models/qwen3_moe.py", line 180, in forward
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     final_hidden_states = self.experts(hidden_states=hidden_states,
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Soft/miniconda3/envs/py311_vllm_dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Soft/miniconda3/envs/py311_vllm_dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1598, in forward
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return torch.ops.vllm.moe_forward(
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Soft/miniconda3/envs/py311_vllm_dev/lib/python3.11/site-packages/torch/_ops.py", line 1243, in __call__
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return self._op(*args, **kwargs)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]            ^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1793, in moe_forward
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return self.forward_impl(hidden_states, router_logits)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1702, in forward_impl
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     final_hidden_states = self.quant_method.apply(
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]                           ^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/model_executor/layers/quantization/fp8.py", line 1128, in apply
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return fused_experts(
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]            ^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/model_executor/layers/fused_moe/fused_moe.py", line 1495, in fused_experts
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return dispatch_fused_experts_func(inplace)(
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/model_executor/layers/fused_moe/fused_moe.py", line 1410, in torch_vllm_inplace_fused_experts
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     torch.ops.vllm.inplace_fused_experts(**kwargs)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Soft/miniconda3/envs/py311_vllm_dev/lib/python3.11/site-packages/torch/_ops.py", line 1243, in __call__
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return self._op(*args, **kwargs)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]            ^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/model_executor/layers/fused_moe/fused_moe.py", line 1120, in inplace_fused_experts
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/model_executor/layers/fused_moe/fused_moe.py", line 1672, in fused_experts_impl
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     invoke_fused_moe_kernel(qcurr_hidden_states,
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Code/vllm_dev/vllm/vllm/model_executor/layers/fused_moe/fused_moe.py", line 620, in invoke_fused_moe_kernel
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     fused_moe_kernel[grid](
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Soft/miniconda3/envs/py311_vllm_dev/lib/python3.11/site-packages/triton/runtime/jit.py", line 390, in <lambda>
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Soft/miniconda3/envs/py311_vllm_dev/lib/python3.11/site-packages/triton/runtime/jit.py", line 617, in run
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     ^^^^^^^^^^
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Soft/miniconda3/envs/py311_vllm_dev/lib/python3.11/site-packages/triton/compiler/compiler.py", line 498, in __getattribute__
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     self._init_handles()
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]   File "/root/Soft/miniconda3/envs/py311_vllm_dev/lib/python3.11/site-packages/triton/compiler/compiler.py", line 483, in _init_handles
^[[1;36m(VllmWorker TP1 pid=42612)^[[0;0m ERROR 09-01 15:06:47 [multiproc_executor.py:611]     raise OutOfResources(self.metadata.shared, max_shared, "shared memory")

The root cause should be that #21700 used use-deep-gemm when tuning moe, resulting in an incorrect tuned config.

This PR re-tunes the related configuration, and then to avoid similar problems in the future, when setting tune, set use_deep_gemm=false

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Jee Jee Li <[email protected]>
@mergify mergify bot added performance Performance-related issues qwen Related to Qwen models labels Sep 2, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a crash that occurs when deploying the Qwen3-Coder model with MoE. The root cause was an incorrect tuned configuration file generated with use_deep_gemm. The fix includes the correctly re-tuned configuration and a safeguard in the benchmark_moe.py script to prevent use_deep_gemm during tuning, as it's only supported for Triton kernels.

The changes are logical and address the issue. I have one suggestion on benchmark_moe.py to make the safeguard more explicit by raising an error instead of silently correcting the flag, which will prevent potential user confusion and use of suboptimal configurations. Overall, this is a good fix.

Comment on lines 682 to 683
print("Only supports tuning triton kernels, set use_deep_gemm=False.")
use_deep_gemm = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While this change correctly prevents the crash by disabling use_deep_gemm during tuning, it might be better to fail explicitly rather than silently changing the behavior. A user might intend to tune for DeepGEMM and miss the print statement in the logs, leading them to unknowingly use a Triton-tuned configuration for DeepGEMM workloads, which could be suboptimal.

Consider raising a ValueError to make the incompatibility explicit and guide the user to the correct usage.

Suggested change
print("Only supports tuning triton kernels, set use_deep_gemm=False.")
use_deep_gemm = False
raise ValueError(
"Tuning with --use-deep-gemm is not supported as it only tunes Triton kernels. Please remove the flag.")

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it is a good idea to ban deep_gemm for tuning
I think the option for use_deep_gemm is from previous PR #13932
@bnellnm CC

@jeejeelee
Copy link
Collaborator Author

IIUC, because tune only targets triton moe kernel, if use_deep_gemm is used in tuning, there will be issues with the results

@bnellnm
Copy link
Contributor

bnellnm commented Sep 4, 2025

I'm not sure if it is a good idea to ban deep_gemm for tuning I think the option for use_deep_gemm is from previous PR #13932 @bnellnm CC

I think it is reasonable to disable deep_gemm here since we still might end up running triton if deep_gemm isn't available.

@yewentao256
Copy link
Member

I'm not sure if it is a good idea to ban deep_gemm for tuning I think the option for use_deep_gemm is from previous PR #13932 @bnellnm CC

I think it is reasonable to disable deep_gemm here since we still might end up running triton if deep_gemm isn't available.

@bnellnm Curious why do we introduce deep_gemm to the tuning script before? If not needed any more, shall we delete it throughly in the script?

@jeejeelee
Copy link
Collaborator Author

This script can not only tune MoE kernel but can also benchmark it.

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense to me then, could you also update as the Gemini suggests?
We should avoid implicitly change of param

@bnellnm
Copy link
Contributor

bnellnm commented Sep 5, 2025

I'm not sure if it is a good idea to ban deep_gemm for tuning I think the option for use_deep_gemm is from previous PR #13932 @bnellnm CC

I think it is reasonable to disable deep_gemm here since we still might end up running triton if deep_gemm isn't available.

@bnellnm Curious why do we introduce deep_gemm to the tuning script before? If not needed any more, shall we delete it throughly in the script?

I probably just added the flag everywhere mechanically.

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 5, 2025
Signed-off-by: Jee Jee Li <[email protected]>
@jeejeelee
Copy link
Collaborator Author

Make sense to me then, could you also update as the Gemini suggests? We should avoid implicitly change of param

Done in cc160c9

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

@yewentao256 yewentao256 enabled auto-merge (squash) September 6, 2025 14:09
@yewentao256 yewentao256 merged commit 62f66be into vllm-project:main Sep 7, 2025
42 checks passed
@jeejeelee jeejeelee deleted the fix-benchmark-moe branch September 7, 2025 11:04
eicherseiji pushed a commit to eicherseiji/vllm that referenced this pull request Sep 9, 2025
skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants