You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# if training args is specified, it will override the one specified in the accelerate configifself.half_precision_backend!="apex":
mixed_precision_dtype=os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
ifself.fp16:
mixed_precision_dtype="fp16"elifself.bf16:
mixed_precision_dtype="bf16"os.environ["ACCELERATE_MIXED_PRECISION"] =mixed_precision_dtype
The text was updated successfully, but these errors were encountered:
def _init_torch_dtype(self) -> None:
""""If torch_dtype is None, find a proper dtype by the train_type/GPU"""
from swift.llm import TrainArguments
if self.torch_dtype is None and isinstance(self, TrainArguments):
# Compatible with --fp16/--bf16
for key in ['fp16', 'bf16']:
value = getattr(self, key)
if value:
self.torch_dtype = {'fp16': 'float16', 'bf16': 'bfloat16'}[key]
self.torch_dtype: Optional[torch.dtype] = HfConfigFactory.to_torch_dtype(self.torch_dtype)
self.torch_dtype: torch.dtype = self._init_model_info()
# Mixed Precision Training
if isinstance(self, TrainArguments) and not is_torch_mps_available():
if self.torch_dtype in {torch.float16, torch.float32}:
if hasattr(self, 'fp16') and self.fp16 == False: ## 增加一句
self.fp16, self.bf16 = False, False
else:
self.fp16, self.bf16 = True, False
elif self.torch_dtype == torch.bfloat16:
self.fp16, self.bf16 = False, True
else:
raise ValueError(f'args.torch_dtype: {self.torch_dtype}')
Describe the feature
目前在cli中添加fp16 为false无法透传的transformers的TrainingArguments,它的fp16会被指定为True。
如果transformers的TrainingArguments中fp16为True,在用V100显卡的时候,在计算torch.matmul(A,B)是会把torch.float32, 准换位torch.float16,导致精度损失,以及可能的NAN错误,体现在nn.linear不准。
--tf32似乎是可以透传的,但是fp16不行,请求更新
python3.11/site-packages/transformers/training_args.py
The text was updated successfully, but these errors were encountered: