11# Copyright (c) OpenMMLab. All rights reserved.
22import json
3+ import warnings
34
45from mmcv .runner import DefaultOptimizerConstructor , get_dist_info
56
67from mmseg .utils import get_root_logger
78from ..builder import OPTIMIZER_BUILDERS
89
910
10- def get_num_layer_layer_wise (var_name , num_max_layer = 12 ):
11+ def get_layer_id_for_convnext (var_name , max_layer_id ):
1112 """Get the layer id to set the different learning rates in ``layer_wise``
1213 decay_type.
1314
1415 Args:
1516 var_name (str): The key of the model.
16- num_max_layer (int): Maximum number of backbone layers.
17+ max_layer_id (int): Maximum number of backbone layers.
1718
1819 Returns:
1920 int: The id number corresponding to different learning rate in
@@ -32,7 +33,7 @@ def get_num_layer_layer_wise(var_name, num_max_layer=12):
3233 elif stage_id == 2 :
3334 layer_id = 3
3435 elif stage_id == 3 :
35- layer_id = num_max_layer
36+ layer_id = max_layer_id
3637 return layer_id
3738 elif var_name .startswith ('backbone.stages' ):
3839 stage_id = int (var_name .split ('.' )[2 ])
@@ -44,19 +45,20 @@ def get_num_layer_layer_wise(var_name, num_max_layer=12):
4445 elif stage_id == 2 :
4546 layer_id = 3 + block_id // 3
4647 elif stage_id == 3 :
47- layer_id = num_max_layer
48+ layer_id = max_layer_id
4849 return layer_id
4950 else :
50- return num_max_layer + 1
51+ return max_layer_id + 1
5152
5253
53- def get_num_layer_stage_wise (var_name , num_max_layer ):
54- """Get the layer id to set the different learning rates in ``stage_wise``
54+ def get_stage_id_for_convnext (var_name , max_stage_id ):
55+ """Get the stage id to set the different learning rates in ``stage_wise``
5556 decay_type.
5657
5758 Args:
5859 var_name (str): The key of the model.
59- num_max_layer (int): Maximum number of backbone layers.
60+ max_stage_id (int): Maximum number of backbone layers.
61+
6062 Returns:
6163 int: The id number corresponding to different learning rate in
6264 ``LearningRateDecayOptimizerConstructor``.
@@ -71,14 +73,41 @@ def get_num_layer_stage_wise(var_name, num_max_layer):
7173 stage_id = int (var_name .split ('.' )[2 ])
7274 return stage_id + 1
7375 else :
74- return num_max_layer - 1
76+ return max_stage_id - 1
77+
78+
79+ def get_layer_id_for_vit (var_name , max_layer_id ):
80+ """Get the layer id to set the different learning rates.
81+
82+ Args:
83+ var_name (str): The key of the model.
84+ num_max_layer (int): Maximum number of backbone layers.
85+
86+ Returns:
87+ int: Returns the layer id of the key.
88+ """
89+
90+ if var_name in ('backbone.cls_token' , 'backbone.mask_token' ,
91+ 'backbone.pos_embed' ):
92+ return 0
93+ elif var_name .startswith ('backbone.patch_embed' ):
94+ return 0
95+ elif var_name .startswith ('backbone.layers' ):
96+ layer_id = int (var_name .split ('.' )[2 ])
97+ return layer_id + 1
98+ else :
99+ return max_layer_id - 1
75100
76101
77102@OPTIMIZER_BUILDERS .register_module ()
78103class LearningRateDecayOptimizerConstructor (DefaultOptimizerConstructor ):
79- """Different learning rates are set for different layers of backbone."""
104+ """Different learning rates are set for different layers of backbone.
80105
81- def add_params (self , params , module ):
106+ Note: Currently, this optimizer constructor is built for ConvNeXt
107+ and BEiT.
108+ """
109+
110+ def add_params (self , params , module , ** kwargs ):
82111 """Add all parameters of module to the params list.
83112
84113 The parameters of the given module will be added to the list of param
@@ -99,7 +128,6 @@ def add_params(self, params, module):
99128 logger .info ('Build LearningRateDecayOptimizerConstructor '
100129 f'{ decay_type } { decay_rate } - { num_layers } ' )
101130 weight_decay = self .base_wd
102-
103131 for name , param in module .named_parameters ():
104132 if not param .requires_grad :
105133 continue # frozen weights
@@ -110,14 +138,22 @@ def add_params(self, params, module):
110138 else :
111139 group_name = 'decay'
112140 this_weight_decay = weight_decay
113-
114- if decay_type == 'layer_wise' :
115- layer_id = get_num_layer_layer_wise (
116- name , self .paramwise_cfg .get ('num_layers' ))
117- logger .info (f'set param { name } as id { layer_id } ' )
141+ if 'layer_wise' in decay_type :
142+ if 'ConvNeXt' in module .backbone .__class__ .__name__ :
143+ layer_id = get_layer_id_for_convnext (
144+ name , self .paramwise_cfg .get ('num_layers' ))
145+ logger .info (f'set param { name } as id { layer_id } ' )
146+ elif 'BEiT' in module .backbone .__class__ .__name__ :
147+ layer_id = get_layer_id_for_vit (name , num_layers )
148+ logger .info (f'set param { name } as id { layer_id } ' )
149+ else :
150+ raise NotImplementedError ()
118151 elif decay_type == 'stage_wise' :
119- layer_id = get_num_layer_stage_wise (name , num_layers )
120- logger .info (f'set param { name } as id { layer_id } ' )
152+ if 'ConvNeXt' in module .backbone .__class__ .__name__ :
153+ layer_id = get_stage_id_for_convnext (name , num_layers )
154+ logger .info (f'set param { name } as id { layer_id } ' )
155+ else :
156+ raise NotImplementedError ()
121157 group_name = f'layer_{ layer_id } _{ group_name } '
122158
123159 if group_name not in parameter_groups :
@@ -146,3 +182,26 @@ def add_params(self, params, module):
146182 }
147183 logger .info (f'Param groups = { json .dumps (to_display , indent = 2 )} ' )
148184 params .extend (parameter_groups .values ())
185+
186+
187+ @OPTIMIZER_BUILDERS .register_module ()
188+ class LayerDecayOptimizerConstructor (LearningRateDecayOptimizerConstructor ):
189+ """Different learning rates are set for different layers of backbone.
190+
191+ Note: Currently, this optimizer constructor is built for BEiT,
192+ and it will be deprecated.
193+ Please use ``LearningRateDecayOptimizerConstructor`` instead.
194+ """
195+
196+ def __init__ (self , optimizer_cfg , paramwise_cfg ):
197+ warnings .warn ('DeprecationWarning: Original '
198+ 'LayerDecayOptimizerConstructor of BEiT '
199+ 'will be deprecated. Please use '
200+ 'LearningRateDecayOptimizerConstructor instead, '
201+ 'and set decay_type = layer_wise_vit in paramwise_cfg.' )
202+ paramwise_cfg .update ({'decay_type' : 'layer_wise_vit' })
203+ warnings .warn ('DeprecationWarning: Layer_decay_rate will '
204+ 'be deleted, please use decay_rate instead.' )
205+ paramwise_cfg ['decay_rate' ] = paramwise_cfg .pop ('layer_decay_rate' )
206+ super (LayerDecayOptimizerConstructor ,
207+ self ).__init__ (optimizer_cfg , paramwise_cfg )
0 commit comments