12
12
import torch
13
13
import torch .nn as nn
14
14
from torch .distributed ._composable .replicate import replicate
15
+
15
16
from torch .distributed .algorithms ._checkpoint .checkpoint_wrapper import (
16
- checkpoint_wrapper as ptd_checkpoint_wrapper ,
17
+ ActivationWrapper ,
17
18
)
18
19
20
+
21
+ class CheckpointWrapper (ActivationWrapper ):
22
+ def __init__ (self , mod : torch .nn .Module , ** kwargs ):
23
+ super ().__init__ (mod )
24
+ self ._checkpoint_wrapped_module = mod
25
+ self ._make_policy_fn = kwargs .get ("make_policy_fn" , None )
26
+
27
+ def forward (self , * args , ** kwargs ):
28
+ from ac_experimental import apply_ac_policy_fn
29
+
30
+ if self ._make_policy_fn is None :
31
+ return apply_ac_policy_fn (
32
+ self ._checkpoint_wrapped_module , * args , ** kwargs , policy_fn = "recompute_all"
33
+ )
34
+ else :
35
+ # Pass is_factory=True so that a new instance of policy_fn is created per AC invocation
36
+ return apply_ac_policy_fn (
37
+ self ._checkpoint_wrapped_module , * args , ** kwargs , policy_fn = self ._make_policy_fn , is_factory = True
38
+ )
39
+
40
+
41
+ def ptd_checkpoint_wrapper (mod , ** kwargs ):
42
+ return CheckpointWrapper (mod , ** kwargs )
43
+
44
+
19
45
from torch .distributed .device_mesh import DeviceMesh
20
46
from torch .distributed .fsdp import CPUOffloadPolicy , fully_shard , MixedPrecisionPolicy
21
47
from torch .distributed .tensor import Replicate , Shard
@@ -226,6 +252,29 @@ def apply_tp(
226
252
torch .ops .aten .max .default ,
227
253
}
228
254
255
+ from torch .utils .checkpoint import CheckpointPolicy
256
+
257
+ # If you want your policy to have state, pass a class. Make sure to
258
+ # create it in global scope to avoid new instances triggering recompiles.
259
+ class CustomPolicy :
260
+ def __init__ (self ):
261
+ super ().__init__ ()
262
+ self .meta = dict ()
263
+
264
+ def __call__ (self , ctx , out , func , * args , ** kwargs ):
265
+ mm_count_key = f"mm_count"
266
+ if func == torch .ops .aten .mm .default :
267
+ self .meta [mm_count_key ] = self .meta .get (mm_count_key , 0 ) + 1
268
+
269
+ # Saves output of all compute ops, except every second mm
270
+ to_save = func in _save_list and not (
271
+ func == torch .ops .aten .mm .default and self .meta [mm_count_key ] % 2 == 0
272
+ )
273
+ return (
274
+ CheckpointPolicy .MUST_SAVE
275
+ if to_save
276
+ else CheckpointPolicy .PREFER_RECOMPUTE
277
+ )
229
278
230
279
def _apply_ac_to_transformer_block (module : nn .Module , ac_config ):
231
280
valid_ac_modes = ("full" , "selective" )
@@ -246,38 +295,11 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
246
295
f"Valid options: 'op' or a positive int representing layer frequency"
247
296
)
248
297
if use_op_sac :
249
- from torch .utils .checkpoint import (
250
- CheckpointPolicy ,
251
- create_selective_checkpoint_contexts ,
252
- )
253
-
254
- def _get_custom_policy (meta ):
255
- def _custom_policy (ctx , func , * args , ** kwargs ):
256
- mode = "recompute" if ctx .is_recompute else "forward"
257
- mm_count_key = f"{ mode } _mm_count"
258
- if func == torch .ops .aten .mm .default :
259
- meta [mm_count_key ] += 1
260
- # Saves output of all compute ops, except every second mm
261
- to_save = func in _save_list and not (
262
- func == torch .ops .aten .mm .default and meta [mm_count_key ] % 2 == 0
263
- )
264
- return (
265
- CheckpointPolicy .MUST_SAVE
266
- if to_save
267
- else CheckpointPolicy .PREFER_RECOMPUTE
268
- )
269
-
270
- return _custom_policy
271
-
272
- def selective_checkpointing_context_fn ():
273
- meta = defaultdict (int )
274
- return create_selective_checkpoint_contexts (_get_custom_policy (meta ))
275
-
276
298
return ptd_checkpoint_wrapper (
277
299
module ,
278
- context_fn = selective_checkpointing_context_fn ,
279
- preserve_rng_state = False ,
300
+ make_policy_fn = CustomPolicy ,
280
301
)
302
+
281
303
elif use_layer_sac :
282
304
# Checkpoint every `ac_freq` of the modules passed to this function
283
305
ac_freq = int (ac_config .selective_ac_option )
0 commit comments