Skip to content

Refactor SP #4121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/Instruction/Megatron-SWIFT训练.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ SWIFT引入了Megatron的并行技术来加速大模型的训练,包括数据
使用Megatron-SWIFT,除了安装swift依赖外,还需要安装以下内容:

```shell
# 推荐torch版本:2.5 / 2.6
pip install pybind11
# transformer_engine
# 若出现安装错误,可以参考该issue解决: https://github.com/modelscope/ms-swift/issues/3793
Expand Down
1 change: 1 addition & 0 deletions docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
- 注意:若对deepseek-r1/qwq模型使用不包含`<think>...</think>`的数据集进行训练,请加在推理训练后模型时额外传入`--response_prefix ''`
- padding_side: 当训练`batch_size>=2`时的padding_side,可选值为'left'、'right',默认为'right'。(推理时的batch_size>=2时,只进行左padding)
- loss_scale: 训练tokens的loss权重设置。默认为`'default'`,代表所有response(含history)以1计算交叉熵损失。可选值为'default'、'last_round'、'all',以及agent需要的loss_scale: 'react'、'agentflan'、'alpha_umi'和'qwen'。其中'last_round'代表只计算最后一轮response的损失,'all'代表计算所有tokens的损失。agent部分可以查看[插件化](../Customization/插件化.md)和[Agent文档](./Agent支持.md)
- sequence_parallel_size: 序列并行大小,默认是1。当前支持pt/sft/dpo。
- use_chat_template: 使用chat模板或generation模板,默认为`True`。`swift pt`会自动设置为generation模板
- template_backend: 选择template后端,可选为'swift'、'jinja',默认为'swift'。如果使用jinja,则使用transformers的`apply_chat_template`。
- 注意:jinja的template后端只支持推理,不支持训练。
Expand Down
1 change: 1 addition & 0 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Hints:
- Note: If you are training the deepseek-r1/qwq model with a dataset that does not include `<think>...</think>`, please pass `--response_prefix ''` additionally when inferring after training.
- padding_side: Padding side when `batch_size>=2` during training. Options are 'left' and 'right', with 'right' as the default. (For inference with batch_size>=2, only left padding is applied.)
- loss_scale: Setting for the loss weight of training tokens. Default is `'default'`, meaning all responses (including history) are calculated with a cross-entropy loss of 1. Options are 'default', 'last_round', 'all', and agent-specific loss scales: 'react', 'agentflan', 'alpha_umi', and 'qwen'. 'last_round' means calculating only the loss of the last round's response, and 'all' calculates the loss for all tokens. For agent parts, see [Pluginization](../Customization/Pluginization.md) and [Agent Training](./Agent-support.md).
- sequence_parallel_size: Sequence parallelism size, default is 1. Currently supported in pt/sft/dpo.
- use_chat_template: Use chat template or generation template, default is `True`. `swift pt` is automatically set to the generation template.
- template_backend: Selection of the template backend. Options are 'swift' and 'jinja', with 'swift' as the default. If using jinja, it applies transformer's `apply_chat_template`.
- Note: The jinja template backend supports only inference, not training.
Expand Down
1 change: 1 addition & 0 deletions docs/source_en/Instruction/Megatron-SWIFT-Training.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ SWIFT incorporates Megatron's parallelization techniques to accelerate the train
To use Megatron-SWIFT, in addition to installing the `swift` dependencies, you also need to install the following:

```shell
# Recommended PyTorch version: 2.5 / 2.6
pip install pybind11
# transformer_engine
# If an installation error occurs, you can refer to this issue for resolution: https://github.com/modelscope/ms-swift/issues/3793
Expand Down
19 changes: 13 additions & 6 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import asdict
from functools import partial, wraps
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import torch.distributed as dist
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -1331,7 +1331,7 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
keys = [
'input_ids', 'inputs_embeds', 'attention_mask', 'labels', 'loss_scale', 'position_ids', 'token_type_ids'
]
pad_value = [self.tokenizer.pad_token_id, 0., 0, -100, 0., 0., 0]
pad_values = [self.tokenizer.pad_token_id, 0., 0, -100, 0., 0., 0]
# Convert to tensor and remove unnecessary dimensions.
seq_lens = None
for key in keys:
Expand All @@ -1352,8 +1352,18 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in

if self.use_megatron:
padding_to = math.ceil(max(seq_lens) / 128) * 128
elif self.sequence_parallel_size > 1:
seq_len = max(seq_lens)
sp_seq_len = math.ceil(max(seq_lens) / self.sequence_parallel_size)
from swift.trainers.sequence_parallel import sequence_parallel
for k, v in res.items():
sp_rank = dist.get_rank(sequence_parallel.sp_group)
new_v = torch.split(v, sp_seq_len, dim=1)
res[k] = new_v[sp_rank].contiguous()

for key, pad_value in zip(keys, pad_value):
padding_to = sp_seq_len + 1

for key, pad_value in zip(keys, pad_values):
if key not in res:
continue
if padding_to is not None:
Expand All @@ -1365,9 +1375,6 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in

# multimodal
res.update(self._data_collator_mm_data(batch))
if use_torchacc() or self.sequence_parallel_size > 1:
res = self._torchacc_xtuner_data_collator(res, padding_to, self.tokenizer, padding_side)

return res

def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
Expand Down
68 changes: 67 additions & 1 deletion swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@

import safetensors
import torch
import torch.distributed as dist
import torch.nn as nn
import transformers
from datasets import Dataset as HfDataset
from modelscope import check_local_model_is_latest
from packaging import version
from peft import PeftModel
from torch.nn import Module
from torch.utils.data import DataLoader
from transformers import PreTrainedModel
from transformers.data.data_collator import DataCollator
from transformers.integrations import is_deepspeed_zero3_enabled
Expand All @@ -28,7 +30,7 @@
from transformers.utils import is_torch_npu_available

from swift.hub import get_hub
from swift.llm import Template
from swift.llm import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard, Template
from swift.plugin import MeanMetric, compute_acc, extra_tuners
from swift.tuners import SwiftModel
from swift.utils import get_logger, is_mp_ddp, use_torchacc
Expand Down Expand Up @@ -100,6 +102,9 @@ def __init__(self,
self.can_return_loss = can_return_loss(model)
self.label_names = self.label_names or ['labels']
self.start_time = time.time()
if self.template.sequence_parallel_size > 1:
from swift.trainers.sequence_parallel import sequence_parallel
sequence_parallel.prepare_trainer(self)

def _save_initial_model(self, output_dir):
# pissa/olora/lora-ga
Expand Down Expand Up @@ -431,3 +436,64 @@ def _evalscope_eval(self):

self.model.train()
return eval_dict

def get_batch_samples(self, *args, **kwargs):
res = super().get_batch_samples(*args, **kwargs)
if self.template.sequence_parallel_size == 1:
return res
batch_samples, num_items_in_batch = res
dist.all_reduce(num_items_in_batch, dist.ReduceOp.SUM)
return batch_samples, num_items_in_batch


class DataLoaderMixin:

def get_train_dataloader(self):
dataloader = None
if self.template.sequence_parallel_size > 1:
from swift.trainers.sequence_parallel import sequence_parallel
dataloader = sequence_parallel.get_dataloader(self, self.train_dataset, self._train_batch_size)
if dataloader is None:
# Higher efficiency
if self.train_dataset is None:
raise ValueError('Trainer: training requires a train_dataset.')
args = self.args
train_dataset = self.train_dataset

dataloader_params = {
'collate_fn': self.data_collator,
'num_workers': args.dataloader_num_workers,
'pin_memory': args.dataloader_pin_memory,
'persistent_workers': args.dataloader_persistent_workers,
'prefetch_factor': args.dataloader_prefetch_factor
}
batch_sampler_params = {
'drop_last': args.dataloader_drop_last,
'shuffle': args.train_dataloader_shuffle,
'data_seed': args.data_seed,
}

if hasattr(train_dataset, '__len__'):
batch_sampler = BatchSamplerShard(
len(train_dataset), batch_size=self._train_batch_size, **batch_sampler_params)
dataloader = DataLoaderShard(train_dataset, batch_sampler, **dataloader_params)
else:
# IterableDataset
if dist.is_initialized():
dataloader_params['prefetch_factor'] = dataloader_params['prefetch_factor'] * dist.get_world_size()
dataloader = DataLoader(train_dataset, batch_size=self._train_batch_size, **dataloader_params)
dataloader = DataLoaderDispatcher(dataloader)

return dataloader

def get_eval_dataloader(self, eval_dataset=None):
dataloader = None
if self.template.sequence_parallel_size > 1:
from swift.trainers.sequence_parallel import sequence_parallel
if eval_dataset is None and self.eval_dataset is None:
raise ValueError('Trainer: evaluation requires an eval_dataset.')
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
dataloader = sequence_parallel.get_dataloader(self, eval_dataset, self.args.eval_batch_size)
if dataloader is None:
return super().get_eval_dataloader(eval_dataset=eval_dataset)
return dataloader
28 changes: 2 additions & 26 deletions swift/trainers/rlhf_trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from transformers import PreTrainedModel
from trl import DPOTrainer as HFDPOTrainer

from ..mixin import SwiftMixin
from ..mixin import DataLoaderMixin, SwiftMixin
from .rlhf_mixin import RLHFTrainerMixin

del HFDPOTrainer.__init__


class DPOTrainer(RLHFTrainerMixin, SwiftMixin, HFDPOTrainer):
class DPOTrainer(RLHFTrainerMixin, SwiftMixin, DataLoaderMixin, HFDPOTrainer):

def __init__(self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
Expand All @@ -35,30 +35,6 @@ def __init__(self,
self.use_weighting = False

super().__init__(model, ref_model, *_args, **kwargs)
if self.template.sequence_parallel_size > 1:
from swift.trainers.sequence_parallel import sequence_parallel
sequence_parallel.prepare_trainer(self)

def get_train_dataloader(self):
dataloader = None
if self.template.sequence_parallel_size > 1:
from swift.trainers.sequence_parallel import sequence_parallel
dataloader = sequence_parallel.get_dataloader(self, self.train_dataset, self._train_batch_size)
if dataloader is None:
return super().get_train_dataloader()
return dataloader

def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None):
if eval_dataset is None and self.eval_dataset is None:
raise ValueError('Trainer: evaluation requires an eval_dataset.')
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
dataloader = None
if self.template.sequence_parallel_size > 1:
from swift.trainers.sequence_parallel import sequence_parallel
dataloader = sequence_parallel.get_dataloader(self, eval_dataset, self.args.eval_batch_size)
if dataloader is None:
return super().get_eval_dataloader(eval_dataset=eval_dataset)
return dataloader

def get_nll_loss(self, logits, labels):
if not self.is_encoder_decoder:
Expand Down
34 changes: 0 additions & 34 deletions swift/trainers/sequence_parallel/ulysses.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,6 @@ def prepare_trainer(self, trainer):
if trainer.train_dataset is None:
raise ValueError('Trainer: training requires a train_dataset.')

trainer.compute_loss_func = partial(loss_scale_sp_func, process_group=self.sp_group)
if hasattr(trainer, 'get_batch_logps'):
trainer.get_batch_logps = partial(get_batch_logps, process_group=self.sp_group)
if hasattr(trainer, 'get_nll_loss'):
Expand All @@ -519,36 +518,3 @@ def rlhf_loss_scale_sp_func(_, *args, **kwargs):
return loss_scale_sp_func(*args, process_group=self.sp_group, **kwargs)

trainer.get_nll_loss = MethodType(rlhf_loss_scale_sp_func, trainer)

def _compute_acc(trainer, outputs, labels) -> None:
args = trainer.args
acc_steps = args.acc_steps
preds = outputs.logits.argmax(dim=-1)
if trainer.state.global_step % acc_steps == 0:
# Gather preds and labels across the sp group
shape0 = preds.shape[0]
preds_output = torch.empty((shape0 * self.sp_world_size, preds.shape[1]),
dtype=preds.dtype,
device=preds.device)
dist.all_gather_into_tensor(preds_output, preds, group=self.sp_group)
preds_output = torch.cat(preds_output.split(shape0, dim=0), dim=1)
shape0 = labels.shape[0]
labels_output = torch.empty((shape0 * self.sp_world_size, labels.shape[1]),
dtype=labels.dtype,
device=labels.device)
dist.all_gather_into_tensor(labels_output, labels, group=self.sp_group)
labels_output = torch.cat(labels_output.split(shape0, dim=0), dim=1)
# roll back to fit compute_acc
labels_output = torch.roll(labels_output, shifts=1, dims=1)
from swift.plugin import MeanMetric, compute_acc
metrics = compute_acc(
preds_output,
labels_output,
acc_strategy=args.acc_strategy,
is_encoder_decoder=trainer.template.is_encoder_decoder)
for k, v in metrics.items():
if k not in trainer._custom_metrics:
trainer._custom_metrics[k] = MeanMetric(nan_value=None)
trainer._custom_metrics[k].update(v)

trainer._compute_acc = MethodType(_compute_acc, trainer)
57 changes: 2 additions & 55 deletions swift/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,18 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from peft import PeftModel
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from transformers import EvalPrediction
from transformers import Seq2SeqTrainer as HfSeq2SeqTrainer
from transformers import Trainer as HfTrainer
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.utils import is_peft_available

from swift.llm import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard
from swift.utils import JsonlWriter, Serializer, gc_collect
from .arguments import Seq2SeqTrainingArguments, TrainingArguments
from .mixin import SwiftMixin
from .mixin import DataLoaderMixin, SwiftMixin


class Trainer(SwiftMixin, HfTrainer):
Expand Down Expand Up @@ -80,7 +77,7 @@ def calculate_metric(self, eval_prediction: EvalPrediction) -> Dict[str, float]:
return calculate_paired_metrics(eval_prediction.predictions, eval_prediction.label_ids)


class Seq2SeqTrainer(SwiftMixin, HfSeq2SeqTrainer):
class Seq2SeqTrainer(SwiftMixin, DataLoaderMixin, HfSeq2SeqTrainer):
args: Seq2SeqTrainingArguments

def __init__(self, *args, **kwargs):
Expand All @@ -91,9 +88,6 @@ def __init__(self, *args, **kwargs):
self.infer_engine = PtEngine.from_model_template(
self.model, self.template, max_batch_size=self.args.per_device_eval_batch_size)
self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'predict.jsonl'))
if self.template.sequence_parallel_size > 1:
from swift.trainers.sequence_parallel import sequence_parallel
sequence_parallel.prepare_trainer(self)

@staticmethod
def _predict_data_collator(batch):
Expand All @@ -117,53 +111,6 @@ def _patch_predict_with_generate(self):
self.data_collator = origin_data_collator
self.template.set_mode(origin_mode)

def get_eval_dataloader(self, eval_dataset=None):
dataloader = None
if self.template.sequence_parallel_size > 1:
from swift.trainers.sequence_parallel import sequence_parallel
dataloader = sequence_parallel.get_dataloader(self, eval_dataset, self.args.eval_batch_size)
if dataloader is None:
return super().get_eval_dataloader(eval_dataset)
return dataloader

def get_train_dataloader(self):
dataloader = None
if self.template.sequence_parallel_size > 1:
from swift.trainers.sequence_parallel import sequence_parallel
dataloader = sequence_parallel.get_dataloader(self, self.train_dataset, self._train_batch_size)
if dataloader is None:
# Higher efficiency
if self.train_dataset is None:
raise ValueError('Trainer: training requires a train_dataset.')
args = self.args
train_dataset = self.train_dataset

dataloader_params = {
'collate_fn': self.data_collator,
'num_workers': args.dataloader_num_workers,
'pin_memory': args.dataloader_pin_memory,
'persistent_workers': args.dataloader_persistent_workers,
'prefetch_factor': args.dataloader_prefetch_factor
}
batch_sampler_params = {
'drop_last': args.dataloader_drop_last,
'shuffle': args.train_dataloader_shuffle,
'data_seed': args.data_seed,
}

if hasattr(train_dataset, '__len__'):
batch_sampler = BatchSamplerShard(
len(train_dataset), batch_size=self._train_batch_size, **batch_sampler_params)
dataloader = DataLoaderShard(train_dataset, batch_sampler, **dataloader_params)
else:
# IterableDataset
if dist.is_initialized():
dataloader_params['prefetch_factor'] = dataloader_params['prefetch_factor'] * dist.get_world_size()
dataloader = DataLoader(train_dataset, batch_size=self._train_batch_size, **dataloader_params)
dataloader = DataLoaderDispatcher(dataloader)

return dataloader

def evaluate(self, *args, **kwargs):
context = self._patch_predict_with_generate() if self.args.predict_with_generate else nullcontext()
with context:
Expand Down
Loading