-
Notifications
You must be signed in to change notification settings - Fork 427
[WIP][RFC] Always flatten model state_dict #1347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you help verify "save seed checkpoint -> load from seed checkpoint with different parallelism" still yields identical loss, as an additional integration test?
@@ -573,8 +590,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]: | |||
""" | |||
# For the first step, we will only load the model weights. | |||
if model_only: | |||
sd = self.states[MODEL].state_dict() | |||
return sd | |||
return self.states[MODEL].state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this mean "model" still exists in state_dict
as a key -- we only flatten it in the checkpoint (and its load and save)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Checkpointer, we still keep separate keys in self.states
, like MODEL, OPTIMIZER. This will allow use to manipulate different state_dicts. This line use MODEL
to access only the model state_dict but this line does not wrap the model state_dict, so there will be no model.
prefix.
@tianyu-l The loss curves don't match, w/ or w/o freqs_cis in the seed checkpoint. The seed checkpoint may have broken before |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In terms of loss curve matching
local_batch_size is 2 when TP is 2, otherwise local_batch_size is 1.
In order to make sure dataloader behaves consistently across multiple runs, we need to fix the DP degree (dp_replicate * dpshard).
So a typical comparison you may have (by keeping the overall DP degree 4) could be
- FSDP 4
- DP 2, FSDP 2, TP 2
- DP 2, FSDP 2, CP 2, PP 2
For details and more examples, please see
https://github.com/pytorch/torchtitan/blob/main/docs/converging.md
4050014
to
f0d0dd7
Compare
@tianyu-l I also tried FSDP4 v.s. FSDP4 TP2 before I tried this setting and the loss curves didn't match. |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM!
The model state_dict is unique compared to other state dictionaries (e.g., optimizer). It's the only one that will be exported outside of TorchTitan and imported from other sources. To ensure FQN consistency, we previously removed the prefix during the first checkpoint load and last checkpoint save. However, this approach has caused confusion among users, despite available options to control behavior. This PR aims to resolve the issue by always flattening the model state dictionary, eliminating the `"MODEL."` prefix from its keys. We decided not to flatten all components due to the risk of key collisions between different components. Instead, this PR only flattens the model state_dict, which is a special case. While this solution isn't perfect, as it introduces different handling for different components, it's a good compromise given the unique nature of the model state_dict. Also see the discussion in pytorch#1321 (comment) This is the pseudo code for the current state: ``` if model_only: state_dict = model.state_dict() else: state_dict = { "MODEL": model, "OPTIMIZER": optimizer, ... } } ``` This is the pseudo code after this PR is landed: ``` state_dict = model.state_dict() if not model_only: state_dict.update( {"OPTIMIZER": optimizer} ... ) ``` FSDP4 v.s. FSDP4 TP2 loss curve with seed checkpoint and --training.seed=42 ![Uploading Screenshot 2025-07-02 at 1.02.23 PM.png…]()
The model state_dict is unique compared to other state dictionaries (e.g., optimizer). It's the only one that will be exported outside of TorchTitan and imported from other sources. To ensure FQN consistency, we previously removed the prefix during the first checkpoint load and last checkpoint save. However, this approach has caused confusion among users, despite available options to control behavior. This PR aims to resolve the issue by always flattening the model state dictionary, eliminating the `"MODEL."` prefix from its keys. We decided not to flatten all components due to the risk of key collisions between different components. Instead, this PR only flattens the model state_dict, which is a special case. While this solution isn't perfect, as it introduces different handling for different components, it's a good compromise given the unique nature of the model state_dict. Also see the discussion in pytorch#1321 (comment) This is the pseudo code for the current state: ``` if model_only: state_dict = model.state_dict() else: state_dict = { "MODEL": model, "OPTIMIZER": optimizer, ... } } ``` This is the pseudo code after this PR is landed: ``` state_dict = model.state_dict() if not model_only: state_dict.update( {"OPTIMIZER": optimizer} ... ) ``` FSDP4 v.s. FSDP4 TP2 loss curve with seed checkpoint and --training.seed=42 ![Uploading Screenshot 2025-07-02 at 1.02.23 PM.png…]()
The model state_dict is unique compared to other state dictionaries (e.g., optimizer). It's the only one that will be exported outside of TorchTitan and imported from other sources. To ensure FQN consistency, we previously removed the prefix during the first checkpoint load and last checkpoint save. However, this approach has caused confusion among users, despite available options to control behavior. This PR aims to resolve the issue by always flattening the model state dictionary, eliminating the `"MODEL."` prefix from its keys. We decided not to flatten all components due to the risk of key collisions between different components. Instead, this PR only flattens the model state_dict, which is a special case. While this solution isn't perfect, as it introduces different handling for different components, it's a good compromise given the unique nature of the model state_dict. Also see the discussion in pytorch#1321 (comment) This is the pseudo code for the current state: ``` if model_only: state_dict = model.state_dict() else: state_dict = { "MODEL": model, "OPTIMIZER": optimizer, ... } } ``` This is the pseudo code after this PR is landed: ``` state_dict = model.state_dict() if not model_only: state_dict.update( {"OPTIMIZER": optimizer} ... ) ``` FSDP4 v.s. FSDP4 TP2 loss curve with seed checkpoint and --training.seed=42 ![Uploading Screenshot 2025-07-02 at 1.02.23 PM.png…]()
The model state_dict is unique compared to other state dictionaries (e.g., optimizer). It's the only one that will be exported outside of TorchTitan and imported from other sources. To ensure FQN consistency, we previously removed the prefix during the first checkpoint load and last checkpoint save. However, this approach has caused confusion among users, despite available options to control behavior.
This PR aims to resolve the issue by always flattening the model state dictionary, eliminating the
"MODEL."
prefix from its keys. We decided not to flatten all components due to the risk of key collisions between different components. Instead, this PR only flattens the model state_dict, which is a special case.While this solution isn't perfect, as it introduces different handling for different components, it's a good compromise given the unique nature of the model state_dict.
Also see the discussion in #1321 (comment)
This is the pseudo code for the current state:
This is the pseudo code after this PR is landed:
FSDP4 v.s. FSDP4 TP2 loss curve with seed checkpoint and --training.seed=42