Skip to content

Fix qwen2.5-omni use_audio_in_video #3987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/framework.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ sentencepiece
tensorboard
tiktoken
tqdm
transformers>=4.33,<4.52
transformers>=4.33,<4.53
transformers_stream_generator
trl>=0.13,<0.17
uvicorn
Expand Down
1 change: 1 addition & 0 deletions swift/llm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def update_attn_impl(config: PretrainedConfig,
attn_impl_keys: Optional[List[str]] = None) -> None:
if attn_impl is None:
return
logger.info(f'attn_impl: {attn_impl}')
use_flash_attn = AttnImpl.to_use_flash_attn(attn_impl)
if use_flash_attn:
attn_impl = 'flash_attention_2'
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def _pre_tokenize_images(self, context_list: List[Context], loss_scale_list: Lis
if context == '<image>' and inputs.is_multimodal and inputs.image_idx < len(inputs.images):
c_list = self.replace_tag('image', inputs.image_idx, inputs)
inputs.image_idx += 1
loss_scale = 0.
loss_scale = 0. if self.template_backend == 'swift' else 1.
else:
c_list = [context]
res += c_list
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/template/template/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
input_ids = encoded['input_ids']
labels = encoded['labels']
idx_list = findall(input_ids, self.boi_token_id)
img_tokens = self.tokenizer.encode(self.processor.full_image_sequence)
img_tokens = self._tokenize(self.processor.full_image_sequence)
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda _: img_tokens)

# TODO: customize
Expand Down
113 changes: 84 additions & 29 deletions swift/llm/template/template/qwen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, List, Literal, Optional, Tuple
Expand Down Expand Up @@ -384,32 +383,50 @@ class Qwen2_5OmniTemplate(Qwen2_5VLTemplate):
version = 'omni'
placeholder_tokens = ['<|IMAGE|>', '<|AUDIO|>', '<|VIDEO|>']

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import Qwen2_5OmniProcessorKwargs
default = Qwen2_5OmniProcessorKwargs._defaults
self.seconds_per_chunk = default['videos_kwargs']['seconds_per_chunk']
self.position_id_per_seconds = default['videos_kwargs']['position_id_per_seconds']
self.use_audio_in_video = get_env_args('use_audio_in_video', bool, False)
self.sampling_rate = get_env_args('sampling_rate', int, self.processor.feature_extractor.sampling_rate)

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
inputs: StdTemplateInputs) -> List[Context]:
from qwen_omni_utils import fetch_image, fetch_video
sampling_rate = self.processor.feature_extractor.sampling_rate
if media_type == 'image':
inputs.images[index] = fetch_image({'image': inputs.images[index]})
return ['<|vision_bos|><|IMAGE|><|vision_eos|>']
elif media_type == 'audio':
sampling_rate = get_env_args('sampling_rate', int, sampling_rate)
inputs.audios[index] = load_audio(inputs.audios[index], sampling_rate)
inputs.audios[index] = load_audio(inputs.audios[index], self.sampling_rate)
return ['<|audio_bos|><|AUDIO|><|audio_eos|>']
elif media_type == 'video':
video = inputs.videos[index]
inputs.videos[index] = fetch_video({'video': video}).to(torch.uint8)
use_audio_in_video = get_env_args('use_audio_in_video', bool, False)
if use_audio_in_video:
if self.use_audio_in_video:
import librosa
sampling_rate = get_env_args('sampling_rate', int, sampling_rate)
video = librosa.load(video, sr=sampling_rate)[0]
inputs.audios.insert(inputs.audio_idx, video)
if video.startswith('http://') or video.startswith('https://'):
import audioread
video = audioread.ffdec.FFmpegAudioFile(video)
video = librosa.load(video, sr=self.sampling_rate)[0]
inputs.audios.insert(inputs.audio_idx, (video, 'video'))
inputs.audio_idx += 1
return ['<|vision_bos|><|audio_bos|><|VIDEO|><|audio_eos|><|vision_eos|>']
return ['<|vision_bos|><|VIDEO|><|vision_eos|>']

def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
encoded = Template._encode(self, inputs)
media_inputs = self.processor(
processor = self.processor
video_audios_mask = []
for i, audio in enumerate(inputs.audios):
if isinstance(audio, tuple) and audio[1] == 'video':
inputs.audios[i] = audio[0]
video_audios_mask.append(True)
else:
video_audios_mask.append(False)
video_audios_mask = torch.tensor(video_audios_mask)
media_inputs = processor(
text='',
audio=inputs.audios or None,
images=inputs.images or None,
Expand All @@ -420,31 +437,70 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
media_inputs = to_float_dtype(media_inputs, self.model_info.torch_dtype)
input_ids = encoded['input_ids']
labels = encoded['labels']
# audio
audio_token_id = self._tokenize('<|AUDIO|>')
idx_list = findall(input_ids, audio_token_id)
feature_attention_mask = media_inputs.get('feature_attention_mask')
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
audio_lengths = (((audio_feature_lengths - 1) // 2 + 1 - 2) // 2 + 1)
else:
audio_lengths = None
audio_lengths_origin = audio_lengths
if idx_list:
if self.use_audio_in_video:
audio_lengths = audio_lengths[~video_audios_mask]

def _get_new_audio_tokens(i):
return audio_token_id * audio_lengths[i]

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_audio_tokens)

for media_type in ['image', 'video']:
token = f'<|{media_type.upper()}|>'
token_id = self._tokenize(token)
idx_list = findall(input_ids, token_id)
if idx_list:
merge_length = self.processor.image_processor.merge_size**2
merge_size = processor.image_processor.merge_size
media_grid_thw = media_inputs.get(f'{media_type}_grid_thw')
if media_type == 'video' and self.use_audio_in_video:
audio_lengths = audio_lengths_origin[video_audios_mask]
video_second_per_grid = media_inputs['video_second_per_grid']

def _get_new_tokens_use_audio_in_video(i):
audio_token_indices = torch.arange(audio_lengths[i])
grid_thw = media_grid_thw[i]
height = grid_thw[1] // merge_size
width = grid_thw[2] // merge_size
video_token_indices = torch.arange(grid_thw[0]).reshape(-1, 1, 1)
video_token_indices = torch.broadcast_to(
video_token_indices, (video_token_indices.shape[0], height, width)).reshape(-1)
video_token_indices = (
video_token_indices * video_second_per_grid[i] * self.position_id_per_seconds)
tokens_per_chunk = int(self.position_id_per_seconds * self.seconds_per_chunk)
video_chunk_indexes = processor.get_chunked_index(video_token_indices, tokens_per_chunk)
audio_chunk_indexes = processor.get_chunked_index(audio_token_indices, tokens_per_chunk)

res = []
for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))):
if j < len(video_chunk_indexes):
video_seq_length = video_chunk_indexes[j][1] - video_chunk_indexes[j][0]
res += token_id * video_seq_length
if j < len(audio_chunk_indexes):
audio_seq_length = audio_chunk_indexes[j][1] - audio_chunk_indexes[j][0]
res += audio_token_id * audio_seq_length
return res

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list,
_get_new_tokens_use_audio_in_video)

def _get_new_tokens(i):
token_len = (media_grid_thw[i].prod() // merge_length)
return token_id * token_len

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
# audio
feature_attention_mask = media_inputs.get('feature_attention_mask')
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1).tolist()
token_id = self._tokenize('<|AUDIO|>')
idx_list = findall(input_ids, token_id)
else:

def _get_new_tokens(i):
place_num = ((audio_feature_lengths[i] - 1) // 2 + 1 - 2) // 2 + 1
return token_id * place_num
def _get_new_tokens(i):
token_len = (media_grid_thw[i].prod() // (merge_size**2))
return token_id * token_len

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)

encoded['input_ids'] = input_ids
encoded['labels'] = labels
Expand All @@ -460,7 +516,6 @@ def _get_position_ids(self, inputs: Dict[str, Any]):
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
else:
audio_feature_lengths = None
use_audio_in_video = get_env_args('use_audio_in_video', bool, False)
video_second_per_grid = inputs.pop('video_second_per_grid', None)
input_ids = inputs['input_ids']
attention_mask = inputs.get('attention_mask')
Expand All @@ -471,7 +526,7 @@ def _get_position_ids(self, inputs: Dict[str, Any]):
inputs.get('image_grid_thw'),
inputs.get('video_grid_thw'),
attention_mask,
use_audio_in_video,
self.use_audio_in_video,
audio_feature_lengths,
video_second_per_grid,
)
Expand All @@ -493,7 +548,7 @@ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:

def generate(self, model, *args, **kwargs):
if kwargs.get('video_grid_thw') is not None:
kwargs['use_audio_in_video'] = get_env_args('use_audio_in_video', bool, False)
kwargs['use_audio_in_video'] = self.use_audio_in_video
return super().generate(model, *args, **kwargs)


Expand Down
2 changes: 2 additions & 0 deletions tests/test_align/test_template/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def test_step_audio_chat():


def test_qwen2_5_omni():
USE_AUDIO_IN_VIDEO = True
os.environ['USE_AUDIO_IN_VIDEO'] = str(USE_AUDIO_IN_VIDEO)
pt_engine = PtEngine('Qwen/Qwen2.5-Omni-7B')
response = _infer_model(pt_engine)
pt_engine.default_template.template_backend = 'jinja'
Expand Down
17 changes: 6 additions & 11 deletions tests/test_align/test_template/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,9 @@ def test_qwen2_5_vl():


def test_qwen2_5_omni():
os.environ['VIDEO_MAX_PIXELS'] = str(28 * 28 * 64)
USE_AUDIO_IN_VIDEO = False
USE_AUDIO_IN_VIDEO = True
os.environ['USE_AUDIO_IN_VIDEO'] = str(USE_AUDIO_IN_VIDEO)
pt_engine = PtEngine('Qwen/Qwen2.5-Omni-7B')
pt_engine = PtEngine('Qwen/Qwen2.5-Omni-7B', attn_impl='flash_attn')
system = ('You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, '
'capable of perceiving auditory and visual inputs, as well as generating text and speech.')
messages = [{'role': 'system', 'content': system}, {'role': 'user', 'content': '<video>'}]
Expand All @@ -143,15 +142,11 @@ def test_qwen2_5_omni():
pt_engine.default_template.template_backend = 'jinja'
response2 = _infer_model(pt_engine, messages=messages, videos=videos)
if USE_AUDIO_IN_VIDEO:

ground_truth = ('Oh, that sounds like a really cool project! Are you using a specific app on the tablet for '
"drawing? And what kind of details are you adding to the guitar? It'd be interesting to hear "
'more about your creative process.')
ground_truth = ("Oh, that's a really cool drawing! It looks like a guitar. You've got the body "
'and the neck drawn in a simple yet effective way. The lines are clean and the '
'shape is well-defined. What made you choose to draw a guitar?')
else:
ground_truth = (
"Oh, that sounds like a really cool project! So, you're using a tablet to draw a guitar and a key? "
"That's a creative way to combine two different things. Have you thought about what you'll do "
'with the final drawing? Maybe could use it for a poster or something? Let me know how it turns out!')
ground_truth = ('嗯,你是在用平板画画呢。你画的这把吉他,看起来很简洁明了。你用的笔触也很流畅,线条很清晰。你对颜色的运用也很不错,整体看起来很协调。你要是还有啥想法或者问题,随时跟我说哈。')
assert response == response2 == ground_truth


Expand Down
11 changes: 6 additions & 5 deletions tests/test_align/test_template/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,11 @@ def test_phi4_vision():

def test_gemma3_vision():
pt_engine = PtEngine('LLM-Research/gemma-3-4b-it')
response = _infer_model(pt_engine)
response = _infer_model(pt_engine, messages=[{'role': 'user', 'content': '<image>Describe this image in detail.'}])
pt_engine.default_template.template_backend = 'jinja'
response2 = _infer_model(pt_engine)
assert response == response2
response2 = _infer_model(pt_engine, messages=[{'role': 'user', 'content': '<image>Describe this image in detail.'}])
assert response[:80] == response2[:80] == (
"Here's a detailed description of the image:\n\n**Overall Impression:**\n\nThe image ")


def test_mistral_2503():
Expand Down Expand Up @@ -596,9 +597,9 @@ def test_kimi_vl():
# test_minicpmo()
# test_valley()
# test_ui_tars()
# test_gemma3_vision()
test_gemma3_vision()
# test_mistral_2503()
# test_llama4()
# test_internvl3_8b()
# test_internvl3_9b()
test_kimi_vl()
# test_kimi_vl()
Loading