Skip to content

Commit 506f39a

Browse files
authored
enable 1 case on XPU (huggingface#11219)
enable case on XPU: 1. tests/quantization/bnb/test_mixed_int8.py::BnB8bitTrainingTests::test_training Signed-off-by: YAO Matrix <[email protected]>
1 parent 8ad68c1 commit 506f39a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def test_training(self):
379379
model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
380380

381381
# Step 4: Check if the gradient is not None
382-
with torch.amp.autocast("cuda", dtype=torch.float16):
382+
with torch.amp.autocast(torch_device, dtype=torch.float16):
383383
out = self.model_8bit(**model_inputs)[0]
384384
out.norm().backward()
385385

0 commit comments

Comments
 (0)