Skip to content

GemmaCausalLM fails to load if TensorFlow NumPy behavior isenabled #2136

Open
@t-kalinowski

Description

@t-kalinowski

Describe the bug
Calling GemmaCausalLM.from_preset() errors if TF NumPy type promotion behavior is enabled. This happens regardless of which Keras backend is used.

To Reproduce

Given script bug.py

# /// script
# dependencies = [
#   "keras",
#   "keras-hub",
#   "tensorflow"
# ]
# ///


import tensorflow as tf
tf.experimental.numpy.experimental_enable_numpy_behavior(dtype_conversion_mode="safe")

import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
keras.config.set_dtype_policy("float16")

import json
with open(os.path.expanduser("~/.kaggle/kaggle.json"), "r") as f:
    kaggle_credentials = json.load(f)
os.environ["KAGGLE_USERNAME"] = kaggle_credentials["username"]
os.environ["KAGGLE_KEY"] = kaggle_credentials["key"]

import keras_hub

gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma_2b_en")

Calling uv run --python 3.11 bug2.py produces:

tomasz@tomaszkalinows-WQVX deep_learning_with_r_3e % uv run --python 3.11 bug2.py
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
Traceback (most recent call last):
  File "/Users/tomasz/github/t-kalinowski/deep_learning_with_r_3e/bug2.py", line 27, in <module>
    gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma_2b_en")
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras_hub/src/models/task.py", line 198, in from_preset
    return loader.load_task(cls, load_weights, load_task_weights, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras_hub/src/utils/preset_utils.py", line 670, in load_task
    return super().load_task(
           ^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras_hub/src/utils/preset_utils.py", line 618, in load_task
    kwargs["backbone"] = self.load_backbone(
                         ^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras_hub/src/utils/preset_utils.py", line 648, in load_backbone
    backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras/src/saving/saving_lib.py", line 631, in _raise_loading_failure
    raise ValueError(msg)
ValueError: A total of 127 objects could not be loaded. Example error message for object <ReversibleEmbedding name=token_embedding, built=True>:

Layer 'token_embedding' expected 1 variables, but received 0 variables during loading. Expected: ['embeddings']

List of objects that could not be loaded:
[<ReversibleEmbedding name=token_embedding, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>]

A similar error happens if KERAS_BACKEND='jax' is configured.

Metadata

Metadata

Assignees

Labels

GemmaGemma model specific issuestype:BugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions