|
| 1 | +"""Modified from https://github.com/open- |
| 2 | +mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py.""" |
| 3 | +import argparse |
| 4 | +import json |
| 5 | +from collections import defaultdict |
| 6 | + |
| 7 | +import matplotlib.pyplot as plt |
| 8 | +import seaborn as sns |
| 9 | + |
| 10 | + |
| 11 | +def plot_curve(log_dicts, args): |
| 12 | + if args.backend is not None: |
| 13 | + plt.switch_backend(args.backend) |
| 14 | + sns.set_style(args.style) |
| 15 | + # if legend is None, use {filename}_{key} as legend |
| 16 | + legend = args.legend |
| 17 | + if legend is None: |
| 18 | + legend = [] |
| 19 | + for json_log in args.json_logs: |
| 20 | + for metric in args.keys: |
| 21 | + legend.append(f'{json_log}_{metric}') |
| 22 | + assert len(legend) == (len(args.json_logs) * len(args.keys)) |
| 23 | + metrics = args.keys |
| 24 | + |
| 25 | + num_metrics = len(metrics) |
| 26 | + for i, log_dict in enumerate(log_dicts): |
| 27 | + epochs = list(log_dict.keys()) |
| 28 | + for j, metric in enumerate(metrics): |
| 29 | + print(f'plot curve of {args.json_logs[i]}, metric is {metric}') |
| 30 | + plot_epochs = [] |
| 31 | + plot_iters = [] |
| 32 | + plot_values = [] |
| 33 | + for epoch in epochs: |
| 34 | + epoch_logs = log_dict[epoch] |
| 35 | + if metric not in epoch_logs.keys(): |
| 36 | + continue |
| 37 | + if metric in ['mIoU', 'mAcc', 'aAcc']: |
| 38 | + plot_epochs.append(epoch) |
| 39 | + plot_values.append(epoch_logs[metric][0]) |
| 40 | + else: |
| 41 | + for idx in range(len(epoch_logs[metric])): |
| 42 | + plot_iters.append(epoch_logs['iter'][idx]) |
| 43 | + plot_values.append(epoch_logs[metric][idx]) |
| 44 | + ax = plt.gca() |
| 45 | + label = legend[i * num_metrics + j] |
| 46 | + if metric in ['mIoU', 'mAcc', 'aAcc']: |
| 47 | + ax.set_xticks(plot_epochs) |
| 48 | + plt.xlabel('epoch') |
| 49 | + plt.plot(plot_epochs, plot_values, label=label, marker='o') |
| 50 | + else: |
| 51 | + plt.xlabel('iter') |
| 52 | + plt.plot(plot_iters, plot_values, label=label, linewidth=0.5) |
| 53 | + plt.legend() |
| 54 | + if args.title is not None: |
| 55 | + plt.title(args.title) |
| 56 | + if args.out is None: |
| 57 | + plt.show() |
| 58 | + else: |
| 59 | + print(f'save curve to: {args.out}') |
| 60 | + plt.savefig(args.out) |
| 61 | + plt.cla() |
| 62 | + |
| 63 | + |
| 64 | +def parse_args(): |
| 65 | + parser = argparse.ArgumentParser(description='Analyze Json Log') |
| 66 | + parser.add_argument( |
| 67 | + 'json_logs', |
| 68 | + type=str, |
| 69 | + nargs='+', |
| 70 | + help='path of train log in json format') |
| 71 | + parser.add_argument( |
| 72 | + '--keys', |
| 73 | + type=str, |
| 74 | + nargs='+', |
| 75 | + default=['mIoU'], |
| 76 | + help='the metric that you want to plot') |
| 77 | + parser.add_argument('--title', type=str, help='title of figure') |
| 78 | + parser.add_argument( |
| 79 | + '--legend', |
| 80 | + type=str, |
| 81 | + nargs='+', |
| 82 | + default=None, |
| 83 | + help='legend of each plot') |
| 84 | + parser.add_argument( |
| 85 | + '--backend', type=str, default=None, help='backend of plt') |
| 86 | + parser.add_argument( |
| 87 | + '--style', type=str, default='dark', help='style of plt') |
| 88 | + parser.add_argument('--out', type=str, default=None) |
| 89 | + args = parser.parse_args() |
| 90 | + return args |
| 91 | + |
| 92 | + |
| 93 | +def load_json_logs(json_logs): |
| 94 | + # load and convert json_logs to log_dict, key is epoch, value is a sub dict |
| 95 | + # keys of sub dict is different metrics |
| 96 | + # value of sub dict is a list of corresponding values of all iterations |
| 97 | + log_dicts = [dict() for _ in json_logs] |
| 98 | + for json_log, log_dict in zip(json_logs, log_dicts): |
| 99 | + with open(json_log, 'r') as log_file: |
| 100 | + for line in log_file: |
| 101 | + log = json.loads(line.strip()) |
| 102 | + # skip lines without `epoch` field |
| 103 | + if 'epoch' not in log: |
| 104 | + continue |
| 105 | + epoch = log.pop('epoch') |
| 106 | + if epoch not in log_dict: |
| 107 | + log_dict[epoch] = defaultdict(list) |
| 108 | + for k, v in log.items(): |
| 109 | + log_dict[epoch][k].append(v) |
| 110 | + return log_dicts |
| 111 | + |
| 112 | + |
| 113 | +def main(): |
| 114 | + args = parse_args() |
| 115 | + json_logs = args.json_logs |
| 116 | + for json_log in json_logs: |
| 117 | + assert json_log.endswith('.json') |
| 118 | + log_dicts = load_json_logs(json_logs) |
| 119 | + plot_curve(log_dicts, args) |
| 120 | + |
| 121 | + |
| 122 | +if __name__ == '__main__': |
| 123 | + main() |
0 commit comments