Skip to content

Commit 23ae1eb

Browse files
authored
[Fix] Input previous results for the last cascade_decode_head (open-mmlab#1450)
* [Fix] Input previous results for the latter cascade_decode_head * minors
1 parent 3f79707 commit 23ae1eb

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

mmseg/models/segmentors/cascade_encoder_decoder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,12 @@ def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg):
7575

7676
for i in range(1, self.num_stages):
7777
# forward test again, maybe unnecessary for most methods.
78-
prev_outputs = self.decode_head[i - 1].forward_test(
79-
x, img_metas, self.test_cfg)
78+
if i == 1:
79+
prev_outputs = self.decode_head[0].forward_test(
80+
x, img_metas, self.test_cfg)
81+
else:
82+
prev_outputs = self.decode_head[i - 1].forward_test(
83+
x, prev_outputs, img_metas, self.test_cfg)
8084
loss_decode = self.decode_head[i].forward_train(
8185
x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg)
8286
losses.update(add_prefix(loss_decode, f'decode_{i}'))

0 commit comments

Comments
 (0)