Skip to content

Support dynamic int8 quantization for Gemma #1612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

james77777778
Copy link
Collaborator

I reopen this PR, which originated from #1591, due to the API generation issue

Notes:

  • These patches might seem inelegant, but I've struggled to find a better way to pass dtype_policy for each layer in subclasses
  • ReversibleEmbedding is imported by the choice of config.keras_3(). Is there a better way to support both Keras 2 and Keras 3?
  • I'm not sure how to properly test the quantized Gemma for keras.distribution. Can someone provide guidance?

Here are some numbers:

Model Mem. (bfloat16) Mem. (int8) Weights (kagglehub) Weights (int8) Notes
"gemma_1.1_instruct_2b_en" 5.69GB 2.82GB 4.7GB 2.4GB
"gemma_1.1_instruct_7b_en" 18.97GB 8.69GB 16.0GB 8.0GB Run on CPU

Model outputs:

# "gemma_1.1_instruct_2b_en" int8 version
What is Keras?

Keras is an open-source machine learning library and framework that provides a high-level interface for building and training deep learning models. It is built on top of TensorFlow, allowing users to leverage the vast resources and capabilities of the TensorFlow ecosystem.

**Key features of Keras:**

- High-level API for building and training models
- Support for a wide range of deep learning algorithms
- Optimized for performance and scalability
- Integration with TensorFlow ecosystem for seamless data loading and processing


**Benefits of using Keras:**

- **Simplified model building:** Keras provides a user-friendly interface for constructing deep learning

# "gemma_1.1_instruct_7b_en" int8 version
What is Keras?

**Keras** is a high-level API for TensorFlow and other machine learning libraries. It provides a user-friendly and modular interface for building, training, and evaluating deep learning models. Keras is designed to be accessible to beginners and experienced ML engineers alike.

**Key features of Keras:**

- **Modular design:** Allows for easy composition of different layers and models.
- **TensorFlow compatibility:** Leverages the power of TensorFlow for backend computation.
- **Python API:** Written in Python, making it easy to use and integrate with other Python libraries.
- **Wide range of layers:**

Standalone script:

python3 gemma_int8.py --model gemma_1.1_instruct_2b_en --save
python3 gemma_int8.py --model gemma_1.1_instruct_2b_en

# Use CPU for the following commands
python3 gemma_int8.py --model gemma_1.1_instruct_7b_en --save
python3 gemma_int8.py --model gemma_1.1_instruct_7b_en
gemma_int8.py
import argparse
import json
import os
import pathlib

import keras
import psutil
import tensorflow as tf

import keras_nlp

# Setup kaggle information
os.environ["KAGGLE_USERNAME"] = "xxxxx"
os.environ["KAGGLE_KEY"] = "xxxxx"


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        default="gemma_1.1_instruct_2b_en",
        choices=["gemma_1.1_instruct_2b_en", "gemma_1.1_instruct_7b_en"],
        help="Which model to demonstrate",
    )
    parser.add_argument(
        "--path",
        default=".",
        help="Path to save and load the model",
    )
    parser.add_argument(
        "--save",
        action="store_true",
        help="Quantize and save the model",
    )
    args = parser.parse_args()
    return args


def get_memory_usage():
    # From CPU or GPU:0
    try:
        memory_stats = tf.config.experimental.get_memory_info("GPU:0")
        peak_usage = memory_stats["peak"] / (2**30)
    except Exception:
        memory_usage = psutil.Process().memory_info().rss
        peak_usage = memory_usage / (2**30)
    return peak_usage


def save_int8_model(
    preset: str, model: keras_nlp.models.GemmaCausalLM, path: pathlib.Path
):
    model.quantize("int8")
    model.summary()
    # Save config
    config = keras.saving.serialize_keras_object(model)
    with open(path / f"{preset}_int8.json", "w") as f:
        f.write(json.dumps(config))
    # Save weights
    model.save_weights(path / f"{preset}_int8.weights.h5")


def load(config_path: pathlib.Path, weights_path: pathlib.Path, preset: str):
    # Load by config file
    with open(config_path, "r") as f:
        config = json.loads(f.read())
    model: keras_nlp.models.GemmaCausalLM = (
        keras.saving.deserialize_keras_object(config)
    )
    # Load weights
    model.load_weights(weights_path)
    # Load preset assets
    model.preprocessor.tokenizer.load_preset_assets(preset)
    return model


if __name__ == "__main__":
    keras.config.set_dtype_policy("bfloat16")
    x = keras.ops.ones([1]) * keras.ops.ones([1])  # Trigger TF dummy logs

    args = get_args()
    path = pathlib.Path(args.path)
    print(f"Peak memory usage (init): {get_memory_usage():.3f} GB")

    # Save
    if args.save:
        model = keras_nlp.models.GemmaCausalLM.from_preset(args.model)
        model.summary()
        print(
            "Peak memory usage (loaded float model): "
            f"{get_memory_usage():.3f} GB"
        )
        save_int8_model(args.model, model, path)
    # Load
    else:
        config_path = path / f"{args.model}_int8.json"
        weights_path = path / f"{args.model}_int8.weights.h5"
        model = load(config_path, weights_path, args.model)
        print(
            "Peak memory usage (loaded int8 model): "
            f"{get_memory_usage():.3f} GB"
        )

        print(model.generate("What is Keras?", max_length=128))

@github-actions github-actions bot added the Gemma Gemma model specific issues label May 2, 2024
@mattdangerw mattdangerw self-requested a review May 2, 2024 03:08
@james77777778
Copy link
Collaborator Author

I've heard from @awsaf49 that this PR worked nicely to significantly reduce the memory footprint for running Gemma 7B on a single 16GB GPU.
If we distribute the model across 2 GPUs, we can further increase the max_length in model.generate to 1024.

Ref:
https://www.kaggle.com/code/awsaf49/gemma-1-1-7b-int8-load

@mattdangerw
Copy link
Member

@james77777778 thanks very much for this! Sorry for the delay reviewing, started poking around last week and should be done later today.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK done finally! Thanks for starting this off, this is a super important feature for the library I think.

I have some concrete suggestions for the reversible_embedding.py diff, but for the modeling changes I'm still not sure exactly what we want. Identified some problems below that I'm not sure how exactly to solve. I'll keep thinking here but please jump in with thoughts too!

@@ -31,11 +31,15 @@ def __init__(
dropout=0,
**kwargs,
):
dtype_policies = kwargs.pop("dtype_policies", None) or {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does dtypes_policies work? Can you link relevant code or an explainer?

@@ -129,6 +132,7 @@ def __init__(
num_key_value_heads=num_key_value_heads,
dropout=dropout,
dtype=dtype,
dtype_policies=dtype_policies.get(f"decoder_block_{i}"),
Copy link
Member

@mattdangerw mattdangerw May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how come we don't need a default here in the get call?

@@ -186,6 +191,17 @@ def get_config(self):
"dropout": self.dropout,
}
)
# Ensure the serialziation of the dtype polices
dtype_policies = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this will work as is for new uploads of models to Kaggle and elsewhere...

Currently, we use keras.saving.serialize_keras_object(backbone) to save the config we upload to Kaggle. We want a few things for that config option...

  1. For it to be fairly minimal, so avoiding a nested dict structure always getting stuck in there might be good.
  2. To allow casting of float dtypes automatically...
# Set the float dtype of the model, cast weights if necessary on load.
keras.config.set_floatx("bfloat16")
backbone = keras_nlp.models.GemmaBackbone.from_preset("...")
# Set the float dtype of the model, cast weights if necessary on load.
keras.config.set_floatx("float32")
backbone = keras_nlp.models.GemmaBackbone.from_preset("...")

I'm not sure how exactly this should work for quantized saves. For a quantized checkpoint, we probably don't want to allow any automatic casting of quantized weights. It would't really make sense.

This code as is would eventually break our Kaggle uploads because new config.json files would always include a complete dtype spec in the form of a nested dtype policy object. Meaning there would be no concise way to load a float32 checkpoint at bfloat16 or vise versa (which is something we are currently relying on quite a bit for guides etc.).

Not sure exactly what we want here though. Let's think it through.

}
)
# Ensure the serialziation of the dtype polices
dtype_policies = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we land this for one model like gemma, but for all KerasNLP models. We want our model specific code to be fairly concise when possible, and this would add a lot of extra code to add to every model we have.

Let's think if there's a way we can do this where most of the logic is consolidated to base classes like backbone.py.

@james77777778
Copy link
Collaborator Author

Hi @mattdangerw
Thank you for reviewing.
Using config.keras_3 to branch the logic in ReversibleEmbedding is definitely feasible. I will fix it soon.

How does dtypes_policies work? Can you link relevant code or an explainer?

dtypes_policies will be a dict[str, dtype_policy]. It is required because, in quantization, we have int8 EinsumDense mixed with other floating layers in CachedGemmaAttention. We need to enable these layers to save and load the correct configuration of each dtype policy.

I don't think this will work as is for new uploads of models to Kaggle and elsewhere...
2. To allow casting of float dtypes automatically...

I haven't thought of it before. Thanks for pointing it out.

To prevent the complete dtype spec from being uploaded, we can add the logic to detect whether it is a quantized model. We'll include the dtype spec if it's quantized and skip it otherwise. What do you think?

Let's think if there's a way we can do this where most of the logic is consolidated to base classes like backbone.py.

We could incorporate the detection I mentioned into backbone.py.
However, it is necessary to add specific logic to each model (and even each layer) to support a more fine-grained dtype policy control.

I had some discussions with @fchollet here:
keras-team/keras#19381

The main challenge with subclasses is identifying which layers (and its sublayers) are quantized and enabling the serialization/deserialization for them.
We can't apply the setter trick I mentioned in that discussion for Gemma because it's already built when in from_config.

@mattdangerw
Copy link
Member

To prevent the complete dtype spec from being uploaded, we can add the logic to detect whether it is a quantized model. We'll include the dtype spec if it's quantized and skip it otherwise. What do you think?

Yeah I think this is the right call? We also want to allow users to save their own quantized versions, upload them and share them where others could get the quantized config.

It feels slightly awkward/implicit, but I think it's the right way to preserve our current usages, but still leave room for what we want for quantization.

Long winded way to say sounds good to me, let's try it :)

Still want to think about how we can add this support to every model in the library with out as much of a code diff, but haven't had time to think on that question yet.

@james77777778
Copy link
Collaborator Author

Long winded way to say sounds good to me, let's try it :)

Great, now that we've reached a consensus. These changes shouldn't take long.

Still want to think about how we can add this support to every model in the library with out as much of a code diff, but haven't had time to think on that question yet.

I don't have a better idea at the moment and I think that a refactor for dtype policy control in Core Keras might be necessary, especially for the subclasses.
However, considering the backward compatibility, this could be a challenging task.

@james77777778
Copy link
Collaborator Author

I will continue this PR once we have a new release of Keras (>3.3.3).

There is an updated in DTypePolicy that makes it possible to have a more flexible quantized dtype policy.
EX: QuantizedDTypePolicy("int8", source_name=None) will interpret the source dtype policy using keras.dtype_policies.dtype_policy()

@martin-gorner
Copy link
Contributor

"a refactor for dtype policy control in Core Keras might be necessary, especially for the subclasses.
However, considering the backward compatibility, this could be a challenging task."

My 2c: now is the time to do it as KerasNLP is still a fairly new library. As usage picks up, things will only get harder.

@james77777778
Copy link
Collaborator Author

My 2c: now is the time to do it as KerasNLP is still a fairly new library. As usage picks up, things will only get harder.

Actually, I don't have a concrete idea for the refactoring at the moment.

I can point out the issue:
We rely on a single dtype argument for subclass creation. So, without complex and verbose logic in __init__, it is hard to support the mixing of floating / quantized dtype polices.

Maybe, we can implicitly pass a dict called sublayer_policies if the layer contains sublayers to tackle this. However, currently, there is no way to automatically map between the policy and the sublayer in __init__. I don't think it's a good idea to require users to hard-code it.

@mattdangerw
Copy link
Member

mattdangerw commented May 30, 2024

I think at the very least, we would probably want to handle this at the Backbone level. E.g. automatically fill in a dict config for dtype by list all direct sublayers by name, something like that. Maybe add some common functionality to help layer construction. This is purely from a code maintainability standpoint--we want adding a new model to this repo to be relatively low friction, and we should factor out common logic where we can. This is a large addition of "diff" for each model.

I can point out the issue: We rely on a single dtype argument for subclass creation. So, without complex and verbose logic in init, it is hard to support the mixing of floating / quantized dtype polices.

A half formed thought is this was kinda similar to our distribution LayoutMap problem. Basically for distributing variables across machines, we do this...

layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = ...
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = ...
layout_map["decoder_block.*attention_output.*kernel"] = ...
distribution = keras.distribution.ModelParallel(mesh, layout_map)
# Globally.
keras.distribution.set_distribution(model_parallel)
# Or locally.
with distribution.scope():
   ...

A somewhat parallel API might be...

dtype_map = keras.dtype_policies.DtypeMap(default="bfloat16")
dtype_map["token_embedding/embeddings"] = "float32"
dtype_map["decoder_block.*attention.*(query|key|value).*kernel"] = "int8"
policy = keras.dtype_policies.DTypePolicy(dtype_map)
# Globally.
keras.dtype_policies.set_dtype_policy(policy)
# Or locally.
model = SomeModel(dtype=polilcy)

Then the init logic can stay simple...

self.sublayer = SomeLayer(..., dtype=self.dtype_policy)

Then we'd have one policy that we could pass around that gave a whole mapping of dtypes all the way down. That also gives a relatively quicker way to specify things like a dtype for all query projections, say. Would still take some figuring out during saving, etc.

@james77777778 wdyt?

@james77777778
Copy link
Collaborator Author

james77777778 commented May 31, 2024

Hi @mattdangerw

I think this idea is great and it should be easy to implement.
Please refer to this PR for more details:
keras-team/keras#19783

I have changed the value type in the mapping from "dtype" to "dtype_policy". This adjustment should make more sense because we rely on that for the behavior of the quantized layers.

In that PR, the saving and loading issues have been solved. It wasn't a difficult one :)

@james77777778
Copy link
Collaborator Author

I have submitted a new PR for this:
#1670

I'm closing this PR now

@james77777778 james77777778 deleted the int8-gemma branch August 22, 2024 08:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants