Skip to content

Fix QAT range learning, ensure scales get gradients #2280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2025

Conversation

andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented May 30, 2025

Summary: The previous _GenericFakeQuantized nulled all gradients except the ones for the input. This is problematic for range learning because scales and zero points are now nn.Parameters and actually require gradients. This commit fixes this by reducing the scope of the autograd.Function to torch.round only, so QAT can just call the fake quantization primitives directly.

Note: Part of the dequantize math currently casts the inputs and the zero points to int32. However, autograd doesn't work with integer math and this part of the code path is now visible to autograd. To make this work, this commit also removes this dtype cast.

Note: This change means we no longer do cachemask and so our numerics no longer matches those of pytorch/pytorch's
fake quantization ops.

Test Plan:
Updated the following test to check for scales and weights being updated:
python test/quantization/test_qat.py -k test_qat_range_learning

Copy link

pytorch-bot bot commented May 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2280

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 9584596 with merge base a2c5ca1 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 30, 2025
@andrewor14 andrewor14 marked this pull request as draft May 30, 2025 19:45
@andrewor14 andrewor14 force-pushed the fix-range-learning branch 3 times, most recently from 5f20d3d to 9671274 Compare May 30, 2025 20:58
@andrewor14 andrewor14 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label May 30, 2025
@andrewor14 andrewor14 requested a review from jerryzh168 May 30, 2025 21:04
@andrewor14 andrewor14 marked this pull request as ready for review May 30, 2025 21:04
@facebook-github-bot
Copy link
Contributor

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@andrewor14 andrewor14 force-pushed the fix-range-learning branch from 9671274 to eec9be6 Compare May 31, 2025 01:13
@facebook-github-bot
Copy link
Contributor

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@jerryzh168
Copy link
Contributor

looks like there are some CI failures

@andrewor14 andrewor14 marked this pull request as draft June 2, 2025 01:44
@andrewor14 andrewor14 force-pushed the fix-range-learning branch from eec9be6 to 8e5e478 Compare June 2, 2025 21:00
@andrewor14 andrewor14 marked this pull request as ready for review June 2, 2025 21:00
@andrewor14 andrewor14 force-pushed the fix-range-learning branch from 8e5e478 to febe4de Compare June 2, 2025 21:03
@facebook-github-bot
Copy link
Contributor

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@andrewor14 andrewor14 force-pushed the fix-range-learning branch 2 times, most recently from c4ad425 to 94344ae Compare June 2, 2025 21:07
@facebook-github-bot
Copy link
Contributor

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@andrewor14 andrewor14 force-pushed the fix-range-learning branch from 94344ae to 1744bfb Compare June 3, 2025 14:00
@facebook-github-bot
Copy link
Contributor

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

**Summary:** The previous `_GenericFakeQuantized` nulled all
gradients except the ones for the input. This is problematic
for range learning because scales and zero points are now
`nn.Parameters` and actually require gradients. This commit
fixes this by reducing the scope of the `autograd.Function`
to `torch.round` only, so QAT can just call the fake
quantization primitives directly.

Note: Part of the dequantize math currently casts the inputs
and the zero points to int32. However, autograd doesn't work
with integer math and this part of the code path is now
visible to autograd. To make this work, this commit also
removes this dtype cast.

Note: This change means we no longer do cachemask and so
our numerics no longer matches those of pytorch/pytorch's
fake quantization ops.

**Test Plan:**
Updated the following test to check for scales and weights
being updated:

python test/quantization/test_qat.py -k test_qat_range_learning
@andrewor14 andrewor14 force-pushed the fix-range-learning branch from 1744bfb to 9584596 Compare June 3, 2025 19:06
@facebook-github-bot
Copy link
Contributor

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@andrewor14 andrewor14 merged commit 2ef656e into main Jun 3, 2025
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants