Skip to content

Commit e1e2c3c

Browse files
committed
[Feature]: Add UT
1 parent 729edf8 commit e1e2c3c

File tree

2 files changed

+248
-46
lines changed

2 files changed

+248
-46
lines changed

mmseg/models/backbones/mae.py

Lines changed: 66 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import math
33
import warnings
44

5+
import numpy as np
56
import torch
6-
import torch.distributed as dist
77
import torch.nn as nn
88
from mmcv.cnn import build_norm_layer
99
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
@@ -17,6 +17,11 @@
1717
from ..utils import PatchEmbed
1818
from .beit import BEiTTransformerEncoderLayer
1919

20+
try:
21+
from scipy import interpolate
22+
except ImportError:
23+
interpolate = None
24+
2025

2126
@BACKBONES.register_module()
2227
class MAE(BaseModule):
@@ -61,8 +66,8 @@ class MAE(BaseModule):
6166
with_cp (bool): Use checkpoint or not. Using checkpoint will save
6267
some memory while slowing down the training speed. Default: False.
6368
pretrained (str, optional): model pretrained path. Default: None.
64-
init_values (float): Initialize the values of MAEAttention and FFN
65-
with learnable scaling.
69+
init_values (float): Initialize the values of Attention and FFN
70+
with learnable scaling. Defaults to 0.1.
6671
init_cfg (dict or list[dict], optional): Initialization config dict.
6772
Default: None.
6873
"""
@@ -91,7 +96,7 @@ def __init__(self,
9196
norm_eval=False,
9297
with_cp=False,
9398
pretrained=None,
94-
init_values=None,
99+
init_values=0.1,
95100
init_cfg=None):
96101
super(MAE, self).__init__(init_cfg=init_cfg)
97102

@@ -166,7 +171,7 @@ def __init__(self,
166171
attn_drop_rate=attn_drop_rate,
167172
drop_path_rate=dpr[i],
168173
num_fcs=num_fcs,
169-
qkv_bias='qv_bias' if qv_bias else False,
174+
bias='qv_bias' if qv_bias else False,
170175
act_cfg=act_cfg,
171176
norm_cfg=norm_cfg,
172177
window_size=window_size,
@@ -191,6 +196,57 @@ def rescale(param, layer_id):
191196
rescale(layer.attn.proj.weight.data, layer_id + 1)
192197
rescale(layer.ffn.layers[1].weight.data, layer_id + 1)
193198

199+
def _geometric_sequence_interpolation(self, src_size, dst_size, sequence,
200+
num):
201+
"""Get new sequence via geometric sequence interpolation.
202+
203+
Args:
204+
src_size (int): Pos_embedding size in pre-trained model.
205+
dst_size (int): Pos_embedding size in the current model.
206+
sequence (tensor): The relative position bias of the pretrain
207+
model after removing the extra tokens.
208+
num (int): Number of attention heads.
209+
Returns:
210+
new_sequence (tensor): Geometric sequence interpolate the
211+
pre-trained relative position bias to the size of
212+
the current model.
213+
"""
214+
215+
def geometric_progression(a, r, n):
216+
return a * (1.0 - r**n) / (1.0 - r)
217+
218+
# Here is a binary function.
219+
left, right = 1.01, 1.5
220+
while right - left > 1e-6:
221+
q = (left + right) / 2.0
222+
gp = geometric_progression(1, q, src_size // 2)
223+
if gp > dst_size // 2:
224+
right = q
225+
else:
226+
left = q
227+
# The position of each interpolated point is determined
228+
# by the ratio obtained by dichotomy.
229+
dis = []
230+
cur = 1
231+
for i in range(src_size // 2):
232+
dis.append(cur)
233+
cur += q**(i + 1)
234+
r_ids = [-_ for _ in reversed(dis)]
235+
x = r_ids + [0] + dis
236+
y = r_ids + [0] + dis
237+
t = dst_size // 2.0
238+
dx = np.arange(-t, t + 0.1, 1.0)
239+
dy = np.arange(-t, t + 0.1, 1.0)
240+
# Interpolation functions are being executed and called.
241+
new_sequence = []
242+
for i in range(num):
243+
z = sequence[:, i].view(src_size, src_size).float().numpy()
244+
f = interpolate.interp2d(x, y, z, kind='cubic')
245+
new_sequence.append(
246+
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(sequence))
247+
new_sequence = torch.cat(new_sequence, dim=-1)
248+
return new_sequence
249+
194250
def init_weights(self):
195251

196252
def _init_weights(m):
@@ -210,51 +266,15 @@ def _init_weights(m):
210266
logger = get_root_logger()
211267
checkpoint = _load_checkpoint(
212268
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
213-
214-
if 'state_dict' in checkpoint:
215-
state_dict = checkpoint['state_dict']
216-
state_dict = {
217-
key.replace('backbone.', ''): val
218-
for key, val in state_dict.items()
219-
}
220-
else:
221-
state_dict = checkpoint
222-
223-
if 'pos_embed' in state_dict:
224-
pos_embed_checkpoint = state_dict['pos_embed']
225-
embedding_size = pos_embed_checkpoint.shape[-1]
226-
num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
227-
# height (== width) for the checkpoint position embedding
228-
orig_size = int(
229-
(pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
230-
# height (== width) for the new position embedding
231-
new_size = int(self.num_patches**0.5)
232-
# class_token and dist_token are kept unchanged
233-
if orig_size != new_size:
234-
if dist.get_rank() == 0:
235-
print('Position interpolate from %dx%d to %dx%d' %
236-
(orig_size, orig_size, new_size, new_size))
237-
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
238-
# only the position tokens are interpolated
239-
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
240-
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
241-
embedding_size).permute(
242-
0, 3, 1, 2)
243-
pos_tokens = torch.nn.functional.interpolate(
244-
pos_tokens,
245-
size=(new_size, new_size),
246-
mode='bicubic',
247-
align_corners=False)
248-
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
249-
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
250-
state_dict['pos_embed'] = new_pos_embed
251-
269+
state_dict = self.resize_rel_pos_embed(checkpoint)
252270
self.load_state_dict(state_dict, False)
253-
254271
elif self.init_cfg is not None:
255272
super(MAE, self).init_weights()
256273
else:
257-
trunc_normal_(self.pos_embed, std=.02)
274+
# We only implement the 'jax_impl' initialization implemented at
275+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
276+
# Copyright 2019 Ross Wightman
277+
# Licensed under the Apache License, Version 2.0 (the "License")
258278
trunc_normal_(self.cls_token, std=.02)
259279
for n, m in self.named_modules():
260280
if isinstance(m, nn.Linear):
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import pytest
3+
import torch
4+
5+
from mmseg.models.backbones.mae import MAE
6+
from .utils import check_norm_state
7+
8+
9+
def test_mae_backbone():
10+
with pytest.raises(TypeError):
11+
# pretrained must be a string path
12+
model = MAE()
13+
model.init_weights(pretrained=0)
14+
15+
with pytest.raises(TypeError):
16+
# img_size must be int or tuple
17+
model = MAE(img_size=512.0)
18+
19+
with pytest.raises(TypeError):
20+
# out_indices must be int ,list or tuple
21+
model = MAE(out_indices=1.)
22+
23+
with pytest.raises(AssertionError):
24+
# The length of img_size tuple must be lower than 3.
25+
MAE(img_size=(224, 224, 224))
26+
27+
with pytest.raises(TypeError):
28+
# Pretrained must be None or Str.
29+
MAE(pretrained=123)
30+
31+
# Test img_size isinstance tuple
32+
imgs = torch.randn(1, 3, 224, 224)
33+
model = MAE(img_size=(224, ))
34+
model.init_weights()
35+
model(imgs)
36+
37+
# Test img_size isinstance tuple
38+
imgs = torch.randn(1, 3, 224, 224)
39+
model = MAE(img_size=(224, 224))
40+
model(imgs)
41+
42+
# Test norm_eval = True
43+
model = MAE(norm_eval=True)
44+
model.train()
45+
46+
# Test BEiT backbone with input size of 224 and patch size of 16
47+
model = MAE()
48+
model.init_weights()
49+
model.train()
50+
51+
# Test qv_bias
52+
model = MAE(qv_bias=False)
53+
model.train()
54+
55+
# Test out_indices = list
56+
model = MAE(out_indices=[2, 4, 8, 12])
57+
model.train()
58+
59+
assert check_norm_state(model.modules(), True)
60+
61+
# Test image size = (224, 224)
62+
imgs = torch.randn(1, 3, 224, 224)
63+
feat = model(imgs)
64+
assert feat[-1].shape == (1, 768, 14, 14)
65+
66+
# Test MAE backbone with input size of 256 and patch size of 16
67+
model = MAE(img_size=(256, 256))
68+
model.init_weights()
69+
model.train()
70+
imgs = torch.randn(1, 3, 256, 256)
71+
feat = model(imgs)
72+
assert feat[-1].shape == (1, 768, 16, 16)
73+
74+
# Test MAE backbone with input size of 32 and patch size of 16
75+
model = MAE(img_size=(32, 32))
76+
model.init_weights()
77+
model.train()
78+
imgs = torch.randn(1, 3, 32, 32)
79+
feat = model(imgs)
80+
assert feat[-1].shape == (1, 768, 2, 2)
81+
82+
# Test unbalanced size input image
83+
model = MAE(img_size=(112, 224))
84+
model.init_weights()
85+
model.train()
86+
imgs = torch.randn(1, 3, 112, 224)
87+
feat = model(imgs)
88+
assert feat[-1].shape == (1, 768, 7, 14)
89+
90+
# Test irregular input image
91+
model = MAE(img_size=(234, 345))
92+
model.init_weights()
93+
model.train()
94+
imgs = torch.randn(1, 3, 234, 345)
95+
feat = model(imgs)
96+
assert feat[-1].shape == (1, 768, 14, 21)
97+
98+
# Test init_values=0
99+
model = MAE(init_values=0)
100+
imgs = torch.randn(1, 3, 224, 224)
101+
feat = model(imgs)
102+
assert feat[-1].shape == (1, 768, 14, 14)
103+
104+
# Test final norm
105+
model = MAE(final_norm=True)
106+
imgs = torch.randn(1, 3, 224, 224)
107+
feat = model(imgs)
108+
assert feat[-1].shape == (1, 768, 14, 14)
109+
110+
# Test patch norm
111+
model = MAE(patch_norm=True)
112+
imgs = torch.randn(1, 3, 224, 224)
113+
feat = model(imgs)
114+
assert feat[-1].shape == (1, 768, 14, 14)
115+
116+
117+
def test_beit_init():
118+
path = 'PATH_THAT_DO_NOT_EXIST'
119+
# Test all combinations of pretrained and init_cfg
120+
# pretrained=None, init_cfg=None
121+
model = MAE(pretrained=None, init_cfg=None)
122+
assert model.init_cfg is None
123+
model.init_weights()
124+
125+
# pretrained=None
126+
# init_cfg loads pretrain from an non-existent file
127+
model = MAE(
128+
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
129+
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
130+
# Test loading a checkpoint from an non-existent file
131+
with pytest.raises(OSError):
132+
model.init_weights()
133+
134+
# test resize_rel_pos_embed
135+
value = torch.randn(732, 16)
136+
ckpt = {
137+
'state_dict': {
138+
'layers.0.attn.relative_position_index': 0,
139+
'layers.0.attn.relative_position_bias_table': value
140+
}
141+
}
142+
model = MAE(img_size=(512, 512))
143+
with pytest.raises(AttributeError):
144+
model.resize_rel_pos_embed(ckpt)
145+
146+
# pretrained=None
147+
# init_cfg=123, whose type is unsupported
148+
model = MAE(pretrained=None, init_cfg=123)
149+
with pytest.raises(TypeError):
150+
model.init_weights()
151+
152+
# pretrained loads pretrain from an non-existent file
153+
# init_cfg=None
154+
model = MAE(pretrained=path, init_cfg=None)
155+
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
156+
# Test loading a checkpoint from an non-existent file
157+
with pytest.raises(OSError):
158+
model.init_weights()
159+
160+
# pretrained loads pretrain from an non-existent file
161+
# init_cfg loads pretrain from an non-existent file
162+
with pytest.raises(AssertionError):
163+
model = MAE(
164+
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
165+
with pytest.raises(AssertionError):
166+
model = MAE(pretrained=path, init_cfg=123)
167+
168+
# pretrain=123, whose type is unsupported
169+
# init_cfg=None
170+
with pytest.raises(TypeError):
171+
model = MAE(pretrained=123, init_cfg=None)
172+
173+
# pretrain=123, whose type is unsupported
174+
# init_cfg loads pretrain from an non-existent file
175+
with pytest.raises(AssertionError):
176+
model = MAE(
177+
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
178+
179+
# pretrain=123, whose type is unsupported
180+
# init_cfg=123, whose type is unsupported
181+
with pytest.raises(AssertionError):
182+
model = MAE(pretrained=123, init_cfg=123)

0 commit comments

Comments
 (0)