Skip to content

Commit 9dae7e1

Browse files
authored
[Enhancement]Revise docstring and UT for layer decay LR (open-mmlab#1540)
* fix docstring * fix ut for optimizer cosntructor
1 parent d1281a0 commit 9dae7e1

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

mmseg/core/optimizers/layer_decay_optimizer_constructor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def get_layer_id_for_convnext(var_name, max_layer_id):
1717
max_layer_id (int): Maximum number of backbone layers.
1818
1919
Returns:
20-
int: The id number corresponding to different learning rate in
20+
int: The id number corresponding to different learning rate in
2121
``LearningRateDecayOptimizerConstructor``.
2222
"""
2323

@@ -60,7 +60,7 @@ def get_stage_id_for_convnext(var_name, max_stage_id):
6060
max_stage_id (int): Maximum number of backbone layers.
6161
6262
Returns:
63-
int: The id number corresponding to different learning rate in
63+
int: The id number corresponding to different learning rate in
6464
``LearningRateDecayOptimizerConstructor``.
6565
"""
6666

@@ -103,8 +103,8 @@ def get_layer_id_for_vit(var_name, max_layer_id):
103103
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
104104
"""Different learning rates are set for different layers of backbone.
105105
106-
Note: Currently, this optimizer constructor is built for ConvNeXt
107-
and BEiT.
106+
Note: Currently, this optimizer constructor is built for ConvNeXt,
107+
BEiT and MAE.
108108
"""
109109

110110
def add_params(self, params, module, **kwargs):

tests/test_core/test_layer_decay_optimizer_constructor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,19 @@ def __init__(self):
157157
self.layers.append(layer)
158158

159159

160+
class ToyMAE(nn.Module):
161+
162+
def __init__(self):
163+
super().__init__()
164+
# add some variables to meet unit test coverate rate
165+
self.cls_token = nn.Parameter(torch.ones(1))
166+
self.patch_embed = nn.Parameter(torch.ones(1))
167+
self.layers = nn.ModuleList()
168+
for _ in range(3):
169+
layer = nn.Conv2d(3, 3, 1)
170+
self.layers.append(layer)
171+
172+
160173
class ToySegmentor(nn.Module):
161174

162175
def __init__(self, backbone):
@@ -236,6 +249,17 @@ def test_learning_rate_decay_optimizer_constructor():
236249
optimizer_cfg, stagewise_paramwise_cfg)
237250
optimizer = optim_constructor(model)
238251

252+
# Test lr wd for MAE
253+
backbone = ToyMAE()
254+
model = PseudoDataParallel(ToySegmentor(backbone))
255+
256+
layerwise_paramwise_cfg = dict(
257+
decay_rate=decay_rate, decay_type='layer_wise', num_layers=3)
258+
optim_constructor = LearningRateDecayOptimizerConstructor(
259+
optimizer_cfg, layerwise_paramwise_cfg)
260+
optimizer = optim_constructor(model)
261+
check_optimizer_lr_wd(optimizer, expected_layer_wise_wd_lr_beit)
262+
239263

240264
def test_beit_layer_decay_optimizer_constructor():
241265

0 commit comments

Comments
 (0)