Skip to content

Commit aefe15a

Browse files
authored
Always ignore freqs_cis (#1338)
We should always ignore freq_cis and other parameters in excluded_parameters_for_model_only to avoid confusion. TODO: Is this going to break PP with seed checkpoint?
1 parent 71b07ad commit aefe15a

File tree

6 files changed

+64
-38
lines changed

6 files changed

+64
-38
lines changed

scripts/convert_llama_to_dcp.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import torch
1212
import torch.distributed.checkpoint as DCP
13-
from torchtitan.models.llama.model import precompute_freqs_cis
1413
from torchtitan.tools.logging import init_logger, logger
1514

1615

@@ -123,13 +122,6 @@ def convert_llama_weights(input_dir, output_dir, max_seq_len: int):
123122
for i in range(len(shards)):
124123
del shards[i]["output.weight"]
125124

126-
# NOTE: precompute freqs_cis because must be persisted by default in torchtitan
127-
state_dict["freqs_cis"] = precompute_freqs_cis(
128-
dims_per_head,
129-
max_seq_len,
130-
params.get("rope_theta", 500000),
131-
)
132-
133125
logger.info(f"Writing to DCP at '{output_dir}'")
134126
output_dir.mkdir(parents=True, exist_ok=True)
135127
storage_writer = DCP.filesystem.FileSystemWriter(output_dir, thread_count=8)

scripts/generate/test_generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
parallelize_module,
2525
RowwiseParallel,
2626
)
27+
from torchtitan.components.checkpoint import excluded_parameters_for_model_only
2728
from torchtitan.components.metrics import build_device_memory_monitor
28-
2929
from torchtitan.config_manager import ConfigManager
3030
from torchtitan.distributed import ParallelDims, utils as dist_utils
3131
from torchtitan.protocols.train_spec import get_train_spec
@@ -142,6 +142,8 @@ def test_generate(
142142
model.eval()
143143

144144
state_dict = {"model": model.state_dict()}
145+
for k in excluded_parameters_for_model_only:
146+
state_dict["model"].pop(k, None)
145147

146148
# Checkpoint Loading
147149
begin = time.monotonic()

tests/unit_tests/test_checkpoint.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,52 @@ def test_enable_first_step_checkpoint(self, mock_save, mock_rank):
534534

535535
manager2.close()
536536

537+
@mock.patch("torch.distributed.get_rank", return_value=0)
538+
@mock.patch("torchtitan.components.checkpoint.dcp.save")
539+
def test_excluded_parameters_not_saved(self, mock_save, mock_rank):
540+
"""Test that freqs_cis is not saved"""
541+
542+
# Create a fake model with freqs_cis and other parameters
543+
class FakeModelWithFreqsCis(nn.Module):
544+
def __init__(self):
545+
super().__init__()
546+
self.weight = nn.Parameter(torch.randn(2, 2))
547+
self.bias = nn.Parameter(torch.randn(2))
548+
# Register freqs_cis as a buffer (common pattern in transformer models)
549+
self.register_buffer("freqs_cis", torch.randn(10, 5))
550+
self.other_param = nn.Parameter(torch.randn(3, 3))
551+
552+
fake_model = FakeModelWithFreqsCis()
553+
mock_save.side_effect = self.fake_save
554+
555+
cfg = self.job_config.checkpoint
556+
cfg.keep_latest_k = 0 # Disable purging
557+
558+
manager = CheckpointManager(
559+
dataloader=self.data_loader,
560+
model_parts=[fake_model],
561+
optimizers=self.optimizers,
562+
lr_schedulers=self.lr_schedulers,
563+
states=self.states,
564+
job_config=self.job_config,
565+
ft_manager=self.ft_manager,
566+
)
567+
568+
manager.save(curr_step=1)
569+
self.assertEqual(mock_save.call_count, 1)
570+
checkpoint_path = os.path.join(self.test_folder, "step-1", "state_dict.pt")
571+
saved_data = torch.load(checkpoint_path, weights_only=False)
572+
model_state_dict = saved_data[MODEL]
573+
574+
# Verify that freqs_cis is NOT in the saved state dict
575+
self.assertNotIn("freqs_cis", model_state_dict)
576+
# Verify that other parameters ARE in the saved state dict
577+
self.assertIn("weight", model_state_dict)
578+
self.assertIn("bias", model_state_dict)
579+
self.assertIn("other_param", model_state_dict)
580+
581+
manager.close()
582+
537583

538584
if __name__ == "__main__":
539585
unittest.main()

torchtitan/components/checkpoint.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,25 @@ class AsyncMode(str, enum.Enum):
4949
ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem"
5050

5151

52+
# For now, we will manually pop the freqs_cis buffer, as we made this permanent
53+
# temporarily and we don't want to include it in the exported state_dict.
54+
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
55+
excluded_parameters_for_model_only = {"freqs_cis"}
56+
57+
5258
class ModelWrapper(Stateful):
5359
def __init__(self, model: nn.Module | list[nn.Module]) -> None:
5460
self.model = [model] if isinstance(model, nn.Module) else model
55-
self.cache_state_dict = {
61+
self.cache_state_dict = self._get_state_dict()
62+
63+
def _get_state_dict(self) -> dict[str, Any]:
64+
state_dict = {
5665
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
5766
}
67+
# Exclude parameters that should not be saved
68+
for excluded_key in excluded_parameters_for_model_only:
69+
state_dict.pop(excluded_key, None)
70+
return state_dict
5871

5972
def state_dict(self) -> dict[str, Any]:
6073
return self.cache_state_dict
@@ -68,9 +81,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
6881
list(map(func, self.model))
6982
# `set_model_state_dict()` does change the keys of the input state_dict,
7083
# we will need to reinitialize the cache_state_dict.
71-
self.cache_state_dict = {
72-
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
73-
}
84+
self.cache_state_dict = self._get_state_dict()
7485

7586

7687
class Terminate:
@@ -81,12 +92,6 @@ class SaveDone:
8192
pass
8293

8394

84-
# For now, we will manually pop the freqs_cis buffer, as we made this permanent
85-
# temporarily and we don't want to include it in the exported state_dict.
86-
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
87-
excluded_parameters_for_model_only = {"freqs_cis"}
88-
89-
9095
@torch.no_grad()
9196
def save_with_gc(state, checkpoint_id):
9297
dcp.save(state, checkpoint_id=checkpoint_id)
@@ -569,8 +574,6 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
569574
# For the first step, we will only load the model weights.
570575
if model_only:
571576
sd = self.states[MODEL].state_dict()
572-
for k in excluded_parameters_for_model_only:
573-
sd.pop(k, None)
574577
return sd
575578

576579
for exclude_key in self.exclude_from_loading:
@@ -600,9 +603,6 @@ def _save_last_step(self, curr_step: int) -> None:
600603
# }.
601604
self.states = self.states[MODEL].state_dict()
602605

603-
for k in excluded_parameters_for_model_only:
604-
self.states.pop(k, None)
605-
606606
if self.export_dtype != torch.float32:
607607
self.states = {
608608
k: v.to(self.export_dtype) for k, v in self.states.items()

torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -497,11 +497,6 @@ class MyJobConfig:
497497
size += v.numel() * v.element_size()
498498
logger.info(f"Total size of the model: {size / 1e9:.2f} GB")
499499

500-
# Do not support PP yet, we will need to iterate over the PP dimension and
501-
# extract the corresponding state_dict and device_mesh.
502-
if "freqs_cis" in state_dict:
503-
state_dict.pop("freqs_cis")
504-
505500
# Our tokenizer is not up-to-date yet.
506501
tok_embeddings_weight = state_dict.pop("tok_embeddings.weight")
507502
output_weight = state_dict.pop("output.weight")
@@ -531,8 +526,6 @@ def state_dict(self) -> dict[str, torch.Tensor]:
531526
dist.barrier()
532527
logger.info(f"Verifies state_dict {time.time() - begin}.")
533528
else:
534-
# oh, this is pretty bad, when can we get rid of the freqs_cis issue?
535-
state_dict["freqs_cis"] = None
536529
trainer.checkpointer.states[MODEL] = DummyModel(state_dict)
537530
trainer.checkpointer.last_save_model_weights_only = True
538531
trainer.checkpointer.export_dtype = next(iter(state_dict.values())).dtype

torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -498,11 +498,6 @@ class MyJobConfig:
498498
size += v.numel() * v.element_size()
499499
logger.info(f"Total size of the model: {size / 1e9:.2f} GB")
500500

501-
# Do not support PP yet, we will need to iterate over the PP dimension and
502-
# extract the corresponding state_dict and device_mesh.
503-
if "freq_cis" in state_dict:
504-
state_dict.pop("freqs_cis")
505-
506501
state_dict = CheckpointConverter(
507502
process_group=trainer.world_mesh.get_group(),
508503
path=config.checkpoint.convert_path,
@@ -526,8 +521,6 @@ def state_dict(self) -> dict[str, torch.Tensor]:
526521
dist.barrier()
527522
logger.info(f"Verifies state_dict {time.time() - begin}.")
528523
else:
529-
# oh, this is pretty bad, when can we get rid of the freqs_cis issue?
530-
state_dict["freqs_cis"] = None
531524
trainer.checkpointer.states[MODEL] = DummyModel(state_dict)
532525
trainer.checkpointer.last_save_model_weights_only = True
533526
trainer.checkpointer.export_dtype = next(iter(state_dict.values())).dtype

0 commit comments

Comments
 (0)