Skip to content

Commit c7601a7

Browse files
committed
refactor ParallelDims and CheckpointManager
1 parent 3ca7041 commit c7601a7

File tree

20 files changed

+99
-104
lines changed

20 files changed

+99
-104
lines changed

scripts/estimate/estimation.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def estimate_memory(job_config: JobConfig):
4848
pp=parallelism_config.pipeline_parallel_degree,
4949
ep=parallelism_config.expert_parallel_degree,
5050
world_size=world_size,
51-
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
5251
)
5352

5453
# only FSDP and HSDP are supported
@@ -76,14 +75,14 @@ def estimate_memory(job_config: JobConfig):
7675

7776
train_spec = get_train_spec(job_config.model.name)
7877

79-
# build meshes
80-
world_mesh = parallel_dims.build_mesh(device_type="cuda")
81-
8278
# build tokenizer
8379
tokenizer = train_spec.build_tokenizer_fn(job_config)
8480

81+
loss_parallel_enabled = (
82+
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
83+
)
8584
train_context = dist_utils.get_train_context(
86-
parallel_dims.loss_parallel_enabled,
85+
loss_parallel_enabled,
8786
job_config.parallelism.enable_compiled_autograd,
8887
)
8988

@@ -108,7 +107,7 @@ def estimate_memory(job_config: JobConfig):
108107
model_converters.convert(model)
109108

110109
# apply PT-D DP/TP parallelisms and activation checkpointing
111-
train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
110+
train_spec.parallelize_fn(model, parallel_dims, job_config)
112111

113112
model.to_empty(device="cuda")
114113
if not active_fake_mode():
@@ -117,9 +116,7 @@ def estimate_memory(job_config: JobConfig):
117116

118117
# build optimizer after applying parallelisms to the model
119118
ft_manager = init_ft_manager(job_config)
120-
optimizers = build_optimizers(
121-
[model], job_config, parallel_dims, world_mesh, ft_manager
122-
)
119+
optimizers = build_optimizers([model], job_config, parallel_dims, ft_manager)
123120
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
124121
# Post optimizer step model converters hook.
125122
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2

scripts/generate/test_generate.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,12 @@ def test_generate(
127127
pp=1,
128128
ep=1,
129129
world_size=world_size,
130-
enable_loss_parallel=False,
131130
)
132-
# Build world mesh for parallelism
133-
world_mesh = parallel_dims.build_mesh(device_type=device_type)
131+
world_mesh = parallel_dims.world_mesh
134132

135133
# apply_tp (with Sequence Parallel) on unevenly sharded
136134
# sequences would require https://github.com/pytorch/torchtitan/pull/686
137-
apply_tp_minus_sp(model, world_mesh["tp"])
135+
apply_tp_minus_sp(model, parallel_dims.world_mesh["tp"])
138136

139137
dist_utils.set_determinism(world_mesh, device, seed, deterministic)
140138

tests/unit_tests/test_model_converter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def build_parallel_dims(job_config, world_size):
2323
pp=parallelism_config.pipeline_parallel_degree,
2424
ep=parallelism_config.expert_parallel_degree,
2525
world_size=world_size,
26-
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
2726
)
2827
return parallel_dims
2928

torchtitan/components/checkpoint.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
)
2727
from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType
2828
from torch.distributed.checkpoint.stateful import Stateful
29-
from torch.utils.data import DataLoader
3029

30+
from torchtitan.components.dataloader import BaseDataLoader
3131
from torchtitan.components.ft import FTManager
3232
from torchtitan.components.lr_scheduler import LRSchedulersContainer
3333
from torchtitan.components.optimizer import OptimizersContainer
@@ -180,17 +180,19 @@ class CheckpointManager:
180180

181181
def __init__(
182182
self,
183-
dataloader: DataLoader,
183+
dataloader: BaseDataLoader | None,
184184
model_parts: list[nn.Module],
185185
optimizers: OptimizersContainer,
186186
lr_schedulers: LRSchedulersContainer,
187187
states: dict[str, Any],
188188
job_config: JobConfig,
189-
ft_manager: FTManager,
189+
ft_manager: FTManager | None = None,
190190
) -> None:
191191
ckpt_config = job_config.checkpoint
192192
self.enable_checkpoint = ckpt_config.enable_checkpoint
193-
self.ft_manager = ft_manager.manager if ft_manager.enabled else None
193+
self.ft_manager = (
194+
ft_manager.manager if ft_manager and ft_manager.enabled else None
195+
)
194196

195197
if self.ft_manager:
196198
optimizers.init_cache_state_dict()

torchtitan/components/optimizer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
StateDictOptions,
1616
)
1717
from torch.distributed.checkpoint.stateful import Stateful
18-
from torch.distributed.device_mesh import DeviceMesh
1918
from torch.optim import Optimizer
2019

2120
from torchtitan.components.ft import FTManager, has_torchft
@@ -244,7 +243,6 @@ def build_optimizers(
244243
model_parts: list[nn.Module],
245244
job_config: JobConfig,
246245
parallel_dims: ParallelDims,
247-
world_mesh: DeviceMesh,
248246
ft_manager: FTManager,
249247
) -> OptimizersContainer:
250248
"""Create a OptimizersContainer for the given model parts and job config.

torchtitan/components/validate.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,12 @@ def __init__(
5050
dp_rank: int,
5151
tokenizer: BaseTokenizer,
5252
parallel_dims: ParallelDims,
53-
world_mesh: torch.distributed.DeviceMesh,
5453
loss_fn: LossFunction,
5554
validation_context: Generator[None, None, None],
5655
maybe_enable_amp: Generator[None, None, None],
5756
):
5857
self.job_config = job_config
5958
self.parallel_dims = parallel_dims
60-
self.world_mesh = world_mesh
6159
self.loss_fn = loss_fn
6260
self.validation_dataloader = build_hf_validation_dataloader(
6361
job_config=job_config,
@@ -78,6 +76,8 @@ def validate(
7876
model = model_parts[0]
7977
model.eval()
8078

79+
parallel_dims = self.parallel_dims
80+
8181
accumulated_losses = []
8282
device_type = utils.device_type
8383
num_steps = 0
@@ -96,13 +96,13 @@ def validate(
9696

9797
optional_context_parallel_ctx = (
9898
dist_utils.create_context_parallel_ctx(
99-
cp_mesh=self.world_mesh["cp"],
99+
cp_mesh=parallel_dims.world_mesh["cp"],
100100
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
101101
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
102102
cp_no_restore_buffers={inputs, labels},
103103
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
104104
)
105-
if self.parallel_dims.cp_enabled
105+
if parallel_dims.cp_enabled
106106
else None
107107
)
108108

@@ -119,8 +119,10 @@ def validate(
119119
# Compute average loss
120120
loss = torch.sum(torch.stack(accumulated_losses))
121121
loss /= num_steps
122-
if self.parallel_dims.dp_cp_enabled:
123-
global_avg_loss = dist_utils.dist_mean(loss, self.world_mesh["dp_cp"])
122+
if parallel_dims.dp_cp_enabled:
123+
global_avg_loss = dist_utils.dist_mean(
124+
loss, parallel_dims.world_mesh["dp_cp"]
125+
)
124126
else:
125127
global_avg_loss = loss
126128

@@ -144,7 +146,6 @@ def build_validator(
144146
dp_rank: int,
145147
tokenizer: BaseTokenizer,
146148
parallel_dims: ParallelDims,
147-
world_mesh: torch.distributed.DeviceMesh,
148149
loss_fn: LossFunction,
149150
validation_context: Generator[None, None, None],
150151
maybe_enable_amp: Generator[None, None, None],
@@ -156,7 +157,6 @@ def build_validator(
156157
dp_rank=dp_rank,
157158
tokenizer=tokenizer,
158159
parallel_dims=parallel_dims,
159-
world_mesh=world_mesh,
160160
loss_fn=loss_fn,
161161
validation_context=validation_context,
162162
maybe_enable_amp=maybe_enable_amp,

torchtitan/distributed/parallel_dims.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
1111

1212
from torchtitan.tools.logging import logger
13+
from torchtitan.tools.utils import device_type
1314

1415

1516
__all__ = ["ParallelDims"]
@@ -24,7 +25,8 @@ class ParallelDims:
2425
pp: int
2526
ep: int
2627
world_size: int
27-
enable_loss_parallel: bool
28+
29+
_world_mesh: DeviceMesh = None
2830

2931
def __post_init__(self):
3032
self._validate()
@@ -55,16 +57,16 @@ def _validate(self):
5557
# EP would borrow all cp and some dp_shard degree
5658
assert ep % cp == 0 and (dp_shard * cp) % ep == 0
5759

58-
def build_mesh(self, device_type: str) -> DeviceMesh:
60+
def build_mesh(self) -> DeviceMesh:
5961
# TODO: Current implementation of ParallelDims for dp2ep Expert Parallel
6062
# is not very clean, due to the limited support from DeviceMesh
6163
# for creating two staggered meshes. Will improve.
6264
if self.ep > 1:
63-
return self._build_mesh_with_ep(device_type)
65+
return self._build_mesh_with_ep()
6466
else:
65-
return self._build_mesh_without_ep(device_type)
67+
return self._build_mesh_without_ep()
6668

67-
def _build_mesh_with_ep(self, device_type: str) -> DeviceMesh:
69+
def _build_mesh_with_ep(self) -> DeviceMesh:
6870
# With ep, dp_shard and ep are derived submeshes:
6971
# dp_shard = dp_shard_mod_ep * dp_shard_in_ep
7072
# ep = dp_shard_in_ep * cp
@@ -128,7 +130,7 @@ def _build_mesh_with_ep(self, device_type: str) -> DeviceMesh:
128130

129131
return mesh
130132

131-
def _build_mesh_without_ep(self, device_type: str) -> DeviceMesh:
133+
def _build_mesh_without_ep(self) -> DeviceMesh:
132134
dims = []
133135
names = []
134136
for d, name in zip(
@@ -173,6 +175,14 @@ def _build_mesh_without_ep(self, device_type: str) -> DeviceMesh:
173175

174176
return mesh
175177

178+
@property
179+
def world_mesh(self) -> str:
180+
# doing late init so ParallelDims can still be used as a lightweight
181+
# dataclass without having to initialize the world mesh
182+
if self._world_mesh is None:
183+
self._world_mesh = self.build_mesh()
184+
return self._world_mesh
185+
176186
@property
177187
def dp_enabled(self):
178188
return self.dp_replicate > 1 or self.dp_shard > 1
@@ -206,18 +216,24 @@ def pp_enabled(self):
206216
return self.pp > 1
207217

208218
@property
209-
def loss_parallel_enabled(self):
210-
return self.tp > 1 and self.enable_loss_parallel
219+
def ep_enabled(self):
220+
return self.ep > 1
211221

212222
@cached_property
213223
def non_data_parallel_size(self):
214224
return self.cp * self.tp * self.pp
215225

216-
@property
217-
def ep_enabled(self):
218-
return self.ep > 1
226+
@cached_property
227+
def seq_len_divisor(self):
228+
# Sequence Parallel requires that seq_len be divisible by TP degree.
229+
# https://github.com/pytorch/torchtitan/pull/640#discussion_r1849481001
219230

220-
@property
231+
# Context Parallel requires that seq_len be divisible by 2 * CP degree,
232+
# when load balancing is enabled (by default).
233+
# https://github.com/pytorch/pytorch/blob/4f62dcc/torch/distributed/tensor/experimental/_attention.py#L1246
234+
return self.tp * (self.cp * 2)
235+
236+
@cached_property
221237
def dense_params_mesh_ndim(self):
222-
# Note: EP params mesh ndim is 1 more due to the 'ep' mesh
238+
# Note: In dp2ep EP, EP params mesh ndim is 1 more due to the 'ep' mesh
223239
return self.dp_replicate_enabled + self.fsdp_enabled + self.tp_enabled

torchtitan/experiments/deepseek_v3/train_ds_real.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ def run_full_model(
155155
pp=pp_size,
156156
cp=1,
157157
tp=1,
158+
ep=1,
158159
world_size=world_mesh.size(),
159-
enable_loss_parallel=False,
160160
)
161161

162162
metrics_processor = build_metrics_processor(
@@ -180,7 +180,7 @@ def run_full_model(
180180
loss_fn = cross_entropy_loss # torch.nn.functional.cross_entropy
181181

182182
ft_manager = ft.init_ft_manager(config)
183-
optimizer = build_optimizers([model], config, ft_manager)
183+
optimizer = build_optimizers([model], config, proxy_parallel_dims, ft_manager)
184184

185185
lr_scheduler = build_lr_schedulers(optimizer, config)
186186

torchtitan/experiments/flux/infra/parallelize.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
def parallelize_flux(
2323
model: nn.Module,
24-
world_mesh: DeviceMesh,
2524
parallel_dims: ParallelDims,
2625
job_config: JobConfig,
2726
):
@@ -36,7 +35,7 @@ def parallelize_flux(
3635

3736
apply_fsdp(
3837
model,
39-
world_mesh[tuple(dp_mesh_dim_names)],
38+
parallel_dims.world_mesh[tuple(dp_mesh_dim_names)],
4039
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
4140
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
4241
cpu_offload=job_config.training.enable_cpu_offload,
@@ -117,7 +116,6 @@ def apply_ac(model: nn.Module, ac_config):
117116
def parallelize_encoders(
118117
t5_model: nn.Module,
119118
clip_model: nn.Module,
120-
world_mesh: DeviceMesh,
121119
parallel_dims: ParallelDims,
122120
job_config: JobConfig,
123121
):
@@ -132,7 +130,7 @@ def parallelize_encoders(
132130
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
133131
)
134132
fsdp_config = {
135-
"mesh": world_mesh[tuple(dp_mesh_dim_names)],
133+
"mesh": parallel_dims.world_mesh[tuple(dp_mesh_dim_names)],
136134
"mp_policy": mp_policy,
137135
}
138136
if job_config.training.enable_cpu_offload:

torchtitan/experiments/flux/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, job_config: JobConfig):
3636
# (mainly for debugging, expect perf loss).
3737
# For Flux model, we need distinct seed across FSDP ranks to ensure we randomly dropout prompts info in dataloader
3838
dist_utils.set_determinism(
39-
self.world_mesh,
39+
self.parallel_dims.world_mesh,
4040
self.device,
4141
job_config.training.seed,
4242
job_config.training.deterministic,
@@ -77,7 +77,6 @@ def __init__(self, job_config: JobConfig):
7777
self.t5_encoder, self.clip_encoder = parallelize_encoders(
7878
t5_model=self.t5_encoder,
7979
clip_model=self.clip_encoder,
80-
world_mesh=self.world_mesh,
8180
parallel_dims=self.parallel_dims,
8281
job_config=job_config,
8382
)

0 commit comments

Comments
 (0)