5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import functools
8
+ from itertools import chain
8
9
from typing import Any , Generic , Iterator , TypeVar
9
10
10
11
import torch
15
16
StateDictOptions ,
16
17
)
17
18
from torch .distributed .checkpoint .stateful import Stateful
19
+ from torch .distributed .device_mesh import DeviceMesh
20
+ from torch .distributed .tensor import DTensor
18
21
from torch .optim import Optimizer
19
22
20
23
from torchtitan .components .ft import FTManager , has_torchft
21
24
from torchtitan .config_manager import JobConfig
25
+ from torchtitan .distributed import ParallelDims
22
26
23
27
__all__ = [
24
28
"OptimizersContainer" ,
@@ -238,9 +242,85 @@ def zero_grad(self, *args, **kwargs) -> None:
238
242
super ().zero_grad (* args , ** kwargs )
239
243
240
244
245
+ class ExpertParallelOptimizersContainer (OptimizersContainer ):
246
+ """
247
+ This class is created to support fused optimizer implementation for Expert Parallel.
248
+ Since in EP, not all the parameters are sharded on the same DeviceMesh, the base
249
+ OptimizersContainer cannot perform fused optimizer steps on all DTensor parameters.
250
+ In this class, we create two optimizers for each model part, one for ep params and the
251
+ other for non-ep params. Parameters in the same optimizer are always on the same DeviceMesh,
252
+ so that fused optimizer can be performed.
253
+ """
254
+
255
+ def __init__ (
256
+ self ,
257
+ model_parts : list [nn .Module ],
258
+ optimizer_cls : type [T ],
259
+ optimizer_kwargs : dict [str , Any ],
260
+ dense_params_mesh_ndim : int ,
261
+ ) -> None :
262
+ ep_params , non_ep_params = [], []
263
+ self .ep_optimizers = []
264
+ self .non_ep_optimizers = []
265
+
266
+ self .model_parts = model_parts
267
+ # This is still needed to
268
+ # 1. reuse other OptimizersContainer's methods other than state dict save / load
269
+ # 2. define LR schedulers
270
+ self .optimizers = []
271
+
272
+ for model in self .model_parts :
273
+ for p in model .parameters ():
274
+ if not p .requires_grad :
275
+ continue
276
+ assert isinstance (p , DTensor )
277
+ if p .device_mesh .ndim == dense_params_mesh_ndim :
278
+ non_ep_params .append (p )
279
+ else :
280
+ ep_params .append (p )
281
+
282
+ ep_optimizer = optimizer_cls (ep_params , ** optimizer_kwargs )
283
+ non_ep_optimizers = optimizer_cls (non_ep_params , ** optimizer_kwargs )
284
+ self .ep_optimizers .append (ep_optimizer )
285
+ self .non_ep_optimizers .append (non_ep_optimizers )
286
+ self .optimizers .append (ep_optimizer )
287
+ self .optimizers .append (non_ep_optimizers )
288
+
289
+ # NOTE: each model part has two optimizers, one for ep params
290
+ # and the other for non-ep params
291
+ self ._validate_length (len (self .model_parts ) * 2 )
292
+ self ._post_init (ep_params , optimizer_kwargs )
293
+ self ._post_init (non_ep_params , optimizer_kwargs )
294
+
295
+ def state_dict (self ) -> dict [str , Any ]:
296
+ func = functools .partial (
297
+ get_optimizer_state_dict ,
298
+ options = StateDictOptions (flatten_optimizer_state_dict = True ),
299
+ )
300
+ return {
301
+ k : v
302
+ for sd in chain (
303
+ map (func , self .model_parts , self .ep_optimizers ),
304
+ map (func , self .model_parts , self .non_ep_optimizers ),
305
+ )
306
+ for k , v in sd .items ()
307
+ }
308
+
309
+ def load_state_dict (self , state_dict : dict [str , Any ]) -> None :
310
+ func = functools .partial (
311
+ set_optimizer_state_dict ,
312
+ optim_state_dict = state_dict ,
313
+ options = StateDictOptions (flatten_optimizer_state_dict = True ),
314
+ )
315
+ list (map (func , self .model_parts , self .ep_optimizers ))
316
+ list (map (func , self .model_parts , self .non_ep_optimizers ))
317
+
318
+
241
319
def build_optimizers (
242
320
model_parts : list [nn .Module ],
243
321
job_config : JobConfig ,
322
+ parallel_dims : ParallelDims ,
323
+ world_mesh : DeviceMesh ,
244
324
ft_manager : FTManager ,
245
325
) -> OptimizersContainer :
246
326
"""Create a OptimizersContainer for the given model parts and job config.
@@ -259,12 +339,23 @@ def build_optimizers(
259
339
Args:
260
340
model_parts (List[nn.Module]): List of model parts to be optimized.
261
341
job_config (JobConfig): Job config containing the optimizer name and parameters.
342
+ parallel_dims (ParallelDims): Parallel dimensions for the model.
262
343
"""
263
344
optim_in_bwd = job_config .optimizer .early_step_in_backward
264
- if optim_in_bwd and job_config .parallelism .pipeline_parallel_degree > 1 :
265
- raise NotImplementedError (
266
- "Optimizers in backward is not supported with pipeline parallelism."
267
- )
345
+ if optim_in_bwd :
346
+ if parallel_dims .ep_enabled :
347
+ raise NotImplementedError (
348
+ "Optimizers in backward is not supported with Expert Parallel."
349
+ )
350
+ if parallel_dims .pp_enabled :
351
+ raise NotImplementedError (
352
+ "Optimizers in backward is not supported with Pipeline Parallel."
353
+ )
354
+ if ft_manager .enabled :
355
+ raise NotImplementedError (
356
+ "TorchFT is not supported with optimizers in backward."
357
+ )
358
+
268
359
name = job_config .optimizer .name
269
360
lr = job_config .optimizer .lr
270
361
beta1 = job_config .optimizer .beta1
@@ -295,19 +386,31 @@ def build_optimizers(
295
386
raise NotImplementedError (f"Optimizer { name } not added." )
296
387
optimizer_cls = optimizer_classes [name ]
297
388
298
- if optim_in_bwd and ft_manager .enabled :
299
- raise ValueError ("TorchFT is not supported with optimizers in backward." )
300
- elif optim_in_bwd :
389
+ if optim_in_bwd :
301
390
return OptimizersInBackwardContainer (
302
391
model_parts , optimizer_cls , optimizer_kwargs
303
392
)
304
- elif ft_manager .enabled :
393
+
394
+ if ft_manager .enabled :
305
395
return FTOptimizersContainer (
306
396
model_parts ,
307
397
optimizer_cls ,
308
398
optimizer_kwargs ,
309
399
ft_manager .manager ,
310
400
use_ft_optimizer = job_config .fault_tolerance .semi_sync_method is None ,
311
401
)
312
- else :
313
- return OptimizersContainer (model_parts , optimizer_cls , optimizer_kwargs )
402
+
403
+ if parallel_dims .ep_enabled and fused :
404
+ if ft_manager .enabled :
405
+ raise NotImplementedError (
406
+ "Expert Parallel with fused optimizer implementation "
407
+ "is not supported with TorchFT yet."
408
+ )
409
+ return ExpertParallelOptimizersContainer (
410
+ model_parts ,
411
+ optimizer_cls ,
412
+ optimizer_kwargs ,
413
+ parallel_dims .dense_params_mesh_ndim ,
414
+ )
415
+
416
+ return OptimizersContainer (model_parts , optimizer_cls , optimizer_kwargs )
0 commit comments