Skip to content

Commit b818946

Browse files
author
谢昕辰
authored
add plot_logs tool (open-mmlab#426)
* Support plot logs * add plot log docs
1 parent 0c31afe commit b818946

File tree

3 files changed

+146
-1
lines changed

3 files changed

+146
-1
lines changed

docs/useful_tools.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,25 @@ python tools/pytorch2onnx.py ${CONFIG_FILE} --checkpoint ${CHECKPOINT_FILE} --ou
6262
```shell
6363
python tools/print_config.py ${CONFIG} [-h] [--options ${OPTIONS [OPTIONS...]}]
6464
```
65+
66+
### Plot training logs
67+
68+
`tools/analyze_logs.py` plot s loss/mIoU curves given a training log file. `pip install seaborn` first to install the dependency.
69+
70+
```shell
71+
python tools/analyze_logs.py xxx.log.json [--keys ${KEYS}] [--legend ${LEGEND}] [--backend ${BACKEND}] [--style ${STYLE}] [--out ${OUT_FILE}]
72+
```
73+
74+
Examples:
75+
76+
- Plot the mIoU, mAcc, aAcc metrics.
77+
78+
```shell
79+
python tools/analyze_logs.py log.json --keys mIoU mAcc aAcc --legend mIoU mAcc aAcc
80+
```
81+
82+
- Plot loss metric.
83+
84+
```shell
85+
python tools/analyze_logs.py log.json --keys loss --legend loss
86+
```

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ line_length = 79
88
multi_line_output = 0
99
known_standard_library = setuptools
1010
known_first_party = mmseg
11-
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,terminaltables,torch
11+
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,seaborn,terminaltables,torch
1212
no_lines_before = STDLIB,LOCALFOLDER
1313
default_section = THIRDPARTY

tools/analyze_logs.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""Modified from https://github.com/open-
2+
mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py."""
3+
import argparse
4+
import json
5+
from collections import defaultdict
6+
7+
import matplotlib.pyplot as plt
8+
import seaborn as sns
9+
10+
11+
def plot_curve(log_dicts, args):
12+
if args.backend is not None:
13+
plt.switch_backend(args.backend)
14+
sns.set_style(args.style)
15+
# if legend is None, use {filename}_{key} as legend
16+
legend = args.legend
17+
if legend is None:
18+
legend = []
19+
for json_log in args.json_logs:
20+
for metric in args.keys:
21+
legend.append(f'{json_log}_{metric}')
22+
assert len(legend) == (len(args.json_logs) * len(args.keys))
23+
metrics = args.keys
24+
25+
num_metrics = len(metrics)
26+
for i, log_dict in enumerate(log_dicts):
27+
epochs = list(log_dict.keys())
28+
for j, metric in enumerate(metrics):
29+
print(f'plot curve of {args.json_logs[i]}, metric is {metric}')
30+
plot_epochs = []
31+
plot_iters = []
32+
plot_values = []
33+
for epoch in epochs:
34+
epoch_logs = log_dict[epoch]
35+
if metric not in epoch_logs.keys():
36+
continue
37+
if metric in ['mIoU', 'mAcc', 'aAcc']:
38+
plot_epochs.append(epoch)
39+
plot_values.append(epoch_logs[metric][0])
40+
else:
41+
for idx in range(len(epoch_logs[metric])):
42+
plot_iters.append(epoch_logs['iter'][idx])
43+
plot_values.append(epoch_logs[metric][idx])
44+
ax = plt.gca()
45+
label = legend[i * num_metrics + j]
46+
if metric in ['mIoU', 'mAcc', 'aAcc']:
47+
ax.set_xticks(plot_epochs)
48+
plt.xlabel('epoch')
49+
plt.plot(plot_epochs, plot_values, label=label, marker='o')
50+
else:
51+
plt.xlabel('iter')
52+
plt.plot(plot_iters, plot_values, label=label, linewidth=0.5)
53+
plt.legend()
54+
if args.title is not None:
55+
plt.title(args.title)
56+
if args.out is None:
57+
plt.show()
58+
else:
59+
print(f'save curve to: {args.out}')
60+
plt.savefig(args.out)
61+
plt.cla()
62+
63+
64+
def parse_args():
65+
parser = argparse.ArgumentParser(description='Analyze Json Log')
66+
parser.add_argument(
67+
'json_logs',
68+
type=str,
69+
nargs='+',
70+
help='path of train log in json format')
71+
parser.add_argument(
72+
'--keys',
73+
type=str,
74+
nargs='+',
75+
default=['mIoU'],
76+
help='the metric that you want to plot')
77+
parser.add_argument('--title', type=str, help='title of figure')
78+
parser.add_argument(
79+
'--legend',
80+
type=str,
81+
nargs='+',
82+
default=None,
83+
help='legend of each plot')
84+
parser.add_argument(
85+
'--backend', type=str, default=None, help='backend of plt')
86+
parser.add_argument(
87+
'--style', type=str, default='dark', help='style of plt')
88+
parser.add_argument('--out', type=str, default=None)
89+
args = parser.parse_args()
90+
return args
91+
92+
93+
def load_json_logs(json_logs):
94+
# load and convert json_logs to log_dict, key is epoch, value is a sub dict
95+
# keys of sub dict is different metrics
96+
# value of sub dict is a list of corresponding values of all iterations
97+
log_dicts = [dict() for _ in json_logs]
98+
for json_log, log_dict in zip(json_logs, log_dicts):
99+
with open(json_log, 'r') as log_file:
100+
for line in log_file:
101+
log = json.loads(line.strip())
102+
# skip lines without `epoch` field
103+
if 'epoch' not in log:
104+
continue
105+
epoch = log.pop('epoch')
106+
if epoch not in log_dict:
107+
log_dict[epoch] = defaultdict(list)
108+
for k, v in log.items():
109+
log_dict[epoch][k].append(v)
110+
return log_dicts
111+
112+
113+
def main():
114+
args = parse_args()
115+
json_logs = args.json_logs
116+
for json_log in json_logs:
117+
assert json_log.endswith('.json')
118+
log_dicts = load_json_logs(json_logs)
119+
plot_curve(log_dicts, args)
120+
121+
122+
if __name__ == '__main__':
123+
main()

0 commit comments

Comments
 (0)