Skip to content

Commit 5b605b0

Browse files
[Fix] Register optimizer constructor with mmseg (open-mmlab#1456)
* [fix] register optimizer onstructor with mmseg * fix lint * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * Update tests/test_core/test_optimizer.py * Update tests/test_core/test_optimizer.py * Update tests/test_core/test_optimizer.py * Update tests/test_core/test_optimizer.py * fix lint Co-authored-by: Miao Zheng <[email protected]>
1 parent 737e7e6 commit 5b605b0

File tree

6 files changed

+106
-7
lines changed

6 files changed

+106
-7
lines changed

mmseg/apis/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
import torch.distributed as dist
99
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
1010
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
11-
build_optimizer, build_runner, get_dist_info)
11+
build_runner, get_dist_info)
1212
from mmcv.utils import build_from_cfg
1313

1414
from mmseg import digit_version
15-
from mmseg.core import DistEvalHook, EvalHook
15+
from mmseg.core import DistEvalHook, EvalHook, build_optimizer
1616
from mmseg.datasets import build_dataloader, build_dataset
1717
from mmseg.utils import find_latest_checkpoint, get_root_logger
1818

mmseg/core/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .builder import (OPTIMIZER_BUILDERS, build_optimizer,
3+
build_optimizer_constructor)
24
from .evaluation import * # noqa: F401, F403
35
from .layer_decay_optimizer_constructor import \
46
LayerDecayOptimizerConstructor # noqa: F401
57
from .seg import * # noqa: F401, F403
68
from .utils import * # noqa: F401, F403
9+
10+
__all__ = [
11+
'LayerDecayOptimizerConstructor', 'OPTIMIZER_BUILDERS', 'build_optimizer',
12+
'build_optimizer_constructor'
13+
]

mmseg/core/builder.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import copy
3+
4+
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS
5+
from mmcv.utils import Registry, build_from_cfg
6+
7+
OPTIMIZER_BUILDERS = Registry(
8+
'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS)
9+
10+
11+
def build_optimizer_constructor(cfg):
12+
constructor_type = cfg.get('type')
13+
if constructor_type in OPTIMIZER_BUILDERS:
14+
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
15+
elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
16+
return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
17+
else:
18+
raise KeyError(f'{constructor_type} is not registered '
19+
'in the optimizer builder registry.')
20+
21+
22+
def build_optimizer(model, cfg):
23+
optimizer_cfg = copy.deepcopy(cfg)
24+
constructor_type = optimizer_cfg.pop('constructor',
25+
'DefaultOptimizerConstructor')
26+
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
27+
optim_constructor = build_optimizer_constructor(
28+
dict(
29+
type=constructor_type,
30+
optimizer_cfg=optimizer_cfg,
31+
paramwise_cfg=paramwise_cfg))
32+
optimizer = optim_constructor(model)
33+
return optimizer

mmseg/core/layer_decay_optimizer_constructor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor,
3-
get_dist_info)
2+
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
43

54
from mmseg.utils import get_root_logger
5+
from .builder import OPTIMIZER_BUILDERS
66

77

88
def get_num_layer_for_vit(var_name, num_max_layer):

mmseg/core/utils/layer_decay_optimizer_constructor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import json
33

4-
from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor,
5-
get_dist_info)
4+
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
65

7-
from ...utils import get_root_logger
6+
from mmseg.utils import get_root_logger
7+
from ..builder import OPTIMIZER_BUILDERS
88

99

1010
def get_num_layer_layer_wise(var_name, num_max_layer=12):

tests/test_core/test_optimizer.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import pytest
3+
import torch
4+
import torch.nn as nn
5+
from mmcv.runner import DefaultOptimizerConstructor
6+
7+
from mmseg.core.builder import (OPTIMIZER_BUILDERS, build_optimizer,
8+
build_optimizer_constructor)
9+
10+
11+
class ExampleModel(nn.Module):
12+
13+
def __init__(self):
14+
super().__init__()
15+
self.param1 = nn.Parameter(torch.ones(1))
16+
self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
17+
self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
18+
self.bn = nn.BatchNorm2d(2)
19+
20+
def forward(self, x):
21+
return x
22+
23+
24+
base_lr = 0.01
25+
base_wd = 0.0001
26+
momentum = 0.9
27+
28+
29+
def test_build_optimizer_constructor():
30+
optimizer_cfg = dict(
31+
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
32+
optim_constructor_cfg = dict(
33+
type='DefaultOptimizerConstructor', optimizer_cfg=optimizer_cfg)
34+
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
35+
# Test whether optimizer constructor can be built from parent.
36+
assert type(optim_constructor) is DefaultOptimizerConstructor
37+
38+
@OPTIMIZER_BUILDERS.register_module()
39+
class MyOptimizerConstructor(DefaultOptimizerConstructor):
40+
pass
41+
42+
optim_constructor_cfg = dict(
43+
type='MyOptimizerConstructor', optimizer_cfg=optimizer_cfg)
44+
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
45+
# Test optimizer constructor can be built from child registry.
46+
assert type(optim_constructor) is MyOptimizerConstructor
47+
48+
# Test unregistered constructor cannot be built
49+
with pytest.raises(KeyError):
50+
build_optimizer_constructor(dict(type='A'))
51+
52+
53+
def test_build_optimizer():
54+
model = ExampleModel()
55+
optimizer_cfg = dict(
56+
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
57+
optimizer = build_optimizer(model, optimizer_cfg)
58+
# test whether optimizer is successfully built from parent.
59+
assert isinstance(optimizer, torch.optim.SGD)

0 commit comments

Comments
 (0)