Skip to content

Commit e7fa57e

Browse files
authored
Extract eh_proj Layer from ParallelLMHead for MTP to Avoid Weight Transposition Issue (#2707)
* fix mtp eh_proj layer * fix mtp update_cfg function * fix stringdoc * simplify class name
1 parent a5ae88d commit e7fa57e

File tree

3 files changed

+136
-4
lines changed

3 files changed

+136
-4
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
import paddle
18+
from paddle import nn
19+
from paddle.distributed import fleet
20+
21+
from .utils import get_tensor
22+
23+
24+
class ParallelEHProjection(nn.Layer):
25+
"""
26+
"Parallelized Embedding Hidden States Projection.
27+
"""
28+
29+
def __init__(
30+
self,
31+
fd_config,
32+
num_embeddings,
33+
embedding_dim,
34+
prefix="",
35+
with_bias=False,
36+
):
37+
"""
38+
Parallelized Embedding Hidden States Projection.
39+
40+
Args:
41+
fd_config (FDConfig): Arguments related to inference, containing
42+
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
43+
num_attention_heads, and ffn_hidden_size.
44+
num_embeddings (int): vocabulary size.
45+
embedding_dim (int): size of hidden state.
46+
prefix (str): full name of the layer in the state dict
47+
"""
48+
super(ParallelEHProjection, self).__init__()
49+
self.linear_weight_key = prefix + ".weight"
50+
if with_bias:
51+
self.linear_bias_key = prefix + ".bias"
52+
else:
53+
self.linear_bias_key = None
54+
self.use_ep = fd_config.parallel_config.use_ep
55+
self.column_cut = True
56+
57+
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
58+
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
59+
60+
if self.use_ep:
61+
self.weight = self.create_parameter(
62+
shape=[embedding_dim, num_embeddings],
63+
dtype=paddle.get_default_dtype(),
64+
is_bias=False,
65+
)
66+
else:
67+
if self.column_cut:
68+
need_gather = True
69+
self.out_linear = ColumnParallelLinear(
70+
embedding_dim,
71+
num_embeddings,
72+
mp_group=fleet.get_hybrid_communicate_group().
73+
get_model_parallel_group(),
74+
weight_attr=None,
75+
has_bias=True
76+
if self.linear_bias_key is not None else False,
77+
gather_output=need_gather,
78+
fuse_matmul_bias=False, # False diff更小
79+
)
80+
else:
81+
self.out_linear = RowParallelLinear(
82+
embedding_dim,
83+
num_embeddings,
84+
mp_group=fleet.get_hybrid_communicate_group().
85+
get_model_parallel_group(),
86+
weight_attr=None,
87+
has_bias=True
88+
if self.linear_bias_key is not None else False,
89+
input_is_parallel=False,
90+
fuse_matmul_bias=False, # False diff更小
91+
)
92+
93+
def load_state_dict(self, state_dict):
94+
"""
95+
Load the checkpoint state dictionary into the layer.
96+
97+
Args:
98+
state_dict (dict): A dictionary containing the checkpoint weights and biases.
99+
"""
100+
101+
if self.use_ep:
102+
self.weight.set_value(
103+
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
104+
paddle.get_default_dtype()))
105+
else:
106+
weight_tensor = get_tensor(
107+
state_dict.pop(self.linear_weight_key)).astype(
108+
paddle.get_default_dtype())
109+
if self.out_linear.weight.shape != weight_tensor.shape:
110+
weight_tensor = weight_tensor.transpose([1, 0])
111+
self.out_linear.weight.set_value(weight_tensor)
112+
113+
if self.linear_bias_key is not None:
114+
bias = get_tensor(state_dict.pop(self.linear_bias_key)).astype(
115+
paddle.get_default_dtype())
116+
self.out_linear.bias.set_value(bias)
117+
118+
def forward(self, input):
119+
"""
120+
Defines the forward computation of the layer.
121+
122+
Args:
123+
input (Tensor): The input tensor to the layer.
124+
125+
Returns:
126+
Tensor: The output tensor after processing through the layer.
127+
"""
128+
logits = input
129+
if self.use_ep:
130+
logits = paddle.matmul(logits, self.weight)
131+
else:
132+
logits = self.out_linear(logits)
133+
return logits

fastdeploy/model_executor/models/ernie4_5_mtp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from paddleformers.utils.log import logger
2727

2828
from fastdeploy.config import FDConfig, ModelConfig
29-
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
29+
from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection
3030
from fastdeploy.model_executor.layers.normalization import RMSNorm
3131
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer
3232
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
@@ -286,7 +286,7 @@ def __init__(
286286
prefix="ernie.mtp_hidden_norm.0",
287287
)
288288

289-
self.eh_proj = ParallelLMHead(
289+
self.eh_proj = ParallelEHProjection(
290290
fd_config=fd_config,
291291
num_embeddings=fd_config.model_config.hidden_size,
292292
embedding_dim=fd_config.model_config.hidden_size * 2,

fastdeploy/spec_decode/mtp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@ def _update_cfg(self, main_model):
6868
"""
6969
Update config for MTP from global config
7070
"""
71-
self.model_config.architectures[0] = self.model_config.architectures[
72-
0].replace("MoeForCausalLM", "MTPForCausalLM")
71+
self.model_config.architectures[0] = "Ernie4_5_MTPForCausalLM"
7372
self.speculative_config.sharing_model = main_model
7473
self.model_config.num_layers = 1
7574
self.parallel_config.model_name_or_path = (

0 commit comments

Comments
 (0)