-
Notifications
You must be signed in to change notification settings - Fork 287
Support dynamic int8 quantization for Gemma #1612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I've heard from @awsaf49 that this PR worked nicely to significantly reduce the memory footprint for running Gemma 7B on a single 16GB GPU. Ref: |
@james77777778 thanks very much for this! Sorry for the delay reviewing, started poking around last week and should be done later today. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK done finally! Thanks for starting this off, this is a super important feature for the library I think.
I have some concrete suggestions for the reversible_embedding.py
diff, but for the modeling changes I'm still not sure exactly what we want. Identified some problems below that I'm not sure how exactly to solve. I'll keep thinking here but please jump in with thoughts too!
@@ -31,11 +31,15 @@ def __init__( | |||
dropout=0, | |||
**kwargs, | |||
): | |||
dtype_policies = kwargs.pop("dtype_policies", None) or {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does dtypes_policies work? Can you link relevant code or an explainer?
@@ -129,6 +132,7 @@ def __init__( | |||
num_key_value_heads=num_key_value_heads, | |||
dropout=dropout, | |||
dtype=dtype, | |||
dtype_policies=dtype_policies.get(f"decoder_block_{i}"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how come we don't need a default here in the get
call?
@@ -186,6 +191,17 @@ def get_config(self): | |||
"dropout": self.dropout, | |||
} | |||
) | |||
# Ensure the serialziation of the dtype polices | |||
dtype_policies = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this will work as is for new uploads of models to Kaggle and elsewhere...
Currently, we use keras.saving.serialize_keras_object(backbone)
to save the config we upload to Kaggle. We want a few things for that config option...
- For it to be fairly minimal, so avoiding a nested dict structure always getting stuck in there might be good.
- To allow casting of float dtypes automatically...
# Set the float dtype of the model, cast weights if necessary on load.
keras.config.set_floatx("bfloat16")
backbone = keras_nlp.models.GemmaBackbone.from_preset("...")
# Set the float dtype of the model, cast weights if necessary on load.
keras.config.set_floatx("float32")
backbone = keras_nlp.models.GemmaBackbone.from_preset("...")
I'm not sure how exactly this should work for quantized saves. For a quantized checkpoint, we probably don't want to allow any automatic casting of quantized weights. It would't really make sense.
This code as is would eventually break our Kaggle uploads because new config.json
files would always include a complete dtype spec in the form of a nested dtype policy object. Meaning there would be no concise way to load a float32 checkpoint at bfloat16 or vise versa (which is something we are currently relying on quite a bit for guides etc.).
Not sure exactly what we want here though. Let's think it through.
} | ||
) | ||
# Ensure the serialziation of the dtype polices | ||
dtype_policies = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, we land this for one model like gemma, but for all KerasNLP models. We want our model specific code to be fairly concise when possible, and this would add a lot of extra code to add to every model we have.
Let's think if there's a way we can do this where most of the logic is consolidated to base classes like backbone.py
.
Hi @mattdangerw
I haven't thought of it before. Thanks for pointing it out. To prevent the complete dtype spec from being uploaded, we can add the logic to detect whether it is a quantized model. We'll include the dtype spec if it's quantized and skip it otherwise. What do you think?
We could incorporate the detection I mentioned into I had some discussions with @fchollet here: The main challenge with subclasses is identifying which layers (and its sublayers) are quantized and enabling the serialization/deserialization for them. |
Yeah I think this is the right call? We also want to allow users to save their own quantized versions, upload them and share them where others could get the quantized config. It feels slightly awkward/implicit, but I think it's the right way to preserve our current usages, but still leave room for what we want for quantization. Long winded way to say sounds good to me, let's try it :) Still want to think about how we can add this support to every model in the library with out as much of a code diff, but haven't had time to think on that question yet. |
Great, now that we've reached a consensus. These changes shouldn't take long.
I don't have a better idea at the moment and I think that a refactor for dtype policy control in Core Keras might be necessary, especially for the subclasses. |
I will continue this PR once we have a new release of Keras ( There is an updated in |
"a refactor for dtype policy control in Core Keras might be necessary, especially for the subclasses. My 2c: now is the time to do it as KerasNLP is still a fairly new library. As usage picks up, things will only get harder. |
Actually, I don't have a concrete idea for the refactoring at the moment. I can point out the issue: Maybe, we can implicitly pass a dict called |
I think at the very least, we would probably want to handle this at the
A half formed thought is this was kinda similar to our distribution layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = ...
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = ...
layout_map["decoder_block.*attention_output.*kernel"] = ...
distribution = keras.distribution.ModelParallel(mesh, layout_map)
# Globally.
keras.distribution.set_distribution(model_parallel)
# Or locally.
with distribution.scope():
... A somewhat parallel API might be... dtype_map = keras.dtype_policies.DtypeMap(default="bfloat16")
dtype_map["token_embedding/embeddings"] = "float32"
dtype_map["decoder_block.*attention.*(query|key|value).*kernel"] = "int8"
policy = keras.dtype_policies.DTypePolicy(dtype_map)
# Globally.
keras.dtype_policies.set_dtype_policy(policy)
# Or locally.
model = SomeModel(dtype=polilcy) Then the init logic can stay simple... self.sublayer = SomeLayer(..., dtype=self.dtype_policy) Then we'd have one policy that we could pass around that gave a whole mapping of dtypes all the way down. That also gives a relatively quicker way to specify things like a dtype for all query projections, say. Would still take some figuring out during saving, etc. @james77777778 wdyt? |
Hi @mattdangerw I think this idea is great and it should be easy to implement. I have changed the value type in the mapping from "dtype" to "dtype_policy". This adjustment should make more sense because we rely on that for the behavior of the quantized layers. In that PR, the saving and loading issues have been solved. It wasn't a difficult one :) |
I have submitted a new PR for this: I'm closing this PR now |
I reopen this PR, which originated from #1591, due to the API generation issue
Notes:
dtype_policy
for each layer in subclassesReversibleEmbedding
is imported by the choice ofconfig.keras_3()
. Is there a better way to support both Keras 2 and Keras 3?keras.distribution
. Can someone provide guidance?Here are some numbers:
Model outputs:
Standalone script:
python3 gemma_int8.py --model gemma_1.1_instruct_2b_en --save python3 gemma_int8.py --model gemma_1.1_instruct_2b_en # Use CPU for the following commands python3 gemma_int8.py --model gemma_1.1_instruct_7b_en --save python3 gemma_int8.py --model gemma_1.1_instruct_7b_en
gemma_int8.py