Skip to content

Commit 08c1ff1

Browse files
committed
dp2ep Expert Parallel
1 parent 6b11290 commit 08c1ff1

File tree

19 files changed

+842
-304
lines changed

19 files changed

+842
-304
lines changed

docs/checkpoint.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,5 +83,5 @@ A seed checkpoint does initialization of the model on a single CPU, and can be l
8383
To create a seed checkpoint, use the same model config as you use for training.
8484
e.g.
8585
```bash
86-
NGPU=1 CONFIG=<path_to_model_config> ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1
86+
NGPU=1 CONFIG=<path_to_model_config> ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
8787
```

scripts/estimate/estimation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def estimate_memory(job_config: JobConfig):
4646
cp=parallelism_config.context_parallel_degree,
4747
tp=parallelism_config.tensor_parallel_degree,
4848
pp=parallelism_config.pipeline_parallel_degree,
49+
ep=parallelism_config.expert_parallel_degree,
4950
world_size=world_size,
5051
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
5152
)
@@ -56,8 +57,9 @@ def estimate_memory(job_config: JobConfig):
5657
or parallel_dims.tp_enabled
5758
or parallel_dims.pp_enabled
5859
or parallel_dims.cp_enabled
60+
or parallel_dims.ep_enabled
5961
):
60-
logger.warning("DDP, TP, PP, CP are not supported yet.")
62+
logger.warning("DDP, TP, PP, CP, EP are not supported yet.")
6163
return
6264
if not parallel_dims.dp_shard_enabled:
6365
logger.warning("FSDP or HSDP is not enabled. Skipping memory estimation.")
@@ -115,7 +117,9 @@ def estimate_memory(job_config: JobConfig):
115117

116118
# build optimizer after applying parallelisms to the model
117119
ft_manager = init_ft_manager(job_config)
118-
optimizers = build_optimizers([model], job_config, ft_manager)
120+
optimizers = build_optimizers(
121+
[model], job_config, parallel_dims, world_mesh, ft_manager
122+
)
119123
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
120124
# Post optimizer step model converters hook.
121125
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2

scripts/generate/test_generate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def test_generate(
125125
cp=1,
126126
tp=world_size,
127127
pp=1,
128+
ep=1,
128129
world_size=world_size,
129130
enable_loss_parallel=False,
130131
)

tests/unit_tests/test_model_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def build_parallel_dims(job_config, world_size):
2121
cp=parallelism_config.context_parallel_degree,
2222
tp=parallelism_config.tensor_parallel_degree,
2323
pp=parallelism_config.pipeline_parallel_degree,
24+
ep=parallelism_config.expert_parallel_degree,
2425
world_size=world_size,
2526
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
2627
)

torchtitan/components/ft.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import copy
88
import importlib
99
from contextlib import nullcontext
10-
from dataclasses import dataclass
1110
from typing import ContextManager, Optional, TYPE_CHECKING, Union
1211

1312
import torch
@@ -18,7 +17,6 @@
1817
from torch.distributed.distributed_c10d import ReduceOp
1918
from torch.distributed.tensor import DTensor
2019
from torchtitan.config_manager import JobConfig
21-
from torchtitan.distributed import ParallelDims
2220

2321
if importlib.util.find_spec("torchft") is not None:
2422
import torchft as ft
@@ -106,41 +104,6 @@ def init_ft_manager(job: JobConfig) -> FTManager:
106104
)
107105

108106

109-
@dataclass
110-
class FTParallelDims(ParallelDims):
111-
ft_manager: FTManager
112-
113-
def build_mesh(self, device_type: str) -> DeviceMesh:
114-
def func(
115-
device_type: str, mesh_shape: list[int], mesh_dim_names: list[str]
116-
) -> DeviceMesh:
117-
from torchft.process_group import ft_init_device_mesh
118-
119-
return ft_init_device_mesh(
120-
device_type=device_type,
121-
mesh_shape=mesh_shape,
122-
mesh_dim_names=mesh_dim_names,
123-
replicate_dim=mesh_dim_names.index("dp_replicate"),
124-
manager=self.ft_manager.manager,
125-
)
126-
127-
dims = []
128-
names = []
129-
for d, name in zip(
130-
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
131-
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
132-
):
133-
if d > 1 or name == "dp_replicate":
134-
dims.append(d)
135-
names.append(name)
136-
137-
return self._build_mesh(device_type, dims, names, func)
138-
139-
@property
140-
def dp_replicate_enabled(self):
141-
return True
142-
143-
144107
def ft_dist_reduce(
145108
x: torch.Tensor, reduceOp: str, mesh: DeviceMesh
146109
) -> tuple[torch.Tensor, str, DeviceMesh]:

torchtitan/components/optimizer.py

Lines changed: 113 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import functools
8+
from itertools import chain
89
from typing import Any, Generic, Iterator, TypeVar
910

1011
import torch
@@ -15,10 +16,13 @@
1516
StateDictOptions,
1617
)
1718
from torch.distributed.checkpoint.stateful import Stateful
19+
from torch.distributed.device_mesh import DeviceMesh
20+
from torch.distributed.tensor import DTensor
1821
from torch.optim import Optimizer
1922

2023
from torchtitan.components.ft import FTManager, has_torchft
2124
from torchtitan.config_manager import JobConfig
25+
from torchtitan.distributed import ParallelDims
2226

2327
__all__ = [
2428
"OptimizersContainer",
@@ -238,9 +242,85 @@ def zero_grad(self, *args, **kwargs) -> None:
238242
super().zero_grad(*args, **kwargs)
239243

240244

245+
class ExpertParallelOptimizersContainer(OptimizersContainer):
246+
"""
247+
This class is created to support fused optimizer implementation for Expert Parallel.
248+
Since in EP, not all the parameters are sharded on the same DeviceMesh, the base
249+
OptimizersContainer cannot perform fused optimizer steps on all DTensor parameters.
250+
In this class, we create two optimizers for each model part, one for ep params and the
251+
other for non-ep params. Parameters in the same optimizer are always on the same DeviceMesh,
252+
so that fused optimizer can be performed.
253+
"""
254+
255+
def __init__(
256+
self,
257+
model_parts: list[nn.Module],
258+
optimizer_cls: type[T],
259+
optimizer_kwargs: dict[str, Any],
260+
dense_params_mesh_ndim: int,
261+
) -> None:
262+
ep_params, non_ep_params = [], []
263+
self.ep_optimizers = []
264+
self.non_ep_optimizers = []
265+
266+
self.model_parts = model_parts
267+
# This is still needed to
268+
# 1. reuse other OptimizersContainer's methods other than state dict save / load
269+
# 2. define LR schedulers
270+
self.optimizers = []
271+
272+
for model in self.model_parts:
273+
for p in model.parameters():
274+
if not p.requires_grad:
275+
continue
276+
assert isinstance(p, DTensor)
277+
if p.device_mesh.ndim == dense_params_mesh_ndim:
278+
non_ep_params.append(p)
279+
else:
280+
ep_params.append(p)
281+
282+
ep_optimizer = optimizer_cls(ep_params, **optimizer_kwargs)
283+
non_ep_optimizers = optimizer_cls(non_ep_params, **optimizer_kwargs)
284+
self.ep_optimizers.append(ep_optimizer)
285+
self.non_ep_optimizers.append(non_ep_optimizers)
286+
self.optimizers.append(ep_optimizer)
287+
self.optimizers.append(non_ep_optimizers)
288+
289+
# NOTE: each model part has two optimizers, one for ep params
290+
# and the other for non-ep params
291+
self._validate_length(len(self.model_parts) * 2)
292+
self._post_init(ep_params, optimizer_kwargs)
293+
self._post_init(non_ep_params, optimizer_kwargs)
294+
295+
def state_dict(self) -> dict[str, Any]:
296+
func = functools.partial(
297+
get_optimizer_state_dict,
298+
options=StateDictOptions(flatten_optimizer_state_dict=True),
299+
)
300+
return {
301+
k: v
302+
for sd in chain(
303+
map(func, self.model_parts, self.ep_optimizers),
304+
map(func, self.model_parts, self.non_ep_optimizers),
305+
)
306+
for k, v in sd.items()
307+
}
308+
309+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
310+
func = functools.partial(
311+
set_optimizer_state_dict,
312+
optim_state_dict=state_dict,
313+
options=StateDictOptions(flatten_optimizer_state_dict=True),
314+
)
315+
list(map(func, self.model_parts, self.ep_optimizers))
316+
list(map(func, self.model_parts, self.non_ep_optimizers))
317+
318+
241319
def build_optimizers(
242320
model_parts: list[nn.Module],
243321
job_config: JobConfig,
322+
parallel_dims: ParallelDims,
323+
world_mesh: DeviceMesh,
244324
ft_manager: FTManager,
245325
) -> OptimizersContainer:
246326
"""Create a OptimizersContainer for the given model parts and job config.
@@ -259,12 +339,23 @@ def build_optimizers(
259339
Args:
260340
model_parts (List[nn.Module]): List of model parts to be optimized.
261341
job_config (JobConfig): Job config containing the optimizer name and parameters.
342+
parallel_dims (ParallelDims): Parallel dimensions for the model.
262343
"""
263344
optim_in_bwd = job_config.optimizer.early_step_in_backward
264-
if optim_in_bwd and job_config.parallelism.pipeline_parallel_degree > 1:
265-
raise NotImplementedError(
266-
"Optimizers in backward is not supported with pipeline parallelism."
267-
)
345+
if optim_in_bwd:
346+
if parallel_dims.ep_enabled:
347+
raise NotImplementedError(
348+
"Optimizers in backward is not supported with Expert Parallel."
349+
)
350+
if parallel_dims.pp_enabled:
351+
raise NotImplementedError(
352+
"Optimizers in backward is not supported with Pipeline Parallel."
353+
)
354+
if ft_manager.enabled:
355+
raise NotImplementedError(
356+
"TorchFT is not supported with optimizers in backward."
357+
)
358+
268359
name = job_config.optimizer.name
269360
lr = job_config.optimizer.lr
270361
beta1 = job_config.optimizer.beta1
@@ -295,19 +386,31 @@ def build_optimizers(
295386
raise NotImplementedError(f"Optimizer {name} not added.")
296387
optimizer_cls = optimizer_classes[name]
297388

298-
if optim_in_bwd and ft_manager.enabled:
299-
raise ValueError("TorchFT is not supported with optimizers in backward.")
300-
elif optim_in_bwd:
389+
if optim_in_bwd:
301390
return OptimizersInBackwardContainer(
302391
model_parts, optimizer_cls, optimizer_kwargs
303392
)
304-
elif ft_manager.enabled:
393+
394+
if ft_manager.enabled:
305395
return FTOptimizersContainer(
306396
model_parts,
307397
optimizer_cls,
308398
optimizer_kwargs,
309399
ft_manager.manager,
310400
use_ft_optimizer=job_config.fault_tolerance.semi_sync_method is None,
311401
)
312-
else:
313-
return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)
402+
403+
if parallel_dims.ep_enabled and fused:
404+
if ft_manager.enabled:
405+
raise NotImplementedError(
406+
"Expert Parallel with fused optimizer implementation "
407+
"is not supported with TorchFT yet."
408+
)
409+
return ExpertParallelOptimizersContainer(
410+
model_parts,
411+
optimizer_cls,
412+
optimizer_kwargs,
413+
parallel_dims.dense_params_mesh_ndim,
414+
)
415+
416+
return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)

torchtitan/config_manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,14 @@ class Parallelism:
363363
The default value is 'allgather'.
364364
"""
365365

366+
expert_parallel_degree: int = 1
367+
"""
368+
Expert parallelism degree. 1 means disabled.
369+
Currently, only "dp2ep" is supported, with the following constraints:
370+
context_parallel_degree <= expert_parallel_degree <= data_parallel_shard_degree * context_parallel_degree
371+
Note that this is still an experimental feature.
372+
"""
373+
366374

367375
@dataclass
368376
class Checkpoint:

0 commit comments

Comments
 (0)