@@ -49,6 +49,7 @@ def _init_decode_head(self, decode_head):
4949 self .decode_head = builder .build_head (decode_head )
5050 self .align_corners = self .decode_head .align_corners
5151 self .num_classes = self .decode_head .num_classes
52+ self .out_channels = self .decode_head .out_channels
5253
5354 def _init_auxiliary_head (self , auxiliary_head ):
5455 """Initialize ``auxiliary_head``"""
@@ -162,10 +163,10 @@ def slide_inference(self, img, img_meta, rescale):
162163 h_stride , w_stride = self .test_cfg .stride
163164 h_crop , w_crop = self .test_cfg .crop_size
164165 batch_size , _ , h_img , w_img = img .size ()
165- num_classes = self .num_classes
166+ out_channels = self .out_channels
166167 h_grids = max (h_img - h_crop + h_stride - 1 , 0 ) // h_stride + 1
167168 w_grids = max (w_img - w_crop + w_stride - 1 , 0 ) // w_stride + 1
168- preds = img .new_zeros ((batch_size , num_classes , h_img , w_img ))
169+ preds = img .new_zeros ((batch_size , out_channels , h_img , w_img ))
169170 count_mat = img .new_zeros ((batch_size , 1 , h_img , w_img ))
170171 for h_idx in range (h_grids ):
171172 for w_idx in range (w_grids ):
@@ -245,7 +246,10 @@ def inference(self, img, img_meta, rescale):
245246 seg_logit = self .slide_inference (img , img_meta , rescale )
246247 else :
247248 seg_logit = self .whole_inference (img , img_meta , rescale )
248- output = F .softmax (seg_logit , dim = 1 )
249+ if self .out_channels == 1 :
250+ output = F .sigmoid (seg_logit )
251+ else :
252+ output = F .softmax (seg_logit , dim = 1 )
249253 flip = img_meta [0 ]['flip' ]
250254 if flip :
251255 flip_direction = img_meta [0 ]['flip_direction' ]
@@ -260,7 +264,11 @@ def inference(self, img, img_meta, rescale):
260264 def simple_test (self , img , img_meta , rescale = True ):
261265 """Simple test with single image."""
262266 seg_logit = self .inference (img , img_meta , rescale )
263- seg_pred = seg_logit .argmax (dim = 1 )
267+ if self .out_channels == 1 :
268+ seg_pred = (seg_logit >
269+ self .decode_head .threshold ).to (seg_logit ).squeeze (1 )
270+ else :
271+ seg_pred = seg_logit .argmax (dim = 1 )
264272 if torch .onnx .is_in_onnx_export ():
265273 # our inference backend only support 4D output
266274 seg_pred = seg_pred .unsqueeze (0 )
@@ -283,7 +291,11 @@ def aug_test(self, imgs, img_metas, rescale=True):
283291 cur_seg_logit = self .inference (imgs [i ], img_metas [i ], rescale )
284292 seg_logit += cur_seg_logit
285293 seg_logit /= len (imgs )
286- seg_pred = seg_logit .argmax (dim = 1 )
294+ if self .out_channels == 1 :
295+ seg_pred = (seg_logit >
296+ self .decode_head .threshold ).to (seg_logit ).squeeze (1 )
297+ else :
298+ seg_pred = seg_logit .argmax (dim = 1 )
287299 seg_pred = seg_pred .cpu ().numpy ()
288300 # unravel batch dim
289301 seg_pred = list (seg_pred )
0 commit comments