@@ -51,6 +51,7 @@ def reluconvbn(data, num_filter, init, norm, name, prefix):
5151class NASFPNNeck (Neck ):
5252 def __init__ (self , pNeck ):
5353 super ().__init__ (pNeck )
54+ self .neck = None
5455
5556 @staticmethod
5657 def get_P0_features (c_features , p_names , dim_reduced , init , norm , kernel = 1 ):
@@ -149,6 +150,9 @@ def get_fused_P_feature(p_features, stage, dim_reduced, init, norm):
149150 'S{}_P7' .format (stage ): P7 }
150151
151152 def get_nasfpn_neck (self , data ):
153+ if self .neck is not None :
154+ return self .neck
155+
152156 dim_reduced = self .p .dim_reduced
153157 norm = self .p .normalizer
154158 num_stage = self .p .num_stage
@@ -169,11 +173,16 @@ def get_nasfpn_neck(self, data):
169173 # stack stage
170174 for i in range (num_stage ):
171175 p_features = self .get_fused_P_feature (p_features , i + 1 , dim_reduced , xavier_init , norm )
172- return p_features ['S{}_P3' .format (num_stage )], \
173- p_features ['S{}_P4' .format (num_stage )], \
174- p_features ['S{}_P5' .format (num_stage )], \
175- p_features ['S{}_P6' .format (num_stage )], \
176- p_features ['S{}_P7' .format (num_stage )]
176+
177+ self .neck = dict (
178+ stride8 = p_features ['S{}_P3' .format (num_stage )],
179+ stride16 = p_features ['S{}_P4' .format (num_stage )],
180+ stride32 = p_features ['S{}_P5' .format (num_stage )],
181+ stride64 = p_features ['S{}_P6' .format (num_stage )],
182+ stride128 = p_features ['S{}_P7' .format (num_stage )]
183+ )
184+
185+ return self .neck
177186
178187 def get_rpn_feature (self , rpn_feat ):
179188 return self .get_nasfpn_neck (rpn_feat )
@@ -318,8 +327,8 @@ def _bbox_subnet(self, conv_feat, conv_channel, num_base_anchor, num_class, stri
318327 return output
319328
320329 def get_output (self , conv_feat ):
321- if self ._cls_logit_list is not None and self ._bbox_delta_list is not None :
322- return self ._cls_logit_list , self ._bbox_delta_list
330+ if self ._cls_logit_dict is not None and self ._bbox_delta_dict is not None :
331+ return self ._cls_logit_dict , self ._bbox_delta_dict
323332
324333 p = self .p
325334 stride = p .anchor_generate .stride
@@ -329,67 +338,43 @@ def get_output(self, conv_feat):
329338 num_base_anchor = len (p .anchor_generate .ratio ) * len (p .anchor_generate .scale )
330339 num_class = p .num_class
331340
332- prior_prob = 0.01
333- pi = - math .log ((1 - prior_prob ) / prior_prob )
334-
335- # shared classification weight and bias
336- self .cls_conv1_weight = X .var ("cls_conv1_weight" , init = X .gauss (std = 0.01 ))
337- self .cls_conv1_bias = X .var ("cls_conv1_bias" , init = X .zero_init ())
338- self .cls_conv2_weight = X .var ("cls_conv2_weight" , init = X .gauss (std = 0.01 ))
339- self .cls_conv2_bias = X .var ("cls_conv2_bias" , init = X .zero_init ())
340- self .cls_conv3_weight = X .var ("cls_conv3_weight" , init = X .gauss (std = 0.01 ))
341- self .cls_conv3_bias = X .var ("cls_conv3_bias" , init = X .zero_init ())
342- self .cls_conv4_weight = X .var ("cls_conv4_weight" , init = X .gauss (std = 0.01 ))
343- self .cls_conv4_bias = X .var ("cls_conv4_bias" , init = X .zero_init ())
344- self .cls_pred_weight = X .var ("cls_pred_weight" , init = X .gauss (std = 0.01 ))
345- self .cls_pred_bias = X .var ("cls_pred_bias" , init = X .constant (pi ))
346-
347- # shared regression weight and bias
348- self .bbox_conv1_weight = X .var ("bbox_conv1_weight" , init = X .gauss (std = 0.01 ))
349- self .bbox_conv1_bias = X .var ("bbox_conv1_bias" , init = X .zero_init ())
350- self .bbox_conv2_weight = X .var ("bbox_conv2_weight" , init = X .gauss (std = 0.01 ))
351- self .bbox_conv2_bias = X .var ("bbox_conv2_bias" , init = X .zero_init ())
352- self .bbox_conv3_weight = X .var ("bbox_conv3_weight" , init = X .gauss (std = 0.01 ))
353- self .bbox_conv3_bias = X .var ("bbox_conv3_bias" , init = X .zero_init ())
354- self .bbox_conv4_weight = X .var ("bbox_conv4_weight" , init = X .gauss (std = 0.01 ))
355- self .bbox_conv4_bias = X .var ("bbox_conv4_bias" , init = X .zero_init ())
356- self .bbox_pred_weight = X .var ("bbox_pred_weight" , init = X .gauss (std = 0.01 ))
357- self .bbox_pred_bias = X .var ("bbox_pred_bias" , init = X .zero_init ())
358-
359- cls_logit_list = []
360- bbox_delta_list = []
361-
362- for i , s in enumerate (stride ):
341+ cls_logit_dict = dict ()
342+ bbox_delta_dict = dict ()
343+
344+ for s in stride :
363345 cls_logit = self ._cls_subnet (
364- conv_feat = conv_feat [i ],
346+ conv_feat = conv_feat ["stride%s" % s ],
365347 conv_channel = conv_channel ,
366348 num_base_anchor = num_base_anchor ,
367349 num_class = num_class ,
368350 stride = s
369351 )
370352
371353 bbox_delta = self ._bbox_subnet (
372- conv_feat = conv_feat [i ],
354+ conv_feat = conv_feat ["stride%s" % s ],
373355 conv_channel = conv_channel ,
374356 num_base_anchor = num_base_anchor ,
375357 num_class = num_class ,
376358 stride = s
377359 )
378360
379- cls_logit_list . append ( cls_logit )
380- bbox_delta_list . append ( bbox_delta )
361+ cls_logit_dict [ "stride%s" % s ] = cls_logit
362+ bbox_delta_dict [ "stride%s" % s ] = bbox_delta
381363
382- self ._cls_logit_list = cls_logit_list
383- self ._bbox_delta_list = bbox_delta_list
364+ self ._cls_logit_dict = cls_logit_dict
365+ self ._bbox_delta_dict = bbox_delta_dict
384366
385- return self ._cls_logit_list , self ._bbox_delta_list
367+ return self ._cls_logit_dict , self ._bbox_delta_dict
386368
387369
388370class RetinaNetNeckWithBN (RetinaNetNeck ):
389371 def __init__ (self , pNeck ):
390372 super ().__init__ (pNeck )
391373
392374 def get_retinanet_neck (self , data ):
375+ if self .neck is not None :
376+ return self .neck
377+
393378 norm = self .p .normalizer
394379 c2 , c3 , c4 , c5 = data
395380
@@ -473,7 +458,7 @@ def get_retinanet_neck(self, data):
473458 )
474459
475460 # P6
476- P6 = X .conv (
461+ p6 = X .conv (
477462 data = c5 ,
478463 kernel = 3 ,
479464 stride = 2 ,
@@ -485,9 +470,9 @@ def get_retinanet_neck(self, data):
485470 )
486471
487472 # P7
488- P6_relu = X .relu (data = P6 , name = "P6_relu" )
489- P7 = X .conv (
490- data = P6_relu ,
473+ p6_relu = X .relu (data = p6 , name = "P6_relu" )
474+ p7 = X .conv (
475+ data = p6_relu ,
491476 kernel = 3 ,
492477 stride = 2 ,
493478 filter = 256 ,
@@ -500,18 +485,26 @@ def get_retinanet_neck(self, data):
500485 p3_conv = norm (p3_conv , name = "P3_conv_bn" )
501486 p4_conv = norm (p4_conv , name = "P4_conv_bn" )
502487 p5_conv = norm (p5_conv , name = "P5_conv_bn" )
503- P6 = norm (P6 , name = "P6_conv_bn" )
504- P7 = norm (P7 , name = "P7_conv_bn" )
488+ p6 = norm (p6 , name = "P6_conv_bn" )
489+ p7 = norm (p7 , name = "P7_conv_bn" )
490+
491+ self .neck = dict (
492+ stride8 = p3_conv ,
493+ stride16 = p4_conv ,
494+ stride32 = p5_conv ,
495+ stride64 = p6 ,
496+ stride128 = p7
497+ )
505498
506- return p3_conv , p4_conv , p5_conv , P6 , P7
499+ return self . neck
507500
508501
509- class MSRAResNet50V1bFPN (Backbone ):
502+ class ResNetV1bFPN (Backbone ):
510503 def __init__ (self , pBackbone ):
511- super (MSRAResNet50V1bFPN , self ).__init__ (pBackbone )
504+ super ().__init__ (pBackbone )
512505 from mxnext .backbone .resnet_v1b import Builder
513506 b = Builder ()
514- self .symbol = b .get_backbone ("msra" , 50 , "fpn" , pBackbone .normalizer , pBackbone .fp16 )
507+ self .symbol = b .get_backbone ("msra" , pBackbone . depth , "fpn" , pBackbone .normalizer , pBackbone .fp16 )
515508
516509 def get_rpn_feature (self ):
517510 return self .symbol
0 commit comments