|
4 | 4 | from contextlib import contextmanager
|
5 | 5 | from functools import partial
|
6 | 6 |
|
| 7 | +import megatron.core |
7 | 8 | import torch
|
8 | 9 | from megatron.core import mpu
|
9 | 10 | from megatron.core.enums import ModelType
|
|
12 | 13 | from megatron.core.rerun_state_machine import RerunMode, get_rerun_state_machine
|
13 | 14 | from megatron.core.utils import StragglerDetector
|
14 | 15 | from megatron.training import ft_integration, get_args, get_timers, is_last_rank, pretrain, print_rank_0, training
|
| 16 | +from packaging import version |
15 | 17 | from torch.distributed.nn import all_reduce
|
16 | 18 |
|
17 | 19 | from swift.utils import get_logger
|
@@ -129,7 +131,7 @@ def evaluate(self,
|
129 | 131 | # make validation batch size independent from training batch size
|
130 | 132 | eval_batch_size = args.global_batch_size
|
131 | 133 | eval_num_microbatches = eval_batch_size // (args.micro_batch_size * args.data_parallel_size)
|
132 |
| - |
| 134 | + megatron_core_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') |
133 | 135 | with torch.no_grad():
|
134 | 136 | iteration = 0
|
135 | 137 | if verbose:
|
@@ -161,19 +163,35 @@ def evaluate(self,
|
161 | 163 | torch.cuda.empty_cache()
|
162 | 164 |
|
163 | 165 | if mpu.is_pipeline_last_stage(ignore_virtual=True):
|
164 |
| - # Reduce across processes. |
165 |
| - for loss_dict in loss_dicts: |
166 |
| - for key in loss_dict: |
| 166 | + if megatron_core_013: |
| 167 | + for key in loss_dicts[0].keys(): |
167 | 168 | if key not in total_loss_dict:
|
168 | 169 | total_loss_dict[key] = torch.tensor([0.0, 0.0], dtype=torch.float).cuda()
|
169 |
| - val = loss_dict[key] |
170 |
| - if isinstance(val, tuple) or isinstance(val, list): |
171 |
| - total_loss_dict[key][0] += val[0] |
172 |
| - total_loss_dict[key][1] += val[1] |
173 |
| - else: |
| 170 | + val = [x[key].view(-1) for x in loss_dicts] |
| 171 | + if val[0].numel() == 2: |
| 172 | + val = torch.vstack(val).sum(dim=0) |
| 173 | + torch.distributed.all_reduce( |
| 174 | + val, group=mpu.get_data_parallel_group(with_context_parallel=True)) |
| 175 | + total_loss_dict[key] += val |
| 176 | + elif val[0].numel() == 1: |
| 177 | + val = torch.cat(val).sum() |
174 | 178 | total_loss_dict[key][0] += val
|
175 |
| - total_loss_dict[key][1] += 1 |
176 |
| - |
| 179 | + total_loss_dict[key][1] += len(loss_dicts) |
| 180 | + else: |
| 181 | + raise ValueError(f'Invalid value shape: {val[0].shape} for key {key}') |
| 182 | + else: |
| 183 | + # Reduce across processes. |
| 184 | + for loss_dict in loss_dicts: |
| 185 | + for key in loss_dict: |
| 186 | + if key not in total_loss_dict: |
| 187 | + total_loss_dict[key] = torch.tensor([0.0, 0.0], dtype=torch.float).cuda() |
| 188 | + val = loss_dict[key] |
| 189 | + if isinstance(val, tuple) or isinstance(val, list): |
| 190 | + total_loss_dict[key][0] += val[0] |
| 191 | + total_loss_dict[key][1] += val[1] |
| 192 | + else: |
| 193 | + total_loss_dict[key][0] += val |
| 194 | + total_loss_dict[key][1] += 1 |
177 | 195 | args.consumed_valid_samples += eval_batch_size
|
178 | 196 |
|
179 | 197 | if args.exit_duration_in_mins:
|
@@ -250,7 +268,8 @@ def loss_func(self, output_tensor: torch.Tensor, *, loss_mask: torch.Tensor):
|
250 | 268 | total_tokens = loss_mask.sum()
|
251 | 269 | loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
|
252 | 270 |
|
253 |
| - if args.context_parallel_size > 1: |
| 271 | + megatron_core_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') |
| 272 | + if args.context_parallel_size > 1 and not megatron_core_013: |
254 | 273 | loss = all_reduce(loss, group=mpu.get_context_parallel_group())
|
255 | 274 |
|
256 | 275 | # Check individual rank losses are not NaN prior to DP all-reduce.
|
@@ -287,19 +306,21 @@ def loss_func(self, output_tensor: torch.Tensor, *, loss_mask: torch.Tensor):
|
287 | 306 | )
|
288 | 307 | # Reduce loss for logging.
|
289 | 308 | reporting_loss = loss.clone().detach()
|
290 |
| - torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) |
291 |
| - |
292 |
| - # loss[0] is a view of loss, so it has ._base not None, which triggers assert error |
293 |
| - # in core/pipeline_parallel/schedule.py::deallocate_output_tensor, calling .clone() |
294 |
| - # on loss[0] fixes this |
295 |
| - local_num_tokens = loss[1].clone().detach().to(torch.int) |
296 |
| - return ( |
| 309 | + lm_loss = loss[0] |
| 310 | + if not megatron_core_013: |
297 | 311 | # fix megatron-lm bug
|
298 | 312 | # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291
|
299 |
| - loss[0] / mpu.get_context_parallel_world_size(), |
| 313 | + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) |
| 314 | + lm_loss = lm_loss / mpu.get_context_parallel_world_size() |
| 315 | + reporting_loss = (reporting_loss[0], reporting_loss[1]) |
| 316 | + else: |
| 317 | + lm_loss = lm_loss.clone() |
| 318 | + local_num_tokens = loss[1].clone().detach().to(torch.int) |
| 319 | + return ( |
| 320 | + lm_loss, |
300 | 321 | local_num_tokens,
|
301 | 322 | {
|
302 |
| - 'lm loss': (reporting_loss[0], reporting_loss[1]) |
| 323 | + 'lm loss': reporting_loss |
303 | 324 | },
|
304 | 325 | )
|
305 | 326 |
|
|
0 commit comments