-
Notifications
You must be signed in to change notification settings - Fork 424
Implement initial_load_path for checkpointer #1236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
02711fe
f9059fe
faafb29
2e39698
a5b3c7c
03872df
675ee70
02313f7
8f8d402
9bc9810
12698d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,6 +81,12 @@ class SaveDone: | |
pass | ||
|
||
|
||
# For now, we will manually pop the freqs_cis buffer, as we made this permanent | ||
# temporarily and we don't want to include it in the exported state_dict. | ||
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404 | ||
excluded_parameters_for_model_only = {"freqs_cis"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we don't save In https://github.com/pytorch/torchtitan/blob/main/scripts/convert_llama_to_dcp.py#L127 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we pop out the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. synced offline and understood the reason:
|
||
|
||
|
||
@torch.no_grad() | ||
def save_with_gc(state, checkpoint_id): | ||
dcp.save(state, checkpoint_id=checkpoint_id) | ||
|
@@ -267,6 +273,10 @@ def load_state_dict(state_dict): | |
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None | ||
|
||
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) | ||
self.initial_load_path = ckpt_config.initial_load_path | ||
self.initial_load_model_weights_only = ( | ||
ckpt_config.initial_load_model_weights_only | ||
) | ||
self.interval = ckpt_config.interval | ||
async_mode = ckpt_config.async_mode.lower() | ||
if async_mode == AsyncMode.ASYNC or self.ft_manager: | ||
|
@@ -287,7 +297,7 @@ def load_state_dict(state_dict): | |
else: | ||
self.purge_thread = None | ||
|
||
self.model_weights_only = ckpt_config.model_weights_only | ||
self.last_save_model_weights_only = ckpt_config.last_save_model_weights_only | ||
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype] | ||
self.exclude_from_loading = ckpt_config.exclude_from_loading | ||
|
||
|
@@ -408,21 +418,38 @@ def load(self, step: int = -1) -> bool: | |
if self.ft_manager: | ||
self._ft_load() | ||
|
||
if not self.enable_checkpoint or not os.path.isdir(self.folder): | ||
if not self.enable_checkpoint: | ||
return False | ||
|
||
if step == -1: | ||
step = self._find_load_step() | ||
model_only = False | ||
if not os.path.exists(self.folder): | ||
if self.initial_load_path: | ||
checkpoint_id = self.initial_load_path | ||
if not os.path.isdir(checkpoint_id): | ||
raise ValueError( | ||
"initial_load_full_checkpoint is specified but the path is not valid." | ||
) | ||
model_only = self.initial_load_model_weights_only | ||
else: | ||
return False | ||
else: | ||
if self.initial_load_path: | ||
logger.info( | ||
"`initial_load_path` is provided but the checkpoint folder exists. " | ||
"Checkpointer will use the checkpoints from the checkpoint folder." | ||
) | ||
step = self._find_load_step() if step == -1 else step | ||
if step == -1: | ||
return False | ||
model_only = step == 0 | ||
checkpoint_id = self._create_checkpoint_id(step) | ||
|
||
checkpoint_id = self._create_checkpoint_id(step) | ||
if not os.path.isdir(checkpoint_id): | ||
return False | ||
if not os.path.isdir(checkpoint_id): | ||
return False | ||
|
||
logger.info(f"Loading the checkpoint at step {step}.") | ||
logger.info(f"Loading the checkpoint from {checkpoint_id}.") | ||
begin = time.monotonic() | ||
states = self._states_to_load(step) | ||
states = self._states_to_load(model_only) | ||
dcp.load(states, checkpoint_id=checkpoint_id) | ||
GarbageCollection.collect("GC collection for checkpoint loading.") | ||
logger.info( | ||
|
@@ -521,28 +548,36 @@ def _ft_load(self) -> None: | |
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds." | ||
) | ||
|
||
def _states_to_load(self, step: int) -> dict[str, Any]: | ||
def _states_to_load(self, model_only: bool) -> dict[str, Any]: | ||
"""Determines which states to load for the given step. | ||
|
||
When checkpointer determines which step of the checkpoint to load, this API is | ||
used to determine which states to load based on the step. | ||
This API is used to determine which states to load based on the | ||
configurations. | ||
|
||
Args: | ||
step (int): The step to load the checkpoint for. | ||
model_only (bool): Whether to load the model only. | ||
|
||
Returns: | ||
Dict[str, Any]: The states to load for the given step. | ||
""" | ||
# For the first step, we will only load the model weights. | ||
states = {MODEL: self.states[MODEL]} if step == 0 else self.states | ||
states_to_load = { | ||
k: v for k, v in states.items() if k not in self.exclude_from_loading | ||
} | ||
if model_only: | ||
sd = self.states[MODEL].state_dict() | ||
for k in excluded_parameters_for_model_only: | ||
sd.pop(k, None) | ||
return sd | ||
|
||
for exclude_key in self.exclude_from_loading: | ||
if exclude_key not in states: | ||
if exclude_key not in self.states: | ||
raise ValueError(f"{exclude_key} not found in state_dict.") | ||
|
||
states_to_load = { | ||
k: v for k, v in self.states.items() if k not in self.exclude_from_loading | ||
} | ||
|
||
if self.ft_manager: | ||
states_to_load.pop(DATALOADER) | ||
|
||
return states_to_load | ||
|
||
def _save_last_step(self, curr_step: int) -> None: | ||
|
@@ -551,18 +586,16 @@ def _save_last_step(self, curr_step: int) -> None: | |
# dtype conversion when we are checkpoint model weights only and the | ||
# current dtype is not the same as the export dtype at the end of the training. | ||
|
||
if self.model_weights_only: | ||
if self.last_save_model_weights_only: | ||
# We update self.states to keep the model only. | ||
# After this update, self.states = { | ||
# 'tok_embeddings.weight':..., | ||
# 'layers.0.attention.wq.weight': ... | ||
# }. | ||
self.states = self.states[MODEL].state_dict() | ||
|
||
# For now, we will manually pop the freqs_cis buffer, as we made this permanent | ||
# temporarily and we don't want to include it in the exported state_dict. | ||
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404 | ||
self.states.pop("freqs_cis", None) | ||
for k in excluded_parameters_for_model_only: | ||
self.states.pop(k, None) | ||
|
||
if self.export_dtype != torch.float32: | ||
self.states = { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -373,21 +373,47 @@ class Checkpoint: | |
When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}. | ||
""" | ||
|
||
initial_load_path: str | None = None | ||
""" | ||
This option specifies the path to the initial checkpoint to load, which is | ||
particularly useful for resuming training from a previous run with a | ||
different output path or when loading a checkpoint from a pre-trained model. | ||
If the checkpoint folder for the current run is not empty, | ||
located at {--job.dump_folder}/{--checkpoint.folder}, this option will be ignored. | ||
Comment on lines
+381
to
+382
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer a UI with the opposite load priority: Pros: Users have explicit control over which folder to load checkpoint, without checking if the The Pros of current PR: In my proposal, in such scenario users can just remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, thanks Chien-Chin for make it so clear! Currently it's very hard for users to realize the folder to load checkpoint is {--job.dump_folder}/{--checkpoint.folder} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue is that users may not have a chance to remove the Suppose users are doing the following training flow
The issue is that ,for step4, the recovering process may be controlled by the cluster (this is not uncommon and Meta's cluster does this), without users intervention. So the config for the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, makes sense. Let's maybe add this targeted use case to the helper message. |
||
This feature allows users to load an initial checkpoint from a different folder and | ||
continue training, saving new checkpoints to the specified folder without affecting | ||
the existing ones. | ||
|
||
Note that the path should contain the full path to the checkpoint folder, | ||
including the step number, if any; for example, | ||
"//pre_train/checkpoints/llama3/llama3_8b/step_10000". | ||
""" | ||
|
||
initial_load_model_weights_only: bool = True | ||
""" | ||
This option specifies if only the model weights should be loaded during the initial | ||
checkpoint load. The option is only used when `initial_load_path` is specified. | ||
If False, the checkpoint at `initial_load_path` is treated as a standard training | ||
checkpoint, including optimizer and training states. | ||
The default setting for this option is True. Note that you will have to use | ||
`--checkpoint.no_initial_load_model_weights_only` to override the default setting. | ||
""" | ||
|
||
interval: int = 500 | ||
"""Checkpointing interval in steps.""" | ||
|
||
model_weights_only: bool = False | ||
last_save_model_weights_only: bool = False | ||
""" | ||
When model_weights_only=True, only model weights will be saved at the end of training. | ||
With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion. | ||
When model_weights_only=False, the full checkpoint will be saved. | ||
When last_save_model_weights_only=True, only model weights will be saved at the end of training, | ||
the last save. With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` | ||
after conversion. When last_save_model_weights_only=False, the full checkpoint will be saved. | ||
A full checkpoint includes model, optimizer and train_state, which can be used to resume training. | ||
The default value is false. | ||
""" | ||
|
||
export_dtype: Literal["float16", "bfloat16", "float32"] = "float32" | ||
""" | ||
Converts to the specified precision when training completes and model_weights_only=true. | ||
Converts to the specified precision when training completes and last_save_model_weights_only=true. | ||
""" | ||
|
||
create_seed_checkpoint: bool = False | ||
|
Uh oh!
There was an error while loading. Please reload this page.