File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed
torchtitan/experiments/llama4/infra Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change 34
34
)
35
35
36
36
37
+ def apply_compile (model : nn .Module ):
38
+ """
39
+ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
40
+ repeated structure. Alternatively one can compile the whole model (after applying DP).
41
+ """
42
+ torch ._dynamo .config .fail_on_recompile_limit_hit = True
43
+ for layer_id , transformer_block in model .layers .named_children ():
44
+ fullgraph = True
45
+ if transformer_block .moe_enabled :
46
+ # Allow graph break for MoE layers
47
+ fullgraph = False
48
+ transformer_block = torch .compile (transformer_block , fullgraph = fullgraph )
49
+ model .layers .register_module (layer_id , transformer_block )
50
+
51
+ logger .info ("Compiling each TransformerBlock with torch.compile" )
52
+
53
+
37
54
def parallelize_llama (
38
55
model : nn .Module ,
39
56
world_mesh : DeviceMesh ,
You can’t perform that action at this time.
0 commit comments