Skip to content

Commit 83da502

Browse files
committed
[not for land] Use new AC
ghstack-source-id: 944d6bc Pull-Request-resolved: #1294
1 parent ed2bbc0 commit 83da502

File tree

2 files changed

+54
-30
lines changed

2 files changed

+54
-30
lines changed

.github/workflows/integration_test_8gpu.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ jobs:
3636
3737
pip config --user set global.progress_bar off
3838
39+
git clone https://github.com/soulitzer/ac-experimental.git && cd ac-experimental && pip install -e . && cd ..
40+
3941
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
4042
4143
USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126

torchtitan/models/llama3/parallelize_llama.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,36 @@
1212
import torch
1313
import torch.nn as nn
1414
from torch.distributed._composable.replicate import replicate
15+
1516
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
16-
checkpoint_wrapper as ptd_checkpoint_wrapper,
17+
ActivationWrapper,
1718
)
1819

20+
21+
class CheckpointWrapper(ActivationWrapper):
22+
def __init__(self, mod: torch.nn.Module, **kwargs):
23+
super().__init__(mod)
24+
self._checkpoint_wrapped_module = mod
25+
self._make_policy_fn = kwargs.get("make_policy_fn", None)
26+
27+
def forward(self, *args, **kwargs):
28+
from ac_experimental import apply_ac_policy_fn
29+
30+
if self._make_policy_fn is None:
31+
return apply_ac_policy_fn(
32+
self._checkpoint_wrapped_module, *args, **kwargs, policy_fn="recompute_all"
33+
)
34+
else:
35+
# Pass is_factory=True so that a new instance of policy_fn is created per AC invocation
36+
return apply_ac_policy_fn(
37+
self._checkpoint_wrapped_module, *args, **kwargs, policy_fn=self._make_policy_fn, is_factory=True
38+
)
39+
40+
41+
def ptd_checkpoint_wrapper(mod, **kwargs):
42+
return CheckpointWrapper(mod, **kwargs)
43+
44+
1945
from torch.distributed.device_mesh import DeviceMesh
2046
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
2147
from torch.distributed.tensor import Replicate, Shard
@@ -226,6 +252,29 @@ def apply_tp(
226252
torch.ops.aten.max.default,
227253
}
228254

255+
from torch.utils.checkpoint import CheckpointPolicy
256+
257+
# If you want your policy to have state, pass a class. Make sure to
258+
# create it in global scope to avoid new instances triggering recompiles.
259+
class CustomPolicy:
260+
def __init__(self):
261+
super().__init__()
262+
self.meta = dict()
263+
264+
def __call__(self, ctx, out, func, *args, **kwargs):
265+
mm_count_key = f"mm_count"
266+
if func == torch.ops.aten.mm.default:
267+
self.meta[mm_count_key] = self.meta.get(mm_count_key, 0) + 1
268+
269+
# Saves output of all compute ops, except every second mm
270+
to_save = func in _save_list and not (
271+
func == torch.ops.aten.mm.default and self.meta[mm_count_key] % 2 == 0
272+
)
273+
return (
274+
CheckpointPolicy.MUST_SAVE
275+
if to_save
276+
else CheckpointPolicy.PREFER_RECOMPUTE
277+
)
229278

230279
def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
231280
valid_ac_modes = ("full", "selective")
@@ -246,38 +295,11 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
246295
f"Valid options: 'op' or a positive int representing layer frequency"
247296
)
248297
if use_op_sac:
249-
from torch.utils.checkpoint import (
250-
CheckpointPolicy,
251-
create_selective_checkpoint_contexts,
252-
)
253-
254-
def _get_custom_policy(meta):
255-
def _custom_policy(ctx, func, *args, **kwargs):
256-
mode = "recompute" if ctx.is_recompute else "forward"
257-
mm_count_key = f"{mode}_mm_count"
258-
if func == torch.ops.aten.mm.default:
259-
meta[mm_count_key] += 1
260-
# Saves output of all compute ops, except every second mm
261-
to_save = func in _save_list and not (
262-
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
263-
)
264-
return (
265-
CheckpointPolicy.MUST_SAVE
266-
if to_save
267-
else CheckpointPolicy.PREFER_RECOMPUTE
268-
)
269-
270-
return _custom_policy
271-
272-
def selective_checkpointing_context_fn():
273-
meta = defaultdict(int)
274-
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
275-
276298
return ptd_checkpoint_wrapper(
277299
module,
278-
context_fn=selective_checkpointing_context_fn,
279-
preserve_rng_state=False,
300+
make_policy_fn=CustomPolicy,
280301
)
302+
281303
elif use_layer_sac:
282304
# Checkpoint every `ac_freq` of the modules passed to this function
283305
ac_freq = int(ac_config.selective_ac_option)

0 commit comments

Comments
 (0)