Skip to content

Commit d1ada1d

Browse files
committed
Add option for selective op AC to filter mm shapes based on fqn
1 parent 01f4e50 commit d1ada1d

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

torchtitan/config_manager.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,18 @@ class ActivationCheckpoint:
487487
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
488488
"""
489489

490+
selective_op_ac_force_recompute_mm_shapes_by_fqns: list[str] = field(default_factory=lambda: [])
491+
"""
492+
When per-op selective ac is used, this list of fully qualified names (relative
493+
to the module at which AC is applied) is used to determine which mm shapes to
494+
force recompute, rather than being considered by rest of the sac policy, e.g
495+
save every other mm. Only nn.Linear modules are supported today.
496+
497+
Note: this config applies to mms not limited to those matching the specified
498+
fqns, e.g. if "moe.router.gate", corresponding to Linear(in, out), is specified,
499+
ANY mm with shape matching (*, in) x (in, out) will be force recomputed.
500+
"""
501+
490502

491503
@dataclass
492504
class Float8:

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,27 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
261261
create_selective_checkpoint_contexts,
262262
)
263263

264+
mm_recompute_shapes = set()
265+
for fqn in ac_config.selective_op_ac_force_recompute_mm_shapes_by_fqns:
266+
submod = dict(module.named_modules()).get(fqn, None)
267+
if submod is None:
268+
continue
269+
if not isinstance(submod, nn.Linear):
270+
raise ValueError(
271+
"selective_op_ac_force_recompute_mm_shapes_by_fqns expected to match "
272+
f"a nn.Linear, but got: {submod}"
273+
)
274+
out_f, in_f = submod.weight.shape
275+
mm_recompute_shapes.add((in_f, out_f))
276+
logger.debug(f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}")
277+
264278
def _get_custom_policy(meta):
265279
def _custom_policy(ctx, func, *args, **kwargs):
266280
mode = "recompute" if ctx.is_recompute else "forward"
267281
mm_count_key = f"{mode}_mm_count"
268282
if func == torch.ops.aten.mm.default:
283+
if args[1].shape in mm_recompute_shapes:
284+
return CheckpointPolicy.PREFER_RECOMPUTE
269285
meta[mm_count_key] += 1
270286
# Saves output of all compute ops, except every second mm
271287
to_save = func in _save_list and not (

0 commit comments

Comments
 (0)