Skip to content

Commit 01f317c

Browse files
committed
[DSV3] Add PP support for DSV3
1 parent b74918a commit 01f317c

File tree

5 files changed

+386
-9
lines changed

5 files changed

+386
-9
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
Download tokenizer:
2+
3+
```
4+
# DeepSeek tokenizer (automatically downloads tokenizer.json and tokenizer_config.json)
5+
python scripts/download_tokenizer.py --repo_id deepseek-ai/DeepSeek-V3
6+
```
7+
8+
Run:
9+
10+
Single GPU - debug_model
11+
```
12+
NGPU=1 LOG_RANK=0 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh
13+
```
14+
15+
FSDP:
16+
17+
```
18+
NGPU=8 LOG_RANK=0 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --parallelism.data_parallel_shard_degree 8
19+
20+
# OOM
21+
NGPU=8 LOG_RANK=0 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.data_parallel_shard_degree 8
22+
```
23+
24+
PP:
25+
26+
for additional logging use: TORCH_LOGS=+pp
27+
28+
```
29+
NGPU=2 LOG_RANK=0,1 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --parallelism.pipeline_parallel_degree 2
30+
31+
NGPU=4 LOG_RANK=0,4 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --parallelism.pipeline_parallel_degree 4
32+
33+
# works with AC=none, but why doesn't this work with AC=full?
34+
NGPU=8 LOG_RANK=0,7 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.pipeline_parallel_degree 8 --parallelism.pipeline_parallel_schedule Interleaved1F1B
35+
```

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
1616

1717
from .infra.parallelize import parallelize_deepseekv3
18+
from .infra.pipeline import pipeline_deepseekv3
1819
from .model.args import DeepSeekV3ModelArgs
1920
from .model.model import DeepSeekV3Model
2021

@@ -116,7 +117,7 @@
116117
cls=DeepSeekV3Model,
117118
config=deepseekv3_configs,
118119
parallelize_fn=parallelize_deepseekv3,
119-
pipelining_fn=None,
120+
pipelining_fn=pipeline_deepseekv3,
120121
build_optimizers_fn=build_optimizers,
121122
build_lr_schedulers_fn=build_lr_schedulers,
122123
build_dataloader_fn=build_hf_dataloader,
Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This file applies the PT-D pipeline parallelism to the Llama model.
8+
9+
import copy
10+
11+
import torch.nn as nn
12+
import torch
13+
from torch.distributed import DeviceMesh
14+
from torch.distributed.pipelining import PipelineStage
15+
from torch.distributed.pipelining.schedules import (
16+
_PipelineSchedule,
17+
get_schedule_class,
18+
ScheduleZBVZeroBubble,
19+
)
20+
21+
from torchtitan.components.loss import LossFunction
22+
from torchtitan.config_manager import JobConfig
23+
from torchtitan.distributed import ParallelDims
24+
from torchtitan.distributed.pipeline import (
25+
build_pipeline_schedule,
26+
stage_ids_this_rank,
27+
)
28+
from torchtitan.protocols.train_spec import DeviceType, ParallelizeFunction
29+
from torchtitan.tools.logging import logger
30+
31+
# Should I use BaseModelArgs instead of DeepSeekV3ModelArgs?
32+
from ..model.args import DeepSeekV3ModelArgs
33+
34+
35+
def _pipeline_friendly_forward(self, tokens: torch.Tensor):
36+
"""
37+
Pipeline friendly forward pass for the DeepSeekV3 model.
38+
This method is only used when pipeline parallelism is enabled.
39+
If model attributes are None, they are skipped in the forward pass.
40+
41+
Args:
42+
tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).
43+
44+
Returns:
45+
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
46+
"""
47+
h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens
48+
# h: (batch_size, seq_len, dim)
49+
for layer in self.layers.values():
50+
h = layer(h, self.freqs_cis)
51+
h = self.norm(h) if self.norm is not None else h
52+
output = self.output(h) if self.output is not None else h
53+
return output
54+
55+
56+
def _patch_model_for_pipeline(model: nn.Module):
57+
"""
58+
Patches the model's forward method to be pipeline-friendly.
59+
This only affects models used in pipeline parallelism.
60+
61+
Args:
62+
model: The model to patch
63+
"""
64+
# Store the original forward method
65+
if not hasattr(model, '_original_forward'):
66+
model._original_forward = model.forward
67+
# Replace with pipeline-friendly version
68+
model.forward = _pipeline_friendly_forward.__get__(model, model.__class__)
69+
70+
71+
def generate_module_names_per_stage(
72+
num_stages: int,
73+
num_layers: int,
74+
input_weight: int = 1,
75+
output_weight: int = 1,
76+
) -> list[list[str]]:
77+
"""
78+
Programmatically generates module names per stage for pipeline parallelism with weighting.
79+
80+
Args:
81+
num_stages: Number of pipeline stages
82+
num_layers: Total number of transformer layers in the model
83+
input_weight: Weight for input modules (tok_embeddings) in layer calculation
84+
output_weight: Weight for output modules (norm + output) in layer calculation
85+
86+
Returns:
87+
List of lists containing module names for each stage
88+
89+
Example:
90+
generate_module_names_per_stage(2, 3, input_weight=2, output_weight=2)
91+
treats embeddings as 2 layers and norm+output as 2 layers for distribution
92+
"""
93+
if num_stages < 1:
94+
raise ValueError("Number of stages must be at least 1")
95+
96+
if num_stages == 1:
97+
# Single stage gets everything
98+
layer_names = [f"layers.{i}" for i in range(num_layers)]
99+
return [["tok_embeddings"] + layer_names + ["norm", "output"]]
100+
101+
# Calculate effective layers including weights
102+
num_effective_layers = num_layers + input_weight + output_weight
103+
104+
if num_stages > num_effective_layers:
105+
raise ValueError(
106+
f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})"
107+
)
108+
109+
# Calculate layers per stage (distribute evenly)
110+
layers_per_stage = num_effective_layers // num_stages
111+
extra_layers = num_effective_layers % num_stages
112+
113+
# Ensure each stage gets at least the weight of input/output modules
114+
if layers_per_stage < max(input_weight, output_weight):
115+
raise ValueError(
116+
f"Layers per stage ({layers_per_stage}) must be >= max(input_weight={input_weight}, output_weight={output_weight})"
117+
)
118+
119+
module_names_per_stage = []
120+
current_layer = 0
121+
122+
for stage_idx in range(num_stages):
123+
stage_modules = []
124+
125+
# Calculate effective layers for this stage
126+
effective_layers_for_stage = layers_per_stage
127+
if stage_idx < extra_layers:
128+
effective_layers_for_stage += 1
129+
130+
# First stage: handle input modules with weighting
131+
if stage_idx == 0:
132+
stage_modules.append("tok_embeddings")
133+
# Account for input weight in layer distribution
134+
remaining_layers_for_stage = effective_layers_for_stage - input_weight
135+
136+
# Add transformer layers
137+
for _ in range(remaining_layers_for_stage):
138+
if current_layer < num_layers:
139+
stage_modules.append(f"layers.{current_layer}")
140+
current_layer += 1
141+
142+
# Last stage: handle output modules with weighting
143+
elif stage_idx == num_stages - 1:
144+
# Account for output weight in layer distribution
145+
remaining_layers_for_stage = effective_layers_for_stage - output_weight
146+
147+
# Add transformer layers
148+
for _ in range(remaining_layers_for_stage):
149+
if current_layer < num_layers:
150+
stage_modules.append(f"layers.{current_layer}")
151+
current_layer += 1
152+
153+
# Add output modules
154+
stage_modules.extend(["norm", "output"])
155+
156+
# Middle stages: only transformer layers
157+
else:
158+
for _ in range(effective_layers_for_stage):
159+
if current_layer < num_layers:
160+
stage_modules.append(f"layers.{current_layer}")
161+
current_layer += 1
162+
163+
module_names_per_stage.append(stage_modules)
164+
165+
return module_names_per_stage
166+
167+
def pipeline_deepseekv3(
168+
model: nn.Module,
169+
world_mesh: DeviceMesh,
170+
parallel_dims: ParallelDims,
171+
job_config: JobConfig,
172+
device: DeviceType,
173+
model_config: DeepSeekV3ModelArgs,
174+
parallelize_fn: ParallelizeFunction,
175+
loss_fn: LossFunction,
176+
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
177+
pp_mesh = world_mesh["pp"]
178+
179+
# Determine the number of virtual stages based on schedule type
180+
schedule_class = get_schedule_class(job_config.parallelism.pipeline_parallel_schedule)
181+
is_single_stage_schedule = schedule_class.__name__ in ["PipelineScheduleSingle"]
182+
183+
# For multi-stage schedules, default is 2 virtual stages per rank
184+
# For single-stage schedules, default is 1 virtual stage per rank
185+
stages_per_rank = 1 if is_single_stage_schedule else 2
186+
num_virtual_stages = parallel_dims.pp * stages_per_rank
187+
188+
# Generate module names per stage programmatically with weighting
189+
num_layers = model_config.n_layers
190+
191+
# You can adjust these weights based on the computational cost of embeddings and output layers
192+
# Higher weights mean these modules are treated as "heavier" in the distribution
193+
input_weight = 1 # Weight for tok_embeddings
194+
output_weight = 1 # Weight for norm + output layers
195+
196+
module_names_per_stage = generate_module_names_per_stage(
197+
num_virtual_stages, num_layers, input_weight, output_weight
198+
)
199+
for i, stage_ms in enumerate(module_names_per_stage):
200+
logger.info(f"Stage {i}: {stage_ms}")
201+
202+
stages, model_parts = pipeline_deepseekv3_module_split(
203+
model, pp_mesh, parallel_dims, job_config, device, module_names_per_stage)
204+
205+
# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
206+
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
207+
# optimizer, and checkpointing
208+
for i, m in enumerate(model_parts):
209+
# apply SPMD-style PT-D techniques
210+
m = parallelize_fn(m, world_mesh, parallel_dims, job_config)
211+
model_parts[i] = m
212+
# NOTE: this is to update the model in the stage
213+
# in case the model is modified e.g. by torch.compile
214+
stages[i].submod = m
215+
216+
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
217+
218+
# This is used in the train loop to determine whether to pass in the input_ids and labels
219+
has_first_stage = False
220+
has_last_stage = False
221+
for stage in stages:
222+
if stage.is_first:
223+
has_first_stage = True
224+
if stage.is_last:
225+
has_last_stage = True
226+
227+
return pp_schedule, model_parts, has_first_stage, has_last_stage
228+
229+
def pipeline_deepseekv3_module_split(
230+
whole_model: nn.Module,
231+
pp_mesh: DeviceMesh,
232+
parallel_dims: ParallelDims,
233+
job_config: JobConfig,
234+
device: DeviceType,
235+
module_names_per_stage: list[list[str]],
236+
) -> tuple[list[PipelineStage], list[nn.Module]]:
237+
"""
238+
This API creates pipeline stages based on specified module names for each stage.
239+
240+
Args:
241+
whole_model: The complete DeepSeekV3Model to be split
242+
pp_mesh: Pipeline parallel device mesh
243+
parallel_dims: Parallel dimensions configuration
244+
job_config: Job configuration
245+
device: Device type
246+
module_names_per_stage: List of lists, where each inner list contains the module names
247+
that should be included in that stage. Module names should be
248+
dot-separated paths. Examples:
249+
- "tok_embeddings" for token embeddings
250+
- "layers.0", "layers.1" for specific transformer layers
251+
- "norm" for the final normalization layer
252+
- "output" for the output projection layer
253+
254+
Returns:
255+
Tuple of (stages, models) where stages are PipelineStage objects and models are the
256+
corresponding model chunks
257+
258+
Example usage:
259+
module_names_per_stage = [
260+
["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer
261+
["layers.1", "layers.2"], # Stage 1: middle layers
262+
["norm", "output"] # Stage 2: final norm + output
263+
]
264+
"""
265+
pp_rank = pp_mesh.get_local_rank()
266+
pp_size = pp_mesh.size()
267+
parallelism_config = job_config.parallelism
268+
269+
def _build_stage_from_modules(
270+
stage_idx: int,
271+
module_names: list[str],
272+
is_first: bool = False,
273+
is_last: bool = False,
274+
) -> tuple[PipelineStage, nn.Module]:
275+
model = copy.deepcopy(whole_model)
276+
277+
# Patch the model to use pipeline-friendly forward method
278+
_patch_model_for_pipeline(model)
279+
280+
# Create a set of modules to keep for faster lookup
281+
modules_to_keep = set(module_names)
282+
283+
# Handle embeddings - remove if not in this stage and not first stage
284+
if "tok_embeddings" not in modules_to_keep:
285+
model.tok_embeddings = None
286+
287+
# Handle layers - remove layers not specified for this stage
288+
layers_to_keep = set()
289+
for name in modules_to_keep:
290+
if name.startswith("layers."):
291+
# Extract layer number (e.g., "layers.0" -> "0")
292+
layer_num = name.split(".", 1)[1]
293+
layers_to_keep.add(layer_num)
294+
295+
# Remove layers not in this stage
296+
for layer_name in list(model.layers.keys()):
297+
if layer_name not in layers_to_keep:
298+
del model.layers[layer_name]
299+
300+
# Handle final normalization layer
301+
if "norm" not in modules_to_keep:
302+
model.norm = None
303+
304+
# Handle output projection layer
305+
if "output" not in modules_to_keep:
306+
model.output = None
307+
308+
stage = PipelineStage(
309+
model,
310+
stage_idx,
311+
len(module_names_per_stage),
312+
device,
313+
group=pp_mesh.get_group("pp"),
314+
)
315+
return stage, model
316+
317+
num_stages = len(module_names_per_stage)
318+
stages = []
319+
models = []
320+
321+
schedule_class = get_schedule_class(parallelism_config.pipeline_parallel_schedule)
322+
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
323+
324+
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
325+
module_names = module_names_per_stage[stage_idx]
326+
stage, model_chunk = _build_stage_from_modules(
327+
stage_idx,
328+
module_names,
329+
is_first=stage_idx == 0,
330+
is_last=stage_idx == num_stages - 1,
331+
)
332+
logger.info(
333+
f"PP rank {pp_rank} is building stage_idx {stage_idx} "
334+
f"with modules {module_names}"
335+
)
336+
stages.append(stage)
337+
models.append(model_chunk)
338+
339+
return stages, models

0 commit comments

Comments
 (0)