Skip to content

Commit 26dba5b

Browse files
committed
Remove Constraint for sm89 hardware
stack-info: PR: #2281, branch: drisspg/stack/61
1 parent c4250a4 commit 26dba5b

File tree

3 files changed

+11
-19
lines changed

3 files changed

+11
-19
lines changed

.github/workflows/float8_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,4 @@ jobs:
5555
pip install .
5656
pytest test/float8 --verbose -s
5757
pytest test/integration --verbose -s
58+
pytest test/dtypes/test_affine_quantized_float.py --verbose -s

test/dtypes/test_affine_quantized_float.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
7676
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
7777
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
7878
@common_utils.parametrize("compile", [True, False])
79-
@common_utils.parametrize(
80-
"granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()]
81-
)
79+
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
8280
# Inputs are (M,..), K, N
8381
@common_utils.parametrize(
8482
"sizes",
@@ -420,9 +418,7 @@ def test_dequantize_affine_float8_scale_broadcasting(self):
420418
@unittest.skipIf(
421419
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
422420
)
423-
@common_utils.parametrize(
424-
"granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()]
425-
)
421+
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
426422
def test_float8_tensor_slicing_basic(self, granularity):
427423
"""Test basic slicing operations on Float8 tensors"""
428424
device = "cuda"
@@ -555,9 +551,7 @@ def test_float8_tensor_slicing_edge_cases(self):
555551
@unittest.skipIf(
556552
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
557553
)
558-
@common_utils.parametrize(
559-
"granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()]
560-
)
554+
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
561555
def test_float8_tensor_slicing_functional_correctness(self, granularity):
562556
"""Test that sliced tensors produce correct results in computations"""
563557
device = "cuda"

torchao/float8/inference.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from torchao.utils import (
2020
is_MI300,
2121
is_sm_at_least_89,
22-
is_sm_at_least_90,
2322
)
2423

2524
Tensor = torch.Tensor
@@ -168,13 +167,11 @@ def _check_hardware_support(
168167
ValueError: If invalid granularity type is provided
169168
"""
170169
for _granularity in granularities:
171-
if isinstance(_granularity, PerTensor):
172-
assert is_sm_at_least_89() or is_MI300(), (
173-
"PerTensor quantization only works for CUDA>=8.9 and MI300+"
174-
)
175-
elif isinstance(_granularity, PerRow):
176-
assert is_sm_at_least_90() or is_MI300(), (
177-
"PerRow quantization only works for CUDA>=9.0 and MI300+"
170+
if not isinstance(_granularity, (PerTensor, PerRow)):
171+
raise ValueError(
172+
f"Invalid granularity type: {_granularity}, only PerTensor or PerRow are supported."
178173
)
179-
else:
180-
raise ValueError(f"Invalid granularity type: {_granularity}")
174+
175+
assert is_sm_at_least_89() or is_MI300(), (
176+
"Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+."
177+
)

0 commit comments

Comments
 (0)