Skip to content

Enabling EmbeddingQuantizer and SharedEmbeddingQuantizer #1525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

dillondesilva
Copy link

Overview

This PR enables the use of EmbeddingQuantizer and SharedEmbeddingQuantizer as quantization configuration options.

Running lintrunner appears to have changed several lines in this file. However, the edits made strictly to enable these new experimental quantizer can be found on the following lines:

  • Lines 46-49: Imports for EmbeddingQuantizer and SharedEmbeddingQuantizer
  • Lines 202-234: Logic for setting EmbeddingQuantizer and SharedEmbeddingQuantizer options
  • Lines 1033-1034: Including options to map quantization config types to the corresponding quantizers.

Copy link

pytorch-bot bot commented Apr 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1525

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Cancelled Job, 1 Unrelated Failure

As of commit 5dbf1ab with merge base a37b08a (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 13, 2025
@Jack-Khuu
Copy link
Contributor

Looks like the imports aren't happy. I wonder if we need a torchao pin bump?
Wanna give that a try?

weight_dtype = getattr(torch, f"int{bit_width}")

try:
quantize_(
model,
model,
int8_dynamic_activation_intx_weight(
weight_dtype=weight_dtype,
granularity=granularity,
Copy link
Contributor

@metascroy metascroy Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

granularity => weight_granularity
has_weight_zeros=True => weight_mapping_type=MappingType.ASYMMETRIC
has_weight_zeros=False => weight_mapping_type=MappingType.SYMMETRIC

@@ -154,45 +170,86 @@ def quantize_model(
print("Encountered error during quantization: {e}")
print("Trying with PlainLayout")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use QDQLayout instead

@metascroy
Copy link
Contributor

Looks like the imports aren't happy. I wonder if we need a torchao pin bump? Wanna give that a try?

Yeah, you will need to update the torchao pin to something more recent (just pick the latest commit in torchao): https://github.com/pytorch/torchchat/blob/main/install/.pins/torchao-pin.txt

@dillondesilva
Copy link
Author

@Jack-Khuu I think its ready to be merged (hopefully haha). Thanks so much for helping out - really appreciate the support from you and Scott :)

Copy link
Contributor

@Jack-Khuu Jack-Khuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks legit, just some small style nits

@metascroy can you give it a glance?

Comment on lines 71 to 85

import inspect


def get_named_parameters(func: Callable) -> List[str]:
# Get the signature of the function

signature = inspect.signature(func)

# Extract the parameters from the signature

parameters = signature.parameters

# Filter and return named parameters

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind undoing the whitespaces?

@@ -110,23 +125,73 @@ def quantize_model(

if isinstance(quantize_options, str):
quantize_options = json.loads(quantize_options)

for quantizer, q_kwargs in quantize_options.items():
if quantizer not in quantizer_class_dict:
raise RuntimeError(f"unknown quantizer {quantizer} specified")
else:
# Use tensor subclass API for int4 weight only.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, Comment got split off from:if (device in ["cuda", "xpu", "npu"]) and quantizer == "linear:int4":

# default setup for affine quantization of activations

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto on the white spaces :)

@metascroy
Copy link
Contributor

The main concern I have is shared embedding quantization must be done first. Not sure how to ensure that in torchchat. cc @Jack-Khuu

@Jack-Khuu
Copy link
Contributor

Jack-Khuu commented Apr 25, 2025

Last little bit: can you add the new quants into the CI https://github.com/pytorch/torchchat/blob/main/.github/workflows/pull.yml

Essentially whereever you see embedding:wx in pull.yml, just add another call in that same test, but using your new quants instead

@metascroy
Copy link
Contributor

Last little bit: can you add the new quants into the CI https://github.com/pytorch/torchchat/blob/main/.github/workflows/pull.yml

Essentially whereever you see embedding:wx in pull.yml, just add another call in that same test, but using your new quants instead

To really test shared embedding, you need to test a model that has embeddings shared with unembeddings. stories110M (currently used in CI) is not one of them.

Some examples: llama1B, llama3B, phi4-mini, etc.

@Jack-Khuu
Copy link
Contributor

Hmmm ok let's do this @dillondesilva you can ignore the comments about testing for now
(just hit the style nits + quant order comment)

We'll make a different PR for testing, since it's a tad more involved

Comment on lines 133 to 157
if quantizer == "experimental:embedding":
group_size = q_kwargs["groupsize"]
bit_width = q_kwargs["bitwidth"]
has_weight_zeros = q_kwargs["has_weight_zeros"]
weight_granularity = (
PerAxis() if group_size == -1 else PerGroup(group_size)
)
weight_dtype = getattr(torch, f"int{bit_width}")
weight_mapping_type = (
MappingType.ASYMMETRIC
if has_weight_zeros
else MappingType.SYMMETRIC
)

try:
model = EmbeddingQuantizer(
weight_dtype=weight_dtype,
granularity=weight_granularity,
mapping_type=weight_mapping_type,
use_fallback=False,
).quantize(model)
except Exception as e:
print(
"Encountered error during quantization with experimental EmbeddingQuantization: {e}"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would you not put this into EmbeddingQuantizer class, or a subclass derived from EmbeddingQuantizer.
At a minimum 134-146 seem to be copy pasta that's replicated multiple times, e.g., right below L159-171, and 195-204, 241-252, ...

Copy link
Author

@dillondesilva dillondesilva May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Were you thinking of wrapping the quantizers from torchao (opposed to directly updating them) for our specific usecase kind of like this?

class TCQuantizer:
    def __init__(self, q_kwargs, quantizer):
        self.q_kwargs = q_kwargs
        self.quantizer = quantizer_class_dict[quantizer]

    def quantize(self, model):
        group_size = self.q_kwargs["groupsize"]
        bit_width = self.q_kwargs["bitwidth"]
        has_weight_zeros = self.q_kwargs["has_weight_zeros"]
        # Other configuration code
        try:
              model = self.quantizer(
                  weight_dtype=weight_dtype,
                  granularity=weight_granularity,
                  mapping_type=weight_mapping_type,
                  use_fallback=False,
              ).quantize(model)
        except Exception as e:
              print(
                  "Encountered error during quantization with quantizer: {e}"
              )

embedding_quantizer = TCQuantizer(q_kwargs, "EmbeddingQuantizer")

@dillondesilva
Copy link
Author

Hmmm ok let's do this @dillondesilva you can ignore the comments about testing for now (just hit the style nits + quant order comment)

We'll make a different PR for testing, since it's a tad more involved

Yep no worries - sounds like a plan! I'll hop onto these changes soon.

@dillondesilva
Copy link
Author

Hmmm ok let's do this @dillondesilva you can ignore the comments about testing for now (just hit the style nits + quant order comment)

We'll make a different PR for testing, since it's a tad more involved

Sounds good! Just addressed the style nits + quant order here

@Jack-Khuu
Copy link
Contributor

Thanks @dillondesilva I'll try to review (and hopefully merge) this today 😃

@Jack-Khuu
Copy link
Contributor

Jack-Khuu commented May 13, 2025

Digging into this a bit:

python3 torchchat.py generate llama3.2-1b --dtype float16 --quantize '{"experimental:shared": {"bitwidth": 4, "groupsize": 32, "has_weight_zeros": true}}' --prompt "Once upon a time,"

Seems to need a little bit of debugging. Mind tracing through this? Left some initial finds

  • I think at the end of the conditionals (e.g. if quantizer == "experimental:shared"), it needs a "continue" call since it uses a slightly different signature than the normal calls

).quantize(model)
except Exception as e:
print(
"Encountered error during quantization with experimental SharedEmbeddingQuantization: {e}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Encountered error during quantization with experimental SharedEmbeddingQuantization: {e}"
f"Encountered error during quantization with experimental SharedEmbeddingQuantization: {e}"

).quantize(model)
except Exception as e:
print(
"Encountered error during quantization with experimental EmbeddingQuantization: {e}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Encountered error during quantization with experimental EmbeddingQuantization: {e}"
f"Encountered error during quantization with experimental EmbeddingQuantization: {e}"

weight_dtype=weight_dtype,
granularity=weight_granularity,
mapping_type=weight_mapping_type,
use_fallback=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
use_fallback=False,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SharedEmbeddingQuantizer doesn't take have a fallback arg

Copy link
Contributor

@Jack-Khuu Jack-Khuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants