Skip to content

Commit 3cc7ae2

Browse files
authored
[Fix] Format shape check (open-mmlab#2753)
as title
1 parent dd47cef commit 3cc7ae2

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

mmseg/datasets/transforms/formatting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,12 @@ def transform(self, results: dict) -> dict:
7474

7575
data_sample = SegDataSample()
7676
if 'gt_seg_map' in results:
77-
if results['gt_seg_map'].shape == 2:
77+
if len(results['gt_seg_map'].shape) == 2:
7878
data = to_tensor(results['gt_seg_map'][None,
7979
...].astype(np.int64))
8080
else:
8181
warnings.warn('Please pay attention your ground truth '
82-
'segmentation map, usually the segentation '
82+
'segmentation map, usually the segmentation '
8383
'map is 2D, but got '
8484
f'{results["gt_seg_map"].shape}')
8585
data = to_tensor(results['gt_seg_map'].astype(np.int64))

tests/test_datasets/test_formatting.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import unittest
55

66
import numpy as np
7+
import pytest
78
from mmengine.structures import BaseDataElement
89

910
from mmseg.datasets.transforms import PackSegInputs
@@ -46,8 +47,11 @@ def test_transform(self):
4647
self.assertEqual(results['data_samples'].ori_shape,
4748
results['data_samples'].gt_sem_seg.shape)
4849
results = copy.deepcopy(self.results)
50+
# test dataset shape is not 2D
4951
results['gt_seg_map'] = np.random.rand(3, 300, 400)
50-
results = transform(results)
52+
msg = 'the segmentation map is 2D'
53+
with pytest.warns(UserWarning, match=msg):
54+
results = transform(results)
5155
self.assertEqual(results['data_samples'].ori_shape,
5256
results['data_samples'].gt_sem_seg.shape)
5357

0 commit comments

Comments
 (0)