Skip to content

updated ReversibleEmbedding call method to handle proper conversion to tensors. #2295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

pctablet505
Copy link
Collaborator

@pctablet505 pctablet505 commented Jun 10, 2025

What does this PR do?

Ensures that the call method of ReversibleEmbedding always converts its input to a TensorFlow tensor. This change improves compatibility when TensorFlow's NumPy behavior is enabled (tf.experimental.numpy.experimental_enable_numpy_behavior), which can cause type inconsistencies if inputs are not explicitly converted.

Why is this needed?

Previously, when NumPy behavior was enabled, the input to ReversibleEmbedding could be a numpy array rather than a tensor, leading to errors during model inference or weight loading. This fix resolves failures such as those observed in keras-hub#2136 and ensures robust operation regardless of backend configuration.

Related Issues/PRs

How was this tested?

  • Confirmed that GemmaCausalLM and similar models now load and run correctly with and without NumPy behavior enabled.

# Ensure embeddings is properly converted to a tensor
embeddings_tensor = self.embeddings
# If it's a Keras variable, get its value
if hasattr(embeddings_tensor, "value"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so what's the error here? embeddings_tensor has no value when tf numpy is enabled? what's the error?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

GemmaCausalLM fails to load if TensorFlow NumPy behavior isenabled
3 participants