Skip to content

Commit 6c86a0c

Browse files
committed
Remove Constraint for sm89 hardware
stack-info: PR: #2281, branch: drisspg/stack/61
1 parent c4250a4 commit 6c86a0c

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

torchao/float8/inference.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from torchao.utils import (
2020
is_MI300,
2121
is_sm_at_least_89,
22-
is_sm_at_least_90,
2322
)
2423

2524
Tensor = torch.Tensor
@@ -168,13 +167,11 @@ def _check_hardware_support(
168167
ValueError: If invalid granularity type is provided
169168
"""
170169
for _granularity in granularities:
171-
if isinstance(_granularity, PerTensor):
172-
assert is_sm_at_least_89() or is_MI300(), (
173-
"PerTensor quantization only works for CUDA>=8.9 and MI300+"
174-
)
175-
elif isinstance(_granularity, PerRow):
176-
assert is_sm_at_least_90() or is_MI300(), (
177-
"PerRow quantization only works for CUDA>=9.0 and MI300+"
170+
if not isinstance(_granularity, (PerTensor, PerRow)):
171+
raise ValueError(
172+
f"Invalid granularity type: {_granularity}, only PerTensor or PerRow are supported."
178173
)
179-
else:
180-
raise ValueError(f"Invalid granularity type: {_granularity}")
174+
175+
assert is_sm_at_least_89() or is_MI300(), (
176+
"Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+."
177+
)

0 commit comments

Comments
 (0)