Skip to content

请求添加sft 参数 fp16: false 透传到transformers 中的TrainingArguments,避免得到NAN #3896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
njzheng opened this issue Apr 16, 2025 · 3 comments
Labels
bug Something isn't working

Comments

@njzheng
Copy link

njzheng commented Apr 16, 2025

Describe the feature
目前在cli中添加fp16 为false无法透传的transformers的TrainingArguments,它的fp16会被指定为True。
如果transformers的TrainingArguments中fp16为True,在用V100显卡的时候,在计算torch.matmul(A,B)是会把torch.float32, 准换位torch.float16,导致精度损失,以及可能的NAN错误,体现在nn.linear不准。

CUDA_VISIBLE_DEVICES=6 \
swift sft \
    --model $model_id_or_path \
    --dataset $dataset_path \
    --val_dataset $val_dataset_path \
    --tuner_backend peft \
    --train_type full \
    --torch_dtype float32 \
      ...
    --fp16 false 

--tf32似乎是可以透传的,但是fp16不行,请求更新

python3.11/site-packages/transformers/training_args.py

        # if training args is specified, it will override the one specified in the accelerate config
        if self.half_precision_backend != "apex":
            mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
            if self.fp16:
                mixed_precision_dtype = "fp16"
            elif self.bf16:
                mixed_precision_dtype = "bf16"
            os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype
@Jintao-Huang Jintao-Huang added the bug Something isn't working label Apr 16, 2025
@njzheng
Copy link
Author

njzheng commented Apr 16, 2025

可如下修改swift/llm/argument/base_args/model_args.py:

def _init_torch_dtype(self) -> None:
        """"If torch_dtype is None, find a proper dtype by the train_type/GPU"""
        from swift.llm import TrainArguments
        if self.torch_dtype is None and isinstance(self, TrainArguments):
            # Compatible with --fp16/--bf16
            for key in ['fp16', 'bf16']:
                value = getattr(self, key)
                if value:
                    self.torch_dtype = {'fp16': 'float16', 'bf16': 'bfloat16'}[key]

        self.torch_dtype: Optional[torch.dtype] = HfConfigFactory.to_torch_dtype(self.torch_dtype)
        self.torch_dtype: torch.dtype = self._init_model_info()
        # Mixed Precision Training
        if isinstance(self, TrainArguments) and not is_torch_mps_available():
            if self.torch_dtype in {torch.float16, torch.float32}:
                if hasattr(self, 'fp16') and self.fp16 == False: ## 增加一句
                    self.fp16, self.bf16 = False, False
                else:
                    self.fp16, self.bf16 = True, False
            elif self.torch_dtype == torch.bfloat16:
                self.fp16, self.bf16 = False, True
            else:
                raise ValueError(f'args.torch_dtype: {self.torch_dtype}')

@Jintao-Huang
Copy link
Collaborator

main分支修复了

@wellhowtosay
Copy link

大佬请教一下,v100不支持bf16,训练的时候,无论 full 还是lora,都设置为fp32是最好的。

@njzheng njzheng closed this as completed Apr 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants