Skip to content

Safetensors conversion #2290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

Bond099
Copy link

@Bond099 Bond099 commented Jun 6, 2025

Description of the change

Reference

Colab Notebook

https://colab.research.google.com/drive/1naqf0sO2J40skndWbVMeQismjL7MuEjd?usp=sharing&authuser=4#scrollTo=sT1sxZZW-_eg

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and works with all backends (TensorFlow, JAX, and PyTorch).
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have followed the Keras Hub Model contribution guidelines in making these changes.
  • I have followed the Keras Hub API design guidelines in making these changes.
  • I have signed the Contributor License Agreement.

@abheesht17
Copy link
Collaborator

Thanks for the PR, will take a look in a bit :)

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just left some initial comments.


# Load weights into Hugging Face model
print("Loading weights into Hugging Face model...")
hf_model.load_state_dict(weights_dict, strict=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the line we probably most need to avoid to avoid the double allocation. Can we try writing the weights out with safetensors directly with the safetensors library here and avoiding this?

That will mean we might need to save the config.json separately, outside of save_pretrained.

Alternately we could see if we could use the torch meta device to avoid allocating memory, but I don't know if that would work when needing to actually save the model.

Copy link
Collaborator

@abheesht17 abheesht17 Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. Can we just save the weights dict, like so? We won't have to load the model that way.

https://github.com/google/tunix/blob/17999b0b653cc61b22ab486fe952c76620ca5ebf/examples/grpo_demo.ipynb?short_path=c135e8d#L1016-L1034.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, sorry, looks like I have loaded the weights three times into memory, keras_model ,weights_dict and hf_model


# Save model
hf_model.save_pretrained(path, safe_serialization=True)
print(f"Model and tokenizer saved to {path}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to make this look more like a library util (which is the eventual intent). No print statements. Just expose export_to_hf in this file. Make a separate test files that does what is happening below, in a unit test annotated with pytest.mark.large, that converts and compares outputs.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a unit test that calls this util and tries loading the result with transformers and seeing if it works. OK to add transformers to our ci environment here https://github.com/keras-team/keras-hub/blob/master/requirements-common.txt

from safetensors.torch import save_file

# Set the Keras backend to jax
os.environ["KERAS_BACKEND"] = "jax"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not do this, this is something we are going to export as part of the library. we actually need this to be able to run on all backends

import os

import torch
from safetensors.torch import save_file
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work on all backends? or do we need to flip between versions depending on the backend? worth testing out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants