Skip to content

Commit 070fa02

Browse files
kfxwRogerChern
authored andcommitted
a fix on gt generation (tusen-ai#260)
* a fix on gt generation * refine the comments in all fcos related files
1 parent 1b80190 commit 070fa02

File tree

5 files changed

+135
-61
lines changed

5 files changed

+135
-61
lines changed

config/fcos_r50v1_fpn_1x.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class FCOSFPNAssignParam:
101101
[512, INF],
102102
]
103103
stride = (8, 16, 32, 64, 128)
104-
num_classifier = 81 - 1 # COCO: 80 object + 1 background
104+
num_classifier = 81 - 1 # COCO: 80 object + 1 background
105105
ignore_label = RpnParam.loss_setting.ignore_label
106106
ignore_offset = RpnParam.loss_setting.ignore_offset
107107
data_size = [PadParam.short, PadParam.long]

models/FCOS/input.py

Lines changed: 105 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,11 @@
66
import mxnet as mx
77
import time
88

9-
class DetectionAugmentation(object):
10-
def __init__(self):
11-
pass
12-
13-
def apply(self, input_record):
14-
pass
15-
16-
9+
# Preparation to generate gt
10+
# output:
11+
# loc_x/loc_y: [int array] xy coordinates for sampling in each scale
12+
# stage_lower/upperbound: [int array] the scale range of each FPN stage, used in FPN stage assignment, defined in config file
13+
# nonignore_area: [boolean array] non-padding area under each scale
1714
class PreMakeFCOSgt(mx.operator.CustomOp):
1815

1916
def __init__(self, fcos_gt_setting):
@@ -22,7 +19,6 @@ def __init__(self, fcos_gt_setting):
2219
self.stride = self.p.stride
2320
self.stages = self.p.stages # type: FCOSFPNAssignParam
2421

25-
# make locations
2622
self.data_size = self.p.data_size
2723
h, w = self.data_size
2824
self.loc_x = []
@@ -32,6 +28,7 @@ def __init__(self, fcos_gt_setting):
3228
self.stage_lowerbound = [-1e-5, 64, 128, 256, 512]
3329
self.stage_upperbound = [64, 128, 256, 512, 1e5]
3430
for idx, stride in enumerate(self.stride):
31+
# make sampling coordinate maps
3532
x = np.array(range(0,w,stride), dtype=np.float32) + stride/2.
3633
y = np.array(range(0,h,stride), dtype=np.float32) + stride/2.
3734
x, y = np.meshgrid(x, y)
@@ -41,6 +38,7 @@ def __init__(self, fcos_gt_setting):
4138
self.loc_y.append(y.reshape(-1))
4239
self.loc_x_T.append(y.T.reshape(-1))
4340
self.loc_y_T.append(x.T.reshape(-1))
41+
# convert numpy/list to ndarray
4442
self.stage_lowerbound[idx] = mx.nd.full(self.loc_x[-1].shape, self.stage_lowerbound[idx])
4543
self.stage_upperbound[idx] = mx.nd.full(self.loc_x[-1].shape, self.stage_upperbound[idx])
4644
self.loc_x = mx.nd.concat(*(self.loc_x), dim=0)
@@ -52,24 +50,26 @@ def __init__(self, fcos_gt_setting):
5250

5351
def forward(self, is_train, req, in_data, out_data, aux):
5452
context = in_data[0].context
55-
if self.loc_x.context != context: # execute only once
53+
if self.loc_x.context != context: # execute only once, load arrays into gpu
5654
self.loc_x = self.loc_x.as_in_context(context)
5755
self.loc_y = self.loc_y.as_in_context(context)
5856
self.loc_x_T = self.loc_x_T.as_in_context(context)
5957
self.loc_y_T = self.loc_y_T.as_in_context(context)
6058
self.stage_lowerbound = self.stage_lowerbound.as_in_context(context)
6159
self.stage_upperbound = self.stage_upperbound.as_in_context(context)
6260

63-
ori_h = in_data[1][0,0] # aspect_ratio_grouping ensures all aspect ratios within a batch are same
61+
ori_h = in_data[1][0,0] # 'aspect_ratio_grouping' in 'detection_input.py' ensures all aspect ratios within a batch are the same
6462
ori_w = in_data[1][0,1]
6563

6664
if ori_h < ori_w:
6765
self.assign(out_data[0], req[0], self.loc_x)
6866
self.assign(out_data[1], req[1], self.loc_y)
67+
# filter out image padding area
6968
nonignore_area = mx.nd.logical_and(lhs=(self.loc_x<ori_w), rhs=(self.loc_y<ori_h))
7069
else:
7170
self.assign(out_data[0], req[0], self.loc_x_T)
7271
self.assign(out_data[1], req[1], self.loc_y_T)
72+
# filter out image padding area
7373
nonignore_area = mx.nd.logical_and(lhs=(self.loc_x_T<ori_w), rhs=(self.loc_y_T<ori_h))
7474

7575
self.assign(out_data[2], req[2], self.stage_lowerbound)
@@ -80,7 +80,7 @@ def forward(self, is_train, req, in_data, out_data, aux):
8080

8181
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
8282
pass
83-
83+
8484
@mx.operator.register("make_fcos_gt_preparation")
8585
class PreMakeFCOSGTProp(mx.operator.CustomOpProp):
8686
def __init__(self):
@@ -112,50 +112,113 @@ def infer_type(self, in_type):
112112
def create_operator(self, ctx, shapes, dtypes):
113113
return PreMakeFCOSgt(self.p)
114114

115+
# --------------------------------------------------------------------------------
116+
# Preparation to generate classification gt
117+
# output:
118+
# bbox_cls: bbox's classification annotations, the last row of bbox annotation
119+
# cls_batch_idx: used for array indexing
120+
class PrepareFCOS_cls_gt(mx.operator.CustomOp):
121+
122+
def __init__(self, fcos_gt_setting):
123+
super(PrepareFCOS_cls_gt, self).__init__()
124+
self.p = fcos_gt_setting
125+
self.batch_idx = None
126+
self.spatial_idx = None
115127

128+
def forward(self, is_train, req, in_data, out_data, aux):
129+
bboxes = in_data[0]
130+
N = bboxes.shape[0]
131+
HW = in_data[1].shape[-1]
132+
133+
if self.batch_idx is None: # excute only once
134+
self.batch_idx = mx.nd.arange(N).reshape((-1,1)).tile((1,HW))
135+
136+
self.assign(out_data[0], req[0], bboxes[:,:,4])
137+
self.assign(out_data[1], req[1], self.batch_idx)
138+
return
139+
140+
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
141+
pass
142+
143+
@mx.operator.register("prepare_fcos_cls_gt")
144+
class PrepareFCOS_CLS_GTProp(mx.operator.CustomOpProp):
145+
def __init__(self):
146+
super(PrepareFCOS_CLS_GTProp, self).__init__(need_top_grad=False)
147+
from config.fcos_r50v1_fpn_1x import throwout_param
148+
self.p = throwout_param
149+
self.stride = self.p.stride
150+
self.data_size = self.p.data_size
151+
152+
def list_arguments(self):
153+
return ['gt_bbox', 'smallest_box_id']
154+
155+
def list_outputs(self):
156+
return ['bbox_cls', 'cls_batch_idx']
157+
158+
def infer_shape(self, in_shape):
159+
n = in_shape[0][0]
160+
h, w = self.data_size
161+
hw = 0
162+
for stride in self.stride:
163+
width = len(range(0,w,stride))
164+
height = len(range(0,h,stride))
165+
hw += height * width
166+
return in_shape, [in_shape[0][:2], [n,hw]], []
167+
168+
def infer_type(self, in_type):
169+
return in_type, [in_type[0], in_type[0]], []
170+
171+
def create_operator(self, ctx, shapes, dtypes):
172+
return PrepareFCOS_cls_gt(self.p)
173+
174+
# --------------------------------------------------------------------------------
175+
# To describe variable's shapes,
176+
# N: Batch size
177+
# M: Number of bbox in gt
178+
# HW: Overall spatial size of concatenated gt from different scales, sum(H_8s*W_8s, H_16s*W_16s,..., H_128s*W_128s)
179+
# num_cls: Number of object categories, use 80 in coco dataset, defined in config file
116180
def make_fcos_gt(gt_bbox, im_info, ignore_offset, ignore_label, num_classifier):
181+
# Preparations before generating gt
117182
loc_x, loc_y, stage_lowerbound, stage_upperbound, nonignore_area = mx.sym.Custom(gt_bbox=gt_bbox, im_info=im_info, op_type='make_fcos_gt_preparation', name='pre_fcos_gt')
183+
bboxes = gt_bbox # (N, M, 4), [x,y,x,y]
118184

119-
bboxes = gt_bbox
120-
121-
# compute offset
122-
#bboxes_ = mx.sym.expand_dims(bboxes, axis=-1)
185+
# Compute offsets to bbox edges at each pixel
123186
l = mx.sym.broadcast_sub( lhs=loc_x, rhs=mx.sym.slice(bboxes, begin=(None,None,0), end=(None,None,1)) )
124187
t = mx.sym.broadcast_sub( lhs=loc_y, rhs=mx.sym.slice(bboxes, begin=(None,None,1), end=(None,None,2)) )
125188
r = mx.sym.broadcast_sub( lhs=mx.sym.slice(bboxes, begin=(None,None,2), end=(None,None,3)), rhs=loc_x )
126189
b = mx.sym.broadcast_sub( lhs=mx.sym.slice(bboxes, begin=(None,None,3), end=(None,None,4)), rhs=loc_y )
127190
offset_gt = mx.sym.stack(l,t,r,b, axis=1) # (N, 4, M, HW)
128-
# clean non-box area
191+
# Reset non-box area, negative offsets indicate out-of-bbox area
129192
in_box_area = mx.sym.min(offset_gt, axis=1, keepdims=True) >= 0 # (N, 1, M, HW)
130193
offset_gt = mx.sym.broadcast_add( lhs=mx.sym.broadcast_mul(lhs=offset_gt, rhs=in_box_area),
131194
rhs=(1 - in_box_area) * ignore_offset
132195
) # offset_gt[!in_box_area] = self.ignore_offset
133-
# assign stage
134-
longest_side = mx.sym.max(offset_gt, axis=1, keepdims=True) # (N, 1, M, HW)
135-
stage_assign_mask = mx.sym.broadcast_logical_and( lhs=mx.sym.broadcast_greater_equal(lhs=longest_side, rhs=stage_lowerbound),
136-
rhs=mx.sym.broadcast_lesser(lhs=longest_side, rhs=stage_upperbound)
137-
) # (N, 1, M, HW)
196+
# Assign FPN stage based on offset values
197+
greatest_offset = mx.sym.max(offset_gt, axis=1, keepdims=True) # (N, 1, M, HW)
198+
stage_assign_mask = mx.sym.broadcast_logical_and( lhs=mx.sym.broadcast_greater_equal(lhs=greatest_offset, rhs=stage_lowerbound),
199+
rhs=mx.sym.broadcast_lesser(lhs=greatest_offset, rhs=stage_upperbound)
200+
) # (N, 1, M, HW)
138201
offset_gt = mx.sym.broadcast_add( lhs=mx.sym.broadcast_mul(lhs=offset_gt, rhs=stage_assign_mask),
139202
rhs=(1 - stage_assign_mask) * ignore_offset
140203
) # offset[!stage_assign_mask] = self.ignore_offset
141-
# fuse box offsets based on box size
204+
# Fuse offsets based on bbox sizes through dim M
205+
# Smaller bboxes are on the top and cover the larger ones
142206
box_size = ( mx.sym.slice(offset_gt, begin=(None,0,None,None), end=(None,1,None,None)) + \
143207
mx.sym.slice(offset_gt, begin=(None,2,None,None), end=(None,3,None,None)) ) \
144208
* ( mx.sym.slice(offset_gt, begin=(None,1,None,None), end=(None,2,None,None)) + \
145209
mx.sym.slice(offset_gt, begin=(None,3,None,None), end=(None,4,None,None)) )
146210
# (offset_gt[:,0,:,:] + offset_gt[:,2,:,:]) * (offset_gt[:,1,:,:] + offset_gt[:,3,:,:])
147-
#box_size = mx.sym.expand_dims(box_size, axis=1)
211+
# Bbox sizes in out-of-bbox area are set to MAX_BBOX_SIZE so that a bbox can always cover background
148212
box_size = mx.sym.broadcast_add( lhs=mx.sym.broadcast_mul(lhs=box_size, rhs=stage_assign_mask),
149213
rhs=(1 - stage_assign_mask) * 1e10
150214
) # box[!stage_assign_mask] = MAX_BBOX_SIZE
151215
smallest_box_id = mx.sym.argmin(box_size, axis=2) # (N, 1, HW)
152216
smallest_box_ids = mx.sym.tile(smallest_box_id, reps=(1,4,1)) # (N, 4, HW)
153217
offset_gt = mx.sym.reshape_like( mx.sym.pick(offset_gt, smallest_box_ids, axis=2), smallest_box_ids )
154218
# (N, 4, HW)
155-
156-
in_box_area = offset_gt != ignore_offset
157-
158-
# centerness
219+
220+
# Calculate centerness values using the formula described in the paper
221+
in_box_area = offset_gt != ignore_offset # centerness is only compute inside bboxes
159222
l_r_sorted = mx.sym.sort(mx.sym.slice(offset_gt, begin=(None,0,None), end=(None,3,None), step=(None,2,None)), axis=1)
160223
# mx.nd.sort(offset_gt[:,[0,2],:], axis=1)
161224
term1_min = mx.sym.reshape(mx.sym.slice(l_r_sorted, begin=(None,0,None), end=(None,1,None)), shape=(0,-1))
@@ -172,20 +235,29 @@ def make_fcos_gt(gt_bbox, im_info, ignore_offset, ignore_label, num_classifier):
172235
centerness_gt = centerness_gt * mx.sym.reshape(mx.sym.slice(in_box_area, begin=(None,0,None), end=(None,1,None)), shape=(0,-1))
173236
# (N, HW), centerness_gt*in_box_area[:,0,:]
174237

175-
# cls
238+
# Classification gt
239+
# smallest_box_id: indicates which bbox is chosen at current position
176240
smallest_box_id = smallest_box_id.reshape((0,-1)) # (N, HW)
177-
cls_gt = mx.sym.one_hot(smallest_box_id, num_classifier) # (N, HW, num_cls)
241+
bbox_cls, cls_batch_idx = mx.sym.Custom(gt_bbox=gt_bbox, smallest_box_id=smallest_box_ids, op_type='prepare_fcos_cls_gt', name='fcos_cls_gt')
242+
# bbox_cls = gt_bbox[:,:,4], cls_batch_idx = [[0,0,...,0],[1,1,...,1],...,[N-1,N-1,...,N-1]]
243+
# bbox_cls: (N, M), cls_batch_idx: (N, HW)
244+
cls_id = mx.sym.stack(cls_batch_idx, smallest_box_id, axis=0)
245+
# Transform bbox id to bbox's class id, e.g. (0th,0th,1st,5th,...)->(person,person,car,bike,...)
246+
cls_gt = mx.sym.gather_nd(bbox_cls, cls_id) # cls_gt = bbox_cls[cls_id], (N, HW)
247+
cls_gt = cls_gt - 1 # 1~81 class id to 0~80 class id
248+
cls_gt = mx.sym.one_hot(cls_gt, num_classifier) # (N, HW, num_cls), one-hot matrix
178249
cls_gt = mx.sym.transpose(cls_gt, axes=(0,2,1)) # (N, num_cls, HW)
179250
cls_gt = mx.sym.broadcast_mul( lhs=cls_gt, rhs=mx.sym.slice(in_box_area, begin=(None,0,None), end=(None,1,None)) )
251+
# cls_gt[:,:,!in_box_area] = 0
180252

181-
# ignore label
253+
# Set ignore labels in the image padding area
182254
nonignore_area = mx.sym.reshape(nonignore_area, shape=(1,-1))
183255
centerness_gt = mx.sym.broadcast_add( lhs=mx.sym.broadcast_mul(lhs=centerness_gt, rhs=nonignore_area),
184256
rhs=(1 - nonignore_area) * ignore_label
185-
) # centerness_gt[!nonignore_area] = self.ignore_label
257+
) # centerness_gt[:,!nonignore_area] = self.ignore_label
186258
nonignore_area = mx.sym.reshape(nonignore_area, shape=(1,1,-1))
187259
cls_gt = mx.sym.broadcast_add( lhs=mx.sym.broadcast_mul(lhs=cls_gt, rhs=nonignore_area),
188260
rhs=(1 - nonignore_area) * ignore_label
189-
)
261+
) # cls_gt[:,:,!nonignore_area] = self.ignore_label
190262

191263
return centerness_gt, mx.sym.reshape(cls_gt, shape=(0,-1)), offset_gt

models/FCOS/loss.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import mxnet as mx
22
import mxnext as X
33

4+
# Use symbol for internal variables computation for better parallelization
5+
# Use custom python op to control loss/gradient flow
6+
47
""" ---Sigmoid Focal Loss---
58
def forward(self, is_train, req, in_data, out_data, aux):
69
logits = in_data[0]
@@ -46,6 +49,8 @@ def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
4649
4750
self.assign(in_grad[0], req[0], grad)"""
4851

52+
53+
# the formula is better demonstrated above
4954
class ComputeSigmoidFocalLoss(mx.operator.CustomOp):
5055
def __init__(self):
5156
super(ComputeSigmoidFocalLoss, self).__init__()
@@ -79,12 +84,11 @@ def create_operator(self, ctx, shapes, dtypes):
7984
return ComputeSigmoidFocalLoss()
8085

8186
def make_sigmoid_focal_loss(gamma, alpha, logits, labels, nonignore_mask):
82-
# conduct most of calculations using symbol and control gradient flow with custom op
83-
p = 1 / (1 + mx.sym.exp(-logits)) # sigmoid
87+
p = 1 / (1 + mx.sym.exp(-logits)) # sigmoid
8488
mask_logits_GE_zero = mx.sym.broadcast_greater_equal(lhs=logits, rhs=mx.sym.zeros((1,1)))
85-
# logits>=0
86-
minus_logits_mask = -1. * logits * mask_logits_GE_zero # -1 * logits * [logits>=0]
87-
negative_abs_logits = logits - 2*logits*mask_logits_GE_zero # logtis - 2 * logits * [logits>=0]
89+
# logits>=0
90+
minus_logits_mask = -1. * logits * mask_logits_GE_zero # -1 * logits * [logits>=0]
91+
negative_abs_logits = logits - 2*logits*mask_logits_GE_zero # logtis - 2 * logits * [logits>=0]
8892
log_one_exp_minus_abs = mx.sym.log(1. + mx.sym.exp(negative_abs_logits))
8993
minus_log = minus_logits_mask - log_one_exp_minus_abs
9094

@@ -101,14 +105,13 @@ def make_sigmoid_focal_loss(gamma, alpha, logits, labels, nonignore_mask):
101105
backward_term2 = one_alpha_p_gamma_one_labels * (minus_log * (1 - p) * gamma - p)
102106
grad = mx.sym.broadcast_div( lhs=-1 * (backward_term1 + backward_term2) * nonignore_mask, rhs=norm.reshape((1,1)) )
103107

104-
loss = X.block_grad(loss)
105-
grad = X.block_grad(grad)
108+
loss = X.block_grad(loss) # symbols are only used for computation
109+
grad = X.block_grad(grad) # use custom op to control gradient flow instead
106110

107111
loss = mx.sym.Custom(logits=logits, loss=loss, grad=grad, op_type='compute_focal_loss', name='focal_loss')
108112
return loss
109113

110-
111-
114+
# -------------------------------------------------------
112115
class ComputeBCELoss(mx.operator.CustomOp):
113116
def __init__(self):
114117
super(ComputeBCELoss, self).__init__()
@@ -152,8 +155,7 @@ def make_binary_cross_entropy_loss(logits, labels, nonignore_mask):
152155

153156
return mx.sym.Custom(logits=logits, loss=loss, grad=grad, op_type='compute_bce_loss', name='sigmoid_bce_loss')
154157

155-
156-
158+
# -------------------------------------------------------
157159
def IoULoss(x_box, y_box, ignore_offset, centerness_label, name='iouloss'):
158160
centerness_label = mx.sym.reshape(centerness_label, shape=(0,1,-1))
159161
y_box = X.block_grad(y_box)
@@ -163,6 +165,7 @@ def IoULoss(x_box, y_box, ignore_offset, centerness_label, name='iouloss'):
163165
target_right = mx.sym.slice_axis(y_box, axis=1, begin=2, end=3)
164166
target_bottom = mx.sym.slice_axis(y_box, axis=1, begin=3, end=4)
165167

168+
# filter out out-of-bbox area, loss is only computed inside bboxes
166169
nonignore_mask = mx.sym.broadcast_logical_and(lhs = mx.sym.broadcast_not_equal(lhs=target_left, rhs=ignore_offset),
167170
rhs = mx.sym.broadcast_greater( lhs=centerness_label, rhs=mx.sym.full((1,1,1), 0) )
168171
)

models/FCOS/metric.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import pdb
2-
31
import mxnet as mx
42
import numpy as np
53

0 commit comments

Comments
 (0)