Skip to content

Commit 31395a8

Browse files
NourollahMeowZheng
andauthored
[Enhancement] .dev Python files updated to get better performance and syntax (open-mmlab#2020)
* logger hooks samples updated * [Docs] Details for WandBLoggerHook Added * [Docs] lint test pass * [Enhancement] .dev Python files updated to get better performance and quality * [Docs] Details for WandBLoggerHook Added * [Docs] lint test pass * [Enhancement] .dev Python files updated to get better performance and quality * [Enhancement] lint test passed * [Enhancement] Change Some Line from Previous to Support Python<3.9 * Update .dev/gather_models.py Co-authored-by: Miao Zheng <[email protected]>
1 parent ecd1ecb commit 31395a8

12 files changed

+70
-91
lines changed

.dev/benchmark_inference.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ def parse_args():
5353
'-s', '--show', action='store_true', help='show results')
5454
parser.add_argument(
5555
'-d', '--device', default='cuda:0', help='Device used for inference')
56-
args = parser.parse_args()
57-
return args
56+
return parser.parse_args()
5857

5958

6059
def inference_model(config_name, checkpoint, args, logger=None):
@@ -66,11 +65,10 @@ def inference_model(config_name, checkpoint, args, logger=None):
6665
0.5, 0.75, 1.0, 1.25, 1.5, 1.75
6766
]
6867
cfg.data.test.pipeline[1].flip = True
68+
elif logger is None:
69+
print(f'{config_name}: unable to start aug test', flush=True)
6970
else:
70-
if logger is not None:
71-
logger.error(f'{config_name}: unable to start aug test')
72-
else:
73-
print(f'{config_name}: unable to start aug test', flush=True)
71+
logger.error(f'{config_name}: unable to start aug test')
7472

7573
model = init_segmentor(cfg, checkpoint, device=args.device)
7674
# test a single image

.dev/check_urls.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,9 @@ def check_url(url):
1818
Returns:
1919
int, bool: status code and check flag.
2020
"""
21-
flag = True
2221
r = requests.head(url)
2322
status_code = r.status_code
24-
if status_code == 403 or status_code == 404:
25-
flag = False
26-
23+
flag = status_code not in [403, 404]
2724
return status_code, flag
2825

2926

@@ -35,8 +32,7 @@ def parse_args():
3532
type=str,
3633
help='Select the model needed to check')
3734

38-
args = parser.parse_args()
39-
return args
35+
return parser.parse_args()
4036

4137

4238
def main():

.dev/gather_benchmark_evaluation_results.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def parse_args():
6262
continue
6363

6464
# Compare between new benchmark results and previous metrics
65-
differential_results = dict()
66-
new_metrics = dict()
65+
differential_results = {}
66+
new_metrics = {}
6767
for record_metric_key in previous_metrics:
6868
if record_metric_key not in metric['metric']:
6969
raise KeyError('record_metric_key not exist, please '

.dev/gather_benchmark_train_results.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ def parse_args():
7272
print(f'log file error: {log_json_path}')
7373
continue
7474

75-
differential_results = dict()
76-
old_results = dict()
77-
new_results = dict()
75+
differential_results = {}
76+
old_results = {}
77+
new_results = {}
7878
for metric_key in model_performance:
7979
if metric_key in ['mIoU']:
8080
metric = round(model_performance[metric_key] * 100, 2)

.dev/gather_models.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def process_checkpoint(in_file, out_file):
3333
# The hash code calculation and rename command differ on different system
3434
# platform.
3535
sha = calculate_file_sha256(out_file)
36-
final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
36+
final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth'
3737
os.rename(out_file, final_file)
3838

3939
# Remove prefix and suffix
@@ -50,25 +50,23 @@ def get_final_iter(config):
5050

5151

5252
def get_final_results(log_json_path, iter_num):
53-
result_dict = dict()
53+
result_dict = {}
5454
last_iter = 0
5555
with open(log_json_path, 'r') as f:
56-
for line in f.readlines():
56+
for line in f:
5757
log_line = json.loads(line)
5858
if 'mode' not in log_line.keys():
5959
continue
60-
6160
# When evaluation, the 'iter' of new log json is the evaluation
6261
# steps on single gpu.
63-
flag1 = ('aAcc' in log_line) or (log_line['mode'] == 'val')
64-
flag2 = (last_iter == iter_num - 50) or (last_iter == iter_num)
62+
flag1 = 'aAcc' in log_line or log_line['mode'] == 'val'
63+
flag2 = last_iter in [iter_num - 50, iter_num]
6564
if flag1 and flag2:
6665
result_dict.update({
6766
key: log_line[key]
6867
for key in RESULTS_LUT if key in log_line
6968
})
7069
return result_dict
71-
7270
last_iter = log_line['iter']
7371

7472

@@ -123,7 +121,7 @@ def main():
123121
exp_dir = osp.join(work_dir, config_name)
124122
# check whether the exps is finished
125123
final_iter = get_final_iter(used_config)
126-
final_model = 'iter_{}.pth'.format(final_iter)
124+
final_model = f'iter_{final_iter}.pth'
127125
model_path = osp.join(exp_dir, final_model)
128126

129127
# skip if the model is still training
@@ -135,7 +133,7 @@ def main():
135133
log_json_paths = glob.glob(osp.join(exp_dir, '*.log.json'))
136134
log_json_path = log_json_paths[0]
137135
model_performance = None
138-
for idx, _log_json_path in enumerate(log_json_paths):
136+
for _log_json_path in log_json_paths:
139137
model_performance = get_final_results(_log_json_path, final_iter)
140138
if model_performance is not None:
141139
log_json_path = _log_json_path
@@ -161,9 +159,10 @@ def main():
161159
model_publish_dir = osp.join(collect_dir, config_name)
162160

163161
publish_model_path = osp.join(model_publish_dir,
164-
config_name + '_' + model['model_time'])
162+
f'{config_name}_' + model['model_time'])
163+
165164
trained_model_path = osp.join(work_dir, config_name,
166-
'iter_{}.pth'.format(model['iters']))
165+
f'iter_{model["iters"]}.pth')
167166
if osp.exists(model_publish_dir):
168167
for file in os.listdir(model_publish_dir):
169168
if file.endswith('.pth'):

.dev/generate_benchmark_evaluation_script.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ def parse_args():
2020
default='.dev/benchmark_evaluation.sh',
2121
help='path to save model benchmark script')
2222

23-
args = parser.parse_args()
24-
return args
23+
return parser.parse_args()
2524

2625

2726
def process_model_info(model_info, work_dir):
@@ -30,10 +29,9 @@ def process_model_info(model_info, work_dir):
3029
job_name = fname
3130
checkpoint = model_info['checkpoint'].strip()
3231
work_dir = osp.join(work_dir, fname)
33-
if not isinstance(model_info['eval'], list):
34-
evals = [model_info['eval']]
35-
else:
36-
evals = model_info['eval']
32+
evals = model_info['eval'] if isinstance(model_info['eval'],
33+
list) else [model_info['eval']]
34+
3735
eval = ' '.join(evals)
3836
return dict(
3937
config=config,

.dev/generate_benchmark_train_script.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,11 @@ def main():
6969
port = args.port
7070
partition_name = 'PARTITION=$1'
7171

72-
commands = []
73-
commands.append(partition_name)
74-
commands.append('\n')
75-
commands.append('\n')
72+
commands = [partition_name, '\n', '\n']
7673

7774
with open(args.txt_path, 'r') as f:
7875
model_cfgs = f.readlines()
79-
for i, cfg in enumerate(model_cfgs):
76+
for cfg in model_cfgs:
8077
create_train_bash_info(commands, cfg, script_name, '$PARTITION',
8178
port)
8279
port += 1

.dev/log_collector/log_collector.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,11 @@
2727
def parse_args():
2828
parser = argparse.ArgumentParser(description='extract info from log.json')
2929
parser.add_argument('config_dir')
30-
args = parser.parse_args()
31-
return args
30+
return parser.parse_args()
3231

3332

3433
def has_keyword(name: str, keywords: list):
35-
for a_keyword in keywords:
36-
if a_keyword in name:
37-
return True
38-
return False
34+
return any(a_keyword in name for a_keyword in keywords)
3935

4036

4137
def main():

.dev/upload_modelzoo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ def parse_args():
1919
type=str,
2020
default='mmsegmentation/v0.5',
2121
help='destination folder')
22-
args = parser.parse_args()
23-
return args
22+
return parser.parse_args()
2423

2524

2625
def main():

docs/en/tutorials/config.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,13 @@ log_config = dict( # config to register logger hook
221221
hooks=[
222222
dict(type='TextLoggerHook', by_epoch=False),
223223
dict(type='TensorboardLoggerHook', by_epoch=False),
224-
dict(type='MMSegWandbHook', by_epoch=False, init_kwargs={'entity': entity, 'project': project, 'config': cfg_dict}), # The Wandb logger is also supported, It requires `wandb` to be installed.
224+
dict(type='MMSegWandbHook', by_epoch=False, # The Wandb logger is also supported, It requires `wandb` to be installed.
225+
init_kwargs={'entity': "OpenMMLab", # The entity used to log on Wandb
226+
'project': "MMSeg", # Project name in WandB
227+
'config': cfg_dict}), # Check https://docs.wandb.ai/ref/python/init for more init arguments.
225228
# MMSegWandbHook is mmseg implementation of WandbLoggerHook. ClearMLLoggerHook, DvcliveLoggerHook, MlflowLoggerHook, NeptuneLoggerHook, PaviLoggerHook, SegmindLoggerHook are also supported based on MMCV implementation.
226229
])
230+
227231
dist_params = dict(backend='nccl') # Parameters to setup distributed training, the port can also be set.
228232
log_level = 'INFO' # The level of logging.
229233
load_from = None # load models as a pre-trained model from a given path. This will not resume training.

0 commit comments

Comments
 (0)