Skip to content

Commit d5988d0

Browse files
committed
[DSV3] Add PP support for DSV3
1 parent 7aff172 commit d5988d0

File tree

7 files changed

+337
-10
lines changed

7 files changed

+337
-10
lines changed

torchtitan/models/deepseek_v3/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@ CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml"
3333
- Activation checkpointing
3434
- Tensor Parallel (TP)
3535
- Expert Parallel (EP)
36+
- Pipeline Parallel (PP)
3637

3738

3839
## To be added
3940
- Modeling
4041
- Merge DeepSeek-V3 and Llama4 MoE common components
4142
- Parallelism
4243
- Context Parallel support for DeepSeek-V3
43-
- PP support for DeepSeek-V3
4444
- torch.compile
4545
- Quantization
4646
- Testing

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
1616

1717
from .infra.parallelize import parallelize_deepseekv3
18+
from .infra.pipeline import pipeline_deepseekv3
1819
from .model.args import DeepSeekV3ModelArgs
1920
from .model.model import DeepSeekV3Model
2021

@@ -116,7 +117,7 @@
116117
cls=DeepSeekV3Model,
117118
config=deepseekv3_configs,
118119
parallelize_fn=parallelize_deepseekv3,
119-
pipelining_fn=None,
120+
pipelining_fn=pipeline_deepseekv3,
120121
build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights
121122
build_lr_schedulers_fn=build_lr_schedulers,
122123
build_dataloader_fn=build_hf_dataloader,
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This file applies the PT-D pipeline parallelism to the Llama model.
8+
9+
import copy
10+
11+
import torch.nn as nn
12+
from torch.distributed import DeviceMesh
13+
from torch.distributed.pipelining import PipelineStage
14+
from torch.distributed.pipelining.schedules import (
15+
_PipelineSchedule,
16+
get_schedule_class,
17+
PipelineScheduleSingle,
18+
ScheduleZBVZeroBubble,
19+
)
20+
21+
from torchtitan.components.loss import LossFunction
22+
from torchtitan.config_manager import JobConfig
23+
from torchtitan.distributed import ParallelDims
24+
from torchtitan.distributed.pipeline import build_pipeline_schedule, stage_ids_this_rank
25+
from torchtitan.protocols.train_spec import DeviceType, ParallelizeFunction
26+
from torchtitan.tools.logging import logger
27+
28+
from ..model.args import DeepSeekV3ModelArgs
29+
30+
31+
def generate_module_names_per_stage(
32+
num_stages: int,
33+
num_layers: int,
34+
input_weight: int = 1,
35+
output_weight: int = 1,
36+
) -> list[list[str]]:
37+
"""
38+
Programmatically generates module names per stage for pipeline parallelism with weighting.
39+
40+
Args:
41+
num_stages: Number of pipeline stages
42+
num_layers: Total number of transformer layers in the model
43+
input_weight: Weight for input modules (tok_embeddings) in layer calculation
44+
output_weight: Weight for output modules (norm + output) in layer calculation
45+
46+
Returns:
47+
List of lists containing module names for each stage
48+
49+
Example:
50+
generate_module_names_per_stage(2, 3, input_weight=2, output_weight=2)
51+
treats embeddings as 2 layers and norm+output as 2 layers for distribution
52+
"""
53+
if num_stages < 1:
54+
raise ValueError("Number of stages must be at least 1")
55+
56+
if num_stages == 1:
57+
# Single stage gets everything
58+
layer_names = [f"layers.{i}" for i in range(num_layers)]
59+
return [["tok_embeddings"] + layer_names + ["norm", "output"]]
60+
61+
# Calculate effective layers including weights
62+
num_effective_layers = num_layers + input_weight + output_weight
63+
64+
if num_stages > num_effective_layers:
65+
raise ValueError(
66+
f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})"
67+
)
68+
69+
# Calculate layers per stage (distribute evenly)
70+
layers_per_stage = num_effective_layers // num_stages
71+
extra_layers = num_effective_layers % num_stages
72+
73+
# Ensure each stage gets at least the weight of input/output modules
74+
if layers_per_stage < max(input_weight, output_weight):
75+
raise ValueError(
76+
f"Layers per stage ({layers_per_stage}) must be >= max(input_weight={input_weight}, output_weight={output_weight})"
77+
)
78+
79+
module_names_per_stage = []
80+
current_layer = 0
81+
82+
for stage_idx in range(num_stages):
83+
stage_modules = []
84+
85+
# Calculate effective layers for this stage
86+
effective_layers_for_stage = layers_per_stage
87+
if stage_idx < extra_layers:
88+
effective_layers_for_stage += 1
89+
90+
# First stage: handle input modules with weighting
91+
if stage_idx == 0:
92+
stage_modules.append("tok_embeddings")
93+
# Account for input weight in layer distribution
94+
remaining_layers_for_stage = effective_layers_for_stage - input_weight
95+
96+
# Add transformer layers
97+
for _ in range(remaining_layers_for_stage):
98+
if current_layer < num_layers:
99+
stage_modules.append(f"layers.{current_layer}")
100+
current_layer += 1
101+
102+
# Last stage: handle output modules with weighting
103+
elif stage_idx == num_stages - 1:
104+
# Account for output weight in layer distribution
105+
remaining_layers_for_stage = effective_layers_for_stage - output_weight
106+
107+
# Add transformer layers
108+
for _ in range(remaining_layers_for_stage):
109+
if current_layer < num_layers:
110+
stage_modules.append(f"layers.{current_layer}")
111+
current_layer += 1
112+
113+
# Add output modules
114+
stage_modules.extend(["norm", "output"])
115+
116+
# Middle stages: only transformer layers
117+
else:
118+
for _ in range(effective_layers_for_stage):
119+
if current_layer < num_layers:
120+
stage_modules.append(f"layers.{current_layer}")
121+
current_layer += 1
122+
123+
module_names_per_stage.append(stage_modules)
124+
125+
return module_names_per_stage
126+
127+
128+
def pipeline_deepseekv3(
129+
model: nn.Module,
130+
world_mesh: DeviceMesh,
131+
parallel_dims: ParallelDims,
132+
job_config: JobConfig,
133+
device: DeviceType,
134+
model_config: DeepSeekV3ModelArgs,
135+
parallelize_fn: ParallelizeFunction,
136+
loss_fn: LossFunction,
137+
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
138+
pp_mesh = world_mesh["pp"]
139+
140+
# Determine the number of virtual stages based on schedule type
141+
schedule_class = get_schedule_class(
142+
job_config.parallelism.pipeline_parallel_schedule
143+
)
144+
is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle)
145+
146+
# For multi-stage schedules, default is 2 virtual stages per rank
147+
# For single-stage schedules, default is 1 virtual stage per rank
148+
stages_per_rank = 1 if is_single_stage_schedule else 2
149+
num_virtual_stages = parallel_dims.pp * stages_per_rank
150+
151+
# Generate module names per stage programmatically with weighting
152+
num_layers = model_config.n_layers
153+
154+
# You can adjust these weights based on the computational cost of embeddings and output layers
155+
# Higher weights mean these modules are treated as "heavier" in the distribution
156+
input_weight = 1 # Weight for tok_embeddings
157+
output_weight = 1 # Weight for norm + output layers
158+
159+
module_names_per_stage = generate_module_names_per_stage(
160+
num_virtual_stages, num_layers, input_weight, output_weight
161+
)
162+
for i, stage_ms in enumerate(module_names_per_stage):
163+
logger.info(f"Stage {i}: {stage_ms}")
164+
165+
stages, model_parts = pipeline_module_split(
166+
model,
167+
pp_mesh,
168+
job_config.parallelism.pipeline_parallel_schedule,
169+
device,
170+
module_names_per_stage,
171+
)
172+
173+
# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
174+
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
175+
# optimizer, and checkpointing
176+
for i, m in enumerate(model_parts):
177+
# apply SPMD-style PT-D techniques
178+
m = parallelize_fn(m, world_mesh, parallel_dims, job_config)
179+
model_parts[i] = m
180+
# NOTE: this is to update the model in the stage
181+
# in case the model is modified e.g. by torch.compile
182+
stages[i].submod = m
183+
184+
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
185+
186+
# This is used in the train loop to determine whether to pass in the input_ids and labels
187+
has_first_stage = False
188+
has_last_stage = False
189+
for stage in stages:
190+
if stage.is_first:
191+
has_first_stage = True
192+
if stage.is_last:
193+
has_last_stage = True
194+
195+
return pp_schedule, model_parts, has_first_stage, has_last_stage
196+
197+
198+
def pipeline_module_split(
199+
whole_model: nn.Module,
200+
pp_mesh: DeviceMesh,
201+
pp_schedule: str,
202+
device: DeviceType,
203+
module_names_per_stage: list[list[str]],
204+
) -> tuple[list[PipelineStage], list[nn.Module]]:
205+
"""
206+
This API creates pipeline stages based on specified module names for each stage.
207+
208+
Args:
209+
whole_model: The complete model to be split
210+
pp_mesh: Pipeline parallel device mesh
211+
pp_schedule: Name of pipeline parallelism schedule
212+
device: Device type
213+
module_names_per_stage: List of lists, where each inner list contains the module names
214+
that should be included in that stage. Module names should be
215+
dot-separated paths. Examples:
216+
- "tok_embeddings" for token embeddings
217+
- "layers.0", "layers.1" for specific transformer layers
218+
- "norm" for the final normalization layer
219+
- "output" for the output projection layer
220+
221+
Returns:
222+
Tuple of (stages, models) where stages are PipelineStage objects and models are the
223+
corresponding model chunks
224+
225+
Example usage:
226+
module_names_per_stage = [
227+
["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer
228+
["layers.1", "layers.2"], # Stage 1: middle layers
229+
["norm", "output"] # Stage 2: final norm + output
230+
]
231+
"""
232+
pp_rank = pp_mesh.get_local_rank()
233+
pp_size = pp_mesh.size()
234+
235+
def _build_stage_from_modules(
236+
stage_idx: int, module_names: list[str], num_stages: int
237+
) -> tuple[PipelineStage, nn.Module]:
238+
model = copy.deepcopy(whole_model)
239+
240+
# Create a set of modules to keep for faster lookup
241+
modules_to_keep = set(module_names)
242+
print(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}")
243+
for module_name, module_value in model.named_children():
244+
# Handle layer-like structures (e.g., "layers.0", "layers.1")
245+
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
246+
layers_to_keep = {
247+
name.split(".", 1)[1]
248+
for name in modules_to_keep
249+
if name.startswith(f"{module_name}.")
250+
}
251+
if layers_to_keep:
252+
# Keep only specified layers
253+
if isinstance(module_value, nn.ModuleDict):
254+
for layer_name in list(module_value.keys()):
255+
if layer_name not in layers_to_keep:
256+
del module_value[layer_name]
257+
elif isinstance(module_value, nn.ModuleList):
258+
indices_to_keep = {
259+
int(idx) for idx in layers_to_keep if idx.isdigit()
260+
}
261+
new_layers = nn.ModuleList(
262+
[
263+
layer
264+
for i, layer in enumerate(module_value)
265+
if i in indices_to_keep
266+
]
267+
)
268+
setattr(model, module_name, new_layers)
269+
else:
270+
# No layers from this structure needed, set to empty structure
271+
if isinstance(module_value, nn.ModuleDict):
272+
setattr(model, module_name, nn.ModuleDict())
273+
elif isinstance(module_value, nn.ModuleList):
274+
setattr(model, module_name, nn.ModuleList())
275+
# Handle simple module attributes (e.g., "linear", "norm")
276+
elif module_name not in modules_to_keep:
277+
# Replace with identity module instead of None
278+
setattr(model, module_name, nn.Identity())
279+
280+
stage = PipelineStage(
281+
model,
282+
stage_idx,
283+
num_stages,
284+
device,
285+
group=pp_mesh.get_group("pp"),
286+
)
287+
return stage, model
288+
289+
num_stages = len(module_names_per_stage)
290+
stages = []
291+
models = []
292+
293+
schedule_class = get_schedule_class(pp_schedule)
294+
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
295+
296+
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
297+
module_names = module_names_per_stage[stage_idx]
298+
stage, model_chunk = _build_stage_from_modules(
299+
stage_idx,
300+
module_names,
301+
num_stages,
302+
)
303+
logger.info(
304+
f"PP rank {pp_rank} is building stage_idx {stage_idx} "
305+
f"with modules {module_names}"
306+
)
307+
stages.append(stage)
308+
models.append(model_chunk)
309+
310+
return stages, models

torchtitan/models/deepseek_v3/model/args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs):
7575
n_limited_groups: int = 1
7676
score_func: Literal["softmax", "sigmoid"] = "softmax"
7777
route_scale: float = 1.0
78-
use_grouped_mm: bool = True
78+
use_grouped_mm: bool = False
7979
load_balance_coeff: float = 1e-3
8080
# Multi-Head Latent Attention (MLA)
8181
q_lora_rank: int = 0

0 commit comments

Comments
 (0)