Skip to content

fix loss_scale bug when meeting <image>,<audio>,<video> #4922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

CrownStar7
Copy link

@CrownStar7 CrownStar7 commented Jul 11, 2025

PR type

  • [✓] Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

在自定义损失权值计算时,遇到了 loss_scale 维度不匹配的问题。查询 PR 记录后,发现该问题此前已被提出,但一直未得到解决。因此,提出了此 PR 以完善 Swift 框架。

此 PR 基于 #3036,结合最新的 3.7.0.dev0 版本代码进行了以下工作:

  • 复现问题
  • 增添代码
  • 验证相关模块的结果

通过这些工作,希望能为 Swift 框架提供更稳定的自定义损失权值计算支持。

BUG内容:

对于Input: 多媒体 + text,Output: text类型的多模态模型,当训练数据user的content包含多媒体,例如 image 时,如果同时启用

  • -- loss_scale <自定义LossScale类>,
  • -- loss_type loss_scale

会出现:

只对input_ids, lables的token进行扩展,之后遇见

  1. 情况1: 没有对loss_scale相对应地进行扩展,导致调用loss_scale_func(output, lables)时, lables和loss_scale维度不匹配,无法进行自定义权值计算loss.
  2. 情况2: 既缺少对loss_scale扩展,又缺少对loss_scale的传递,导致loss_scale丢失,调用loss_scale_func(output, lables)时, loss_scale为None,虽然能正常运行,但无法进行自定义权值计算loss.

自定义LossScale类代码,如下所示:

import re  
from swift.plugin.loss_scale.loss_scale import LossScale, loss_scale_map
from swift.llm.template.utils import ContextType  
  
class ColspanRowspanLossScale(LossScale):    
    def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool, **kwargs):    
        if context_type == ContextType.RESPONSE:   
            pattern = r'(colspan="[^"]*"|rowspan="[^"]*")'  
            matches = list(re.finditer(pattern, context))  
              
            if not matches:  
                return [context], [1.0]  
    
            parts = []  
            weights = []  
            last_end = 0

            for match in matches:  
                start, end = match.span()    
                if start > last_end:  
                    parts.append(context[last_end:start])  
                    weights.append(1.0)  
 
                parts.append(context[start:end])  
                weights.append(3.0)  
                  
                last_end = end  
            if last_end < len(context):  
                parts.append(context[last_end:])  
                weights.append(1.0)
            return parts, weights  
            
        return super().get_loss_scale(context, context_type, is_last_round)  
  
loss_scale_map['colspan_rowspan_loss_scale'] = ColspanRowspanLossScale

缺少同时扩展loss_scale维度的函数:

#  swift/llm/template/base.py  
    def _extend_tokens(input_ids: List[int], labels: Optional[List[int]], replace_idx_list: List[int],
                       get_new_tokens: Callable[[int], List[int]]) -> Tuple[List[int], Optional[List[int]]]:
        added_tokens_len = 0
        for i, idx in enumerate(replace_idx_list):
            new_tokens = get_new_tokens(i)
            token_len = len(new_tokens)
            input_ids = input_ids[:idx + added_tokens_len] + new_tokens + input_ids[added_tokens_len + idx + 1:]
            if labels:
                labels = labels[:idx + added_tokens_len] + [-100] * token_len + labels[added_tokens_len + idx + 1:]
            
            # 此处没有相应扩展loss_scale维度
            added_tokens_len += token_len - 1
        return input_ids, labels

缺少loss_scale传递的函数(部分代码):

class mPlugOwl3Template(Template):
    version = None

    def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
        encoded = super()._encode(inputs)
        images = inputs.images
        videos = inputs.videos
        cut_enable = not videos
        input_ids = encoded['input_ids']
        labels = encoded['labels']
        .....
        # 丢失loss_scale
        encoded = {}
        if images:
            .....

            input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
            image_token_idx = torch.tensor(findall(input_ids, image_token_list))
            if self.version == '241101':
                media_offset = image_token_idx
            else:
                _range = torch.arange(len(input_ids))[:, None]
                matrix = (_range > image_token_idx[None]).sum(dim=1)
                media_offset = torch.stack([torch.zeros(matrix.shape[0], dtype=torch.long), matrix], dim=-1)[None]
            encoded.update({
                'pixel_values': image_inputs['pixel_values'],
                'media_offset': media_offset,
            })

         # 丢失loss_scale
        encoded['input_ids'] = input_ids
        encoded['labels'] = labels
        return encoded

自定义loss无法计算的函数:

@register_loss_func(LossType.loss_scale)
def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
    """Loss func

    Args:
        outputs: The model outputs
        labels: The labels
        loss_scale: The loss scale
        num_items_in_batch: Number of tokens in the labels of gradient accumulation round that are not -100.

    Returns:

    """
    loss, masks = ce_loss_func(outputs, labels)
    if loss_scale is not None:
        shift_scale = loss_scale[..., 1:].to(masks.device)
        
        # 此处出现错误
        shift_scale = shift_scale[masks]
        loss = (shift_scale * loss)
    if num_items_in_batch is None:
        loss = loss.mean()
    else:
        # compat transformers>=4.46
        loss = loss.sum() / num_items_in_batch
    return loss

解决方案

增加def _extend_loss_scale():函数:

#  swift/llm/template/base.py  
    @staticmethod
    def _extend_loss_scale(loss_scale: Optional[List[float]], replace_idx_list: List[int],
                       get_new_tokens: Callable[[int], List[int]]) -> Optional[List[float]]:
        if loss_scale:
            added_tokens_len = 0
            for i, idx in enumerate(replace_idx_list):
                new_tokens = get_new_tokens(i)
                token_len = len(new_tokens)
                
                scale_idx = loss_scale[idx + added_tokens_len]
                loss_scale = loss_scale[:idx + added_tokens_len] + [scale_idx] * token_len + loss_scale[added_tokens_len + idx + 1:]

                added_tokens_len += token_len - 1

        return loss_scale

Experiment results

增添代码后,涉及的所有模块

通过对整个项目进行关键字查询,发现共涉及到15种多模态模板类的_encode()函数,如下图所示:
image-1

分别是:

  • class Qwen2_5OmniTemplat(Qwen2_5VLTemplate)
  • class Qwen2VLTemplate(Template)
  • class PixtralTemplate(Template)
  • class mPlugOwl3Template(Template)
  • class KimiVLTemplate(Template)
  • class Mistral2503Template(Template)
  • class MiniCPMV2_6Template(MiniCPMVTemplate)
  • class Phi4MMTemplate(Template)
  • class MegrezOmniTemplate(Template)
  • class Llama4Template(Template)
  • class Internvl2Template(InternvlTemplate)
  • class Gemma3VisionTemplate(Gemma3Template)
  • class Gemma3nTemplate(Gemma3Template)
  • class Emu3ChatTemplate(Template)
  • class KeyeVLTemplate(Template)

问题复现及验证修改结果

实验设置

  1. 为方便验证修改结果,均下载每种模板对应最小参数量的模型.
  2. 使用CLI方式微调模型
  3. 通过--external_plugins 引入自定义LossScale类,
  4. 仅验证image, 其他如videos, audio未验证,认为扩展原理相同,如果理解有误,敬请指教.

实验结果概括

① 以下模板受制于设备,未进行验证,但认为扩展逻辑一样,如果理解有误,敬请指教.

  • Emu3ChatTemplate: Emu3-Chat
  • KimiVLTemplate: Kimi-VL-A3B-Thinking
  • Llama4Template: Llama-Guard-4-12B
  • Phi4MMTemplate: microsoft-Phi-4-multimodal-instruct

②以下模板既未扩展loss_scale, 也未传递loss_scale,因此无法进行自定义权值计算loss

  • MiniCPMV2_6Template
  • mPlugOwl3Template

③ 除此之外,其他模板均未扩展loss_scale,但可以传递loss_scale,因此无法进行自定义权值计算loss.

具体结果如下:

1. Qwen2VLTemplate

模型:
allenai/olmOCR-7B-0225-preview

缺乏扩展

问题复现

image-2 image-3 image-4 image-5 image-6 image-7 image-8

验证修改结果

image-9

2. Qwen2_5OmniTemplat

模型:
Qwen2.5-Omni-3B
缺乏扩展

问题复现

image-10 image-11

验证修改结果

image-12

3. PixtralTemplate

模型:
Pixtral-12B-2409-bnb-4bit
loss_scale维度不匹配

问题复现

image-41 image-42 image-44 image-45

验证修改结果

image-46

4. mPlugOwl3Template

模型:

mPLUG-Owl3-2B-241014
loss_scale丢失

问题复现

image-29 image-30

验证修改结果

image-28

5. KimiVLTemplate

模型:

在modelscope上, mlx-community/Kimi-VL-A3B-Thinking-4bit的model.safetensors.index.json文件内容,是未量化版本模型的配置,不是量化版本的配置.

官方版本模型受制于实验设备,未进行验证,但扩展逻辑一样

问题复现

pass

验证修改结果

pass

6. Mistral2503Template

模型:
Mistral-Small-3.1-24B-Base-2503-bnb-4bit

loss_scale维度不匹配

问题复现

image-31 image-32 image-33 image-34 image-35

验证修改结果

image-36

7. MiniCPMV2_6Template

模型:

MiniCPM-V_2_6
loss_scale丢失

问题复现

image-22 image-23 image-24 image-25

验证修改结果

image-26 image-27

8. Phi4MMTemplate

模型:

问题复现

pass

验证修改结果

pass

9. MegrezOmniTemplate

AI-ModelScope/Megrez-3B-Omni

模型:

问题复现

image-18 image-19 image-20

验证修改结果

image-21

10. Llama4Template

模型:
ResearchLlama-4-Scout-Dense-12B

问题复现

pass

验证修改结果

pass

11. Internvl2Template

模型:
InternVL2_5-1B

缺乏扩展

问题复现

image-17 image-14 image-15 image-16

验证修改结果

image-13

12. Gemma3VisionTemplate

模型:

gemma-3-4b-it

问题复现

image-47 image-48 image-49 image-50

验证修改结果

image-51

13. Gemma3nTemplate

模型:

gemma-3n-E4B

问题复现

image-52 image-53 image-54 image-55 image-56

验证修改结果

image-57

14. KeyeVLTemplate

模型:

Keye-VL-8B-Preview

问题复现

image-60 image-61 image-62

验证修改结果

image-58 image-59

15. Emu3ChatTemplate

模型:

问题复现

pass

验证修改结果

pass

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant