Skip to content

Commit b2724da

Browse files
committed
init commit
1 parent 0032f0b commit b2724da

File tree

430 files changed

+20454
-1
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

430 files changed

+20454
-1
lines changed

.dev/clean_models.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import argparse
2+
import glob
3+
import json
4+
import os
5+
import os.path as osp
6+
7+
import mmcv
8+
9+
# build schedule look-up table to automatically find the final model
10+
SCHEDULES_LUT = {
11+
'20ki': 20000,
12+
'40ki': 40000,
13+
'60ki': 60000,
14+
'80ki': 80000,
15+
'160ki': 160000
16+
}
17+
RESULTS_LUT = ['mIoU', 'mAcc', 'aAcc']
18+
19+
20+
def get_final_iter(config):
21+
iter_num = SCHEDULES_LUT[config.split('_')[-2]]
22+
return iter_num
23+
24+
25+
def get_final_results(log_json_path, iter_num):
26+
result_dict = dict()
27+
with open(log_json_path, 'r') as f:
28+
for line in f.readlines():
29+
log_line = json.loads(line)
30+
if 'mode' not in log_line.keys():
31+
continue
32+
33+
if log_line['mode'] == 'train' and log_line['iter'] == iter_num:
34+
result_dict['memory'] = log_line['memory']
35+
36+
if log_line['iter'] == iter_num:
37+
result_dict.update({
38+
key: log_line[key]
39+
for key in RESULTS_LUT if key in log_line
40+
})
41+
return result_dict
42+
43+
44+
def parse_args():
45+
parser = argparse.ArgumentParser(description='Gather benchmarked models')
46+
parser.add_argument(
47+
'root',
48+
type=str,
49+
help='root path of benchmarked models to be gathered')
50+
parser.add_argument(
51+
'config',
52+
type=str,
53+
help='root path of benchmarked configs to be gathered')
54+
55+
args = parser.parse_args()
56+
return args
57+
58+
59+
def main():
60+
args = parse_args()
61+
models_root = args.root
62+
config_name = args.config
63+
64+
# find all models in the root directory to be gathered
65+
raw_configs = list(mmcv.scandir(config_name, '.py', recursive=True))
66+
67+
# filter configs that is not trained in the experiments dir
68+
used_configs = []
69+
for raw_config in raw_configs:
70+
work_dir = osp.splitext(osp.basename(raw_config))[0]
71+
if osp.exists(osp.join(models_root, work_dir)):
72+
used_configs.append(work_dir)
73+
print(f'Find {len(used_configs)} models to be gathered')
74+
75+
# find final_ckpt and log file for trained each config
76+
# and parse the best performance
77+
model_infos = []
78+
for used_config in used_configs:
79+
exp_dir = osp.join(models_root, used_config)
80+
# check whether the exps is finished
81+
final_iter = get_final_iter(used_config)
82+
final_model = 'iter_{}.pth'.format(final_iter)
83+
model_path = osp.join(exp_dir, final_model)
84+
85+
# skip if the model is still training
86+
if not osp.exists(model_path):
87+
print(f'{used_config} not finished yet')
88+
continue
89+
90+
# get logs
91+
log_json_path = glob.glob(osp.join(exp_dir, '*.log.json'))[0]
92+
log_txt_path = glob.glob(osp.join(exp_dir, '*.log'))[0]
93+
model_performance = get_final_results(log_json_path, final_iter)
94+
95+
if model_performance is None:
96+
print(f'{used_config} does not have performance')
97+
continue
98+
99+
model_time = osp.split(log_txt_path)[-1].split('.')[0]
100+
model_infos.append(
101+
dict(
102+
config=used_config,
103+
results=model_performance,
104+
iters=final_iter,
105+
model_time=model_time,
106+
log_json_path=osp.split(log_json_path)[-1]))
107+
108+
# publish model for each checkpoint
109+
for model in model_infos:
110+
111+
model_name = osp.split(model['config'])[-1].split('.')[0]
112+
113+
model_name += '_' + model['model_time']
114+
for checkpoints in mmcv.scandir(
115+
osp.join(models_root, model['config']), suffix='.pth'):
116+
if checkpoints.endswith(f"iter_{model['iters']}.pth"
117+
) or checkpoints.endswith('latest.pth'):
118+
continue
119+
print('removing {}'.format(
120+
osp.join(models_root, model['config'], checkpoints)))
121+
os.remove(osp.join(models_root, model['config'], checkpoints))
122+
123+
124+
if __name__ == '__main__':
125+
main()

.dev/gather_models.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import argparse
2+
import glob
3+
import json
4+
import os
5+
import os.path as osp
6+
import shutil
7+
import subprocess
8+
9+
import mmcv
10+
import torch
11+
12+
# build schedule look-up table to automatically find the final model
13+
RESULTS_LUT = ['mIoU', 'mAcc', 'aAcc']
14+
15+
16+
def process_checkpoint(in_file, out_file):
17+
checkpoint = torch.load(in_file, map_location='cpu')
18+
# remove optimizer for smaller file size
19+
if 'optimizer' in checkpoint:
20+
del checkpoint['optimizer']
21+
# if it is necessary to remove some sensitive data in checkpoint['meta'],
22+
# add the code here.
23+
torch.save(checkpoint, out_file)
24+
sha = subprocess.check_output(['sha256sum', out_file]).decode()
25+
final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
26+
subprocess.Popen(['mv', out_file, final_file])
27+
return final_file
28+
29+
30+
def get_final_iter(config):
31+
iter_num = config.split('_')[-2]
32+
assert iter_num.endswith('k')
33+
return int(iter_num[:-1]) * 1000
34+
35+
36+
def get_final_results(log_json_path, iter_num):
37+
result_dict = dict()
38+
with open(log_json_path, 'r') as f:
39+
for line in f.readlines():
40+
log_line = json.loads(line)
41+
if 'mode' not in log_line.keys():
42+
continue
43+
44+
if log_line['mode'] == 'train' and log_line['iter'] == iter_num:
45+
result_dict['memory'] = log_line['memory']
46+
47+
if log_line['iter'] == iter_num:
48+
result_dict.update({
49+
key: log_line[key]
50+
for key in RESULTS_LUT if key in log_line
51+
})
52+
return result_dict
53+
54+
55+
def parse_args():
56+
parser = argparse.ArgumentParser(description='Gather benchmarked models')
57+
parser.add_argument(
58+
'root',
59+
type=str,
60+
help='root path of benchmarked models to be gathered')
61+
parser.add_argument(
62+
'config',
63+
type=str,
64+
help='root path of benchmarked configs to be gathered')
65+
parser.add_argument(
66+
'out_dir',
67+
type=str,
68+
help='output path of gathered models to be stored')
69+
parser.add_argument('out_file', type=str, help='the output json file name')
70+
parser.add_argument(
71+
'--filter', type=str, nargs='+', default=[], help='config filter')
72+
parser.add_argument(
73+
'--all', action='store_true', help='whether include .py and .log')
74+
75+
args = parser.parse_args()
76+
return args
77+
78+
79+
def main():
80+
args = parse_args()
81+
models_root = args.root
82+
models_out = args.out_dir
83+
config_name = args.config
84+
mmcv.mkdir_or_exist(models_out)
85+
86+
# find all models in the root directory to be gathered
87+
raw_configs = list(mmcv.scandir(config_name, '.py', recursive=True))
88+
89+
# filter configs that is not trained in the experiments dir
90+
used_configs = []
91+
for raw_config in raw_configs:
92+
work_dir = osp.splitext(osp.basename(raw_config))[0]
93+
if osp.exists(osp.join(models_root, work_dir)):
94+
used_configs.append((work_dir, raw_config))
95+
print(f'Find {len(used_configs)} models to be gathered')
96+
97+
# find final_ckpt and log file for trained each config
98+
# and parse the best performance
99+
model_infos = []
100+
for used_config, raw_config in used_configs:
101+
bypass = True
102+
for p in args.filter:
103+
if p in used_config:
104+
bypass = False
105+
break
106+
if bypass:
107+
continue
108+
exp_dir = osp.join(models_root, used_config)
109+
# check whether the exps is finished
110+
final_iter = get_final_iter(used_config)
111+
final_model = 'iter_{}.pth'.format(final_iter)
112+
model_path = osp.join(exp_dir, final_model)
113+
114+
# skip if the model is still training
115+
if not osp.exists(model_path):
116+
print(f'{used_config} train not finished yet')
117+
continue
118+
119+
# get logs
120+
log_json_paths = glob.glob(osp.join(exp_dir, '*.log.json'))
121+
log_json_path = log_json_paths[0]
122+
model_performance = None
123+
for idx, _log_json_path in enumerate(log_json_paths):
124+
model_performance = get_final_results(_log_json_path, final_iter)
125+
if model_performance is not None:
126+
log_json_path = _log_json_path
127+
break
128+
129+
if model_performance is None:
130+
print(f'{used_config} model_performance is None')
131+
continue
132+
133+
model_time = osp.split(log_json_path)[-1].split('.')[0]
134+
model_infos.append(
135+
dict(
136+
config=used_config,
137+
raw_config=raw_config,
138+
results=model_performance,
139+
iters=final_iter,
140+
model_time=model_time,
141+
log_json_path=osp.split(log_json_path)[-1]))
142+
143+
# publish model for each checkpoint
144+
publish_model_infos = []
145+
for model in model_infos:
146+
model_publish_dir = osp.join(models_out,
147+
model['raw_config'].rstrip('.py'))
148+
model_name = osp.split(model['config'])[-1].split('.')[0]
149+
150+
publish_model_path = osp.join(model_publish_dir,
151+
model_name + '_' + model['model_time'])
152+
trained_model_path = osp.join(models_root, model['config'],
153+
'iter_{}.pth'.format(model['iters']))
154+
if osp.exists(model_publish_dir):
155+
for file in os.listdir(model_publish_dir):
156+
if file.endswith('.pth'):
157+
print(f'model {file} found')
158+
model['model_path'] = osp.abspath(
159+
osp.join(model_publish_dir, file))
160+
break
161+
if 'model_path' not in model:
162+
print(f'dir {model_publish_dir} exists, no model found')
163+
164+
else:
165+
mmcv.mkdir_or_exist(model_publish_dir)
166+
167+
# convert model
168+
final_model_path = process_checkpoint(trained_model_path,
169+
publish_model_path)
170+
model['model_path'] = final_model_path
171+
172+
new_json_path = f'{model_name}-{model["log_json_path"]}'
173+
# copy log
174+
shutil.copy(
175+
osp.join(models_root, model['config'], model['log_json_path']),
176+
osp.join(model_publish_dir, new_json_path))
177+
if args.all:
178+
new_txt_path = new_json_path.rstrip('.json')
179+
shutil.copy(
180+
osp.join(models_root, model['config'],
181+
model['log_json_path'].rstrip('.json')),
182+
osp.join(model_publish_dir, new_txt_path))
183+
184+
if args.all:
185+
# copy config to guarantee reproducibility
186+
raw_config = osp.join(config_name, model['raw_config'])
187+
mmcv.Config.fromfile(raw_config).dump(
188+
osp.join(model_publish_dir, osp.basename(raw_config)))
189+
190+
publish_model_infos.append(model)
191+
192+
models = dict(models=publish_model_infos)
193+
mmcv.dump(models, osp.join(models_out, args.out_file))
194+
195+
196+
if __name__ == '__main__':
197+
main()

0 commit comments

Comments
 (0)