Skip to content

Commit 15fede2

Browse files
authored
Add check for input shape (qubvel-org#549)
1 parent e0d8320 commit 15fede2

File tree

3 files changed

+28
-0
lines changed

3 files changed

+28
-0
lines changed

segmentation_models_pytorch/base/model.py

+15
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,23 @@ def initialize(self):
99
if self.classification_head is not None:
1010
init.initialize_head(self.classification_head)
1111

12+
def check_input_shape(self, x):
13+
14+
h, w = x.shape[-2:]
15+
output_stride = self.encoder.output_stride
16+
if h % output_stride != 0 or w % output_stride != 0:
17+
new_h = (h // output_stride + 1) * output_stride if h % output_stride != 0 else h
18+
new_w = (w // output_stride + 1) * output_stride if w % output_stride != 0 else w
19+
raise RuntimeError(
20+
f"Wrong input shape height={h}, width={w}. Expected image height and width "
21+
f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})."
22+
)
23+
1224
def forward(self, x):
1325
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
26+
27+
self.check_input_shape(x)
28+
1429
features = self.encoder(x)
1530
decoder_output = self.decoder(*features)
1631

segmentation_models_pytorch/encoders/_base.py

+8
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,17 @@ class EncoderMixin:
1212
- patching first convolution for arbitrary input channels
1313
"""
1414

15+
_output_stride = 32
16+
1517
@property
1618
def out_channels(self):
1719
"""Return channels dimensions for each tensor of forward output of encoder"""
1820
return self._out_channels[: self._depth + 1]
1921

22+
@property
23+
def output_stride(self):
24+
return min(self._output_stride, 2 ** self._depth)
25+
2026
def set_in_channels(self, in_channels, pretrained=True):
2127
"""Change first convolution channels"""
2228
if in_channels == 3:
@@ -49,6 +55,8 @@ def make_dilated(self, output_stride):
4955
else:
5056
raise ValueError("Output stride should be 16 or 8, got {}.".format(output_stride))
5157

58+
self._output_stride = output_stride
59+
5260
stages = self.get_stages()
5361
for stage_indx, dilation_rate in zip(stage_list, dilation_list):
5462
utils.replace_strides_with_dilation(

segmentation_models_pytorch/encoders/timm_universal.py

+5
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=
2424
in_channels,
2525
] + self.model.feature_info.channels()
2626
self._depth = depth
27+
self._output_stride = output_stride
2728

2829
def forward(self, x):
2930
features = self.model(x)
@@ -35,3 +36,7 @@ def forward(self, x):
3536
@property
3637
def out_channels(self):
3738
return self._out_channels
39+
40+
@property
41+
def output_stride(self):
42+
return min(self._output_stride, 2 ** self._depth)

0 commit comments

Comments
 (0)