Skip to content

Commit 0c8400a

Browse files
authored
Use a separate communicator for AMAX reduction (NVIDIA#1585)
* add amax reduction group for fp8 training * use a separate communicator for amax reduction across DP-ranks reflect suggestion
1 parent 6943fd2 commit 0c8400a

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

apex/transformer/parallel_state.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = None
4949
# Data parallel group that the current rank belongs to.
5050
_DATA_PARALLEL_GROUP = None
51+
# Data parallel AMAX reduction group that the current rank belongs to.
52+
_DATA_PARALLEL_GROUP_FOR_AMAX_REDUCTION = None
5153

5254
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
5355
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
@@ -84,6 +86,7 @@ def initialize_model_parallel(
8486
pipeline_model_parallel_size_: int = 1,
8587
virtual_pipeline_model_parallel_size_: Optional[int] = None,
8688
pipeline_model_parallel_split_rank_: Optional[int] = None,
89+
use_fp8_: bool = False,
8790
*,
8891
default_backend: Optional[str] = None,
8992
p2p_backend: Optional[str] = None,
@@ -96,6 +99,7 @@ def initialize_model_parallel(
9699
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
97100
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved pipeline).
98101
pipeline_model_parallel_split_rank: for models with both encoder and decoder, rank in pipeline with split point.
102+
use_fp8_: FP8 training that needs AMAX reduction across data-parallel ranks.
99103
Keyword Arguments:
100104
default_backend: Backend of process groups except for pipeline parallel ones.
101105
If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used.
@@ -186,6 +190,9 @@ def initialize_model_parallel(
186190
# Build the data-parallel groups.
187191
global _DATA_PARALLEL_GROUP
188192
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
193+
if use_fp8_:
194+
global _DATA_PARALLEL_GROUP_FOR_AMAX_REDUCTION
195+
assert _DATA_PARALLEL_GROUP_FOR_AMAX_REDUCTION is None, "amax reduction group is already initialized"
189196
all_data_parallel_group_ranks = []
190197
for i in range(pipeline_model_parallel_size):
191198
start_rank = i * num_pipeline_model_parallel_groups
@@ -196,6 +203,10 @@ def initialize_model_parallel(
196203
group = torch.distributed.new_group(ranks, backend=default_backend)
197204
if rank in ranks:
198205
_DATA_PARALLEL_GROUP = group
206+
if use_fp8_:
207+
group = torch.distributed.new_group(ranks, backend=default_backend)
208+
if rank in ranks:
209+
_DATA_PARALLEL_GROUP_FOR_AMAX_REDUCTION = group
199210

200211
# Build the model-parallel groups.
201212
global _MODEL_PARALLEL_GROUP
@@ -363,6 +374,13 @@ def get_data_parallel_group():
363374
return _DATA_PARALLEL_GROUP
364375

365376

377+
def get_data_parallel_amax_reduction_group():
378+
"""Get the amax reduction group the caller rank belongs to."""
379+
assert _DATA_PARALLEL_GROUP_FOR_AMAX_REDUCTION is not None, \
380+
"AMAX reduction group is not initialized"
381+
return _DATA_PARALLEL_GROUP_FOR_AMAX_REDUCTION
382+
383+
366384
def get_embedding_group():
367385
"""Get the embedding group the caller rank belongs to."""
368386
assert _EMBEDDING_GROUP is not None, "embedding group is not initialized"
@@ -655,6 +673,8 @@ def destroy_model_parallel():
655673
_PIPELINE_MODEL_PARALLEL_GROUP = None
656674
global _DATA_PARALLEL_GROUP
657675
_DATA_PARALLEL_GROUP = None
676+
global _DATA_PARALLEL_GROUP_FOR_AMAX_REDUCTION
677+
_DATA_PARALLEL_GROUP_FOR_AMAX_REDUCTION = None
658678
global _EMBEDDING_GROUP
659679
_EMBEDDING_GROUP = None
660680
global _POSITION_EMBEDDING_GROUP

0 commit comments

Comments
 (0)