Skip to content

compile: turn off fullgraph=True to support llama4 #1182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: gh/bdhirsh/3/base
Choose a base branch
from

Conversation

bdhirsh
Copy link

@bdhirsh bdhirsh commented May 12, 2025

This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile

CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile

Stack from ghstack (oldest at bottom):

bdhirsh added a commit that referenced this pull request May 12, 2025
ghstack-source-id: 98bb0ed
Pull Request resolved: #1182
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 12, 2025
@@ -304,7 +304,7 @@ def apply_compile(model: nn.Module):
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
for layer_id, transformer_block in model.layers.named_children():
transformer_block = torch.compile(transformer_block, fullgraph=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment/TODO to remind us to turn it back on when issues are resolved?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh will do. there are two related things here:

(1) grouped_mm support in compile, which torchtitan uses in llama4. I added basic support in core in this PR: pytorch/pytorch#153384

(2) E2E llama4 + compile in torchtitan. The current reason this completely blows up today is that torchtitan's llama4 + FSDP2 integration requires wrapping the MoE layer, which requires installing backward hooks around the MoE layer. Compile does not support compiling backward hooks (we graph break), and so we need to do one of these options:

(a) allow the graph break (turn off fullgraph=True)

(b) tweak torchtitan so that instead of compiling each transformer layer, we compile MoE layers separately, and compile the rest of the transformer block layer separately as well.

I also mentioned this to @tianyu-l but calling it out here: (a) is easier to do, so I'm doing it here, but it does have the risk that if any changes are made to core that increases the number of graph breaks in torchtitan, we won't error as loudly (we may see a perf drop instead). (b) is probably better to do at some point, I'm just doing the simpler thing here.

Are folks working on torchtitan interested in running benchmarks for titan + llama4 (with compile on/off?)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin I actually tweaked the PR so that we still fullgraph=True compile the "regular" transformer blocks, and only fullgraph=False the blocks with MoE layers. I think this should reduce the risk we hit regressions, so this may be a reasonable long term solution (when using FSDP2 in torchtitan), as long as we see reasonable perf numbers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh Thanks for the thorough explanation. The 8GPU integration test is timeout. Since llama4 is not in the integration test, the integrate test issue should not be caused by this PR. I still relaunch the test but feel free to land it.

This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile
```
CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile
```




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 12, 2025
ghstack-source-id: f22f920
Pull Request resolved: #1182
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile
```
CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile
```




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 12, 2025
ghstack-source-id: cd16b65
Pull Request resolved: #1182
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile
```
CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile
```




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 12, 2025
ghstack-source-id: 1539e21
Pull Request resolved: #1182
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current reason this completely blows up today is that torchtitan's llama4 + FSDP2 integration requires wrapping the MoE layer, which requires installing backward hooks around the MoE layer.

hmm let's be more careful here. This is only true when EP is used, specifically dp2ep (e.g. in #732). Currently in Llama 4, EP is not supported yet, which means we are doing homogeneous FSDP2 wrapping to all the transformer blocks only (not to MoE modules). So I suppose the full graph compilation shouldn't be violated. If we set full_graph=True, where would it break?

@bdhirsh
Copy link
Author

bdhirsh commented May 13, 2025

Locally, I was seeing that when we compile each transformer block layer, dynamo was trying (and failing) to graph break, because someone was attempting to install backward hooks inside of one of the transformer blocks. If that's surprising to you I can try to find the code that is installing the backward hook.

@tianyu-l
Copy link
Contributor

I think the backward hooks are from the auxiliary-loss-free load balancing (#1114).

The load balancing algorithm would possess a bias term for each expert, based on the number of tokens an expert has seen so far.

  1. The single-device algo needs a backward hook to update the bias term after each iteration.
  2. For multi-device, we need another backward hook to all-reduce the bias term across all DP ranks, as different DP ranks see different inputs.

Using forward, forward pre, or backward pre hooks would cause conflict with activation checkpointing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants