Skip to content

Commit 33c4b27

Browse files
committed
Add support for fbgemm fp8 kernels
Summary: fp8 per row quantized weight with fp8 dynamic per row quantization only for now Test Plan: python test/dtypes/test_fbgemm_fp8.py in torchao/_models/llama folder: export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B-Instruct python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization fbgemm-fp8 --batch_size 1 Reviewers: Subscribers: Tasks: Tags:
1 parent d963a88 commit 33c4b27

File tree

8 files changed

+234
-48
lines changed

8 files changed

+234
-48
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
to_affine_quantized_intx,
2525
to_affine_quantized_intx_static,
2626
)
27+
from torchao.float8.config import e4m3_dtype
2728
from torchao.quantization import (
29+
FbgemmConfig,
2830
GemliteUIntXWeightOnlyConfig,
2931
Int4WeightOnlyConfig,
3032
Int8DynamicActivationInt8WeightConfig,
@@ -45,6 +47,7 @@
4547
is_fbcode,
4648
is_ROCM,
4749
is_sm_at_least_89,
50+
is_sm_at_least_90,
4851
)
4952

5053
is_cusparselt_available = (
@@ -99,6 +102,10 @@ def get_quantization_functions(
99102
if is_sm_at_least_89():
100103
base_functions.append(float8_weight_only())
101104

105+
if is_sm_at_least_90():
106+
base_functions.append(FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16))
107+
base_functions.append(FbgemmConfig(e4m3_dtype, e4m3_dtype, torch.bfloat16))
108+
102109
return base_functions
103110

104111

test/dtypes/test_fbgemm_quantized_tensor.py renamed to test/dtypes/test_fbgemm_fp8.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,28 @@
1212
run_tests,
1313
)
1414

15+
from torchao.float8.config import e4m3_dtype
1516
from torchao.quantization import (
1617
FbgemmConfig,
1718
quantize_,
1819
)
1920
from torchao.quantization.utils import compute_error
20-
from torchao.utils import (
21-
TORCH_VERSION_AT_LEAST_2_6,
22-
is_sm_at_least_90,
23-
)
21+
from torchao.utils import is_sm_at_least_90
2422

2523

26-
class TestFbgemmInt4Tensor(TestCase):
24+
class TestFbgemmFp8Tensor(TestCase):
2725
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
2826
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
29-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Need torch >= 2.6")
3027
def test_linear(self):
3128
dtype = torch.bfloat16
3229
device = "cuda"
3330
input = torch.randn(1, 128, dtype=dtype, device=device)
3431
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
3532
original = linear(input)
3633
config = FbgemmConfig(
37-
input_dtype=torch.bfloat16,
38-
weight_dtype=torch.int4,
34+
input_dtype=e4m3_dtype,
35+
weight_dtype=e4m3_dtype,
3936
output_dtype=torch.bfloat16,
40-
block_size=[1, 128],
4137
)
4238
quantize_(linear, config)
4339
quantized = linear(input)

test/dtypes/test_fbgemm_quantized.py renamed to test/dtypes/test_fbgemm_int4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_linear(self):
3333
input_dtype=torch.bfloat16,
3434
weight_dtype=torch.int4,
3535
output_dtype=torch.bfloat16,
36-
block_size=(1, 128),
36+
block_size=[1, 128],
3737
)
3838
quantize_(linear, config)
3939
quantized = linear(input)

torchao/_models/llama/generate.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -442,23 +442,25 @@ def ffn_or_attn_only(mod, fqn):
442442
f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
443443
)
444444
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq))
445-
elif "fbgemm" in quantization:
445+
elif "fbgemm" in quantization and "int4" in quantization:
446446
from torchao.quantization import FbgemmConfig
447447

448448
_, precision, group_size = quantization.split("-")
449449
group_size = int(group_size)
450450
block_size = [1, group_size]
451-
if precision == "int4":
452-
quantize_(
453-
model,
454-
FbgemmConfig(
455-
torch.bfloat16, torch.int4, torch.bfloat16, block_size
456-
),
457-
)
458-
else:
459-
raise NotImplementedError(
460-
f"FbegemmConfig({precision=}) not supported yet"
461-
)
451+
assert precision == "int4", f"FbegemmConfig({precision=}) not supported yet"
452+
quantize_(
453+
model,
454+
FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, block_size),
455+
)
456+
elif "fbgemm" in quantization and "fp8" in quantization:
457+
from torchao.float8.config import e4m3_dtype
458+
from torchao.quantization import FbgemmConfig
459+
460+
quantize_(
461+
model,
462+
FbgemmConfig(e4m3_dtype, e4m3_dtype, torch.bfloat16),
463+
)
462464
elif "int4dq-" in quantization:
463465
from torchao.dtypes import CutlassInt4PackedLayout
464466

torchao/dtypes/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
to_affine_quantized_intx,
99
to_affine_quantized_intx_static,
1010
)
11-
from .fbgemm_quantized_tensor import to_fbgemm_quantized
11+
from .fbgemm_fp8_tensor import to_fbgemm_fp8
12+
from .fbgemm_int4_tensor import to_fbgemm_int4
1213
from .floatx import (
1314
CutlassSemiSparseLayout,
1415
Float8Layout,
@@ -62,5 +63,6 @@
6263
"PackedLinearInt8DynamicActivationIntxWeightLayout",
6364
"to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight",
6465
"Int4XPULayout",
65-
"to_fbgemm_quantized",
66+
"to_fbgemm_int4",
67+
"to_fbgemm_fp8",
6668
]

torchao/dtypes/fbgemm_fp8_tensor.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from typing import Optional
9+
10+
import torch
11+
from torch.utils._python_dispatch import return_and_correct_aliasing
12+
13+
from torchao.utils import (
14+
TORCH_VERSION_AT_LEAST_2_5,
15+
TorchAOBaseTensor,
16+
)
17+
18+
__all__ = [
19+
"to_fbgemm_fp8",
20+
]
21+
22+
aten = torch.ops.aten
23+
24+
25+
class FbgemmFp8Tensor(TorchAOBaseTensor):
26+
tensor_data_attrs = ["float8_data", "scale", "activation_scale_ub"]
27+
tensor_attributes = ["dtype"]
28+
29+
def __new__(cls, float8_data, scale, activation_scale_ub, dtype):
30+
shape = float8_data.shape
31+
kwargs = {}
32+
kwargs["device"] = float8_data.device
33+
kwargs["dtype"] = dtype
34+
kwargs["requires_grad"] = False
35+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
36+
37+
def __init__(self, float8_data, scale, activation_scale_ub, dtype):
38+
self.float8_data = float8_data
39+
self.scale = scale
40+
self.activation_scale_ub = activation_scale_ub
41+
42+
def __tensor_flatten__(self):
43+
return self.tensor_data_attrs, [
44+
getattr(self, attr) for attr in self.tensor_attributes
45+
]
46+
47+
@classmethod
48+
def __tensor_unflatten__(
49+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
50+
):
51+
return cls(
52+
*[tensor_data_dict[name] for name in cls.tensor_data_attrs],
53+
*tensor_attributes,
54+
)
55+
56+
def _apply_fn_to_data(self, fn):
57+
return self.__class__(
58+
*[fn(getattr(self, attr)) for attr in self.tensor_data_attrs],
59+
*[getattr(self, attr) for attr in self.tensor_attributes],
60+
)
61+
62+
def __repr__(self):
63+
return (
64+
f"{self.__class__.__name__}(weight={self.float8_data}, scale={self.scale}, "
65+
f"activation_scale_ub={self.activation_scale_ub}, "
66+
f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
67+
)
68+
69+
def _quantization_type(self):
70+
return f"shape={self.shape}, activation_scale_ub={self.activation_scale_ub}, device={self.device}"
71+
72+
@classmethod
73+
def from_float(
74+
cls,
75+
w: torch.Tensor,
76+
activation_scale_ub: Optional[float] = None,
77+
):
78+
if activation_scale_ub is None:
79+
activation_scale_ub = 1200.0
80+
81+
activation_scale_ub = torch.tensor(
82+
[activation_scale_ub],
83+
dtype=torch.float,
84+
device=w.device,
85+
)
86+
wq, w_scale = torch.ops.triton.quantize_fp8_row(w)
87+
# wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
88+
dtype = w.dtype
89+
del w
90+
return FbgemmFp8Tensor(
91+
wq,
92+
w_scale,
93+
activation_scale_ub=activation_scale_ub,
94+
dtype=dtype,
95+
)
96+
97+
98+
implements = FbgemmFp8Tensor.implements
99+
100+
101+
@implements([torch.nn.functional.linear, aten.linear.default])
102+
def _(func, types, args, kwargs):
103+
input_tensor, weight_tensor, bias = (
104+
args[0],
105+
args[1],
106+
args[2] if len(args) > 2 else None,
107+
)
108+
if not input_tensor.is_floating_point():
109+
raise NotImplementedError(
110+
f"{func} is not implemented for non floating point input"
111+
)
112+
113+
orig_act_size = input_tensor.size()
114+
orig_out_features = weight_tensor.shape[-2]
115+
116+
# not used?
117+
# num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device)
118+
# xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
119+
# input_tensor, num_tokens, weight_tensor.activation_scale_ub
120+
# )
121+
xq, x_scale = torch.ops.triton.quantize_fp8_row(
122+
input_tensor, weight_tensor.activation_scale_ub
123+
)
124+
print("xq shape:", xq.shape, " xscale shape:", x_scale.shape)
125+
print("weight tensor:", weight_tensor.shape)
126+
res = torch.ops.fbgemm.f8f8bf16_rowwise(
127+
xq,
128+
weight_tensor.float8_data,
129+
x_scale,
130+
weight_tensor.scale,
131+
use_fast_accum=True,
132+
)
133+
print("res shape:", res.shape)
134+
res = res.reshape(*orig_act_size[:-1], orig_out_features)
135+
if bias is not None:
136+
res = res + bias
137+
138+
return res
139+
140+
141+
@implements([aten.detach.default, aten.alias.default])
142+
def _(func, types, args, kwargs):
143+
return return_and_correct_aliasing(
144+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
145+
)
146+
147+
148+
@implements([aten.clone.default, aten.copy_.default])
149+
def _(func, types, args, kwargs):
150+
return return_and_correct_aliasing(
151+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
152+
)
153+
154+
155+
to_fbgemm_fp8 = FbgemmFp8Tensor.from_float
156+
157+
158+
if TORCH_VERSION_AT_LEAST_2_5:
159+
# Allow a model with FbgemmFp8Tensor weights to be loaded with `weights_only=True`
160+
torch.serialization.add_safe_globals([FbgemmFp8Tensor])

torchao/dtypes/fbgemm_quantized_tensor.py renamed to torchao/dtypes/fbgemm_int4_tensor.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
import torch
1212
from torch.utils._python_dispatch import return_and_correct_aliasing
1313

14-
from torchao.utils import TorchAOBaseTensor
14+
from torchao.utils import (
15+
TORCH_VERSION_AT_LEAST_2_5,
16+
TorchAOBaseTensor,
17+
)
1518

1619
__all__ = [
17-
"to_fbgemm_quantized",
20+
"to_fbgemm_int4",
1821
]
1922

2023
aten = torch.ops.aten
@@ -71,25 +74,22 @@ def __repr__(self):
7174
f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
7275
)
7376

77+
def _quantization_type(self):
78+
return f"shape={self.shape}, group_size={self.group_size}, device={self.device}"
79+
7480
@classmethod
7581
def from_float(
7682
cls,
7783
w: torch.Tensor,
78-
input_dtype: torch.dtype,
79-
weight_dtype: torch.dtype,
80-
output_dtype: torch.dtype,
8184
block_size: List[int],
8285
):
8386
assert len(block_size) == w.ndim, (
8487
f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}"
8588
)
86-
group_size = block_size[-1]
89+
if int4_row_quantize_zp is None:
90+
raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0")
8791

88-
assert (input_dtype, weight_dtype, output_dtype) == (
89-
torch.bfloat16,
90-
torch.int4,
91-
torch.bfloat16,
92-
)
92+
group_size = block_size[-1]
9393

9494
if w.ndim >= 3:
9595
wq, scale, zero_point = zip(
@@ -138,9 +138,10 @@ def _(func, types, args, kwargs):
138138
weight_tensor.scale,
139139
weight_tensor.zero_point,
140140
)
141+
res = res.reshape(*orig_act_size[:-1], orig_out_features)
141142
if bias is not None:
142143
res = res + bias
143-
return res.reshape(*orig_act_size[:-1], orig_out_features)
144+
return res
144145

145146

146147
@implements([aten.detach.default, aten.alias.default])
@@ -157,5 +158,9 @@ def _(func, types, args, kwargs):
157158
)
158159

159160

160-
# We can have `to_fbgemm_tensor` to dispatch to different Fbgemm tensors later
161-
to_fbgemm_quantized = FbgemmInt4Tensor.from_float
161+
to_fbgemm_int4 = FbgemmInt4Tensor.from_float
162+
163+
164+
if TORCH_VERSION_AT_LEAST_2_5:
165+
# Allow a model with FbgemmInt4Tensor weights to be loaded with `weights_only=True`
166+
torch.serialization.add_safe_globals([FbgemmInt4Tensor])

0 commit comments

Comments
 (0)