Skip to content

Commit 0e97eaa

Browse files
authored
[megatron] compat megatron-core main branch (#4606)
1 parent 6627364 commit 0e97eaa

File tree

6 files changed

+51
-27
lines changed

6 files changed

+51
-27
lines changed

docs/source/Instruction/命令行参数.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ soft overlong 奖励参数
494494
- 🔥infer_backend: 推理加速后端,支持'pt'、'vllm'、'lmdeploy'三种推理引擎。默认为'pt'。
495495
- 🔥max_batch_size: 指定infer_backend为pt时生效,用于批量推理,默认为1。若设置为-1,则不受限制。
496496
- 🔥result_path: 推理结果存储路径(jsonl),默认为None,保存在checkpoint目录(含args.json文件)或者'./result'目录,最终存储路径会在命令行中打印。
497+
- 注意:若已存在`result_path`文件,则会进行追加写入。
497498
- write_batch_size: 结果写入`result_path`的batch_size。默认为1000。若设置为-1,则不受限制。
498499
- metric: 对推理的结果进行评估,目前支持'acc'和'rouge'。默认为None,即不进行评估。
499500
- val_dataset_sample: 推理数据集采样数,默认为None。

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,7 @@ Inference arguments include the [base arguments](#base-arguments), [merge argume
515515
- 🔥infer_backend: Inference acceleration backend, supporting three inference engines: 'pt', 'vllm', and 'lmdeploy'. The default is 'pt'.
516516
- 🔥max_batch_size: Effective when infer_backend is set to 'pt'; used for batch inference, with a default value of 1. If set to -1, there is no restriction.
517517
- 🔥result_path: Path to store inference results (jsonl). The default is None, meaning results are saved in the checkpoint directory (with args.json file) or './result' directory. The final storage path will be printed in the command line.
518+
- Note: If the `result_path` file already exists, it will be appended to.
518519
- write_batch_size: The batch size for writing results to result_path. Defaults to 1000. If set to -1, there is no restriction.
519520
- metric: Evaluate the results of the inference, currently supporting 'acc' and 'rouge'. The default is None, meaning no evaluation is performed.
520521
- val_dataset_sample: Number of samples from the inference dataset, default is None.

swift/megatron/argument/megatron_args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ def _init_mixed_precision(self):
225225
os.environ['NVTE_APPLY_QK_LAYER_SCALING'] = '1'
226226

227227
def _init_moe(self):
228+
if self.num_experts is None:
229+
return
228230
if self.moe_shared_expert_intermediate_size == 0:
229231
self.moe_shared_expert_intermediate_size = None
230232
if self.moe_ffn_hidden_size is None:

swift/megatron/init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
203203
if args.mtp_num_layers is not None:
204204
mtp_loss_scale = 1 / get_num_microbatches()
205205
MTPLossLoggingHelper.track_mtp_metrics(mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict)
206-
if iteration % args.log_interval == 0:
206+
if iteration % args.log_interval == 0 or iteration == 1:
207207
if args.record_memory_history and is_last_rank():
208208
snapshot = torch.cuda.memory._snapshot()
209209
from pickle import dump

swift/megatron/train/trainers/dpo_trainer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,10 @@ def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, lab
146146
torch.distributed.all_reduce(
147147
reporting_metric, torch.distributed.ReduceOp.AVG, group=mpu.get_data_parallel_group())
148148
reporting_metric = {k: reporting_metric[i] for i, k in enumerate(metric.keys())}
149-
return (
150-
# fix megatron-lm bug
151-
# https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291
152-
loss / mpu.get_context_parallel_world_size(),
153-
reporting_metric)
149+
# fix megatron-lm bug
150+
# https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291
151+
loss = loss / mpu.get_context_parallel_world_size()
152+
return (loss, reporting_metric)
154153

155154
def _replace_data_iterator(self, data_iterator):
156155
args = get_args()

swift/megatron/train/trainers/trainer.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from contextlib import contextmanager
55
from functools import partial
66

7+
import megatron.core
78
import torch
89
from megatron.core import mpu
910
from megatron.core.enums import ModelType
@@ -12,6 +13,7 @@
1213
from megatron.core.rerun_state_machine import RerunMode, get_rerun_state_machine
1314
from megatron.core.utils import StragglerDetector
1415
from megatron.training import ft_integration, get_args, get_timers, is_last_rank, pretrain, print_rank_0, training
16+
from packaging import version
1517
from torch.distributed.nn import all_reduce
1618

1719
from swift.utils import get_logger
@@ -129,7 +131,7 @@ def evaluate(self,
129131
# make validation batch size independent from training batch size
130132
eval_batch_size = args.global_batch_size
131133
eval_num_microbatches = eval_batch_size // (args.micro_batch_size * args.data_parallel_size)
132-
134+
megatron_core_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
133135
with torch.no_grad():
134136
iteration = 0
135137
if verbose:
@@ -161,19 +163,35 @@ def evaluate(self,
161163
torch.cuda.empty_cache()
162164

163165
if mpu.is_pipeline_last_stage(ignore_virtual=True):
164-
# Reduce across processes.
165-
for loss_dict in loss_dicts:
166-
for key in loss_dict:
166+
if megatron_core_013:
167+
for key in loss_dicts[0].keys():
167168
if key not in total_loss_dict:
168169
total_loss_dict[key] = torch.tensor([0.0, 0.0], dtype=torch.float).cuda()
169-
val = loss_dict[key]
170-
if isinstance(val, tuple) or isinstance(val, list):
171-
total_loss_dict[key][0] += val[0]
172-
total_loss_dict[key][1] += val[1]
173-
else:
170+
val = [x[key].view(-1) for x in loss_dicts]
171+
if val[0].numel() == 2:
172+
val = torch.vstack(val).sum(dim=0)
173+
torch.distributed.all_reduce(
174+
val, group=mpu.get_data_parallel_group(with_context_parallel=True))
175+
total_loss_dict[key] += val
176+
elif val[0].numel() == 1:
177+
val = torch.cat(val).sum()
174178
total_loss_dict[key][0] += val
175-
total_loss_dict[key][1] += 1
176-
179+
total_loss_dict[key][1] += len(loss_dicts)
180+
else:
181+
raise ValueError(f'Invalid value shape: {val[0].shape} for key {key}')
182+
else:
183+
# Reduce across processes.
184+
for loss_dict in loss_dicts:
185+
for key in loss_dict:
186+
if key not in total_loss_dict:
187+
total_loss_dict[key] = torch.tensor([0.0, 0.0], dtype=torch.float).cuda()
188+
val = loss_dict[key]
189+
if isinstance(val, tuple) or isinstance(val, list):
190+
total_loss_dict[key][0] += val[0]
191+
total_loss_dict[key][1] += val[1]
192+
else:
193+
total_loss_dict[key][0] += val
194+
total_loss_dict[key][1] += 1
177195
args.consumed_valid_samples += eval_batch_size
178196

179197
if args.exit_duration_in_mins:
@@ -250,7 +268,8 @@ def loss_func(self, output_tensor: torch.Tensor, *, loss_mask: torch.Tensor):
250268
total_tokens = loss_mask.sum()
251269
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
252270

253-
if args.context_parallel_size > 1:
271+
megatron_core_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
272+
if args.context_parallel_size > 1 and not megatron_core_013:
254273
loss = all_reduce(loss, group=mpu.get_context_parallel_group())
255274

256275
# Check individual rank losses are not NaN prior to DP all-reduce.
@@ -287,19 +306,21 @@ def loss_func(self, output_tensor: torch.Tensor, *, loss_mask: torch.Tensor):
287306
)
288307
# Reduce loss for logging.
289308
reporting_loss = loss.clone().detach()
290-
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
291-
292-
# loss[0] is a view of loss, so it has ._base not None, which triggers assert error
293-
# in core/pipeline_parallel/schedule.py::deallocate_output_tensor, calling .clone()
294-
# on loss[0] fixes this
295-
local_num_tokens = loss[1].clone().detach().to(torch.int)
296-
return (
309+
lm_loss = loss[0]
310+
if not megatron_core_013:
297311
# fix megatron-lm bug
298312
# https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291
299-
loss[0] / mpu.get_context_parallel_world_size(),
313+
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
314+
lm_loss = lm_loss / mpu.get_context_parallel_world_size()
315+
reporting_loss = (reporting_loss[0], reporting_loss[1])
316+
else:
317+
lm_loss = lm_loss.clone()
318+
local_num_tokens = loss[1].clone().detach().to(torch.int)
319+
return (
320+
lm_loss,
300321
local_num_tokens,
301322
{
302-
'lm loss': (reporting_loss[0], reporting_loss[1])
323+
'lm loss': reporting_loss
303324
},
304325
)
305326

0 commit comments

Comments
 (0)