Skip to content

Commit 68aceef

Browse files
tastelikefeetJintao-Huang
authored andcommitted
optimize imports (#4883)
1 parent 0edb218 commit 68aceef

File tree

3 files changed

+14
-36
lines changed

3 files changed

+14
-36
lines changed

swift/trainers/sequence_parallel/ring_attention.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,17 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2-
import math
32
import os
43
from functools import partial
54
from types import MethodType
6-
from typing import Any, Dict, Iterator, List
75

8-
import datasets
9-
import numpy as np
106
import torch
11-
import torch.distributed as dist
12-
import torch.nn as nn
137
import torch.nn.functional as F
14-
from datasets import Dataset
158
from packaging import version
16-
from torch.distributed.device_mesh import init_device_mesh
17-
from torch.nn import CrossEntropyLoss
18-
from torch.utils.data import DataLoader, Sampler
199

20-
from swift.llm import DataLoaderDispatcher, DataLoaderShard, get_llm_model, to_device
21-
from swift.utils import get_current_device, get_device, get_dist_setting, seed_worker
10+
from swift.llm import get_llm_model
2211
from .base import CommonSequenceParallel
23-
from .utils import (ChunkedCrossEntropyLoss, GatherLoss, SequenceParallelDispatcher, SequenceParallelSampler,
24-
_get_per_token_logps_grpo, _get_train_sampler_grpo, _prepare_inputs, _prepare_inputs_grpo,
25-
get_common_dataloader, get_per_token_logps, loss_scale_sp_func, old_policy_grpo, setup_compute_acc,
12+
from .utils import (SequenceParallelDispatcher, SequenceParallelSampler, _get_per_token_logps_grpo,
13+
_get_train_sampler_grpo, _prepare_inputs, _prepare_inputs_grpo, get_common_dataloader,
14+
get_per_token_logps, loss_scale_sp_func, old_policy_grpo, setup_compute_acc,
2615
split_by_mini_batches_grpo)
2716

2817
RING_ATTN_GROUP = None

swift/trainers/sequence_parallel/ulysses.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,17 @@
1-
import math
2-
import os
3-
from contextlib import contextmanager
41
from functools import partial
52
from types import MethodType
6-
from typing import Any, Dict, Iterator, List, Optional, Tuple
3+
from typing import Any, Optional, Tuple
74

8-
import datasets
9-
import numpy as np
105
import torch
116
import torch.distributed as dist
12-
import trl
137
from packaging import version
14-
from torch.distributed.device_mesh import init_device_mesh
15-
from torch.nn import CrossEntropyLoss
16-
from torch.utils.data import DataLoader, Sampler
17-
from trl.extras.profiling import profiling_decorator
188

19-
from swift.llm import DataLoaderDispatcher, DataLoaderShard, get_llm_model, to_device
20-
from swift.utils import get_current_device, get_device, get_dist_setting, seed_worker
9+
from swift.llm import get_llm_model
2110
from .base import CommonSequenceParallel
22-
from .utils import (ChunkedCrossEntropyLoss, GatherLoss, SequenceParallelDispatcher, SequenceParallelSampler,
23-
_get_per_token_logps_grpo, _get_train_sampler_grpo, _prepare_inputs, _prepare_inputs_grpo,
24-
get_common_dataloader, get_per_token_logps, loss_scale_sp_func, old_policy_grpo,
25-
padding_free_context_grpo, setup_compute_acc, split_by_mini_batches_grpo)
11+
from .utils import (SequenceParallelDispatcher, SequenceParallelSampler, _get_per_token_logps_grpo,
12+
_get_train_sampler_grpo, _prepare_inputs, _prepare_inputs_grpo, get_common_dataloader,
13+
get_per_token_logps, loss_scale_sp_func, old_policy_grpo, setup_compute_acc,
14+
split_by_mini_batches_grpo)
2615

2716
assert version.parse(torch.__version__) >= version.parse('2.0.0')
2817
torch._dynamo.config.capture_dynamic_output_shape_ops = True
@@ -289,6 +278,7 @@ def prepare_trainer(self, trainer):
289278
trainer.get_per_token_logps = partial(get_per_token_logps, sp_instance=self)
290279

291280
elif trainer.__class__.__name__ == 'GRPOTrainer':
281+
import trl
292282
assert version.parse(trl.__version__) >= version.parse('0.18.0')
293283
trainer.ulysses = self
294284
trainer.args.gradient_accumulation_steps = trainer.args.gradient_accumulation_steps * self.sp_world_size

swift/trainers/sequence_parallel/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import os
44
from contextlib import contextmanager
55
from functools import partial
6-
from types import MethodType
7-
from typing import Any, Dict, Iterator, List, Optional, Tuple
6+
from typing import Any, Dict, Iterator, List, Tuple
87

98
import datasets
109
import numpy as np
@@ -13,8 +12,8 @@
1312
from torch.nn import CrossEntropyLoss
1413
from torch.utils.data import DataLoader, Sampler
1514

16-
from swift.llm import DataLoaderDispatcher, DataLoaderShard, get_llm_model, to_device
17-
from swift.utils import get_current_device, get_device, get_dist_setting, seed_worker
15+
from swift.llm import DataLoaderDispatcher, DataLoaderShard, get_llm_model
16+
from swift.utils import get_current_device, get_dist_setting, seed_worker
1817

1918
# Conditional import for profiling decorator
2019
try:

0 commit comments

Comments
 (0)