Skip to content

Commit d568d06

Browse files
authored
[Refactor] Use MMCV MODEL_REGISTRY (open-mmlab#515)
* [Refactor] Use MMCV MODEL_REGISTRY * fixed args
1 parent 2da3da4 commit d568d06

File tree

1 file changed

+14
-34
lines changed

1 file changed

+14
-34
lines changed

mmseg/models/builder.py

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,35 @@
11
import warnings
22

3-
from mmcv.utils import Registry, build_from_cfg
4-
from torch import nn
3+
from mmcv.cnn import MODELS as MMCV_MODELS
4+
from mmcv.utils import Registry
55

6-
BACKBONES = Registry('backbone')
7-
NECKS = Registry('neck')
8-
HEADS = Registry('head')
9-
LOSSES = Registry('loss')
10-
SEGMENTORS = Registry('segmentor')
6+
MODELS = Registry('models', parent=MMCV_MODELS)
117

12-
13-
def build(cfg, registry, default_args=None):
14-
"""Build a module.
15-
16-
Args:
17-
cfg (dict, list[dict]): The config of modules, is is either a dict
18-
or a list of configs.
19-
registry (:obj:`Registry`): A registry the module belongs to.
20-
default_args (dict, optional): Default arguments to build the module.
21-
Defaults to None.
22-
23-
Returns:
24-
nn.Module: A built nn module.
25-
"""
26-
27-
if isinstance(cfg, list):
28-
modules = [
29-
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
30-
]
31-
return nn.Sequential(*modules)
32-
else:
33-
return build_from_cfg(cfg, registry, default_args)
8+
BACKBONES = MODELS
9+
NECKS = MODELS
10+
HEADS = MODELS
11+
LOSSES = MODELS
12+
SEGMENTORS = MODELS
3413

3514

3615
def build_backbone(cfg):
3716
"""Build backbone."""
38-
return build(cfg, BACKBONES)
17+
return BACKBONES.build(cfg)
3918

4019

4120
def build_neck(cfg):
4221
"""Build neck."""
43-
return build(cfg, NECKS)
22+
return NECKS.build(cfg)
4423

4524

4625
def build_head(cfg):
4726
"""Build head."""
48-
return build(cfg, HEADS)
27+
return HEADS.build(cfg)
4928

5029

5130
def build_loss(cfg):
5231
"""Build loss."""
53-
return build(cfg, LOSSES)
32+
return LOSSES.build(cfg)
5433

5534

5635
def build_segmentor(cfg, train_cfg=None, test_cfg=None):
@@ -63,4 +42,5 @@ def build_segmentor(cfg, train_cfg=None, test_cfg=None):
6342
'train_cfg specified in both outer field and model field '
6443
assert cfg.get('test_cfg') is None or test_cfg is None, \
6544
'test_cfg specified in both outer field and model field '
66-
return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
45+
return SEGMENTORS.build(
46+
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))

0 commit comments

Comments
 (0)