-
Notifications
You must be signed in to change notification settings - Fork 424
Add support for saving HF format tensors with DCP #1351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@Saiteja64 This will conflict with your PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall the logic LGTM, please address comments and ensure that this PR doesn't conflict with the PR from @Saiteja64. Please also add a test result -- save a hf checkpoint and load one back and check the accuracy.
torchtitan/components/checkpoint.py
Outdated
if hf_safetensors_format: | ||
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True) | ||
if is_async: | ||
return dcp.async_save( | ||
state_dict, storage_writer=storage_writer, process_group=pg | ||
) | ||
else: | ||
return dcp.save(state_dict, storage_writer=storage_writer) | ||
else: | ||
if is_async: | ||
return dcp.async_save( | ||
state_dict, checkpoint_id=checkpoint_id, process_group=pg | ||
) | ||
else: | ||
return dcp.save(state_dict, checkpoint_id=checkpoint_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should simplify the function as follow
if hf_safetensors_format: | |
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True) | |
if is_async: | |
return dcp.async_save( | |
state_dict, storage_writer=storage_writer, process_group=pg | |
) | |
else: | |
return dcp.save(state_dict, storage_writer=storage_writer) | |
else: | |
if is_async: | |
return dcp.async_save( | |
state_dict, checkpoint_id=checkpoint_id, process_group=pg | |
) | |
else: | |
return dcp.save(state_dict, checkpoint_id=checkpoint_id) | |
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True) if hf_safetensors_format else None | |
checkpoint_id = checkpoint_id if not hf_safetensors_format else None | |
if is_async: | |
return dcp.async_save( | |
state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_id, process_group=pg | |
) | |
else: | |
return dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_id) |
torchtitan/config_manager.py
Outdated
enable_hf_safetensors_format: bool = False | ||
""" | ||
Enable the use of safetensors format for checkpointing. This will save checkpoints | ||
in safetensors format instead of the default DCP format. The default value is False. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also mention the possible performance penalty? It's not cost free, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM! Thanks for working on this! So from the logging, save a llama3 8B model checkpoints as HF format takes ~200s, and load a HF checkpoint needs ~30s, is this correct?
in safetensors format instead of the default DCP format. There will be a performance | ||
cost in using this as we need to consolidate the sharded tensors to full tensors as | ||
a separate step. Last_save_model_weights must be true because safetensors doesn't | ||
support saving non tensors. On load, this argument isn't needed as we will detect |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the code, it also supported loading from a HF checkpoint. Should we change the parameter name , say "enable_safetensors_format" to avoid confusing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this argument is only for saving. On load, it will load whatever the latest checkpoint is, regardless of the type, so no extra argument is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, LGTM, please fix the remaining comments.
if checkpoint_type == CheckpointType.SAFETENSORS: | ||
model_only = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should assert if model_only
is not True, rather than silently change model_only
value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how should I change the logic to allow for this? right now if os.path.exists(self.folder), then there is no way for model_only to be True other than when step == 0. self.initial_load_model_weights_only isn't used in this code path either. It's also already silently being changed in line 516 which is why I did it this way, but happy to change in the way you think is best
self.dcp_load( | ||
self.ft_states, | ||
checkpoint_id=checkpoint_id, | ||
checkpoint_type=CheckpointType.DCP, # FT checkpoints are always DCP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add one more line, "because FT checkpoint currently only save/load dataloader.".
save is actually faster than that. This run was ~140 seconds, but it was before I added the num_threads argument, so it should be faster now. |
If checkpoint.enable_save_safetensors_format is set, then save the checkpoint with DCP HF components that will save the checkpoint in .safetensors files instead of regular DCP format on final save. On load, we can decide which type of load to do based on checkpoint type.
Successful save:
Successful load: