Skip to content

Commit 5a58e40

Browse files
huangzehaoRogerChern
authored andcommitted
Nasfpn refactor (tusen-ai#117)
* general backbone * align with retinanet refactor * fix
1 parent f63b499 commit 5a58e40

File tree

3 files changed

+54
-59
lines changed

3 files changed

+54
-59
lines changed

config/NASFPN/retina_r50v1b_fpn_640640_25epoch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from models.retinanet.builder import RetinaNet as Detector
2-
from models.NASFPN.builder import MSRAResNet50V1bFPN as Backbone
2+
from models.NASFPN.builder import ResNetV1bFPN as Backbone
33
from models.NASFPN.builder import RetinaNetNeckWithBN as Neck
44
from models.NASFPN.builder import RetinaNetHeadWithBN as RpnHead
55
from mxnext.complicate import normalizer_factory
@@ -27,6 +27,7 @@ class NormalizeParam:
2727
class BackboneParam:
2828
fp16 = General.fp16
2929
normalizer = NormalizeParam.normalizer
30+
depth = 50
3031

3132

3233
class NeckParam:
@@ -106,7 +107,7 @@ class ModelParam:
106107
memonger_until = "stage3_unit21_plus"
107108

108109
class pretrain:
109-
prefix = "pretrain_model/resnet50_v1b"
110+
prefix = "pretrain_model/resnet%s_v1b" % BackboneParam.depth
110111
epoch = 0
111112
fixed_param = ["conv0"]
112113

config/NASFPN/retina_r50v1b_nasfpn_640640_25epoch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from models.retinanet.builder import RetinaNet as Detector
2-
from models.NASFPN.builder import MSRAResNet50V1bFPN as Backbone
2+
from models.NASFPN.builder import ResNetV1bFPN as Backbone
33
from models.NASFPN.builder import NASFPNNeck as Neck
44
from models.NASFPN.builder import RetinaNetHeadWithBN as RpnHead
55
from mxnext.complicate import normalizer_factory
@@ -27,6 +27,7 @@ class NormalizeParam:
2727
class BackboneParam:
2828
fp16 = General.fp16
2929
normalizer = NormalizeParam.normalizer
30+
depth = 50
3031

3132

3233
class NeckParam:
@@ -108,7 +109,7 @@ class ModelParam:
108109
memonger_until = "S7_P6_7_bn"
109110

110111
class pretrain:
111-
prefix = "pretrain_model/resnet50_v1b"
112+
prefix = "pretrain_model/resnet%s_v1b" % BackboneParam.depth
112113
epoch = 0
113114
fixed_param = ["conv0"]
114115

models/NASFPN/builder.py

Lines changed: 48 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def reluconvbn(data, num_filter, init, norm, name, prefix):
5151
class NASFPNNeck(Neck):
5252
def __init__(self, pNeck):
5353
super().__init__(pNeck)
54+
self.neck = None
5455

5556
@staticmethod
5657
def get_P0_features(c_features, p_names, dim_reduced, init, norm, kernel=1):
@@ -149,6 +150,9 @@ def get_fused_P_feature(p_features, stage, dim_reduced, init, norm):
149150
'S{}_P7'.format(stage): P7}
150151

151152
def get_nasfpn_neck(self, data):
153+
if self.neck is not None:
154+
return self.neck
155+
152156
dim_reduced = self.p.dim_reduced
153157
norm = self.p.normalizer
154158
num_stage = self.p.num_stage
@@ -169,11 +173,16 @@ def get_nasfpn_neck(self, data):
169173
# stack stage
170174
for i in range(num_stage):
171175
p_features = self.get_fused_P_feature(p_features, i + 1, dim_reduced, xavier_init, norm)
172-
return p_features['S{}_P3'.format(num_stage)], \
173-
p_features['S{}_P4'.format(num_stage)], \
174-
p_features['S{}_P5'.format(num_stage)], \
175-
p_features['S{}_P6'.format(num_stage)], \
176-
p_features['S{}_P7'.format(num_stage)]
176+
177+
self.neck = dict(
178+
stride8=p_features['S{}_P3'.format(num_stage)],
179+
stride16=p_features['S{}_P4'.format(num_stage)],
180+
stride32=p_features['S{}_P5'.format(num_stage)],
181+
stride64=p_features['S{}_P6'.format(num_stage)],
182+
stride128=p_features['S{}_P7'.format(num_stage)]
183+
)
184+
185+
return self.neck
177186

178187
def get_rpn_feature(self, rpn_feat):
179188
return self.get_nasfpn_neck(rpn_feat)
@@ -318,8 +327,8 @@ def _bbox_subnet(self, conv_feat, conv_channel, num_base_anchor, num_class, stri
318327
return output
319328

320329
def get_output(self, conv_feat):
321-
if self._cls_logit_list is not None and self._bbox_delta_list is not None:
322-
return self._cls_logit_list, self._bbox_delta_list
330+
if self._cls_logit_dict is not None and self._bbox_delta_dict is not None:
331+
return self._cls_logit_dict, self._bbox_delta_dict
323332

324333
p = self.p
325334
stride = p.anchor_generate.stride
@@ -329,67 +338,43 @@ def get_output(self, conv_feat):
329338
num_base_anchor = len(p.anchor_generate.ratio) * len(p.anchor_generate.scale)
330339
num_class = p.num_class
331340

332-
prior_prob = 0.01
333-
pi = -math.log((1-prior_prob) / prior_prob)
334-
335-
# shared classification weight and bias
336-
self.cls_conv1_weight = X.var("cls_conv1_weight", init=X.gauss(std=0.01))
337-
self.cls_conv1_bias = X.var("cls_conv1_bias", init=X.zero_init())
338-
self.cls_conv2_weight = X.var("cls_conv2_weight", init=X.gauss(std=0.01))
339-
self.cls_conv2_bias = X.var("cls_conv2_bias", init=X.zero_init())
340-
self.cls_conv3_weight = X.var("cls_conv3_weight", init=X.gauss(std=0.01))
341-
self.cls_conv3_bias = X.var("cls_conv3_bias", init=X.zero_init())
342-
self.cls_conv4_weight = X.var("cls_conv4_weight", init=X.gauss(std=0.01))
343-
self.cls_conv4_bias = X.var("cls_conv4_bias", init=X.zero_init())
344-
self.cls_pred_weight = X.var("cls_pred_weight", init=X.gauss(std=0.01))
345-
self.cls_pred_bias = X.var("cls_pred_bias", init=X.constant(pi))
346-
347-
# shared regression weight and bias
348-
self.bbox_conv1_weight = X.var("bbox_conv1_weight", init=X.gauss(std=0.01))
349-
self.bbox_conv1_bias = X.var("bbox_conv1_bias", init=X.zero_init())
350-
self.bbox_conv2_weight = X.var("bbox_conv2_weight", init=X.gauss(std=0.01))
351-
self.bbox_conv2_bias = X.var("bbox_conv2_bias", init=X.zero_init())
352-
self.bbox_conv3_weight = X.var("bbox_conv3_weight", init=X.gauss(std=0.01))
353-
self.bbox_conv3_bias = X.var("bbox_conv3_bias", init=X.zero_init())
354-
self.bbox_conv4_weight = X.var("bbox_conv4_weight", init=X.gauss(std=0.01))
355-
self.bbox_conv4_bias = X.var("bbox_conv4_bias", init=X.zero_init())
356-
self.bbox_pred_weight = X.var("bbox_pred_weight", init=X.gauss(std=0.01))
357-
self.bbox_pred_bias = X.var("bbox_pred_bias", init=X.zero_init())
358-
359-
cls_logit_list = []
360-
bbox_delta_list = []
361-
362-
for i, s in enumerate(stride):
341+
cls_logit_dict = dict()
342+
bbox_delta_dict = dict()
343+
344+
for s in stride:
363345
cls_logit = self._cls_subnet(
364-
conv_feat=conv_feat[i],
346+
conv_feat=conv_feat["stride%s" % s],
365347
conv_channel=conv_channel,
366348
num_base_anchor=num_base_anchor,
367349
num_class=num_class,
368350
stride=s
369351
)
370352

371353
bbox_delta = self._bbox_subnet(
372-
conv_feat=conv_feat[i],
354+
conv_feat=conv_feat["stride%s" % s],
373355
conv_channel=conv_channel,
374356
num_base_anchor=num_base_anchor,
375357
num_class=num_class,
376358
stride=s
377359
)
378360

379-
cls_logit_list.append(cls_logit)
380-
bbox_delta_list.append(bbox_delta)
361+
cls_logit_dict["stride%s" % s] = cls_logit
362+
bbox_delta_dict["stride%s" % s] = bbox_delta
381363

382-
self._cls_logit_list = cls_logit_list
383-
self._bbox_delta_list = bbox_delta_list
364+
self._cls_logit_dict = cls_logit_dict
365+
self._bbox_delta_dict = bbox_delta_dict
384366

385-
return self._cls_logit_list, self._bbox_delta_list
367+
return self._cls_logit_dict, self._bbox_delta_dict
386368

387369

388370
class RetinaNetNeckWithBN(RetinaNetNeck):
389371
def __init__(self, pNeck):
390372
super().__init__(pNeck)
391373

392374
def get_retinanet_neck(self, data):
375+
if self.neck is not None:
376+
return self.neck
377+
393378
norm = self.p.normalizer
394379
c2, c3, c4, c5 = data
395380

@@ -473,7 +458,7 @@ def get_retinanet_neck(self, data):
473458
)
474459

475460
# P6
476-
P6 = X.conv(
461+
p6 = X.conv(
477462
data=c5,
478463
kernel=3,
479464
stride=2,
@@ -485,9 +470,9 @@ def get_retinanet_neck(self, data):
485470
)
486471

487472
# P7
488-
P6_relu = X.relu(data=P6, name="P6_relu")
489-
P7 = X.conv(
490-
data=P6_relu,
473+
p6_relu = X.relu(data=p6, name="P6_relu")
474+
p7 = X.conv(
475+
data=p6_relu,
491476
kernel=3,
492477
stride=2,
493478
filter=256,
@@ -500,18 +485,26 @@ def get_retinanet_neck(self, data):
500485
p3_conv = norm(p3_conv, name="P3_conv_bn")
501486
p4_conv = norm(p4_conv, name="P4_conv_bn")
502487
p5_conv = norm(p5_conv, name="P5_conv_bn")
503-
P6 = norm(P6, name="P6_conv_bn")
504-
P7 = norm(P7, name="P7_conv_bn")
488+
p6 = norm(p6, name="P6_conv_bn")
489+
p7 = norm(p7, name="P7_conv_bn")
490+
491+
self.neck = dict(
492+
stride8=p3_conv,
493+
stride16=p4_conv,
494+
stride32=p5_conv,
495+
stride64=p6,
496+
stride128=p7
497+
)
505498

506-
return p3_conv, p4_conv, p5_conv, P6, P7
499+
return self.neck
507500

508501

509-
class MSRAResNet50V1bFPN(Backbone):
502+
class ResNetV1bFPN(Backbone):
510503
def __init__(self, pBackbone):
511-
super(MSRAResNet50V1bFPN, self).__init__(pBackbone)
504+
super().__init__(pBackbone)
512505
from mxnext.backbone.resnet_v1b import Builder
513506
b = Builder()
514-
self.symbol = b.get_backbone("msra", 50, "fpn", pBackbone.normalizer, pBackbone.fp16)
507+
self.symbol = b.get_backbone("msra", pBackbone.depth, "fpn", pBackbone.normalizer, pBackbone.fp16)
515508

516509
def get_rpn_feature(self):
517510
return self.symbol

0 commit comments

Comments
 (0)