-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[quantization] use channel scales for w4a8 + misc fixes #23570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
35398d0
to
e1cd08e
Compare
Signed-off-by: czhu-cohere <[email protected]>
e1cd08e
to
8d6f5b6
Compare
assert bias is None, "bias not supported by CUTLASS W4A8" | ||
c = self.config | ||
w_q, w_s, _, _ = self._get_weight_params(layer) | ||
w_ch_s = layer.weight_chan_scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel this is better than modifying every place self._get_weight_params
is called
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed; we eventually need to refactor MPLinearKernel
to get rid of _get_weight_params
which seems possible with torch 2.8.0 where we may not need to re-register the params in process_weights_after_loading
and can just map everything to consistent names in MPLinearKernel
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! Thanks for the contribution!
@LucasWilkinson can you attach the |
…#23570) Signed-off-by: czhu-cohere <[email protected]> Signed-off-by: tc-mb <[email protected]>
…#23570) Signed-off-by: czhu-cohere <[email protected]>
…#23570) Signed-off-by: czhu-cohere <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
…#23570) Signed-off-by: czhu-cohere <[email protected]>
…#23570) Signed-off-by: czhu-cohere <[email protected]>
…#23570) Signed-off-by: czhu-cohere <[email protected]>
Purpose
Load per-channel scales for w4a8. This can recover the quality drop from naively casting bf16 scales to fp8 on certain benchmarks (mmlu pro). Computationally, this is 'free' since the previous implementation used
torch.ones
as a placeholder.The
fp8
group scales andfp32
per-channel scales can be generated as a post-processing step after we have a w4a16 checkpoint withbf16
scales. For testing we used an adhoc workflow to generate the checkpoint but will look at integrating that intollm-compressor
as a next step:fp8_scales, fp32_chan_scales = quantfp8(bf16_scales)
fp8_scales
by 8 to avoid saturation when multiplied by int4fp32_chan_scales
by 8 to compensateThen pass the adjusted
fp8_scales/fp32_chan_scales
to thew4a8
kernel.Test Plan
lm-eval (gsm8k, mmlu pro) compare to w4a16 and previous w4a8 (Cohere Command A)
also add an example model
czhu-cohere/TinyLlama-1.1B-Chat-v1.0-W4A8-e2e
and add corresponding test intests/quantization/test_compressed_tensors.py
Test Result
(Optional) Documentation Update
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.