|
5 | 5 | import numpy as np |
6 | 6 | import torch |
7 | 7 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel |
8 | | -from mmcv.runner import build_optimizer, build_runner |
| 8 | +from mmcv.runner import HOOKS, build_optimizer, build_runner |
| 9 | +from mmcv.utils import build_from_cfg |
9 | 10 |
|
10 | 11 | from mmseg.core import DistEvalHook, EvalHook |
11 | 12 | from mmseg.datasets import build_dataloader, build_dataset |
@@ -113,6 +114,20 @@ def train_segmentor(model, |
113 | 114 | runner.register_hook( |
114 | 115 | eval_hook(val_dataloader, **eval_cfg), priority='LOW') |
115 | 116 |
|
| 117 | + # user-defined hooks |
| 118 | + if cfg.get('custom_hooks', None): |
| 119 | + custom_hooks = cfg.custom_hooks |
| 120 | + assert isinstance(custom_hooks, list), \ |
| 121 | + f'custom_hooks expect list type, but got {type(custom_hooks)}' |
| 122 | + for hook_cfg in cfg.custom_hooks: |
| 123 | + assert isinstance(hook_cfg, dict), \ |
| 124 | + 'Each item in custom_hooks expects dict type, but got ' \ |
| 125 | + f'{type(hook_cfg)}' |
| 126 | + hook_cfg = hook_cfg.copy() |
| 127 | + priority = hook_cfg.pop('priority', 'NORMAL') |
| 128 | + hook = build_from_cfg(hook_cfg, HOOKS) |
| 129 | + runner.register_hook(hook, priority=priority) |
| 130 | + |
116 | 131 | if cfg.resume_from: |
117 | 132 | runner.resume(cfg.resume_from) |
118 | 133 | elif cfg.load_from: |
|
0 commit comments