Skip to content

Commit 9a66467

Browse files
committed
rebase onto main branch
1 parent a1a4f6c commit 9a66467

File tree

3 files changed

+44
-9
lines changed

3 files changed

+44
-9
lines changed

torchtitan/models/deepseek_v3/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# DeepSeek-V3 in torchtitan
2+
13
Download tokenizer:
24

35
```

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,18 @@
1818
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
1919
from torchtitan.distributed import ParallelDims
2020
from torchtitan.experiments.llama4.infra.expert_parallel import NoParallel
21-
from torchtitan.experiments.llama4.infra.parallelize import apply_moe_ep_tp
22-
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_fsdp
21+
from torchtitan.experiments.llama4.infra.parallelize import apply_fsdp, apply_moe_ep_tp
22+
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
2323
from torchtitan.tools.logging import logger
2424

2525

26+
# Adapted from llama4/infra/parallelize.py
2627
def parallelize_deepseekv3(
2728
model: nn.Module,
2829
world_mesh: DeviceMesh,
2930
parallel_dims: ParallelDims,
3031
job_config: JobConfig,
3132
):
32-
3333
if parallel_dims.tp_enabled:
3434
if job_config.parallelism.enable_async_tensor_parallel:
3535
# TODO(jianiw): This branch needs to be tested and enabled
@@ -59,6 +59,7 @@ def parallelize_deepseekv3(
5959
enable_async_tp=False,
6060
)
6161

62+
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
6263
apply_moe_ep_tp(
6364
model,
6465
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
@@ -73,16 +74,26 @@ def parallelize_deepseekv3(
7374
if job_config.activation_checkpoint.mode != "none":
7475
apply_ac(model, job_config.activation_checkpoint)
7576

77+
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
78+
if job_config.training.compile:
79+
raise NotImplementedError("torch.compile is not supported yet for deepseekv3")
80+
7681
dp_mesh: DeviceMesh | None = None
77-
if (
78-
parallel_dims.dp_shard_enabled
79-
): # apply FSDP or HSDP, potentially with Context Parallel
82+
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
83+
# apply FSDP or HSDP, potentially with Context Parallel
8084
if parallel_dims.dp_replicate_enabled:
81-
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
85+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
8286
else:
83-
dp_mesh_dim_names = ("dp_shard",)
87+
dp_mesh_dim_names = ("dp_shard_cp",)
8488
dp_mesh = world_mesh[tuple(dp_mesh_dim_names)]
8589

90+
# the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
91+
dp_mod_ep_mesh_dim_names = []
92+
if parallel_dims.ep_enabled:
93+
if parallel_dims.dp_replicate_enabled:
94+
dp_mod_ep_mesh_dim_names.append("dp_replicate")
95+
dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep")
96+
8697
apply_fsdp(
8798
model,
8899
dp_mesh,
@@ -91,13 +102,34 @@ def parallelize_deepseekv3(
91102
pp_enabled=parallel_dims.pp_enabled,
92103
cpu_offload=job_config.training.enable_cpu_offload,
93104
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
105+
dp_mod_ep_mesh=(
106+
world_mesh[tuple(dp_mod_ep_mesh_dim_names)]
107+
if dp_mod_ep_mesh_dim_names
108+
else None
109+
),
94110
)
95111

96112
if parallel_dims.dp_replicate_enabled:
97113
logger.info("Applied HSDP to the model")
98114
else:
99115
logger.info("Applied FSDP to the model")
100116

117+
if parallel_dims.cp_enabled:
118+
logger.info("Applied Context Parallel to the model")
119+
120+
if job_config.training.enable_cpu_offload:
121+
logger.info("Applied CPU Offloading to the model")
122+
elif parallel_dims.dp_replicate_enabled:
123+
if world_mesh.ndim > 1:
124+
raise RuntimeError("DDP has not supported > 1D parallelism")
125+
dp_mesh = world_mesh
126+
apply_ddp(
127+
model,
128+
dp_mesh,
129+
enable_compile=job_config.training.compile,
130+
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
131+
)
132+
101133
return model
102134

103135

torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)
4949
data_parallel_replicate_degree = 1
5050
data_parallel_shard_degree = -1
5151
fsdp_reshard_after_forward = "default" # default / never / always
52-
tensor_parallel_degree = 2
52+
tensor_parallel_degree = 1
5353
enable_async_tensor_parallel = false
54+
expert_parallel_degree = 2
5455

5556
[checkpoint]
5657
enable_checkpoint = false

0 commit comments

Comments
 (0)