11""" optim factory """
2+ import collections
23import logging
34import os
4- from typing import Optional
5+ import re
6+ from collections import defaultdict
7+ from itertools import chain , islice
8+ from typing import Any , Callable , Dict , Iterator , Optional , Tuple , Union
59
610from mindspore import load_checkpoint , load_param_into_net , nn
711
1418
1519_logger = logging .getLogger (__name__ )
1620
21+ MATCH_PREV_GROUP = [9 ]
22+
1723
1824def init_group_params (params , weight_decay , weight_decay_filter , no_weight_decay ):
1925 if weight_decay_filter == "disable" :
@@ -37,6 +43,152 @@ def init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay
3743 ]
3844
3945
46+ def param_groups_layer_decay (
47+ model : nn .Cell ,
48+ lr : Optional [float ] = 1e-3 ,
49+ weight_decay : float = 0.05 ,
50+ no_weight_decay_list : Tuple [str ] = (),
51+ layer_decay : float = 0.75 ,
52+ ):
53+ """
54+ Parameter groups for layer-wise lr decay & weight decay
55+ """
56+ no_weight_decay_list = set (no_weight_decay_list )
57+ param_group_names = {} # NOTE for debugging
58+ param_groups = {}
59+ if hasattr (model , "group_matcher" ):
60+ layer_map = group_with_matcher (model .trainable_params (), model .group_matcher (coarse = False ), reverse = True )
61+ else :
62+ layer_map = _layer_map (model )
63+
64+ num_layers = max (layer_map .values ()) + 1
65+ layer_max = num_layers - 1
66+ layer_scales = list (layer_decay ** (layer_max - i ) for i in range (num_layers ))
67+
68+ for name , param in model .parameters_and_names ():
69+ if not param .requires_grad :
70+ continue
71+
72+ # no decay: all 1D parameters and model specific ones
73+ if param .ndim == 1 or name in no_weight_decay_list :
74+ g_decay = "no_decay"
75+ this_decay = 0.0
76+ else :
77+ g_decay = "decay"
78+ this_decay = weight_decay
79+
80+ layer_id = layer_map .get (name , layer_max )
81+ group_name = "layer_%d_%s" % (layer_id , g_decay )
82+
83+ if group_name not in param_groups :
84+ this_scale = layer_scales [layer_id ]
85+ param_group_names [group_name ] = {
86+ "lr" : [learning_rate * this_scale for learning_rate in lr ],
87+ "weight_decay" : this_decay ,
88+ "param_names" : [],
89+ }
90+ param_groups [group_name ] = {
91+ "lr" : [learning_rate * this_scale for learning_rate in lr ],
92+ "weight_decay" : this_decay ,
93+ "params" : [],
94+ }
95+
96+ param_group_names [group_name ]["param_names" ].append (name )
97+ param_groups [group_name ]["params" ].append (param )
98+
99+ return list (param_groups .values ())
100+
101+
102+ MATCH_PREV_GROUP = (99999 ,)
103+
104+
105+ def group_with_matcher (
106+ named_objects : Iterator [Tuple [str , Any ]], group_matcher : Union [Dict , Callable ], reverse : bool = False
107+ ):
108+ if isinstance (group_matcher , dict ):
109+ # dictionary matcher contains a dict of raw-string regex expr that must be compiled
110+ compiled = []
111+ for group_ordinal , (_ , mspec ) in enumerate (group_matcher .items ()):
112+ if mspec is None :
113+ continue
114+ # map all matching specifications into 3-tuple (compiled re, prefix, suffix)
115+ if isinstance (mspec , (tuple , list )):
116+ # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)
117+ for sspec in mspec :
118+ compiled += [(re .compile (sspec [0 ]), (group_ordinal ,), sspec [1 ])]
119+ else :
120+ compiled += [(re .compile (mspec ), (group_ordinal ,), None )]
121+ group_matcher = compiled
122+
123+ def _get_grouping (name ):
124+ if isinstance (group_matcher , (list , tuple )):
125+ for match_fn , prefix , suffix in group_matcher :
126+ r = match_fn .match (name )
127+ if r :
128+ parts = (prefix , r .groups (), suffix )
129+ # map all tuple elem to int for numeric sort, filter out None entries
130+ return tuple (map (float , chain .from_iterable (filter (None , parts ))))
131+ return (float ("inf" ),) # un-matched layers (neck, head) mapped to largest ordinal
132+ else :
133+ ord = group_matcher (name )
134+ if not isinstance (ord , collections .abc .Iterable ):
135+ return (ord ,)
136+ return tuple (ord )
137+
138+ grouping = defaultdict (list )
139+ for param in named_objects :
140+ grouping [_get_grouping (param .name )].append (param .name )
141+ # remap to integers
142+ layer_id_to_param = defaultdict (list )
143+ lid = - 1
144+ for k in sorted (filter (lambda x : x is not None , grouping .keys ())):
145+ if lid < 0 or k [- 1 ] != MATCH_PREV_GROUP [0 ]:
146+ lid += 1
147+ layer_id_to_param [lid ].extend (grouping [k ])
148+
149+ if reverse :
150+ # output reverse mapping
151+ param_to_layer_id = {}
152+ for lid , lm in layer_id_to_param .items ():
153+ for n in lm :
154+ param_to_layer_id [n ] = lid
155+ return param_to_layer_id
156+
157+ return layer_id_to_param
158+
159+
160+ def _group (it , size ):
161+ it = iter (it )
162+ return iter (lambda : tuple (islice (it , size )), ())
163+
164+
165+ def _layer_map (model , layers_per_group = 12 , num_groups = None ):
166+ def _in_head (n , hp ):
167+ if not hp :
168+ return True
169+ elif isinstance (hp , (tuple , list )):
170+ return any ([n .startswith (hpi ) for hpi in hp ])
171+ else :
172+ return n .startswith (hp )
173+
174+ # attention: need to add pretrained_cfg attr to model
175+ head_prefix = getattr (model , "pretrained_cfg" , {}).get ("classifier" , None )
176+ names_trunk = []
177+ names_head = []
178+ for n , _ in model .parameters_and_names ():
179+ names_head .append (n ) if _in_head (n , head_prefix ) else names_trunk .append (n )
180+
181+ # group non-head layers
182+ num_trunk_layers = len (names_trunk )
183+ if num_groups is not None :
184+ layers_per_group = - (num_trunk_layers // - num_groups )
185+ names_trunk = list (_group (names_trunk , layers_per_group ))
186+ num_trunk_groups = len (names_trunk )
187+ layer_map = {n : i for i , l in enumerate (names_trunk ) for n in l }
188+ layer_map .update ({n : num_trunk_groups for n in names_head })
189+ return layer_map
190+
191+
40192def create_optimizer (
41193 model_or_params ,
42194 opt : str = "adam" ,
@@ -45,6 +197,7 @@ def create_optimizer(
45197 momentum : float = 0.9 ,
46198 nesterov : bool = False ,
47199 weight_decay_filter : str = "disable" ,
200+ layer_decay : Optional [float ] = None ,
48201 loss_scale : float = 1.0 ,
49202 schedule_decay : float = 4e-3 ,
50203 checkpoint_path : str = "" ,
@@ -95,6 +248,15 @@ def create_optimizer(
95248 "when creating an mindspore.nn.Optimizer instance. "
96249 "NOTE: mindspore.nn.Optimizer will filter Norm parmas from weight decay. "
97250 )
251+ elif layer_decay is not None and isinstance (model_or_params , nn .Cell ):
252+ params = param_groups_layer_decay (
253+ model_or_params ,
254+ lr = lr ,
255+ weight_decay = weight_decay ,
256+ layer_decay = layer_decay ,
257+ no_weight_decay_list = no_weight_decay ,
258+ )
259+ weight_decay = 0.0
98260 elif weight_decay_filter == "disable" or "norm_and_bias" :
99261 params = init_group_params (params , weight_decay , weight_decay_filter , no_weight_decay )
100262 weight_decay = 0.0
0 commit comments