|
1 |
| -import math |
2 |
| -import os |
3 |
| -from contextlib import contextmanager |
4 | 1 | from functools import partial
|
5 | 2 | from types import MethodType
|
6 |
| -from typing import Any, Dict, Iterator, List, Optional, Tuple |
| 3 | +from typing import Any, Optional, Tuple |
7 | 4 |
|
8 |
| -import datasets |
9 |
| -import numpy as np |
10 | 5 | import torch
|
11 | 6 | import torch.distributed as dist
|
12 |
| -import trl |
13 | 7 | from packaging import version
|
14 |
| -from torch.distributed.device_mesh import init_device_mesh |
15 |
| -from torch.nn import CrossEntropyLoss |
16 |
| -from torch.utils.data import DataLoader, Sampler |
17 |
| -from trl.extras.profiling import profiling_decorator |
18 | 8 |
|
19 |
| -from swift.llm import DataLoaderDispatcher, DataLoaderShard, get_llm_model, to_device |
20 |
| -from swift.utils import get_current_device, get_device, get_dist_setting, seed_worker |
| 9 | +from swift.llm import get_llm_model |
21 | 10 | from .base import CommonSequenceParallel
|
22 |
| -from .utils import (ChunkedCrossEntropyLoss, GatherLoss, SequenceParallelDispatcher, SequenceParallelSampler, |
23 |
| - _get_per_token_logps_grpo, _get_train_sampler_grpo, _prepare_inputs, _prepare_inputs_grpo, |
24 |
| - get_common_dataloader, get_per_token_logps, loss_scale_sp_func, old_policy_grpo, |
25 |
| - padding_free_context_grpo, setup_compute_acc, split_by_mini_batches_grpo) |
| 11 | +from .utils import (SequenceParallelDispatcher, SequenceParallelSampler, _get_per_token_logps_grpo, |
| 12 | + _get_train_sampler_grpo, _prepare_inputs, _prepare_inputs_grpo, get_common_dataloader, |
| 13 | + get_per_token_logps, loss_scale_sp_func, old_policy_grpo, setup_compute_acc, |
| 14 | + split_by_mini_batches_grpo) |
26 | 15 |
|
27 | 16 | assert version.parse(torch.__version__) >= version.parse('2.0.0')
|
28 | 17 | torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
@@ -289,6 +278,7 @@ def prepare_trainer(self, trainer):
|
289 | 278 | trainer.get_per_token_logps = partial(get_per_token_logps, sp_instance=self)
|
290 | 279 |
|
291 | 280 | elif trainer.__class__.__name__ == 'GRPOTrainer':
|
| 281 | + import trl |
292 | 282 | assert version.parse(trl.__version__) >= version.parse('0.18.0')
|
293 | 283 | trainer.ulysses = self
|
294 | 284 | trainer.args.gradient_accumulation_steps = trainer.args.gradient_accumulation_steps * self.sp_world_size
|
|
0 commit comments