Skip to content

Commit 77e8ce3

Browse files
authored
[Bug Fix] Fix TTA resize scale (open-mmlab#334)
* fix tta bug * modify as suggested * fix test_tta bug
1 parent 8f8e77d commit 77e8ce3

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

mmseg/datasets/pipelines/test_time_aug.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __call__(self, results):
104104
aug_data = []
105105
if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):
106106
h, w = results['img'].shape[:2]
107-
img_scale = [(int(h * ratio), int(w * ratio))
107+
img_scale = [(int(w * ratio), int(h * ratio))
108108
for ratio in self.img_ratios]
109109
else:
110110
img_scale = self.img_scale

mmseg/datasets/pipelines/transforms.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,9 @@ def _random_scale(self, results):
156156

157157
if self.ratio_range is not None:
158158
if self.img_scale is None:
159-
scale, scale_idx = self.random_sample_ratio(
160-
results['img'].shape[:2], self.ratio_range)
159+
h, w = results['img'].shape[:2]
160+
scale, scale_idx = self.random_sample_ratio((w, h),
161+
self.ratio_range)
161162
else:
162163
scale, scale_idx = self.random_sample_ratio(
163164
self.img_scale[0], self.ratio_range)

tests/test_data/test_tta.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_multi_scale_flip_aug():
108108
)
109109
tta_module = build_from_cfg(tta_transform, PIPELINES)
110110
tta_results = tta_module(results.copy())
111-
assert tta_results['scale'] == [(144, 256), (288, 512), (576, 1024)]
111+
assert tta_results['scale'] == [(256, 144), (512, 288), (1024, 576)]
112112
assert tta_results['flip'] == [False, False, False]
113113

114114
tta_transform = dict(
@@ -120,8 +120,8 @@ def test_multi_scale_flip_aug():
120120
)
121121
tta_module = build_from_cfg(tta_transform, PIPELINES)
122122
tta_results = tta_module(results.copy())
123-
assert tta_results['scale'] == [(144, 256), (144, 256), (288, 512),
124-
(288, 512), (576, 1024), (576, 1024)]
123+
assert tta_results['scale'] == [(256, 144), (256, 144), (512, 288),
124+
(512, 288), (1024, 576), (1024, 576)]
125125
assert tta_results['flip'] == [False, True, False, True, False, True]
126126

127127
tta_transform = dict(

0 commit comments

Comments
 (0)