Skip to content

[megatron] Support Qwen3 #3995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
</p>

<p align="center">
<a href="https://arxiv.org/abs/2408.05517">Paper</a> &nbsp | <a href="https://swift.readthedocs.io/en/latest/">Swift3.x En Doc</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">Swift3.x中文文档</a> &nbsp
<a href="https://arxiv.org/abs/2408.05517">Paper</a> &nbsp | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a> &nbsp
</p>

## 📖 Table of Contents
Expand Down
2 changes: 1 addition & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
</p>

<p align="center">
<a href="https://arxiv.org/abs/2408.05517">论文</a> &nbsp | <a href="https://swift.readthedocs.io/en/latest/">Swift3.x En Doc</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">Swift3.x中文文档</a> &nbsp
<a href="https://arxiv.org/abs/2408.05517">论文</a> &nbsp | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a> &nbsp
</p>

## 📖 目录
Expand Down
18 changes: 8 additions & 10 deletions swift/llm/train/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from datasets import Dataset as HfDataset

from swift.plugin import extra_callbacks, get_loss_func, get_metric
from swift.trainers import IntervalStrategy, TrainerFactory
from swift.trainers import TrainerFactory
from swift.utils import (append_to_jsonl, get_logger, get_model_parameter_info, is_master, plot_images, stat_array,
use_torchacc)
from ..argument import TrainArguments
Expand Down Expand Up @@ -108,7 +108,7 @@ def _get_data_collator(self):

@staticmethod
def _save_val_dataset(output_dir: str, val_dataset):
if isinstance(val_dataset, HfDataset):
if is_master() and isinstance(val_dataset, HfDataset):
os.makedirs(output_dir, exist_ok=True)
val_dataset_path = os.path.join(output_dir, 'val_dataset.jsonl')
append_to_jsonl(val_dataset_path, val_dataset.to_list())
Expand All @@ -118,8 +118,11 @@ def run(self):
args = self.args

train_dataset, val_dataset = self._get_dataset()
self._save_val_dataset(args.output_dir, val_dataset)
train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset)

if args.task_type == 'seq_cls':
args.problem_type = args.problem_type or getattr(self.model.config, 'problem_type', None)
logger.info(f'args.problem_type: {args.problem_type}')
args.save_args()

data_collator = self._get_data_collator()
Expand Down Expand Up @@ -239,6 +242,8 @@ def _stat_dataset(self, dataset: HfDataset):
def _encode_dataset(self, train_dataset, val_dataset):
template = self.template
args = self.args
output_dir = getattr(args, 'output_dir', None) or getattr(args, 'save')
self._save_val_dataset(output_dir, val_dataset)
is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo'
predict_with_generate = getattr(args, 'predict_with_generate', False)
if not is_grpo:
Expand Down Expand Up @@ -269,13 +274,6 @@ def _encode_dataset(self, train_dataset, val_dataset):
if val_dataset is not None and not predict_with_generate:
self.train_msg['val_dataset'] = self._stat_dataset(val_dataset)

if val_dataset is None and hasattr(args, 'training_args'):
args.training_args.evaluation_strategy = IntervalStrategy.NO
args.training_args.eval_strategy = IntervalStrategy.NO

if args.task_type == 'seq_cls':
args.problem_type = args.problem_type or getattr(self.model.config, 'problem_type', None)
logger.info(f'args.problem_type: {args.problem_type}')
return train_dataset, val_dataset


Expand Down
5 changes: 5 additions & 0 deletions swift/megatron/argument/megatron_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class ExtraMegatronArguments:
padded_vocab_size: Optional[int] = None
rope_scaling: Optional[Union[dict, str]] = None
torch_dtype: Optional[torch.dtype] = None
model_type: Optional[str] = None


@dataclass
Expand Down Expand Up @@ -102,6 +103,8 @@ class MegatronArguments(ExtraMegatronArguments):
add_qkv_bias: bool = True
attention_dropout: float = 0.
hidden_dropout: float = 0.
kv_channels: Optional[int] = None
qk_layernorm: bool = False
transformer_impl: Literal['local', 'transformer_engine'] = 'transformer_engine'

# mixed precision
Expand Down Expand Up @@ -136,6 +139,8 @@ def _init_mixed_precision(self):
ModelArguments._init_mixed_precision(self)
if self.apply_query_key_layer_scaling is None:
self.apply_query_key_layer_scaling = self.fp16
if self.apply_query_key_layer_scaling:
os.environ['NVTE_APPLY_QK_LAYER_SCALING'] = '1'

def __post_init__(self):
from swift.llm.argument.base_args.model_args import ModelArguments
Expand Down
4 changes: 3 additions & 1 deletion swift/megatron/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
'untie_embeddings_and_output_weights': ['tie_word_embeddings'],
'swiglu': ['hidden_act'],
'add_qkv_bias': ['attention_bias'],
'disable_bias_linear': ['mlp_bias']
'disable_bias_linear': ['mlp_bias'],
'kv_channels': ['head_dim'],
'model_type': ['model_type'],
}


Expand Down
5 changes: 3 additions & 2 deletions swift/megatron/model/gpt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.llm import ModelType
from ..config import convert_hf_config
from ..constant import MegatronModelType
from ..register import MegatronModelMeta, register_megatron_model
from .config import convert_gpt_hf_config
from .hf2mcore import convert_hf2mcore
from .mcore2hf import convert_mcore2hf
from .model import model_provider
Expand Down Expand Up @@ -34,4 +34,5 @@
ModelType.numina,
ModelType.ziya,
ModelType.mengzi3,
], model_provider, convert_hf_config, convert_mcore2hf, convert_hf2mcore))
ModelType.qwen3,
], model_provider, convert_gpt_hf_config, convert_mcore2hf, convert_hf2mcore))
10 changes: 10 additions & 0 deletions swift/megatron/model/gpt/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Any, Dict

from ..config import convert_hf_config


def convert_gpt_hf_config(config) -> Dict[str, Any]:
res = convert_hf_config(config)
if res.get('model_type') == 'qwen3':
res['qk_layernorm'] = True
return res
3 changes: 3 additions & 0 deletions swift/megatron/model/gpt/hf2mcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def set_attn_state(args, mg_layer, hf_layer):
hf_attn.v_proj.bias.reshape((num_query_groups, -1)),
],
dim=1).reshape(-1))
if args.qk_layernorm:
mg_attn.q_layernorm.weight.data.copy_(hf_attn.q_norm.weight)
mg_attn.k_layernorm.weight.data.copy_(hf_attn.k_norm.weight)


def set_mlp_state(args, mg_layer, hf_layer):
Expand Down
4 changes: 4 additions & 0 deletions swift/megatron/model/gpt/mcore2hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def set_attn_state(args, mg_layer, hf_layer):
hf_attn.k_proj.bias.data.copy_(mg_attn_bias[:, q_dim:-kv_dim].reshape(-1))
hf_attn.v_proj.bias.data.copy_(mg_attn_bias[:, -kv_dim:].reshape(-1))

if args.qk_layernorm:
hf_attn.q_norm.weight.data.copy_(mg_attn.q_layernorm.weight)
hf_attn.k_norm.weight.data.copy_(mg_attn.k_layernorm.weight)


def set_mlp_state(args, mg_layer, hf_layer):
hf_layer.mlp.gate_proj.weight.data.copy_(mg_layer.mlp.linear_fc1.weight[:args.ffn_hidden_size])
Expand Down
6 changes: 5 additions & 1 deletion swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import unwrap_model
from transformers.trainer import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.trainer_utils import EvalPrediction, IntervalStrategy
from transformers.utils import is_torch_npu_available

from swift.hub import get_hub
Expand Down Expand Up @@ -68,6 +68,10 @@ def __init__(self,
'invoked_by': 'local_trainer',
'third_party': 'swift',
})
if eval_dataset is None and args:
args.evaluation_strategy = IntervalStrategy.NO
args.eval_strategy = IntervalStrategy.NO

self._custom_metrics = {}
self.template = template
self.max_memory = 0
Expand Down
7 changes: 6 additions & 1 deletion tests/megatron/test_align/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def test_llama3_2():
_test_model('LLM-Research/Llama-3.2-1B-Instruct')


def test_qwen3():
_test_model('Qwen/Qwen3-0.6B-Base')


if __name__ == '__main__':
# test_llama2()
# test_llama3()
Expand All @@ -54,4 +58,5 @@ def test_llama3_2():
# test_yi()
# test_megrez()
# test_llama3_1()
test_llama3_2()
# test_llama3_2()
test_qwen3()