Skip to content

Commit 0e3d1b8

Browse files
authored
[Fix] Add Pytorch HardSwish assertion in unit test (open-mmlab#1294)
* assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test * assert original HardSwish when PyTorch > 1.6 in unit test
1 parent 92068b4 commit 0e3d1b8

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

tests/test_models/test_backbones/test_blocks.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import mmcv
33
import pytest
44
import torch
5+
from mmcv.utils import TORCH_VERSION, digit_version
56

67
from mmseg.models.utils import (InvertedResidual, InvertedResidualV3, SELayer,
78
make_divisible)
@@ -108,19 +109,34 @@ def test_inv_residualv3():
108109
assert inv_module.expand_conv.conv.kernel_size == (1, 1)
109110
assert inv_module.expand_conv.conv.stride == (1, 1)
110111
assert inv_module.expand_conv.conv.padding == (0, 0)
111-
assert isinstance(inv_module.expand_conv.activate, mmcv.cnn.HSwish)
112112

113113
assert isinstance(inv_module.depthwise_conv.conv,
114114
mmcv.cnn.bricks.Conv2dAdaptivePadding)
115115
assert inv_module.depthwise_conv.conv.kernel_size == (3, 3)
116116
assert inv_module.depthwise_conv.conv.stride == (2, 2)
117117
assert inv_module.depthwise_conv.conv.padding == (0, 0)
118118
assert isinstance(inv_module.depthwise_conv.bn, torch.nn.BatchNorm2d)
119-
assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish)
119+
120120
assert inv_module.linear_conv.conv.kernel_size == (1, 1)
121121
assert inv_module.linear_conv.conv.stride == (1, 1)
122122
assert inv_module.linear_conv.conv.padding == (0, 0)
123123
assert isinstance(inv_module.linear_conv.bn, torch.nn.BatchNorm2d)
124+
125+
if (TORCH_VERSION == 'parrots'
126+
or digit_version(TORCH_VERSION) < digit_version('1.7')):
127+
# Note: Use PyTorch official HSwish
128+
# when torch>=1.7 after MMCV >= 1.4.5.
129+
# Hardswish is not supported when PyTorch version < 1.6.
130+
# And Hardswish in PyTorch 1.6 does not support inplace.
131+
# More details could be found from:
132+
# https://github.com/open-mmlab/mmcv/pull/1709
133+
assert isinstance(inv_module.expand_conv.activate, mmcv.cnn.HSwish)
134+
assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish)
135+
else:
136+
assert isinstance(inv_module.expand_conv.activate, torch.nn.Hardswish)
137+
assert isinstance(inv_module.depthwise_conv.activate,
138+
torch.nn.Hardswish)
139+
124140
x = torch.rand(1, 32, 64, 64)
125141
output = inv_module(x)
126142
assert output.shape == (1, 40, 32, 32)

0 commit comments

Comments
 (0)