4848_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = None
4949# Data parallel group that the current rank belongs to.
5050_DATA_PARALLEL_GROUP = None
51+ # Data parallel AMAX reduction group that the current rank belongs to.
52+ _DATA_PARALLEL_GROUP_FOR_AMAX_REDUCTION = None
5153
5254_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
5355_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
@@ -84,6 +86,7 @@ def initialize_model_parallel(
8486 pipeline_model_parallel_size_ : int = 1 ,
8587 virtual_pipeline_model_parallel_size_ : Optional [int ] = None ,
8688 pipeline_model_parallel_split_rank_ : Optional [int ] = None ,
89+ use_fp8_ : bool = False ,
8790 * ,
8891 default_backend : Optional [str ] = None ,
8992 p2p_backend : Optional [str ] = None ,
@@ -96,6 +99,7 @@ def initialize_model_parallel(
9699 pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
97100 virtual_pipeline_model_parallel_size: number of virtual stages (interleaved pipeline).
98101 pipeline_model_parallel_split_rank: for models with both encoder and decoder, rank in pipeline with split point.
102+ use_fp8_: FP8 training that needs AMAX reduction across data-parallel ranks.
99103 Keyword Arguments:
100104 default_backend: Backend of process groups except for pipeline parallel ones.
101105 If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used.
@@ -186,6 +190,9 @@ def initialize_model_parallel(
186190 # Build the data-parallel groups.
187191 global _DATA_PARALLEL_GROUP
188192 assert _DATA_PARALLEL_GROUP is None , "data parallel group is already initialized"
193+ if use_fp8_ :
194+ global _DATA_PARALLEL_GROUP_FOR_AMAX_REDUCTION
195+ assert _DATA_PARALLEL_GROUP_FOR_AMAX_REDUCTION is None , "amax reduction group is already initialized"
189196 all_data_parallel_group_ranks = []
190197 for i in range (pipeline_model_parallel_size ):
191198 start_rank = i * num_pipeline_model_parallel_groups
@@ -196,6 +203,10 @@ def initialize_model_parallel(
196203 group = torch .distributed .new_group (ranks , backend = default_backend )
197204 if rank in ranks :
198205 _DATA_PARALLEL_GROUP = group
206+ if use_fp8_ :
207+ group = torch .distributed .new_group (ranks , backend = default_backend )
208+ if rank in ranks :
209+ _DATA_PARALLEL_GROUP_FOR_AMAX_REDUCTION = group
199210
200211 # Build the model-parallel groups.
201212 global _MODEL_PARALLEL_GROUP
@@ -363,6 +374,13 @@ def get_data_parallel_group():
363374 return _DATA_PARALLEL_GROUP
364375
365376
377+ def get_data_parallel_amax_reduction_group ():
378+ """Get the amax reduction group the caller rank belongs to."""
379+ assert _DATA_PARALLEL_GROUP_FOR_AMAX_REDUCTION is not None , \
380+ "AMAX reduction group is not initialized"
381+ return _DATA_PARALLEL_GROUP_FOR_AMAX_REDUCTION
382+
383+
366384def get_embedding_group ():
367385 """Get the embedding group the caller rank belongs to."""
368386 assert _EMBEDDING_GROUP is not None , "embedding group is not initialized"
@@ -655,6 +673,8 @@ def destroy_model_parallel():
655673 _PIPELINE_MODEL_PARALLEL_GROUP = None
656674 global _DATA_PARALLEL_GROUP
657675 _DATA_PARALLEL_GROUP = None
676+ global _DATA_PARALLEL_GROUP_FOR_AMAX_REDUCTION
677+ _DATA_PARALLEL_GROUP_FOR_AMAX_REDUCTION = None
658678 global _EMBEDDING_GROUP
659679 _EMBEDDING_GROUP = None
660680 global _POSITION_EMBEDDING_GROUP
0 commit comments