Skip to content

refactor mm target_regex (compat peft/vllm) #3879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/notebook/qwen2vl-ocr/ocr-sft.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,11 @@
"logger.info(f'model_info: {model.model_info}')\n",
"template = get_template(model.model_meta.template, processor, default_system=system, max_length=max_length)\n",
"template.set_mode('train')\n",
"if template.use_model:\n",
" template.model = model\n",
"\n",
"# Get target_modules and add trainable LoRA modules to the model.\n",
"model_arch = get_model_arch(model.model_meta.model_arch)\n",
"target_modules = get_multimodal_target_regex(model_arch, freeze_llm=freeze_llm, freeze_vit=freeze_vit, \n",
"target_modules = get_multimodal_target_regex(model, freeze_llm=freeze_llm, freeze_vit=freeze_vit, \n",
" freeze_aligner=freeze_aligner)\n",
"lora_config = LoraConfig(task_type='CAUSAL_LM', r=lora_rank, lora_alpha=lora_alpha,\n",
" target_modules=target_modules)\n",
Expand Down Expand Up @@ -217,7 +218,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
"version": "3.11.11"
}
},
"nbformat": 4,
Expand Down
28 changes: 28 additions & 0 deletions examples/train/full/qwen2_5_32b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# 8 * 80GiB
NPROC_PER_NODE=8 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
swift sft \
--model Qwen/Qwen2.5-32B \
--train_type full \
--dataset 'liucong/Chinese-DeepSeek-R1-Distill-data-110k-SFT' \
--torch_dtype bfloat16 \
--max_steps 2000 \
--streaming true \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--learning_rate 1e-5 \
--gradient_accumulation_steps 2 \
--packing true \
--eval_steps 200 \
--save_steps 200 \
--logging_steps 5 \
--max_length 8192 \
--warmup_ratio 0.05 \
--dataloader_num_workers 8 \
--dataset_num_proc 8 \
--save_total_limit 2 \
--save_only_model true \
--output_dir output/Qwen2.5-32B \
--deepspeed zero3 \
--use_liger_kernel true \
--attn_impl flash_attn
28 changes: 28 additions & 0 deletions examples/train/moe/llama4.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Manually select `target_modules` to avoid 'all-linear' selecting 'router'
NPROC_PER_NODE=4 \
USE_HF=1 \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
swift sft \
--model meta-llama/Llama-4-Scout-17B-16E-Instruct \
--dataset 'linxy/LaTeX_OCR:full#5000' \
--train_type lora \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--learning_rate 1e-4 \
--lora_rank 8 \
--lora_alpha 32 \
--target_regex '^(language_model)\..*\.(q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)$' \
--freeze_vit true \
--gradient_accumulation_steps 4 \
--gradient_checkpointing true \
--eval_steps 50 \
--save_steps 50 \
--save_total_limit 2 \
--logging_steps 5 \
--max_length 2048 \
--output_dir output \
--warmup_ratio 0.05 \
--deepspeed zero3 \
--dataloader_num_workers 4
28 changes: 28 additions & 0 deletions examples/train/moe/qwen2_5_moe.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Manually select `target_modules` to avoid 'all-linear' selecting 'gate'
CUDA_VISIBLE_DEVICES=0,1 \
swift sft \
--model Qwen/Qwen2-57B-A14B-Instruct \
--train_type lora \
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
'AI-ModelScope/alpaca-gpt4-data-en#500' \
'swift/self-cognition#500' \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--learning_rate 1e-4 \
--lora_rank 8 \
--lora_alpha 32 \
--target_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj \
--gradient_accumulation_steps 16 \
--eval_steps 50 \
--save_steps 50 \
--save_total_limit 2 \
--logging_steps 5 \
--max_length 2048 \
--output_dir output \
--system 'You are a helpful assistant.' \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--model_author swift \
--model_name swift-robot
38 changes: 17 additions & 21 deletions swift/llm/train/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Union

import torch
import torch.nn as nn
import transformers
from packaging import version
from transformers import TrainingArguments
Expand Down Expand Up @@ -45,40 +46,36 @@ def apply_liger(model_type: str):


def get_multimodal_target_regex(
model_arch,
model,
*,
freeze_llm: bool = False,
freeze_vit: bool = True,
freeze_aligner: bool = True,
ignore_embedding: bool = True,
include_embedding: bool = True,
) -> str:
model_arch = get_model_arch(model.model_meta.model_arch)
modules = []
rejected_modules = []
if not freeze_llm:
modules += model_arch.language_model
if not freeze_vit:
modules += model_arch.vision_tower
if freeze_aligner:
rejected_modules += model_arch.aligner
else:
if not freeze_aligner:
modules += model_arch.aligner
elif not freeze_vit:
rejected_modules += model_arch.aligner

assert len(modules) > 0, f'modules: {modules}'
prefix_pattern = '|'.join(modules)
rejected_pattern = '|'.join(rejected_modules)

ignore_pattern = ['lora_A', 'lora_B', 'base_layer']
if ignore_embedding:
ignore_pattern += [r'\w*emb\w*', 'wte', 'shared']
ignore_pattern += model_arch.embedding or []
# lm_head
ignore_pattern += ['lm_head', 'output', 'score', 'v_head', 'classifier']
ignore_pattern += model_arch.lm_head or []
ignore_pattern = '|'.join(ignore_pattern)

target_regex = rf'^({prefix_pattern})'
if ignore_pattern:
target_regex += rf'(?!.*\b({ignore_pattern})\b).*'
extra_layers = []
if include_embedding:
extra_layers.append(nn.Embedding)
target_modules = []
for module in modules:
target_modules += find_all_linears(model, model_arch, extra_layers, sub_module=module)
target_regex = rf'^({prefix_pattern})\..*\.({"|".join(target_modules)})$'
if rejected_pattern:
target_regex = rf'(?!^({rejected_pattern}))' + target_regex
return target_regex
Expand All @@ -91,14 +88,13 @@ def get_target_modules(args, model) -> Union[str, List[str]]:
return args.target_modules
target_modules = args.target_modules.copy()
if 'all-linear' in target_modules:
model_arch = get_model_arch(args.model_meta.model_arch)
if model_meta.is_multimodal and model_arch:
if model_meta.is_multimodal:
return get_multimodal_target_regex(
model_arch,
model,
freeze_llm=args.freeze_llm,
freeze_vit=args.freeze_vit,
freeze_aligner=args.freeze_aligner,
ignore_embedding='all-embedding' not in target_modules)
include_embedding='all-embedding' in target_modules)
else:
target_modules.remove('all-linear')
target_modules += find_all_linears(model)
Expand Down
11 changes: 0 additions & 11 deletions swift/tuners/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional


def _create_and_replace_hook(self, peft_config, adapter_name, target, *args, **kwargs):
all_supported_names = ('linear', )
all_supported_types = (torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D, lora.Linear)
target_modules = getattr(peft_config, 'target_modules', None)
if target is None:
return

if isinstance(target_modules, str) and not any(
[name in target.__class__.__name__.lower()
for name in all_supported_names]) and not any([isinstance(target, type_) for type_ in all_supported_types]):
return

if target.__class__.__name__ == 'NonDynamicallyQuantizableLinear':
return

Expand Down
70 changes: 31 additions & 39 deletions swift/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from .env import get_dist_setting, is_dist, is_dist_ta, is_master
from .logger import get_logger
from .utils import deep_getattr

logger = get_logger()

Expand Down Expand Up @@ -153,15 +154,23 @@ def _sync_max_memory(max_memory: Dict[Union[int, str], int]) -> Dict[Union[int,
return new_max_memory


def find_layers(model: nn.Module, cond: Callable[[str, nn.Module], bool]) -> List[str]:
def find_layers(model: nn.Module,
cond: Callable[[str, nn.Module], bool],
sub_module: Optional[str] = None) -> List[str]:
# The content of target_module_names cannot exist in inner_nodes.
sub_module_str = sub_module
if sub_module is None:
sub_module = model
else:
sub_module = deep_getattr(model, sub_module)
inner_nodes = set()
for name, module in model.named_modules():
name = re.sub(r'\d+\.', '{}.', name)
if not cond(name, module):
inner_nodes.add(name)
target_module_names = set()
for name, module in model.named_modules():
for name, module in sub_module.named_modules():
name = f'{sub_module_str}.{name}'
if cond(name, module):
module_name_list = name.split('.')
module_name = module_name_list.pop()
Expand All @@ -183,51 +192,34 @@ def find_embedding(model: nn.Module) -> List[str]:
return find_layers(model, lambda name, module: isinstance(module, torch.nn.Embedding))


def find_all_linears(model: nn.Module) -> List[str]:
from swift.llm import get_model_arch
model_info = model.model_info
model_arch = get_model_arch(model.model_meta.model_arch)
def find_all_linears(model, model_arch=None, extra_layers=None, sub_module=None):
if model_arch is None:
from swift.llm import get_model_arch
model_arch = get_model_arch(model.model_meta.model_arch)
# lm_head
if model_arch and model_arch.lm_head:
output = model_arch.lm_head
idx = output.rfind('.')
lm_head_name = output[idx + 1:]
else:
lm_head_name = 'lm_head'

quant_method = model_info.quant_method
quant_bits = model_info.quant_bits
if quant_method == 'bnb':
from bitsandbytes.nn import Linear4bit, Linear8bitLt
if quant_bits == 4:
linear_cls = [Linear4bit]
elif quant_bits == 8:
linear_cls = [Linear8bitLt]
elif quant_method == 'hqq':
from hqq.core.quantize import HQQLinear
linear_cls = [HQQLinear]
elif quant_method == 'eetq':
from eetq import EetqLinear
linear_cls = [EetqLinear]
elif quant_method == 'gptq':
from peft.utils import get_auto_gptq_quant_linear, get_quantization_config
gptq_quantization_config = get_quantization_config(model, 'gptq')
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)
linear_cls = [AutoGPTQQuantLinear]
elif quant_method == 'awq':
from awq.modules.linear import WQLinear_GEMM
linear_cls = [WQLinear_GEMM]
elif quant_method == 'aqlm':
from aqlm import QuantizedLinear
linear_cls = [QuantizedLinear]
else:
linear_cls = [nn.Linear]

# 'score', 'classifier': classification model
# 'v_head': reward model
ignore_layers = [lm_head_name, 'score', 'v_head', 'classifier']
return find_layers(
model, lambda name, module: isinstance(module, tuple(linear_cls)) and all(layer not in name
for layer in ignore_layers))
ignore_layers = [lm_head_name, 'score', 'v_head', 'classifier'] + ['lora_A', 'lora_B', 'base_layer']
ignore_linear_cls = [
'glulinear' # phi4-mm
]

def _cond(name, module):
module_name = module.__class__.__name__.lower()
if (extra_layers and isinstance(module, tuple(extra_layers)) or
('linear' in module_name and all(linear_cls not in module_name
for linear_cls in ignore_linear_cls))) and all(layer not in name
for layer in ignore_layers):
return True
return False

return find_layers(model, _cond, sub_module=sub_module)


@contextmanager
Expand Down