Skip to content

Commit be8f073

Browse files
[Feature] Add with cp to mit and vit (open-mmlab#1431)
* add with cp to mit and vit * add test unit Co-authored-by: jiangyitong <[email protected]>
1 parent 17f8a96 commit be8f073

File tree

4 files changed

+62
-10
lines changed

4 files changed

+62
-10
lines changed

mmseg/models/backbones/mit.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66
import torch.nn as nn
7+
import torch.utils.checkpoint as cp
78
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
89
from mmcv.cnn.bricks.drop import build_dropout
910
from mmcv.cnn.bricks.transformer import MultiheadAttention
@@ -235,6 +236,8 @@ class TransformerEncoderLayer(BaseModule):
235236
Default:None.
236237
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
237238
Attention of Segformer. Default: 1.
239+
with_cp (bool): Use checkpoint or not. Using checkpoint will save
240+
some memory while slowing down the training speed. Default: False.
238241
"""
239242

240243
def __init__(self,
@@ -248,7 +251,8 @@ def __init__(self,
248251
act_cfg=dict(type='GELU'),
249252
norm_cfg=dict(type='LN'),
250253
batch_first=True,
251-
sr_ratio=1):
254+
sr_ratio=1,
255+
with_cp=False):
252256
super(TransformerEncoderLayer, self).__init__()
253257

254258
# The ret[0] of build_norm_layer is norm name.
@@ -275,9 +279,19 @@ def __init__(self,
275279
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
276280
act_cfg=act_cfg)
277281

282+
self.with_cp = with_cp
283+
278284
def forward(self, x, hw_shape):
279-
x = self.attn(self.norm1(x), hw_shape, identity=x)
280-
x = self.ffn(self.norm2(x), hw_shape, identity=x)
285+
286+
def _inner_forward(x):
287+
x = self.attn(self.norm1(x), hw_shape, identity=x)
288+
x = self.ffn(self.norm2(x), hw_shape, identity=x)
289+
return x
290+
291+
if self.with_cp and x.requires_grad:
292+
x = cp.checkpoint(_inner_forward, x)
293+
else:
294+
x = _inner_forward(x)
281295
return x
282296

283297

@@ -319,6 +333,8 @@ class MixVisionTransformer(BaseModule):
319333
pretrained (str, optional): model pretrained path. Default: None.
320334
init_cfg (dict or list[dict], optional): Initialization config dict.
321335
Default: None.
336+
with_cp (bool): Use checkpoint or not. Using checkpoint will save
337+
some memory while slowing down the training speed. Default: False.
322338
"""
323339

324340
def __init__(self,
@@ -339,7 +355,8 @@ def __init__(self,
339355
act_cfg=dict(type='GELU'),
340356
norm_cfg=dict(type='LN', eps=1e-6),
341357
pretrained=None,
342-
init_cfg=None):
358+
init_cfg=None,
359+
with_cp=False):
343360
super(MixVisionTransformer, self).__init__(init_cfg=init_cfg)
344361

345362
assert not (init_cfg and pretrained), \
@@ -358,8 +375,9 @@ def __init__(self,
358375
self.patch_sizes = patch_sizes
359376
self.strides = strides
360377
self.sr_ratios = sr_ratios
378+
self.with_cp = with_cp
361379
assert num_stages == len(num_layers) == len(num_heads) \
362-
== len(patch_sizes) == len(strides) == len(sr_ratios)
380+
== len(patch_sizes) == len(strides) == len(sr_ratios)
363381

364382
self.out_indices = out_indices
365383
assert max(out_indices) < self.num_stages
@@ -392,6 +410,7 @@ def __init__(self,
392410
qkv_bias=qkv_bias,
393411
act_cfg=act_cfg,
394412
norm_cfg=norm_cfg,
413+
with_cp=with_cp,
395414
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
396415
])
397416
in_channels = embed_dims_i

mmseg/models/backbones/vit.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66
import torch.nn as nn
7+
import torch.utils.checkpoint as cp
78
from mmcv.cnn import build_norm_layer
89
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
910
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
@@ -41,6 +42,8 @@ class TransformerEncoderLayer(BaseModule):
4142
batch_first (bool): Key, Query and Value are shape of
4243
(batch, n, embed_dim)
4344
or (n, batch, embed_dim). Default: True.
45+
with_cp (bool): Use checkpoint or not. Using checkpoint will save
46+
some memory while slowing down the training speed. Default: False.
4447
"""
4548

4649
def __init__(self,
@@ -54,7 +57,8 @@ def __init__(self,
5457
qkv_bias=True,
5558
act_cfg=dict(type='GELU'),
5659
norm_cfg=dict(type='LN'),
57-
batch_first=True):
60+
batch_first=True,
61+
with_cp=False):
5862
super(TransformerEncoderLayer, self).__init__()
5963

6064
self.norm1_name, norm1 = build_norm_layer(
@@ -82,6 +86,8 @@ def __init__(self,
8286
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
8387
act_cfg=act_cfg)
8488

89+
self.with_cp = with_cp
90+
8591
@property
8692
def norm1(self):
8793
return getattr(self, self.norm1_name)
@@ -91,8 +97,16 @@ def norm2(self):
9197
return getattr(self, self.norm2_name)
9298

9399
def forward(self, x):
94-
x = self.attn(self.norm1(x), identity=x)
95-
x = self.ffn(self.norm2(x), identity=x)
100+
101+
def _inner_forward(x):
102+
x = self.attn(self.norm1(x), identity=x)
103+
x = self.ffn(self.norm2(x), identity=x)
104+
return x
105+
106+
if self.with_cp and x.requires_grad:
107+
x = cp.checkpoint(_inner_forward, x)
108+
else:
109+
x = _inner_forward(x)
96110
return x
97111

98112

@@ -251,6 +265,7 @@ def __init__(self,
251265
qkv_bias=qkv_bias,
252266
act_cfg=act_cfg,
253267
norm_cfg=norm_cfg,
268+
with_cp=with_cp,
254269
batch_first=True))
255270

256271
self.final_norm = final_norm

tests/test_models/test_backbones/test_mit.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import torch
44

55
from mmseg.models.backbones import MixVisionTransformer
6-
from mmseg.models.backbones.mit import EfficientMultiheadAttention, MixFFN
6+
from mmseg.models.backbones.mit import (EfficientMultiheadAttention, MixFFN,
7+
TransformerEncoderLayer)
78

89

910
def test_mit():
@@ -56,6 +57,14 @@ def test_mit():
5657
outs = MHA(temp, hw_shape, temp)
5758
assert out.shape == (1, token_len, 64)
5859

60+
# Test TransformerEncoderLayer with checkpoint forward
61+
block = TransformerEncoderLayer(
62+
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
63+
assert block.with_cp
64+
x = torch.randn(1, 56 * 56, 64)
65+
x_out = block(x, (56, 56))
66+
assert x_out.shape == torch.Size([1, 56 * 56, 64])
67+
5968

6069
def test_mit_init():
6170
path = 'PATH_THAT_DO_NOT_EXIST'

tests/test_models/test_backbones/test_vit.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import pytest
33
import torch
44

5-
from mmseg.models.backbones.vit import VisionTransformer
5+
from mmseg.models.backbones.vit import (TransformerEncoderLayer,
6+
VisionTransformer)
67
from .utils import check_norm_state
78

89

@@ -119,6 +120,14 @@ def test_vit_backbone():
119120
assert feat[0][0].shape == (1, 768, 14, 14)
120121
assert feat[0][1].shape == (1, 768)
121122

123+
# Test TransformerEncoderLayer with checkpoint forward
124+
block = TransformerEncoderLayer(
125+
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
126+
assert block.with_cp
127+
x = torch.randn(1, 56 * 56, 64)
128+
x_out = block(x)
129+
assert x_out.shape == torch.Size([1, 56 * 56, 64])
130+
122131

123132
def test_vit_init():
124133
path = 'PATH_THAT_DO_NOT_EXIST'

0 commit comments

Comments
 (0)