File tree Expand file tree Collapse file tree 2 files changed +7
-3
lines changed
mmseg/datasets/transforms Expand file tree Collapse file tree 2 files changed +7
-3
lines changed Original file line number Diff line number Diff line change @@ -74,12 +74,12 @@ def transform(self, results: dict) -> dict:
7474
7575 data_sample = SegDataSample ()
7676 if 'gt_seg_map' in results :
77- if results ['gt_seg_map' ].shape == 2 :
77+ if len ( results ['gt_seg_map' ].shape ) == 2 :
7878 data = to_tensor (results ['gt_seg_map' ][None ,
7979 ...].astype (np .int64 ))
8080 else :
8181 warnings .warn ('Please pay attention your ground truth '
82- 'segmentation map, usually the segentation '
82+ 'segmentation map, usually the segmentation '
8383 'map is 2D, but got '
8484 f'{ results ["gt_seg_map" ].shape } ' )
8585 data = to_tensor (results ['gt_seg_map' ].astype (np .int64 ))
Original file line number Diff line number Diff line change 44import unittest
55
66import numpy as np
7+ import pytest
78from mmengine .structures import BaseDataElement
89
910from mmseg .datasets .transforms import PackSegInputs
@@ -46,8 +47,11 @@ def test_transform(self):
4647 self .assertEqual (results ['data_samples' ].ori_shape ,
4748 results ['data_samples' ].gt_sem_seg .shape )
4849 results = copy .deepcopy (self .results )
50+ # test dataset shape is not 2D
4951 results ['gt_seg_map' ] = np .random .rand (3 , 300 , 400 )
50- results = transform (results )
52+ msg = 'the segmentation map is 2D'
53+ with pytest .warns (UserWarning , match = msg ):
54+ results = transform (results )
5155 self .assertEqual (results ['data_samples' ].ori_shape ,
5256 results ['data_samples' ].gt_sem_seg .shape )
5357
You can’t perform that action at this time.
0 commit comments