Skip to content

Commit 4132fd8

Browse files
committed
support layer_decay in optim_factory
1 parent a086e4e commit 4132fd8

File tree

3 files changed

+167
-2
lines changed

3 files changed

+167
-2
lines changed

config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def create_parser():
167167
help='Whether use clip grad (default=False)')
168168
group.add_argument('--clip_value', type=float, default=15.0,
169169
help='Clip value (default=15.0)')
170+
group.add_argument('--layer_decay', type=float, default=None,
171+
help='layer-wise learning rate decay (default: None)')
170172
group.add_argument('--gradient_accumulation_steps', type=int, default=1,
171173
help="Accumulate the gradients of n batches before update.")
172174

mindcv/optim/optim_factory.py

Lines changed: 163 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
""" optim factory """
2+
import collections
23
import logging
34
import os
4-
from typing import Optional
5+
import re
6+
from collections import defaultdict
7+
from itertools import chain, islice
8+
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union
59

610
from mindspore import load_checkpoint, load_param_into_net, nn
711

@@ -14,6 +18,8 @@
1418

1519
_logger = logging.getLogger(__name__)
1620

21+
MATCH_PREV_GROUP = [9]
22+
1723

1824
def init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay):
1925
if weight_decay_filter == "disable":
@@ -37,6 +43,152 @@ def init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay
3743
]
3844

3945

46+
def param_groups_layer_decay(
47+
model: nn.Cell,
48+
lr: Optional[float] = 1e-3,
49+
weight_decay: float = 0.05,
50+
no_weight_decay_list: Tuple[str] = (),
51+
layer_decay: float = 0.75,
52+
):
53+
"""
54+
Parameter groups for layer-wise lr decay & weight decay
55+
"""
56+
no_weight_decay_list = set(no_weight_decay_list)
57+
param_group_names = {} # NOTE for debugging
58+
param_groups = {}
59+
if hasattr(model, "group_matcher"):
60+
layer_map = group_with_matcher(model.trainable_params(), model.group_matcher(coarse=False), reverse=True)
61+
else:
62+
layer_map = _layer_map(model)
63+
64+
num_layers = max(layer_map.values()) + 1
65+
layer_max = num_layers - 1
66+
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
67+
68+
for name, param in model.parameters_and_names():
69+
if not param.requires_grad:
70+
continue
71+
72+
# no decay: all 1D parameters and model specific ones
73+
if param.ndim == 1 or name in no_weight_decay_list:
74+
g_decay = "no_decay"
75+
this_decay = 0.0
76+
else:
77+
g_decay = "decay"
78+
this_decay = weight_decay
79+
80+
layer_id = layer_map.get(name, layer_max)
81+
group_name = "layer_%d_%s" % (layer_id, g_decay)
82+
83+
if group_name not in param_groups:
84+
this_scale = layer_scales[layer_id]
85+
param_group_names[group_name] = {
86+
"lr": [learning_rate * this_scale for learning_rate in lr],
87+
"weight_decay": this_decay,
88+
"param_names": [],
89+
}
90+
param_groups[group_name] = {
91+
"lr": [learning_rate * this_scale for learning_rate in lr],
92+
"weight_decay": this_decay,
93+
"params": [],
94+
}
95+
96+
param_group_names[group_name]["param_names"].append(name)
97+
param_groups[group_name]["params"].append(param)
98+
99+
return list(param_groups.values())
100+
101+
102+
MATCH_PREV_GROUP = (99999,)
103+
104+
105+
def group_with_matcher(
106+
named_objects: Iterator[Tuple[str, Any]], group_matcher: Union[Dict, Callable], reverse: bool = False
107+
):
108+
if isinstance(group_matcher, dict):
109+
# dictionary matcher contains a dict of raw-string regex expr that must be compiled
110+
compiled = []
111+
for group_ordinal, (_, mspec) in enumerate(group_matcher.items()):
112+
if mspec is None:
113+
continue
114+
# map all matching specifications into 3-tuple (compiled re, prefix, suffix)
115+
if isinstance(mspec, (tuple, list)):
116+
# multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)
117+
for sspec in mspec:
118+
compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]
119+
else:
120+
compiled += [(re.compile(mspec), (group_ordinal,), None)]
121+
group_matcher = compiled
122+
123+
def _get_grouping(name):
124+
if isinstance(group_matcher, (list, tuple)):
125+
for match_fn, prefix, suffix in group_matcher:
126+
r = match_fn.match(name)
127+
if r:
128+
parts = (prefix, r.groups(), suffix)
129+
# map all tuple elem to int for numeric sort, filter out None entries
130+
return tuple(map(float, chain.from_iterable(filter(None, parts))))
131+
return (float("inf"),) # un-matched layers (neck, head) mapped to largest ordinal
132+
else:
133+
ord = group_matcher(name)
134+
if not isinstance(ord, collections.abc.Iterable):
135+
return (ord,)
136+
return tuple(ord)
137+
138+
grouping = defaultdict(list)
139+
for param in named_objects:
140+
grouping[_get_grouping(param.name)].append(param.name)
141+
# remap to integers
142+
layer_id_to_param = defaultdict(list)
143+
lid = -1
144+
for k in sorted(filter(lambda x: x is not None, grouping.keys())):
145+
if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
146+
lid += 1
147+
layer_id_to_param[lid].extend(grouping[k])
148+
149+
if reverse:
150+
# output reverse mapping
151+
param_to_layer_id = {}
152+
for lid, lm in layer_id_to_param.items():
153+
for n in lm:
154+
param_to_layer_id[n] = lid
155+
return param_to_layer_id
156+
157+
return layer_id_to_param
158+
159+
160+
def _group(it, size):
161+
it = iter(it)
162+
return iter(lambda: tuple(islice(it, size)), ())
163+
164+
165+
def _layer_map(model, layers_per_group=12, num_groups=None):
166+
def _in_head(n, hp):
167+
if not hp:
168+
return True
169+
elif isinstance(hp, (tuple, list)):
170+
return any([n.startswith(hpi) for hpi in hp])
171+
else:
172+
return n.startswith(hp)
173+
174+
# attention: need to add pretrained_cfg attr to model
175+
head_prefix = getattr(model, "pretrained_cfg", {}).get("classifier", None)
176+
names_trunk = []
177+
names_head = []
178+
for n, _ in model.parameters_and_names():
179+
names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)
180+
181+
# group non-head layers
182+
num_trunk_layers = len(names_trunk)
183+
if num_groups is not None:
184+
layers_per_group = -(num_trunk_layers // -num_groups)
185+
names_trunk = list(_group(names_trunk, layers_per_group))
186+
num_trunk_groups = len(names_trunk)
187+
layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}
188+
layer_map.update({n: num_trunk_groups for n in names_head})
189+
return layer_map
190+
191+
40192
def create_optimizer(
41193
model_or_params,
42194
opt: str = "adam",
@@ -45,6 +197,7 @@ def create_optimizer(
45197
momentum: float = 0.9,
46198
nesterov: bool = False,
47199
weight_decay_filter: str = "disable",
200+
layer_decay: Optional[float] = None,
48201
loss_scale: float = 1.0,
49202
schedule_decay: float = 4e-3,
50203
checkpoint_path: str = "",
@@ -95,6 +248,15 @@ def create_optimizer(
95248
"when creating an mindspore.nn.Optimizer instance. "
96249
"NOTE: mindspore.nn.Optimizer will filter Norm parmas from weight decay. "
97250
)
251+
elif layer_decay is not None and isinstance(model_or_params, nn.Cell):
252+
params = param_groups_layer_decay(
253+
model_or_params,
254+
lr=lr,
255+
weight_decay=weight_decay,
256+
layer_decay=layer_decay,
257+
no_weight_decay_list=no_weight_decay,
258+
)
259+
weight_decay = 0.0
98260
elif weight_decay_filter == "disable" or "norm_and_bias":
99261
params = init_group_params(params, weight_decay, weight_decay_filter, no_weight_decay)
100262
weight_decay = 0.0

train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,14 @@ def main():
210210
else:
211211
optimizer_loss_scale = 1.0
212212
optimizer = create_optimizer(
213-
network.trainable_params(),
213+
network,
214214
opt=args.opt,
215215
lr=lr_scheduler,
216216
weight_decay=args.weight_decay,
217217
momentum=args.momentum,
218218
nesterov=args.use_nesterov,
219219
weight_decay_filter=args.weight_decay_filter,
220+
layer_decay=args.layer_decay,
220221
loss_scale=optimizer_loss_scale,
221222
checkpoint_path=opt_ckpt_path,
222223
eps=args.eps,

0 commit comments

Comments
 (0)