Skip to content

fix omni aligner #4117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 7, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/train/long_text/zero3.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Env: 4 * A100
# https://github.com/modelscope/ms-swift/blob/main/examples/train/megatron/long_text.sh
# Max Length: 16K
# GPU Memory: 4 * 56GB, Training Speed 10s/it
NPROC_PER_NODE=4 \
1 change: 1 addition & 0 deletions swift/llm/model/model_arch.py
Original file line number Diff line number Diff line change
@@ -471,6 +471,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
MLLMModelArch.qwen2_5_omni,
language_model='thinker.model',
vision_tower=['thinker.audio_tower', 'thinker.visual'],
aligner=['thinker.audio_tower.proj', 'thinker.visual.merger'],
generator=['talker', 'token2wav'],
))

29 changes: 16 additions & 13 deletions swift/llm/train/tuner.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from packaging import version
from transformers import TrainingArguments

from swift.llm import TrainArguments, get_model_arch
from swift.llm import TrainArguments, deep_getattr, get_model_arch
from swift.plugin import Tuner, extra_tuners
from swift.tuners import Swift
from swift.utils import (activate_parameters, find_all_linears, find_embedding, find_norm, freeze_parameters,
@@ -59,30 +59,33 @@ def get_multimodal_target_regex(
) -> str:
model_arch = get_model_arch(model.model_meta.model_arch)
modules = []
rejected_modules = []
if not freeze_llm:
modules += model_arch.language_model
if not freeze_vit:
modules += model_arch.vision_tower
if not freeze_aligner:
modules += model_arch.aligner
elif not freeze_vit:
rejected_modules += model_arch.aligner

assert len(modules) > 0, f'modules: {modules}'
prefix_pattern = '|'.join(modules)
rejected_pattern = '|'.join(rejected_modules)

extra_layers = []
if include_embedding:
extra_layers.append(nn.Embedding)
target_modules = []
res = []
for module in modules:
target_modules += find_all_linears(model, model_arch, extra_layers, sub_module=module)
target_regex = rf'^({prefix_pattern}).*\.({"|".join(target_modules)})$'
if rejected_pattern:
target_regex = rf'(?!^({rejected_pattern}))' + target_regex
return target_regex
rejected_modules = []
if not freeze_vit:
for aligner in model_arch.aligner:
if aligner.startswith(f'{module}.'):
rejected_modules.append(aligner)

sub_module = deep_getattr(model, module)
target_modules = find_all_linears(sub_module, model_arch, extra_layers)
target_modules = [tm for tm in target_modules if tm]
target_pattern = rf'.*\.({"|".join(target_modules)})' if target_modules else ''
rejected_pattern = rf'(?!({"|".join(rejected_modules)}))' if rejected_modules else ''
res.append(rf'{rejected_pattern}{module}{target_pattern}')

return rf'^({"|".join(res)})$'


def get_target_modules(args, model) -> Union[str, List[str]]:
2 changes: 1 addition & 1 deletion swift/utils/torch_utils.py
Original file line number Diff line number Diff line change
@@ -174,7 +174,7 @@ def find_layers(
target_module_names = set()
for name, module in sub_module.named_modules():
if sub_module_str:
name = f'{sub_module_str}.{name}'
name = f'{sub_module_str}.{name}' if name else sub_module_str
if cond(name, module):
module_name_list = name.split('.')
module_name = module_name_list.pop()
Loading