18
18
from torchtitan .config_manager import JobConfig , TORCH_DTYPE_MAP
19
19
from torchtitan .distributed import ParallelDims
20
20
from torchtitan .experiments .llama4 .infra .expert_parallel import NoParallel
21
- from torchtitan .experiments .llama4 .infra .parallelize import apply_moe_ep_tp
22
- from torchtitan .models .llama3 .infra .parallelize import apply_ac , apply_fsdp
21
+ from torchtitan .experiments .llama4 .infra .parallelize import apply_fsdp , apply_moe_ep_tp
22
+ from torchtitan .models .llama3 .infra .parallelize import apply_ac , apply_ddp
23
23
from torchtitan .tools .logging import logger
24
24
25
25
26
+ # Adapted from llama4/infra/parallelize.py
26
27
def parallelize_deepseekv3 (
27
28
model : nn .Module ,
28
29
world_mesh : DeviceMesh ,
29
30
parallel_dims : ParallelDims ,
30
31
job_config : JobConfig ,
31
32
):
32
-
33
33
if parallel_dims .tp_enabled :
34
34
if job_config .parallelism .enable_async_tensor_parallel :
35
35
# TODO(jianiw): This branch needs to be tested and enabled
@@ -59,6 +59,7 @@ def parallelize_deepseekv3(
59
59
enable_async_tp = False ,
60
60
)
61
61
62
+ if parallel_dims .tp_enabled or parallel_dims .ep_enabled :
62
63
apply_moe_ep_tp (
63
64
model ,
64
65
tp_mesh = world_mesh ["tp" ] if parallel_dims .tp_enabled else None ,
@@ -73,16 +74,26 @@ def parallelize_deepseekv3(
73
74
if job_config .activation_checkpoint .mode != "none" :
74
75
apply_ac (model , job_config .activation_checkpoint )
75
76
77
+ # turn on per-TransformerBlock compile after AC wrapping and before FSDP
78
+ if job_config .training .compile :
79
+ raise NotImplementedError ("torch.compile is not supported yet for deepseekv3" )
80
+
76
81
dp_mesh : DeviceMesh | None = None
77
- if (
78
- parallel_dims .dp_shard_enabled
79
- ): # apply FSDP or HSDP, potentially with Context Parallel
82
+ if parallel_dims .fsdp_enabled or parallel_dims .ep_enabled :
83
+ # apply FSDP or HSDP, potentially with Context Parallel
80
84
if parallel_dims .dp_replicate_enabled :
81
- dp_mesh_dim_names = ("dp_replicate" , "dp_shard " )
85
+ dp_mesh_dim_names = ("dp_replicate" , "dp_shard_cp " )
82
86
else :
83
- dp_mesh_dim_names = ("dp_shard " ,)
87
+ dp_mesh_dim_names = ("dp_shard_cp " ,)
84
88
dp_mesh = world_mesh [tuple (dp_mesh_dim_names )]
85
89
90
+ # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
91
+ dp_mod_ep_mesh_dim_names = []
92
+ if parallel_dims .ep_enabled :
93
+ if parallel_dims .dp_replicate_enabled :
94
+ dp_mod_ep_mesh_dim_names .append ("dp_replicate" )
95
+ dp_mod_ep_mesh_dim_names .append ("dp_shard_mod_ep" )
96
+
86
97
apply_fsdp (
87
98
model ,
88
99
dp_mesh ,
@@ -91,13 +102,34 @@ def parallelize_deepseekv3(
91
102
pp_enabled = parallel_dims .pp_enabled ,
92
103
cpu_offload = job_config .training .enable_cpu_offload ,
93
104
reshard_after_forward_policy = job_config .parallelism .fsdp_reshard_after_forward ,
105
+ dp_mod_ep_mesh = (
106
+ world_mesh [tuple (dp_mod_ep_mesh_dim_names )]
107
+ if dp_mod_ep_mesh_dim_names
108
+ else None
109
+ ),
94
110
)
95
111
96
112
if parallel_dims .dp_replicate_enabled :
97
113
logger .info ("Applied HSDP to the model" )
98
114
else :
99
115
logger .info ("Applied FSDP to the model" )
100
116
117
+ if parallel_dims .cp_enabled :
118
+ logger .info ("Applied Context Parallel to the model" )
119
+
120
+ if job_config .training .enable_cpu_offload :
121
+ logger .info ("Applied CPU Offloading to the model" )
122
+ elif parallel_dims .dp_replicate_enabled :
123
+ if world_mesh .ndim > 1 :
124
+ raise RuntimeError ("DDP has not supported > 1D parallelism" )
125
+ dp_mesh = world_mesh
126
+ apply_ddp (
127
+ model ,
128
+ dp_mesh ,
129
+ enable_compile = job_config .training .compile ,
130
+ enable_compiled_autograd = job_config .parallelism .enable_compiled_autograd ,
131
+ )
132
+
101
133
return model
102
134
103
135
0 commit comments