Skip to content

Fix QAT range learning, ensure scales get gradients #2280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 20 additions & 43 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from torchao.quantization.qat.utils import (
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_GenericFakeQuantize,
_get_qmin_qmax,
)
from torchao.quantization.quant_api import (
Expand Down Expand Up @@ -585,42 +584,6 @@ def test_qat_8da4w_quantizer_gradients(self):
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16)
self._test_qat_quantized_gradients(quantizer)

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_qat_generic_fake_quantize(self):
"""
Test that the generic fake quantize used in 8da4w QAT matches
the numerics of existing fake quantize ops in Pytorch in both
the forward and the backward passes.
"""
(qmin, qmax) = _get_qmin_qmax(4)
py_input = torch.randn(16, 64).float().requires_grad_()
py_s = torch.randn(16).float()
py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32)
py_out = torch.fake_quantize_per_channel_affine(
py_input, py_s, py_zp, 0, qmin, qmax
)
py_out.sum().backward()

ao_input = copy.deepcopy(py_input)
ao_input.grad.data.zero_()
block_size = (1, ao_input.shape[-1])
ao_s = copy.deepcopy(py_s)
ao_zp = copy.deepcopy(py_zp)
ao_out = _GenericFakeQuantize.apply(
ao_input, block_size, ao_s, ao_zp, qmin, qmax
)
ao_out.sum().backward()

torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0)

# Test that gradients are close enough
num_grads = py_input.grad.numel()
num_equal_grads = torch.eq(py_input.grad, ao_input.grad).flatten().sum().item()
num_equal_grad_threshold = 0.8
self.assertGreaterEqual(num_equal_grads / num_grads, num_equal_grad_threshold)

def _assert_close_4w(self, val, ref):
# Note: for int4 weight-only quantization, we do not expect exact match
# because torch._weight_int4pack_mm and torch.mm do not match exactly.
Expand Down Expand Up @@ -1700,16 +1663,30 @@ def test_qat_range_learning(self):
m(*example_inputs)

# Simulate training
num_steps = 10
optimizer = torch.optim.SGD(
m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5
)
loss_fn = torch.nn.CrossEntropyLoss()
target = torch.randn(1, 512).float()
out = m(*example_inputs)
loss = loss_fn(out, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
for i in range(num_steps):
prev_scale = copy.deepcopy(m.linear1.weight_fake_quantizer.scale)
prev_weight = copy.deepcopy(m.linear1.weight)
optimizer.zero_grad()
target = torch.randn(1, 512).float()
out = m(*example_inputs)
loss = loss_fn(out, target)
loss.backward()
optimizer.step()
# Assert that scales have valid gradients and are being updated
new_scale = m.linear1.weight_fake_quantizer.scale
self.assertIsNotNone(new_scale.grad)
self.assertNotEqual(torch.count_nonzero(new_scale.grad), 0)
self.assertFalse(torch.equal(new_scale, prev_scale))
# Assert that weights have valid gradients and are being updated
new_weight = m.linear1.weight
self.assertIsNotNone(new_weight.grad)
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
self.assertFalse(torch.equal(new_weight, prev_weight))


if __name__ == "__main__":
Expand Down
11 changes: 6 additions & 5 deletions torchao/quantization/qat/affine_fake_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
choose_qparams_affine,
choose_qparams_affine_dont_preserve_zero,
choose_qparams_affine_tinygemm,
fake_quantize_affine,
)
from torchao.utils import TorchAOBaseTensor

from .utils import (
_GenericFakeQuantize,
_UnwrapAffineFakeQuantizedTensor,
)

Expand Down Expand Up @@ -90,14 +90,15 @@ def apply_fake_quant_fn(t: torch.Tensor):
scale_dtype,
zero_point_dtype,
)
fq = _GenericFakeQuantize.apply(
fq = fake_quantize_affine(
t,
block_size,
scale,
zero_point,
qmin,
qmax,
zero_point_domain,
quant_dtype=torch.int32,
quant_min=qmin,
quant_max=qmax,
zero_point_domain=zero_point_domain,
)
return fq

Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/qat/fake_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_DTYPE_TO_BIT_WIDTH,
_DTYPE_TO_QVALUE_BOUNDS,
MappingType,
_Round,
choose_qparams_affine,
)
from torchao.quantization.utils import (
Expand All @@ -31,7 +32,6 @@
from .utils import (
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_Round,
)


Expand Down
83 changes: 11 additions & 72 deletions torchao/quantization/qat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,68 +4,19 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import torch

from torchao.quantization.quant_primitives import (
ZeroPointDomain,
fake_quantize_affine_cachemask,
fake_quantize_affine,
)
from torchao.quantization.utils import (
_get_per_token_block_size,
)


class _GenericFakeQuantize(torch.autograd.Function):
"""
Implementation of generic fake quantize with backward STE.

With the appropriate input tensor shape, this can be used to express
grouped per channel fake quantize or per token fake quantize.
"""

@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
input: torch.Tensor,
block_size: List[int],
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> torch.Tensor:
# avoid circular dependencies
from torchao.quantization.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)

if isinstance(input, AffineFakeQuantizedTensor):
_input = input.original_tensor
else:
_input = input

(fq, mask) = fake_quantize_affine_cachemask(
_input,
block_size,
scales,
zero_points,
torch.int32,
quant_min,
quant_max,
zero_point_domain,
)

ctx.save_for_backward(mask)
return fq

@staticmethod
def backward(ctx, gy):
(mask,) = ctx.saved_tensors
return gy * mask, None, None, None, None, None, None


# TODO: delete?
class _UnwrapAffineFakeQuantizedTensor(torch.autograd.Function):
"""
Helper autograd function to unwrap `AffineFakeQuantizedTensor` while ensuring
Expand All @@ -91,20 +42,6 @@ def backward(ctx, gy):
return (gy,)


class _Round(torch.autograd.Function):
"""
Implementation of generic round operation with backward STE.
"""

@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
return torch.round(x)

@staticmethod
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
return gy


def _fake_quantize_per_channel_group(
input: torch.Tensor,
scales: torch.Tensor,
Expand All @@ -118,14 +55,15 @@ def _fake_quantize_per_channel_group(
assert input.shape[-1] % group_size == 0
assert input.dim() == 2
block_size = (1, group_size)
return _GenericFakeQuantize.apply(
return fake_quantize_affine(
input,
block_size,
scales,
zero_points,
quant_min,
quant_max,
zero_point_domain,
quant_dtype=torch.int32,
quant_min=quant_min,
quant_max=quant_max,
zero_point_domain=zero_point_domain,
)


Expand All @@ -140,13 +78,14 @@ def _fake_quantize_per_token(

_per_token_quant_qparam_dim_check(input, scales, zero_points)
block_size = _get_per_token_block_size(input)
fq = _GenericFakeQuantize.apply(
fq = fake_quantize_affine(
input,
block_size,
scales,
zero_points,
quant_min,
quant_max,
quant_dtype=torch.int32,
quant_min=quant_min,
quant_max=quant_max,
)
return fq.reshape_as(input).to(input.dtype)

Expand Down
41 changes: 27 additions & 14 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,20 @@ class TorchAODType(Enum):
register_custom_op = _register_custom_op(quant_lib)


class _Round(torch.autograd.Function):
"""
Implementation of generic round operation with backward STE.
"""

@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
return torch.round(x)

@staticmethod
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
return gy


# TODO: decide on if we want to allow custom quant_min/quant_max here
def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
"""Get quant_min and quant_max args based on dtype and also
Expand Down Expand Up @@ -407,7 +421,7 @@ def _quantize_affine_no_dtype_cast(
zero_point = None

quant = torch.clamp(
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
_Round.apply(input * (1.0 / scale)) + zero_point, quant_min, quant_max
)
quant = quant.view(original_shape)

Expand Down Expand Up @@ -493,7 +507,7 @@ def _quantize_affine_float_zero_point_no_dtype_cast(

mid_point = (quant_max + quant_min + 1) / 2
min_val = zero_point - scale * mid_point
quant = torch.clamp(torch.round((input - min_val) / scale), quant_min, quant_max)
quant = torch.clamp(_Round.apply((input - min_val) / scale), quant_min, quant_max)
quant = quant.view(original_shape)

return quant
Expand Down Expand Up @@ -577,7 +591,7 @@ def _quantize_affine_no_zero_point_no_dtype_cast(
# with numel=0 which we handle by unifying the two
zero_point = None

quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max)
quant = torch.clamp(_Round.apply(input * (1.0 / scale)), quant_min, quant_max)
quant = quant.view(original_shape)

return quant
Expand Down Expand Up @@ -692,10 +706,9 @@ def _dequantize_affine_no_dtype_check(

# Force a copy to avoid input modification due
# to upcoming in-place operations.
dequant = input.to(torch.int32, copy=True)
dequant = input.to(output_dtype, copy=True)
if zero_point is not None:
dequant = dequant - zero_point.to(torch.int32)
dequant = dequant.to(output_dtype)
dequant = dequant - zero_point.to(output_dtype)
dequant = dequant * scale

return dequant.view(original_shape).to(output_dtype)
Expand Down Expand Up @@ -1202,7 +1215,7 @@ def choose_qparams_affine_dont_preserve_zero(
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
scale = torch.clamp(scale, min=eps)
# Zero point is int
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = quant_min - _Round.apply(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
if zero_point_dtype is None:
zero_point_dtype = torch.int32
Expand Down Expand Up @@ -1308,7 +1321,7 @@ def choose_qparams_affine_with_min_max(
if zero_point_domain == ZeroPointDomain.NONE:
zero_point = None
elif zero_point_domain == ZeroPointDomain.INT:
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = quant_min - _Round.apply(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
if zero_point_dtype is None:
zero_point_dtype = torch.int32
Expand Down Expand Up @@ -1400,7 +1413,7 @@ def _choose_qparams_affine(
assert mapping_type == MappingType.ASYMMETRIC.name
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
scale = torch.clamp(scale, min=eps)
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = quant_min - _Round.apply(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
if zero_point_dtype is None:
zero_point_dtype = torch.int32
Expand Down Expand Up @@ -1434,7 +1447,7 @@ def choose_qparams_and_quantize_affine_qqq(
s_group *= 2 / max_q_val # 2 => symmetric

# Quantize
q_w = torch.round(w / s_group).int()
q_w = _Round.apply(w / s_group).int()
q_w += half_q_val
q_w = torch.clamp(q_w, 0, max_q_val)
# Compute ref (dequantized)
Expand Down Expand Up @@ -1467,7 +1480,7 @@ def reshape_w(w):
s_channel /= max_q_val

# Quantize
q_w = torch.round(w / s_channel).int()
q_w = _Round.apply(w / s_channel).int()
q_w = torch.clamp(q_w, -max_q_val, max_q_val)
# Compute ref (dequantized)
w_ref = q_w.half() * s_channel
Expand Down Expand Up @@ -1871,7 +1884,7 @@ def choose_qparams_and_quantize_affine_hqq(

# Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14
if nbits in [4]:
zero = torch.round(zero)
zero = _Round.apply(zero)

# Fine-tune weights
if optimize:
Expand All @@ -1887,7 +1900,7 @@ def choose_qparams_and_quantize_affine_hqq(
else:
zero = zero.to(compute_dtype)
scale = scale.to(compute_dtype)
W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])
W_q = _Round.apply(W * scale + zero).clamp(min_max[0], min_max[1])

# Store meta-data (we invert the scale for dequantization)
scale = 1.0 / scale
Expand Down Expand Up @@ -2004,7 +2017,7 @@ def choose_qparams_affine_float8(
if scale_dtype is not torch.float32:
# Shielding for Version > 2.8
assert scale_dtype is torch.float8_e8m0fnu, "Only float8_e8m0fnuz is supported"
scale = torch.exp2(torch.round(torch.log2(scale)))
scale = torch.exp2(_Round.apply(torch.log2(scale)))
return scale.to(dtype=torch.float32)


Expand Down
Loading