Skip to content

Commit 845098b

Browse files
authored
Update train.py (#428)
* Update train.py Add user-defined hooks. * Update train.py * Update train.py
1 parent 9feaa7c commit 845098b

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

mmseg/apis/train.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import numpy as np
66
import torch
77
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
8-
from mmcv.runner import build_optimizer, build_runner
8+
from mmcv.runner import HOOKS, build_optimizer, build_runner
9+
from mmcv.utils import build_from_cfg
910

1011
from mmseg.core import DistEvalHook, EvalHook
1112
from mmseg.datasets import build_dataloader, build_dataset
@@ -113,6 +114,20 @@ def train_segmentor(model,
113114
runner.register_hook(
114115
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
115116

117+
# user-defined hooks
118+
if cfg.get('custom_hooks', None):
119+
custom_hooks = cfg.custom_hooks
120+
assert isinstance(custom_hooks, list), \
121+
f'custom_hooks expect list type, but got {type(custom_hooks)}'
122+
for hook_cfg in cfg.custom_hooks:
123+
assert isinstance(hook_cfg, dict), \
124+
'Each item in custom_hooks expects dict type, but got ' \
125+
f'{type(hook_cfg)}'
126+
hook_cfg = hook_cfg.copy()
127+
priority = hook_cfg.pop('priority', 'NORMAL')
128+
hook = build_from_cfg(hook_cfg, HOOKS)
129+
runner.register_hook(hook, priority=priority)
130+
116131
if cfg.resume_from:
117132
runner.resume(cfg.resume_from)
118133
elif cfg.load_from:

0 commit comments

Comments
 (0)