Skip to content

Add support for saving HF format tensors with DCP #1351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open

Conversation

ankitageorge
Copy link

@ankitageorge ankitageorge commented Jun 27, 2025

If checkpoint.enable_save_safetensors_format is set, then save the checkpoint with DCP HF components that will save the checkpoint in .safetensors files instead of regular DCP format on final save. On load, we can decide which type of load to do based on checkpoint type.

Successful save:

(titan) [[email protected] /data/users/ankitageorge/torchtitan (dcp-hf)]$ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
+ NGPU=8
+ export LOG_RANK=0,1,2
+ LOG_RANK=0,1,2
+ CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml
+ overrides=
+ '[' 0 -ne 0 ']'
+ TORCHFT_LIGHTHOUSE=http://localhost:29510/
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+ TORCHFT_LIGHTHOUSE=http://localhost:29510/
+ torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0,1,2 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml
W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] 
W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] *****************************************
W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] *****************************************
[rank0]:[titan] 2025-07-10 19:20:49,848 - root - INFO - Starting job: Llama 3 8B training
[rank1]:[titan] 2025-07-10 19:20:49,985 - root - INFO - Starting job: Llama 3 8B training
[rank2]:[titan] 2025-07-10 19:20:51,188 - root - INFO - Starting job: Llama 3 8B training
[rank0]:[titan] 2025-07-10 19:20:52,644 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:[titan] 2025-07-10 19:20:52,646 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank0]:[titan] 2025-07-10 19:20:52,650 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:NCCL version 2.27.5+cuda12.9
[rank1]:[titan] 2025-07-10 19:20:52,976 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank1]:[titan] 2025-07-10 19:20:52,979 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank1]:[titan] 2025-07-10 19:20:52,984 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank2]:[titan] 2025-07-10 19:20:53,902 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank2]:[titan] 2025-07-10 19:20:53,905 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank2]:[titan] 2025-07-10 19:20:53,910 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:[titan] 2025-07-10 19:20:56,568 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank0]:[titan] 2025-07-10 19:20:56,568 - root - INFO - Preparing c4 dataset from allenai/c4
[rank2]:[titan] 2025-07-10 19:20:56,593 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank2]:[titan] 2025-07-10 19:20:56,593 - root - INFO - Preparing c4 dataset from allenai/c4
[rank1]:[titan] 2025-07-10 19:20:56,616 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank1]:[titan] 2025-07-10 19:20:56,616 - root - INFO - Preparing c4 dataset from allenai/c4
[rank2]:[titan] 2025-07-10 19:21:02,550 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001)
[rank2]:[titan] 2025-07-10 19:21:02,944 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory
[rank2]:[titan] 2025-07-10 19:21:02,968 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank2]:[titan] 2025-07-10 19:21:02,969 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank2]:[titan] 2025-07-10 19:21:02,970 - root - INFO - Applied selective activation checkpointing to the model
[rank1]:[titan] 2025-07-10 19:21:03,101 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001)
[rank0]:[titan] 2025-07-10 19:21:03,142 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001)
[rank2]:[titan] 2025-07-10 19:21:03,123 - root - INFO - Applied FSDP to the model
[rank1]:[titan] 2025-07-10 19:21:03,491 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory
[rank1]:[titan] 2025-07-10 19:21:03,515 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank1]:[titan] 2025-07-10 19:21:03,516 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank1]:[titan] 2025-07-10 19:21:03,517 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:[titan] 2025-07-10 19:21:03,550 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250710-1921
[rank0]:[titan] 2025-07-10 19:21:03,551 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory
[rank0]:[titan] 2025-07-10 19:21:03,574 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank0]:[titan] 2025-07-10 19:21:03,575 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:[titan] 2025-07-10 19:21:03,576 - root - INFO - Applied selective activation checkpointing to the model
[rank1]:[titan] 2025-07-10 19:21:03,675 - root - INFO - Applied FSDP to the model
[rank0]:[titan] 2025-07-10 19:21:03,732 - root - INFO - Applied FSDP to the model
[rank2]:[titan] 2025-07-10 19:21:03,813 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank2]:[titan] 2025-07-10 19:21:03,813 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14
[rank2]:[titan] 2025-07-10 19:21:03,814 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%)
[rank2]:[titan] 2025-07-10 19:21:03,817 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2.
[rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Mixed precision training is handled by fully_shard
[rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200).
[rank2]:[titan] 2025-07-10 19:21:03,877 - root - INFO - Training starts at step 1.
[rank2]:[titan] 2025-07-10 19:21:03,877 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank1]:[titan] 2025-07-10 19:21:04,369 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank1]:[titan] 2025-07-10 19:21:04,370 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14
[rank1]:[titan] 2025-07-10 19:21:04,370 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%)
[rank1]:[titan] 2025-07-10 19:21:04,373 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2.
[rank0]:[titan] 2025-07-10 19:21:04,335 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank0]:[titan] 2025-07-10 19:21:04,336 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14
[rank0]:[titan] 2025-07-10 19:21:04,336 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%)
[rank0]:[titan] 2025-07-10 19:21:04,340 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2.
[rank1]:[titan] 2025-07-10 19:21:04,430 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank1]:[titan] 2025-07-10 19:21:04,430 - root - INFO - Mixed precision training is handled by fully_shard
[rank0]:[titan] 2025-07-10 19:21:04,415 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank0]:[titan] 2025-07-10 19:21:04,415 - root - INFO - Mixed precision training is handled by fully_shard
[rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200).
[rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Training starts at step 1.
[rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200).
[rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:[titan] 2025-07-10 19:21:11,407 - root - INFO - step:  1  loss: 12.2520  grad_norm:  4.0543  memory: 42.12GiB(53.23%)  tps: 1,046  tflops: 60.58  mfu: 19.42%
[rank0]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Calling checkpoint save after step 1
[rank0]:[titan] 2025-07-10 19:21:11,408 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank2]:[titan] 2025-07-10 19:21:11,406 - root - INFO - step:  1  loss: 12.2520  grad_norm:  4.0543  memory: 42.12GiB(53.23%)  tps: 971  tflops: 56.23  mfu: 18.02%
[rank2]:[titan] 2025-07-10 19:21:11,406 - root - INFO - Calling checkpoint save after step 1
[rank2]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank1]:[titan] 2025-07-10 19:21:11,406 - root - INFO - step:  1  loss: 12.2520  grad_norm:  4.0543  memory: 42.12GiB(53.23%)  tps: 1,038  tflops: 60.13  mfu: 19.27%
[rank1]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Calling checkpoint save after step 1
[rank1]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank2]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Calling checkpoint save after step 2
[rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2.
[rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Num keys before parsing 291, after 291
[rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096])
[rank0]:[titan] 2025-07-10 19:21:14,015 - root - INFO - Calling checkpoint save after step 2
[rank0]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank0]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2.
[rank0]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Num keys before parsing 291, after 291
[rank0]:[titan] 2025-07-10 19:21:14,017 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096])
[rank1]:[titan] 2025-07-10 19:21:14,023 - root - INFO - Calling checkpoint save after step 2
[rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2.
[rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Num keys before parsing 291, after 291
[rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096])
[rank0]:Done writing metadata. Took %.2f secs. 0.026559114456176758
[rank0]:Done writing data. Took %.2f secs. 66.62590146064758
[rank0]:Done consolidating. Took %.2f secs. 66.62735033035278
[rank0]:time taken for all reduce:  141.72666668891907
[rank1]:time taken for all reduce:  141.73284125328064
[rank2]:time taken for all reduce:  141.72900009155273
[rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - [GC] GC collection invoked by checkpointer. 0.03 seconds.
[rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - Training completed
[rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - Destroying the purge thread.
[rank0]:[titan] 2025-07-10 19:23:36,827 - root - INFO - [GC] GC collection invoked by checkpointer. 0.02 seconds.
[rank0]:[titan] 2025-07-10 19:23:36,828 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - [GC] GC collection invoked by checkpointer. 0.03 seconds.
[rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - Training completed
[rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - Destroying the purge thread.
[rank2]:[titan] 2025-07-10 19:23:37,243 - root - INFO - Process group destroyed.
[rank0]:[titan] 2025-07-10 19:23:38,828 - root - INFO - Training completed
[rank0]:[titan] 2025-07-10 19:23:38,829 - root - INFO - Destroying the purge thread.
[rank1]:[titan] 2025-07-10 19:23:39,503 - root - INFO - Process group destroyed.
[rank0]:[titan] 2025-07-10 19:23:39,705 - root - INFO - Process group destroyed.

Successful load:

(titan) [[email protected] /data/users/ankitageorge/torchtitan (dcp-hf)]$ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
+ NGPU=8
+ export LOG_RANK=0
+ LOG_RANK=0
+ CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml
+ overrides=
+ '[' 0 -ne 0 ']'
+ TORCHFT_LIGHTHOUSE=http://localhost:29510/
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+ TORCHFT_LIGHTHOUSE=http://localhost:29510/
+ torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml
W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] 
W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] *****************************************
W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] *****************************************
[rank0]:[titan] 2025-07-10 20:56:24,765 - root - INFO - Starting job: Llama 3 8B training
[rank0]:NCCL version 2.27.5+cuda12.9
[rank0]:[titan] 2025-07-10 20:56:27,746 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:[titan] 2025-07-10 20:56:27,748 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank0]:[titan] 2025-07-10 20:56:27,753 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:[titan] 2025-07-10 20:56:30,608 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank0]:[titan] 2025-07-10 20:56:30,608 - root - INFO - Preparing c4 dataset from allenai/c4
[rank0]:[titan] 2025-07-10 20:56:36,070 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001)
[rank0]:[titan] 2025-07-10 20:56:36,430 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250710-2056
[rank0]:[titan] 2025-07-10 20:56:36,431 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory
[rank0]:[titan] 2025-07-10 20:56:36,452 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank0]:[titan] 2025-07-10 20:56:36,454 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:[titan] 2025-07-10 20:56:36,455 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:[titan] 2025-07-10 20:56:36,598 - root - INFO - Applied FSDP to the model
[rank0]:[titan] 2025-07-10 20:56:37,138 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank0]:[titan] 2025-07-10 20:56:37,138 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14
[rank0]:[titan] 2025-07-10 20:56:37,138 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%)
[rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Mixed precision training is handled by fully_shard
[rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 1000 (warmup 200).
[rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Loading the checkpoint from ./outputs/checkpoint/step-3.
[rank0]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/distributed/checkpoint/hf_storage.py:259: UserWarning: The given buffer is not writable, and PyTorch does not support non-writable tensors. This means you can write to the underlying (supposedly non-writable) buffer using the tensor. You may want to copy the buffer to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_new.cpp:1579.)
[rank0]:  tensor = torch.frombuffer(
[rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - [GC] GC collection for checkpoint loading. 0.01 seconds.
[rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Finished loading the checkpoint in 27.21 seconds.
[rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:[titan] 2025-07-10 20:57:11,168 - root - INFO - step:  1  loss: 12.0247  grad_norm: 42.7524  memory: 42.12GiB(53.23%)  tps: 236  tflops: 13.67  mfu: 4.38%
[rank0]:[titan] 2025-07-10 20:57:11,168 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 27, 2025
@ankitageorge ankitageorge changed the title Dcp hf Add support for saving HF format tensors with DCP Jun 27, 2025
@fegin
Copy link
Contributor

fegin commented Jun 27, 2025

@Saiteja64 This will conflict with your PR.

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall the logic LGTM, please address comments and ensure that this PR doesn't conflict with the PR from @Saiteja64. Please also add a test result -- save a hf checkpoint and load one back and check the accuracy.

Comment on lines 116 to 130
if hf_safetensors_format:
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True)
if is_async:
return dcp.async_save(
state_dict, storage_writer=storage_writer, process_group=pg
)
else:
return dcp.save(state_dict, storage_writer=storage_writer)
else:
if is_async:
return dcp.async_save(
state_dict, checkpoint_id=checkpoint_id, process_group=pg
)
else:
return dcp.save(state_dict, checkpoint_id=checkpoint_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should simplify the function as follow

Suggested change
if hf_safetensors_format:
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True)
if is_async:
return dcp.async_save(
state_dict, storage_writer=storage_writer, process_group=pg
)
else:
return dcp.save(state_dict, storage_writer=storage_writer)
else:
if is_async:
return dcp.async_save(
state_dict, checkpoint_id=checkpoint_id, process_group=pg
)
else:
return dcp.save(state_dict, checkpoint_id=checkpoint_id)
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True) if hf_safetensors_format else None
checkpoint_id = checkpoint_id if not hf_safetensors_format else None
if is_async:
return dcp.async_save(
state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_id, process_group=pg
)
else:
return dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_id)

enable_hf_safetensors_format: bool = False
"""
Enable the use of safetensors format for checkpointing. This will save checkpoints
in safetensors format instead of the default DCP format. The default value is False.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also mention the possible performance penalty? It's not cost free, right?

@ankitageorge ankitageorge marked this pull request as ready for review July 11, 2025 12:41
Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM! Thanks for working on this! So from the logging, save a llama3 8B model checkpoints as HF format takes ~200s, and load a HF checkpoint needs ~30s, is this correct?

in safetensors format instead of the default DCP format. There will be a performance
cost in using this as we need to consolidate the sharded tensors to full tensors as
a separate step. Last_save_model_weights must be true because safetensors doesn't
support saving non tensors. On load, this argument isn't needed as we will detect
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the code, it also supported loading from a HF checkpoint. Should we change the parameter name , say "enable_safetensors_format" to avoid confusing?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this argument is only for saving. On load, it will load whatever the latest checkpoint is, regardless of the type, so no extra argument is needed.

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, LGTM, please fix the remaining comments.

Comment on lines +525 to +526
if checkpoint_type == CheckpointType.SAFETENSORS:
model_only = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should assert if model_only is not True, rather than silently change model_only value.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how should I change the logic to allow for this? right now if os.path.exists(self.folder), then there is no way for model_only to be True other than when step == 0. self.initial_load_model_weights_only isn't used in this code path either. It's also already silently being changed in line 516 which is why I did it this way, but happy to change in the way you think is best

self.dcp_load(
self.ft_states,
checkpoint_id=checkpoint_id,
checkpoint_type=CheckpointType.DCP, # FT checkpoints are always DCP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add one more line, "because FT checkpoint currently only save/load dataloader.".

@ankitageorge
Copy link
Author

Overall LGTM! Thanks for working on this! So from the logging, save a llama3 8B model checkpoints as HF format takes ~200s, and load a HF checkpoint needs ~30s, is this correct?

save is actually faster than that. This run was ~140 seconds, but it was before I added the num_threads argument, so it should be faster now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants