-
Notifications
You must be signed in to change notification settings - Fork 275
Fix QAT range learning, ensure scales get gradients #2280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2280
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit 9584596 with merge base a2c5ca1 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
5f20d3d
to
9671274
Compare
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
9671274
to
eec9be6
Compare
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
looks like there are some CI failures |
eec9be6
to
8e5e478
Compare
8e5e478
to
febe4de
Compare
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
c4ad425
to
94344ae
Compare
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
94344ae
to
1744bfb
Compare
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
**Summary:** The previous `_GenericFakeQuantized` nulled all gradients except the ones for the input. This is problematic for range learning because scales and zero points are now `nn.Parameters` and actually require gradients. This commit fixes this by reducing the scope of the `autograd.Function` to `torch.round` only, so QAT can just call the fake quantization primitives directly. Note: Part of the dequantize math currently casts the inputs and the zero points to int32. However, autograd doesn't work with integer math and this part of the code path is now visible to autograd. To make this work, this commit also removes this dtype cast. Note: This change means we no longer do cachemask and so our numerics no longer matches those of pytorch/pytorch's fake quantization ops. **Test Plan:** Updated the following test to check for scales and weights being updated: python test/quantization/test_qat.py -k test_qat_range_learning
1744bfb
to
9584596
Compare
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary: The previous
_GenericFakeQuantized
nulled all gradients except the ones for the input. This is problematic for range learning because scales and zero points are nownn.Parameters
and actually require gradients. This commit fixes this by reducing the scope of theautograd.Function
totorch.round
only, so QAT can just call the fake quantization primitives directly.Note: Part of the dequantize math currently casts the inputs and the zero points to int32. However, autograd doesn't work with integer math and this part of the code path is now visible to autograd. To make this work, this commit also removes this dtype cast.
Note: This change means we no longer do cachemask and so our numerics no longer matches those of pytorch/pytorch's
fake quantization ops.
Test Plan:
Updated the following test to check for scales and weights being updated:
python test/quantization/test_qat.py -k test_qat_range_learning