Skip to content

请求添加sft 参数 fp16: false 透传到transformers 中的TrainingArguments,避免得到NAN #3896

Closed
@njzheng

Description

@njzheng

Describe the feature
目前在cli中添加fp16 为false无法透传的transformers的TrainingArguments,它的fp16会被指定为True。
如果transformers的TrainingArguments中fp16为True,在用V100显卡的时候,在计算torch.matmul(A,B)是会把torch.float32, 准换位torch.float16,导致精度损失,以及可能的NAN错误,体现在nn.linear不准。

CUDA_VISIBLE_DEVICES=6 \
swift sft \
    --model $model_id_or_path \
    --dataset $dataset_path \
    --val_dataset $val_dataset_path \
    --tuner_backend peft \
    --train_type full \
    --torch_dtype float32 \
      ...
    --fp16 false 

--tf32似乎是可以透传的,但是fp16不行,请求更新

python3.11/site-packages/transformers/training_args.py

        # if training args is specified, it will override the one specified in the accelerate config
        if self.half_precision_backend != "apex":
            mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
            if self.fp16:
                mixed_precision_dtype = "fp16"
            elif self.bf16:
                mixed_precision_dtype = "bf16"
            os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions