-
Notifications
You must be signed in to change notification settings - Fork 368
compile: turn off fullgraph=True to support llama4 #1182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/bdhirsh/3/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
@@ -304,7 +304,7 @@ def apply_compile(model: nn.Module): | |||
repeated structure. Alternatively one can compile the whole model (after applying DP). | |||
""" | |||
for layer_id, transformer_block in model.layers.named_children(): | |||
transformer_block = torch.compile(transformer_block, fullgraph=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a comment/TODO to remind us to turn it back on when issues are resolved?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh will do. there are two related things here:
(1) grouped_mm support in compile, which torchtitan uses in llama4. I added basic support in core in this PR: pytorch/pytorch#153384
(2) E2E llama4 + compile in torchtitan. The current reason this completely blows up today is that torchtitan's llama4 + FSDP2 integration requires wrapping the MoE layer, which requires installing backward hooks around the MoE layer. Compile does not support compiling backward hooks (we graph break), and so we need to do one of these options:
(a) allow the graph break (turn off fullgraph=True
)
(b) tweak torchtitan so that instead of compiling each transformer layer, we compile MoE layers separately, and compile the rest of the transformer block layer separately as well.
I also mentioned this to @tianyu-l but calling it out here: (a) is easier to do, so I'm doing it here, but it does have the risk that if any changes are made to core that increases the number of graph breaks in torchtitan, we won't error as loudly (we may see a perf drop instead). (b) is probably better to do at some point, I'm just doing the simpler thing here.
Are folks working on torchtitan interested in running benchmarks for titan + llama4 (with compile on/off?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin I actually tweaked the PR so that we still fullgraph=True compile the "regular" transformer blocks, and only fullgraph=False the blocks with MoE layers. I think this should reduce the risk we hit regressions, so this may be a reasonable long term solution (when using FSDP2 in torchtitan), as long as we see reasonable perf numbers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh Thanks for the thorough explanation. The 8GPU integration test is timeout. Since llama4 is not in the integration test, the integrate test issue should not be caused by this PR. I still relaunch the test but feel free to land it.
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile ``` CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile ``` [ghstack-poisoned]
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile ``` CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile ``` [ghstack-poisoned]
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile ``` CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile ``` [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current reason this completely blows up today is that torchtitan's llama4 + FSDP2 integration requires wrapping the MoE layer, which requires installing backward hooks around the MoE layer.
hmm let's be more careful here. This is only true when EP is used, specifically dp2ep (e.g. in #732). Currently in Llama 4, EP is not supported yet, which means we are doing homogeneous FSDP2 wrapping to all the transformer blocks only (not to MoE modules). So I suppose the full graph compilation shouldn't be violated. If we set full_graph=True
, where would it break?
Locally, I was seeing that when we compile each transformer block layer, dynamo was trying (and failing) to graph break, because someone was attempting to install backward hooks inside of one of the transformer blocks. If that's surprising to you I can try to find the code that is installing the backward hook. |
I think the backward hooks are from the auxiliary-loss-free load balancing (#1114). The load balancing algorithm would possess a bias term for each expert, based on the number of tokens an expert has seen so far.
Using forward, forward pre, or backward pre hooks would cause conflict with activation checkpointing. |
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile
Stack from ghstack (oldest at bottom):