1
1
import warnings
2
2
3
- from mmcv .utils import Registry , build_from_cfg
4
- from torch import nn
3
+ from mmcv .cnn import MODELS as MMCV_MODELS
4
+ from mmcv . utils import Registry
5
5
6
- BACKBONES = Registry ('backbone' )
7
- NECKS = Registry ('neck' )
8
- HEADS = Registry ('head' )
9
- LOSSES = Registry ('loss' )
10
- SEGMENTORS = Registry ('segmentor' )
6
+ MODELS = Registry ('models' , parent = MMCV_MODELS )
11
7
12
-
13
- def build (cfg , registry , default_args = None ):
14
- """Build a module.
15
-
16
- Args:
17
- cfg (dict, list[dict]): The config of modules, is is either a dict
18
- or a list of configs.
19
- registry (:obj:`Registry`): A registry the module belongs to.
20
- default_args (dict, optional): Default arguments to build the module.
21
- Defaults to None.
22
-
23
- Returns:
24
- nn.Module: A built nn module.
25
- """
26
-
27
- if isinstance (cfg , list ):
28
- modules = [
29
- build_from_cfg (cfg_ , registry , default_args ) for cfg_ in cfg
30
- ]
31
- return nn .Sequential (* modules )
32
- else :
33
- return build_from_cfg (cfg , registry , default_args )
8
+ BACKBONES = MODELS
9
+ NECKS = MODELS
10
+ HEADS = MODELS
11
+ LOSSES = MODELS
12
+ SEGMENTORS = MODELS
34
13
35
14
36
15
def build_backbone (cfg ):
37
16
"""Build backbone."""
38
- return build (cfg , BACKBONES )
17
+ return BACKBONES . build (cfg )
39
18
40
19
41
20
def build_neck (cfg ):
42
21
"""Build neck."""
43
- return build (cfg , NECKS )
22
+ return NECKS . build (cfg )
44
23
45
24
46
25
def build_head (cfg ):
47
26
"""Build head."""
48
- return build (cfg , HEADS )
27
+ return HEADS . build (cfg )
49
28
50
29
51
30
def build_loss (cfg ):
52
31
"""Build loss."""
53
- return build (cfg , LOSSES )
32
+ return LOSSES . build (cfg )
54
33
55
34
56
35
def build_segmentor (cfg , train_cfg = None , test_cfg = None ):
@@ -63,4 +42,5 @@ def build_segmentor(cfg, train_cfg=None, test_cfg=None):
63
42
'train_cfg specified in both outer field and model field '
64
43
assert cfg .get ('test_cfg' ) is None or test_cfg is None , \
65
44
'test_cfg specified in both outer field and model field '
66
- return build (cfg , SEGMENTORS , dict (train_cfg = train_cfg , test_cfg = test_cfg ))
45
+ return SEGMENTORS .build (
46
+ cfg , default_args = dict (train_cfg = train_cfg , test_cfg = test_cfg ))
0 commit comments