Skip to content

Commit 050291b

Browse files
committed
compile, but you need pytorch/torchtitan-ep
1 parent 08c1ff1 commit 050291b

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,23 @@
3434
)
3535

3636

37+
def apply_compile(model: nn.Module):
38+
"""
39+
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
40+
repeated structure. Alternatively one can compile the whole model (after applying DP).
41+
"""
42+
torch._dynamo.config.fail_on_recompile_limit_hit = True
43+
for layer_id, transformer_block in model.layers.named_children():
44+
fullgraph = True
45+
if transformer_block.moe_enabled:
46+
# Allow graph break for MoE layers
47+
fullgraph = False
48+
transformer_block = torch.compile(transformer_block, fullgraph=fullgraph)
49+
model.layers.register_module(layer_id, transformer_block)
50+
51+
logger.info("Compiling each TransformerBlock with torch.compile")
52+
53+
3754
def parallelize_llama(
3855
model: nn.Module,
3956
world_mesh: DeviceMesh,

0 commit comments

Comments
 (0)