11import warnings
22
3- from mmcv .utils import Registry , build_from_cfg
4- from torch import nn
3+ from mmcv .cnn import MODELS as MMCV_MODELS
4+ from mmcv . utils import Registry
55
6- BACKBONES = Registry ('backbone' )
7- NECKS = Registry ('neck' )
8- HEADS = Registry ('head' )
9- LOSSES = Registry ('loss' )
10- SEGMENTORS = Registry ('segmentor' )
6+ MODELS = Registry ('models' , parent = MMCV_MODELS )
117
12-
13- def build (cfg , registry , default_args = None ):
14- """Build a module.
15-
16- Args:
17- cfg (dict, list[dict]): The config of modules, is is either a dict
18- or a list of configs.
19- registry (:obj:`Registry`): A registry the module belongs to.
20- default_args (dict, optional): Default arguments to build the module.
21- Defaults to None.
22-
23- Returns:
24- nn.Module: A built nn module.
25- """
26-
27- if isinstance (cfg , list ):
28- modules = [
29- build_from_cfg (cfg_ , registry , default_args ) for cfg_ in cfg
30- ]
31- return nn .Sequential (* modules )
32- else :
33- return build_from_cfg (cfg , registry , default_args )
8+ BACKBONES = MODELS
9+ NECKS = MODELS
10+ HEADS = MODELS
11+ LOSSES = MODELS
12+ SEGMENTORS = MODELS
3413
3514
3615def build_backbone (cfg ):
3716 """Build backbone."""
38- return build (cfg , BACKBONES )
17+ return BACKBONES . build (cfg )
3918
4019
4120def build_neck (cfg ):
4221 """Build neck."""
43- return build (cfg , NECKS )
22+ return NECKS . build (cfg )
4423
4524
4625def build_head (cfg ):
4726 """Build head."""
48- return build (cfg , HEADS )
27+ return HEADS . build (cfg )
4928
5029
5130def build_loss (cfg ):
5231 """Build loss."""
53- return build (cfg , LOSSES )
32+ return LOSSES . build (cfg )
5433
5534
5635def build_segmentor (cfg , train_cfg = None , test_cfg = None ):
@@ -63,4 +42,5 @@ def build_segmentor(cfg, train_cfg=None, test_cfg=None):
6342 'train_cfg specified in both outer field and model field '
6443 assert cfg .get ('test_cfg' ) is None or test_cfg is None , \
6544 'test_cfg specified in both outer field and model field '
66- return build (cfg , SEGMENTORS , dict (train_cfg = train_cfg , test_cfg = test_cfg ))
45+ return SEGMENTORS .build (
46+ cfg , default_args = dict (train_cfg = train_cfg , test_cfg = test_cfg ))
0 commit comments