|
2 | 2 | import mmcv |
3 | 3 | import pytest |
4 | 4 | import torch |
| 5 | +from mmcv.utils import TORCH_VERSION, digit_version |
5 | 6 |
|
6 | 7 | from mmseg.models.utils import (InvertedResidual, InvertedResidualV3, SELayer, |
7 | 8 | make_divisible) |
@@ -108,19 +109,34 @@ def test_inv_residualv3(): |
108 | 109 | assert inv_module.expand_conv.conv.kernel_size == (1, 1) |
109 | 110 | assert inv_module.expand_conv.conv.stride == (1, 1) |
110 | 111 | assert inv_module.expand_conv.conv.padding == (0, 0) |
111 | | - assert isinstance(inv_module.expand_conv.activate, mmcv.cnn.HSwish) |
112 | 112 |
|
113 | 113 | assert isinstance(inv_module.depthwise_conv.conv, |
114 | 114 | mmcv.cnn.bricks.Conv2dAdaptivePadding) |
115 | 115 | assert inv_module.depthwise_conv.conv.kernel_size == (3, 3) |
116 | 116 | assert inv_module.depthwise_conv.conv.stride == (2, 2) |
117 | 117 | assert inv_module.depthwise_conv.conv.padding == (0, 0) |
118 | 118 | assert isinstance(inv_module.depthwise_conv.bn, torch.nn.BatchNorm2d) |
119 | | - assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish) |
| 119 | + |
120 | 120 | assert inv_module.linear_conv.conv.kernel_size == (1, 1) |
121 | 121 | assert inv_module.linear_conv.conv.stride == (1, 1) |
122 | 122 | assert inv_module.linear_conv.conv.padding == (0, 0) |
123 | 123 | assert isinstance(inv_module.linear_conv.bn, torch.nn.BatchNorm2d) |
| 124 | + |
| 125 | + if (TORCH_VERSION == 'parrots' |
| 126 | + or digit_version(TORCH_VERSION) < digit_version('1.7')): |
| 127 | + # Note: Use PyTorch official HSwish |
| 128 | + # when torch>=1.7 after MMCV >= 1.4.5. |
| 129 | + # Hardswish is not supported when PyTorch version < 1.6. |
| 130 | + # And Hardswish in PyTorch 1.6 does not support inplace. |
| 131 | + # More details could be found from: |
| 132 | + # https://github.com/open-mmlab/mmcv/pull/1709 |
| 133 | + assert isinstance(inv_module.expand_conv.activate, mmcv.cnn.HSwish) |
| 134 | + assert isinstance(inv_module.depthwise_conv.activate, mmcv.cnn.HSwish) |
| 135 | + else: |
| 136 | + assert isinstance(inv_module.expand_conv.activate, torch.nn.Hardswish) |
| 137 | + assert isinstance(inv_module.depthwise_conv.activate, |
| 138 | + torch.nn.Hardswish) |
| 139 | + |
124 | 140 | x = torch.rand(1, 32, 64, 64) |
125 | 141 | output = inv_module(x) |
126 | 142 | assert output.shape == (1, 40, 32, 32) |
|
0 commit comments