Skip to content

Commit 4371ba5

Browse files
authored
[Fix] Fix bugs when out_channels==1 (#2911)
1 parent ced29fc commit 4371ba5

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

mmseg/models/segmentors/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def postprocess_result(self,
187187
if C > 1:
188188
i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True)
189189
else:
190+
i_seg_logits = i_seg_logits.sigmoid()
190191
i_seg_pred = (i_seg_logits >
191192
self.decode_head.threshold).to(i_seg_logits)
192193
data_samples[i].set_data({

mmseg/models/segmentors/encoder_decoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,10 @@ def slide_inference(self, inputs: Tensor,
260260
h_stride, w_stride = self.test_cfg.stride
261261
h_crop, w_crop = self.test_cfg.crop_size
262262
batch_size, _, h_img, w_img = inputs.size()
263-
num_classes = self.num_classes
263+
out_channels = self.out_channels
264264
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
265265
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
266-
preds = inputs.new_zeros((batch_size, num_classes, h_img, w_img))
266+
preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
267267
count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
268268
for h_idx in range(h_grids):
269269
for w_idx in range(w_grids):

0 commit comments

Comments
 (0)