-
Notifications
You must be signed in to change notification settings - Fork 285
Safetensors conversion #2290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Safetensors conversion #2290
Conversation
Thanks for the PR, will take a look in a bit :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just left some initial comments.
|
||
# Load weights into Hugging Face model | ||
print("Loading weights into Hugging Face model...") | ||
hf_model.load_state_dict(weights_dict, strict=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the line we probably most need to avoid to avoid the double allocation. Can we try writing the weights out with safetensors directly with the safetensors library here and avoiding this?
That will mean we might need to save the config.json separately, outside of save_pretrained
.
Alternately we could see if we could use the torch meta device to avoid allocating memory, but I don't know if that would work when needing to actually save the model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. Can we just save the weights dict, like so? We won't have to load the model that way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, sorry, looks like I have loaded the weights three times into memory, keras_model
,weights_dict
and hf_model
|
||
# Save model | ||
hf_model.save_pretrained(path, safe_serialization=True) | ||
print(f"Model and tokenizer saved to {path}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try to make this look more like a library util (which is the eventual intent). No print statements. Just expose export_to_hf
in this file. Make a separate test files that does what is happening below, in a unit test annotated with pytest.mark.large, that converts and compares outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a unit test that calls this util and tries loading the result with transformers and seeing if it works. OK to add transformers to our ci environment here https://github.com/keras-team/keras-hub/blob/master/requirements-common.txt
from safetensors.torch import save_file | ||
|
||
# Set the Keras backend to jax | ||
os.environ["KERAS_BACKEND"] = "jax" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's not do this, this is something we are going to export as part of the library. we actually need this to be able to run on all backends
import os | ||
|
||
import torch | ||
from safetensors.torch import save_file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work on all backends? or do we need to flip between versions depending on the backend? worth testing out
Description of the change
Reference
Colab Notebook
https://colab.research.google.com/drive/1naqf0sO2J40skndWbVMeQismjL7MuEjd?usp=sharing&authuser=4#scrollTo=sT1sxZZW-_egChecklist