Skip to content

Commit 1377131

Browse files
modify vit
1 parent 0ae504d commit 1377131

File tree

3 files changed

+486
-379
lines changed

3 files changed

+486
-379
lines changed

mmseg/models/backbones/helpers.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import torch
2+
import torch.nn.functional as F
3+
import math
4+
import logging
5+
import warnings
6+
import errno
7+
import os
8+
import sys
9+
import re
10+
import zipfile
11+
from urllib.parse import urlparse # noqa: F401
12+
13+
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
14+
_logger = logging.getLogger(__name__)
15+
16+
17+
def load_state_dict_from_url(url, model_dir=None, file_name=None, check_hash=False, progress=True, map_location=None):
18+
# Issue warning to move data if old env is set
19+
if os.getenv('TORCH_MODEL_ZOO'):
20+
warnings.warn(
21+
'TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
22+
23+
if model_dir is None:
24+
hub_dir = torch.hub.get_dir()
25+
model_dir = os.path.join(hub_dir, 'checkpoints')
26+
try:
27+
os.makedirs(model_dir)
28+
except OSError as e:
29+
if e.errno == errno.EEXIST:
30+
# Directory already exists, ignore.
31+
pass
32+
else:
33+
# Unexpected OSError, re-raise.
34+
raise
35+
parts = urlparse(url)
36+
filename = os.path.basename(parts.path)
37+
if file_name is not None:
38+
filename = file_name
39+
cached_file = os.path.join(model_dir, filename)
40+
if not os.path.exists(cached_file):
41+
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
42+
hash_prefix = HASH_REGEX.search(
43+
filename).group(1) if check_hash else None
44+
torch.hub.download_url_to_file(
45+
url, cached_file, hash_prefix, progress=progress)
46+
if zipfile.is_zipfile(cached_file):
47+
state_dict = torch.load(
48+
cached_file, map_location=map_location)['model']
49+
else:
50+
state_dict = torch.load(cached_file, map_location=map_location)
51+
return state_dict
52+
53+
54+
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, pos_embed_interp=False, num_patches=576, align_corners=False):
55+
if cfg is None:
56+
cfg = getattr(model, 'default_cfg')
57+
if cfg is None or 'url' not in cfg or not cfg['url']:
58+
_logger.warning(
59+
"Pretrained model URL is invalid, using random initialization.")
60+
return
61+
62+
if 'pretrained_finetune' in cfg and cfg['pretrained_finetune']:
63+
state_dict = torch.load(cfg['pretrained_finetune'])
64+
print('load pre-trained weight from ' + cfg['pretrained_finetune'])
65+
else:
66+
state_dict = load_state_dict_from_url(
67+
cfg['url'], progress=False, map_location='cpu')
68+
print('load pre-trained weight from imagenet21k')
69+
70+
if filter_fn is not None:
71+
state_dict = filter_fn(state_dict)
72+
73+
if in_chans == 1:
74+
conv1_name = cfg['first_conv']
75+
_logger.info(
76+
'Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
77+
conv1_weight = state_dict[conv1_name + '.weight']
78+
# Some weights are in torch.half, ensure it's float for sum on CPU
79+
conv1_type = conv1_weight.dtype
80+
conv1_weight = conv1_weight.float()
81+
O, I, J, K = conv1_weight.shape
82+
if I > 3:
83+
assert conv1_weight.shape[1] % 3 == 0
84+
# For models with space2depth stems
85+
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
86+
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
87+
else:
88+
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
89+
conv1_weight = conv1_weight.to(conv1_type)
90+
state_dict[conv1_name + '.weight'] = conv1_weight
91+
elif in_chans != 3:
92+
conv1_name = cfg['first_conv']
93+
conv1_weight = state_dict[conv1_name + '.weight']
94+
conv1_type = conv1_weight.dtype
95+
conv1_weight = conv1_weight.float()
96+
O, I, J, K = conv1_weight.shape
97+
if I == 3:
98+
_logger.warning(
99+
'Deleting first conv (%s) from pretrained weights.' % conv1_name)
100+
del state_dict[conv1_name + '.weight']
101+
strict = False
102+
else:
103+
# NOTE this strategy should be better than random init, but there could be other combinations of
104+
# the original RGB input layer weights that'd work better for specific cases.
105+
_logger.info(
106+
'Repeating first conv (%s) weights in channel dim.' % conv1_name)
107+
repeat = int(math.ceil(in_chans / 3))
108+
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[
109+
:, :in_chans, :, :]
110+
conv1_weight *= (3 / float(in_chans))
111+
conv1_weight = conv1_weight.to(conv1_type)
112+
state_dict[conv1_name + '.weight'] = conv1_weight
113+
114+
classifier_name = cfg['classifier']
115+
if num_classes == 1000 and cfg['num_classes'] == 1001:
116+
# special case for imagenet trained models with extra background class in pretrained weights
117+
classifier_weight = state_dict[classifier_name + '.weight']
118+
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
119+
classifier_bias = state_dict[classifier_name + '.bias']
120+
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
121+
elif num_classes != cfg['num_classes']:
122+
# completely discard fully connected for all other differences between pretrained and created model
123+
del state_dict[classifier_name + '.weight']
124+
del state_dict[classifier_name + '.bias']
125+
strict = False
126+
127+
if pos_embed_interp:
128+
n, c, hw = state_dict['pos_embed'].transpose(1, 2).shape
129+
h = w = int(math.sqrt(hw))
130+
pos_embed_weight = state_dict['pos_embed'][:, (-h * w):]
131+
pos_embed_weight = pos_embed_weight.transpose(1, 2)
132+
n, c, hw = pos_embed_weight.shape
133+
h = w = int(math.sqrt(hw))
134+
pos_embed_weight = pos_embed_weight.view(n, c, h, w)
135+
136+
pos_embed_weight = F.interpolate(pos_embed_weight, size=int(
137+
math.sqrt(num_patches)), mode='bilinear', align_corners=align_corners)
138+
pos_embed_weight = pos_embed_weight.view(n, c, -1).transpose(1, 2)
139+
140+
cls_token_weight = state_dict['pos_embed'][:, 0].unsqueeze(1)
141+
142+
state_dict['pos_embed'] = torch.cat(
143+
(cls_token_weight, pos_embed_weight), dim=1)
144+
145+
model.load_state_dict(state_dict, strict=strict)

0 commit comments

Comments
 (0)