Skip to content

Implement initial_load_path for checkpointer #1236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ interval = 500


2. SAVE ONLY MODEL WEIGHTS
By setting `model_weights_only` to `True`, the checkpoint will only contain the model weights and exclude the optimizer state and extra train states, resulting in a smaller checkpoint size.
By setting `last_save_model_weights_only` to `True`, the checkpoint will only contain the model weights and exclude the optimizer state and extra train states, resulting in a smaller checkpoint size.
```
[checkpoint]
enable_checkpoint = true
model_weights_only = true
last_save_model_weights_only = true
```

3. CHOOSE DESIRED EXPORT PRECISION
The default model states are in `float32`. You can choose to export the checkpoint in a lower precision format such as `bfloat16`.
```
[checkpoint]
enable_checkpoint = true
model_weights_only = true
last_save_model_weights_only = true
export_dtype = "bfloat16"
```

Expand All @@ -48,7 +48,7 @@ enable_checkpoint = true
folder = "checkpoint"
interval = 10
load_step = 5
model_weights_only = true
last_save_model_weights_only = true
export_dtype = "bfloat16"
```

Expand Down
8 changes: 4 additions & 4 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,22 +122,22 @@ def build_test_list():
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.model_weights_only",
"--checkpoint.last_save_model_weights_only",
],
],
"Checkpoint Integration Test - Save Model Weights Only fp32",
"model_weights_only_fp32",
"last_save_model_weights_only_fp32",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.model_weights_only",
"--checkpoint.last_save_model_weights_only",
"--checkpoint.export_dtype bfloat16",
],
],
"Checkpoint Integration Test - Save Model Weights Only bf16",
"model_weights_only_bf16",
"last_save_model_weights_only_bf16",
),
OverrideDefinitions(
[
Expand Down
8 changes: 7 additions & 1 deletion tests/unit_tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,22 @@ def fake_get_model_state_dict(model, *args, **kwargs):
return model.state_dict()


# TODO: The unittest is not well structured and does not cover enough paths.
# It should be refactored.


@dataclass
class DummyCheckpointConfig:
enable_checkpoint: bool = True
folder: str = "dummy_folder"
interval: int = 10
async_mode: str = "disabled"
keep_latest_k: int = 0
model_weights_only: bool = False
last_save_model_weights_only: bool = False
export_dtype: str = "float32"
exclude_from_loading = []
initial_load_model_weights_only: bool = False
initial_load_path: str = ""


@dataclass
Expand Down
79 changes: 56 additions & 23 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ class SaveDone:
pass


# For now, we will manually pop the freqs_cis buffer, as we made this permanent
# temporarily and we don't want to include it in the exported state_dict.
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
excluded_parameters_for_model_only = {"freqs_cis"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't save freqs_cis, what will happen during loading such a checkpoint, when the model marks it as persistent buffer?

In https://github.com/pytorch/torchtitan/blob/main/scripts/convert_llama_to_dcp.py#L127
it explicitly creates this field when converting the Meta original checkpoint. I thought it was because without this field the checkpoint won't be loadable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we pop out the freqs_cis in TorchTitan checkpointer when loading weight_only checkpoints, the loading will succeed regardless whether freqs_cis is saved or not. We can change the conversion script after this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

synced offline and understood the reason:

  • DCP loading requires the checkpoint files to have all the fields in the state dict -- it can have more but not less.
  • After this change, we need to make sure we init the buffer properly in the train script even if we are going to load a checkpoint. Previous there was a proposal that if loading from checkpoint then don't call init_weights. We can still separate by having another init_buffer function.
  • There are two issues on why freqs_cis is made persistent: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
    The second is less critical; we should fix the first compiling issue.



@torch.no_grad()
def save_with_gc(state, checkpoint_id):
dcp.save(state, checkpoint_id=checkpoint_id)
Expand Down Expand Up @@ -267,6 +273,10 @@ def load_state_dict(state_dict):
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None

self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.initial_load_path = ckpt_config.initial_load_path
self.initial_load_model_weights_only = (
ckpt_config.initial_load_model_weights_only
)
self.interval = ckpt_config.interval
async_mode = ckpt_config.async_mode.lower()
if async_mode == AsyncMode.ASYNC or self.ft_manager:
Expand All @@ -287,7 +297,7 @@ def load_state_dict(state_dict):
else:
self.purge_thread = None

self.model_weights_only = ckpt_config.model_weights_only
self.last_save_model_weights_only = ckpt_config.last_save_model_weights_only
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
self.exclude_from_loading = ckpt_config.exclude_from_loading

Expand Down Expand Up @@ -408,21 +418,38 @@ def load(self, step: int = -1) -> bool:
if self.ft_manager:
self._ft_load()

if not self.enable_checkpoint or not os.path.isdir(self.folder):
if not self.enable_checkpoint:
return False

if step == -1:
step = self._find_load_step()
model_only = False
if not os.path.exists(self.folder):
if self.initial_load_path:
checkpoint_id = self.initial_load_path
if not os.path.isdir(checkpoint_id):
raise ValueError(
"initial_load_full_checkpoint is specified but the path is not valid."
)
model_only = self.initial_load_model_weights_only
else:
return False
else:
if self.initial_load_path:
logger.info(
"`initial_load_path` is provided but the checkpoint folder exists. "
"Checkpointer will use the checkpoints from the checkpoint folder."
)
step = self._find_load_step() if step == -1 else step
if step == -1:
return False
model_only = step == 0
checkpoint_id = self._create_checkpoint_id(step)

checkpoint_id = self._create_checkpoint_id(step)
if not os.path.isdir(checkpoint_id):
return False
if not os.path.isdir(checkpoint_id):
return False

logger.info(f"Loading the checkpoint at step {step}.")
logger.info(f"Loading the checkpoint from {checkpoint_id}.")
begin = time.monotonic()
states = self._states_to_load(step)
states = self._states_to_load(model_only)
dcp.load(states, checkpoint_id=checkpoint_id)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
Expand Down Expand Up @@ -521,28 +548,36 @@ def _ft_load(self) -> None:
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
)

def _states_to_load(self, step: int) -> dict[str, Any]:
def _states_to_load(self, model_only: bool) -> dict[str, Any]:
"""Determines which states to load for the given step.

When checkpointer determines which step of the checkpoint to load, this API is
used to determine which states to load based on the step.
This API is used to determine which states to load based on the
configurations.

Args:
step (int): The step to load the checkpoint for.
model_only (bool): Whether to load the model only.

Returns:
Dict[str, Any]: The states to load for the given step.
"""
# For the first step, we will only load the model weights.
states = {MODEL: self.states[MODEL]} if step == 0 else self.states
states_to_load = {
k: v for k, v in states.items() if k not in self.exclude_from_loading
}
if model_only:
sd = self.states[MODEL].state_dict()
for k in excluded_parameters_for_model_only:
sd.pop(k, None)
return sd

for exclude_key in self.exclude_from_loading:
if exclude_key not in states:
if exclude_key not in self.states:
raise ValueError(f"{exclude_key} not found in state_dict.")

states_to_load = {
k: v for k, v in self.states.items() if k not in self.exclude_from_loading
}

if self.ft_manager:
states_to_load.pop(DATALOADER)

return states_to_load

def _save_last_step(self, curr_step: int) -> None:
Expand All @@ -551,18 +586,16 @@ def _save_last_step(self, curr_step: int) -> None:
# dtype conversion when we are checkpoint model weights only and the
# current dtype is not the same as the export dtype at the end of the training.

if self.model_weights_only:
if self.last_save_model_weights_only:
# We update self.states to keep the model only.
# After this update, self.states = {
# 'tok_embeddings.weight':...,
# 'layers.0.attention.wq.weight': ...
# }.
self.states = self.states[MODEL].state_dict()

# For now, we will manually pop the freqs_cis buffer, as we made this permanent
# temporarily and we don't want to include it in the exported state_dict.
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
self.states.pop("freqs_cis", None)
for k in excluded_parameters_for_model_only:
self.states.pop(k, None)

if self.export_dtype != torch.float32:
self.states = {
Expand Down
36 changes: 31 additions & 5 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,21 +373,47 @@ class Checkpoint:
When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
"""

initial_load_path: str | None = None
"""
This option specifies the path to the initial checkpoint to load, which is
particularly useful for resuming training from a previous run with a
different output path or when loading a checkpoint from a pre-trained model.
If the checkpoint folder for the current run is not empty,
located at {--job.dump_folder}/{--checkpoint.folder}, this option will be ignored.
Comment on lines +381 to +382
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer a UI with the opposite load priority:
load_path specifies the path to load checkpoint. If None, then load from checkpoint.folder.

Pros: Users have explicit control over which folder to load checkpoint, without checking if the {--job.dump_folder}/{--checkpoint.folder} folder is empty or not.

The Pros of current PR:
If users are going to repeated run a job, the first time loading from initial checkpoint, but all following runs loading from newly saved checkpoint, then they don't need to change configs from run to run.

In my proposal, in such scenario users can just remove the load_path config from 2nd run.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, thanks Chien-Chin for make it so clear! Currently it's very hard for users to realize the folder to load checkpoint is {--job.dump_folder}/{--checkpoint.folder}

Copy link
Contributor Author

@fegin fegin May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that users may not have a chance to remove the load_path config from the subsequence run.

Suppose users are doing the following training flow

  1. Using load_path to load from the pre-training checkpoint.
  2. Continue the training or fine-tuning, which will periodically save a checkpoint in checkpoint.folder.
  3. Nodes fail.
  4. Recover from the failure and intend to load the checkpoints.

The issue is that ,for step4, the recovering process may be controlled by the cluster (this is not uncommon and Meta's cluster does this), without users intervention. So the config for the second run will be exactly the same as step 1. So the trainer will load the pre-train checkpoint instead of what are saved in step 2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense. Let's maybe add this targeted use case to the helper message.

This feature allows users to load an initial checkpoint from a different folder and
continue training, saving new checkpoints to the specified folder without affecting
the existing ones.

Note that the path should contain the full path to the checkpoint folder,
including the step number, if any; for example,
"//pre_train/checkpoints/llama3/llama3_8b/step_10000".
"""

initial_load_model_weights_only: bool = True
"""
This option specifies if only the model weights should be loaded during the initial
checkpoint load. The option is only used when `initial_load_path` is specified.
If False, the checkpoint at `initial_load_path` is treated as a standard training
checkpoint, including optimizer and training states.
The default setting for this option is True. Note that you will have to use
`--checkpoint.no_initial_load_model_weights_only` to override the default setting.
"""

interval: int = 500
"""Checkpointing interval in steps."""

model_weights_only: bool = False
last_save_model_weights_only: bool = False
"""
When model_weights_only=True, only model weights will be saved at the end of training.
With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion.
When model_weights_only=False, the full checkpoint will be saved.
When last_save_model_weights_only=True, only model weights will be saved at the end of training,
the last save. With this, checkpoints can be loaded using `torch.load(..., weights_only=True)`
after conversion. When last_save_model_weights_only=False, the full checkpoint will be saved.
A full checkpoint includes model, optimizer and train_state, which can be used to resume training.
The default value is false.
"""

export_dtype: Literal["float16", "bfloat16", "float32"] = "float32"
"""
Converts to the specified precision when training completes and model_weights_only=true.
Converts to the specified precision when training completes and last_save_model_weights_only=true.
"""

create_seed_checkpoint: bool = False
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/experiments/flux/tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ def build_test_list():
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.model_weights_only",
"--checkpoint.last_save_model_weights_only",
],
],
"Checkpoint Integration Test - Save Model Weights Only fp32",
"model_weights_only_fp32",
"last_save_model_weights_only_fp32",
),
# Parallelism tests.
OverrideDefinitions(
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/flux/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,6 @@ mode = "full"
enable_checkpoint = false
folder = "checkpoint"
interval = 5
model_weights_only = false
last_save_model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@ mode = "full"
enable_checkpoint = false
folder = "checkpoint"
interval = 1_000
model_weights_only = false
last_save_model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@ mode = "full"
enable_checkpoint = false
folder = "checkpoint"
interval = 1_000
model_weights_only = false
last_save_model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ def _reshard_send(
def _reshard_receive(
self, assignment: _Assignment, state_dict: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:

flatten_tensor = torch.empty(
sum(math.prod(s) for s, d in zip(assignment.shapes, assignment.dtypes)),
dtype=assignment.dtypes[0],
Expand Down Expand Up @@ -535,7 +534,7 @@ def state_dict(self) -> dict[str, torch.Tensor]:
# oh, this is pretty bad, when can we get rid of the freqs_cis issue?
state_dict["freqs_cis"] = None
trainer.checkpointer.states[MODEL] = DummyModel(state_dict)
trainer.checkpointer.model_weights_only = True
trainer.checkpointer.last_save_model_weights_only = True
trainer.checkpointer.export_dtype = next(iter(state_dict.values())).dtype
trainer.checkpointer.save(curr_step=0, force=True)
time.sleep(2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def state_dict(self) -> dict[str, torch.Tensor]:
# oh, this is pretty bad, when can we get rid of the freqs_cis issue?
state_dict["freqs_cis"] = None
trainer.checkpointer.states[MODEL] = DummyModel(state_dict)
trainer.checkpointer.model_weights_only = True
trainer.checkpointer.last_save_model_weights_only = True
trainer.checkpointer.export_dtype = next(iter(state_dict.values())).dtype
trainer.checkpointer.save(curr_step=0, force=True)
time.sleep(2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ context_parallel_degree = 1
enable_checkpoint = false
folder = "checkpoint"
interval = 10
model_weights_only = false
last_save_model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ context_parallel_degree = 1
enable_checkpoint = false
folder = "checkpoint"
interval = 500
model_weights_only = false
last_save_model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ context_parallel_degree = 1
enable_checkpoint = false
folder = "checkpoint"
interval = 500
model_weights_only = false
last_save_model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama3/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ context_parallel_degree = 1
enable_checkpoint = false
folder = "checkpoint"
interval = 10
model_weights_only = false
last_save_model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama3/train_configs/llama3_405b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ context_parallel_degree = 1
enable_checkpoint = false
folder = "checkpoint"
interval = 500
model_weights_only = false
last_save_model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

Expand Down
Loading
Loading