Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmseg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .version import __version__, version_info

MMCV_MIN = '1.1.4'
MMCV_MAX = '1.3.0'
MMCV_MAX = '1.4.0'


def digit_version(version_str):
Expand Down
3 changes: 2 additions & 1 deletion mmseg/models/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .inverted_residual import InvertedResidual, InvertedResidualV3
from .make_divisible import make_divisible
from .res_layer import ResLayer
from .se_layer import SELayer
from .self_attention_block import SelfAttentionBlock
from .up_conv_block import UpConvBlock

__all__ = [
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
'UpConvBlock', 'InvertedResidualV3'
'UpConvBlock', 'InvertedResidualV3', 'SELayer'
]
Empty file added tests/__init__.py
Empty file.
Empty file added tests/test_models/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions tests/test_models/test_backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .utils import all_zeros, check_norm_state, is_block, is_norm

__all__ = ['is_norm', 'is_block', 'all_zeros', 'check_norm_state']
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,20 @@
import pytest
import torch

from mmseg.models.utils import InvertedResidual, InvertedResidualV3
from mmseg.models.utils import (InvertedResidual, InvertedResidualV3, SELayer,
make_divisible)


def test_make_divisible():
# test with min_value = None
assert make_divisible(10, 4) == 12
assert make_divisible(9, 4) == 12
assert make_divisible(1, 4) == 4

# test with min_value = 8
assert make_divisible(10, 4, 8) == 12
assert make_divisible(9, 4, 8) == 12
assert make_divisible(1, 4, 8) == 8


def test_inv_residual():
Expand Down Expand Up @@ -118,3 +131,39 @@ def test_inv_residualv3():
x = torch.randn(2, 32, 64, 64, requires_grad=True)
output = inv_module(x)
assert output.shape == (2, 40, 32, 32)


def test_se_layer():
with pytest.raises(AssertionError):
# test act_cfg assertion.
SELayer(32, act_cfg=(dict(type='ReLU'), ))

# test config with channels = 16.
se_layer = SELayer(16)
assert se_layer.conv1.conv.kernel_size == (1, 1)
assert se_layer.conv1.conv.stride == (1, 1)
assert se_layer.conv1.conv.padding == (0, 0)
assert isinstance(se_layer.conv1.activate, torch.nn.ReLU)
assert se_layer.conv2.conv.kernel_size == (1, 1)
assert se_layer.conv2.conv.stride == (1, 1)
assert se_layer.conv2.conv.padding == (0, 0)
assert isinstance(se_layer.conv2.activate, mmcv.cnn.HSigmoid)

x = torch.rand(1, 16, 64, 64)
output = se_layer(x)
assert output.shape == (1, 16, 64, 64)

# test config with channels = 16, act_cfg = dict(type='ReLU').
se_layer = SELayer(16, act_cfg=dict(type='ReLU'))
assert se_layer.conv1.conv.kernel_size == (1, 1)
assert se_layer.conv1.conv.stride == (1, 1)
assert se_layer.conv1.conv.padding == (0, 0)
assert isinstance(se_layer.conv1.activate, torch.nn.ReLU)
assert se_layer.conv2.conv.kernel_size == (1, 1)
assert se_layer.conv2.conv.stride == (1, 1)
assert se_layer.conv2.conv.padding == (0, 0)
assert isinstance(se_layer.conv2.activate, torch.nn.ReLU)

x = torch.rand(1, 16, 64, 64)
output = se_layer(x)
assert output.shape == (1, 16, 64, 64)
150 changes: 150 additions & 0 deletions tests/test_models/test_backbones/test_cgnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import pytest
import torch

from mmseg.models.backbones import CGNet
from mmseg.models.backbones.cgnet import (ContextGuidedBlock,
GlobalContextExtractor)


def test_cgnet_GlobalContextExtractor():
block = GlobalContextExtractor(16, 16, with_cp=True)
x = torch.randn(2, 16, 64, 64, requires_grad=True)
x_out = block(x)
assert x_out.shape == torch.Size([2, 16, 64, 64])


def test_cgnet_context_guided_block():
with pytest.raises(AssertionError):
# cgnet ContextGuidedBlock GlobalContextExtractor channel and reduction
# constraints.
ContextGuidedBlock(8, 8)

# test cgnet ContextGuidedBlock with checkpoint forward
block = ContextGuidedBlock(
16, 16, act_cfg=dict(type='PReLU'), with_cp=True)
assert block.with_cp
x = torch.randn(2, 16, 64, 64, requires_grad=True)
x_out = block(x)
assert x_out.shape == torch.Size([2, 16, 64, 64])

# test cgnet ContextGuidedBlock without checkpoint forward
block = ContextGuidedBlock(32, 32)
assert not block.with_cp
x = torch.randn(3, 32, 32, 32)
x_out = block(x)
assert x_out.shape == torch.Size([3, 32, 32, 32])

# test cgnet ContextGuidedBlock with down sampling
block = ContextGuidedBlock(32, 32, downsample=True)
assert block.conv1x1.conv.in_channels == 32
assert block.conv1x1.conv.out_channels == 32
assert block.conv1x1.conv.kernel_size == (3, 3)
assert block.conv1x1.conv.stride == (2, 2)
assert block.conv1x1.conv.padding == (1, 1)

assert block.f_loc.in_channels == 32
assert block.f_loc.out_channels == 32
assert block.f_loc.kernel_size == (3, 3)
assert block.f_loc.stride == (1, 1)
assert block.f_loc.padding == (1, 1)
assert block.f_loc.groups == 32
assert block.f_loc.dilation == (1, 1)
assert block.f_loc.bias is None

assert block.f_sur.in_channels == 32
assert block.f_sur.out_channels == 32
assert block.f_sur.kernel_size == (3, 3)
assert block.f_sur.stride == (1, 1)
assert block.f_sur.padding == (2, 2)
assert block.f_sur.groups == 32
assert block.f_sur.dilation == (2, 2)
assert block.f_sur.bias is None

assert block.bottleneck.in_channels == 64
assert block.bottleneck.out_channels == 32
assert block.bottleneck.kernel_size == (1, 1)
assert block.bottleneck.stride == (1, 1)
assert block.bottleneck.bias is None

x = torch.randn(1, 32, 32, 32)
x_out = block(x)
assert x_out.shape == torch.Size([1, 32, 16, 16])

# test cgnet ContextGuidedBlock without down sampling
block = ContextGuidedBlock(32, 32, downsample=False)
assert block.conv1x1.conv.in_channels == 32
assert block.conv1x1.conv.out_channels == 16
assert block.conv1x1.conv.kernel_size == (1, 1)
assert block.conv1x1.conv.stride == (1, 1)
assert block.conv1x1.conv.padding == (0, 0)

assert block.f_loc.in_channels == 16
assert block.f_loc.out_channels == 16
assert block.f_loc.kernel_size == (3, 3)
assert block.f_loc.stride == (1, 1)
assert block.f_loc.padding == (1, 1)
assert block.f_loc.groups == 16
assert block.f_loc.dilation == (1, 1)
assert block.f_loc.bias is None

assert block.f_sur.in_channels == 16
assert block.f_sur.out_channels == 16
assert block.f_sur.kernel_size == (3, 3)
assert block.f_sur.stride == (1, 1)
assert block.f_sur.padding == (2, 2)
assert block.f_sur.groups == 16
assert block.f_sur.dilation == (2, 2)
assert block.f_sur.bias is None

x = torch.randn(1, 32, 32, 32)
x_out = block(x)
assert x_out.shape == torch.Size([1, 32, 32, 32])


def test_cgnet_backbone():
with pytest.raises(AssertionError):
# check invalid num_channels
CGNet(num_channels=(32, 64, 128, 256))

with pytest.raises(AssertionError):
# check invalid num_blocks
CGNet(num_blocks=(3, 21, 3))

with pytest.raises(AssertionError):
# check invalid dilation
CGNet(num_blocks=2)

with pytest.raises(AssertionError):
# check invalid reduction
CGNet(reductions=16)

with pytest.raises(AssertionError):
# check invalid num_channels and reduction
CGNet(num_channels=(32, 64, 128), reductions=(64, 129))

# Test CGNet with default settings
model = CGNet()
model.init_weights()
model.train()

imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size([2, 35, 112, 112])
assert feat[1].shape == torch.Size([2, 131, 56, 56])
assert feat[2].shape == torch.Size([2, 256, 28, 28])

# Test CGNet with norm_eval True and with_cp True
model = CGNet(norm_eval=True, with_cp=True)
with pytest.raises(TypeError):
# check invalid pretrained
model.init_weights(pretrained=8)
model.init_weights()
model.train()

imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size([2, 35, 112, 112])
assert feat[1].shape == torch.Size([2, 131, 56, 56])
assert feat[2].shape == torch.Size([2, 256, 28, 28])
31 changes: 31 additions & 0 deletions tests/test_models/test_backbones/test_fast_scnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
import torch

from mmseg.models.backbones import FastSCNN


def test_fastscnn_backbone():
with pytest.raises(AssertionError):
# Fast-SCNN channel constraints.
FastSCNN(
3, (32, 48),
64, (64, 96, 128), (2, 2, 1),
global_out_channels=127,
higher_in_channels=64,
lower_in_channels=128)

# Test FastSCNN Standard Forward
model = FastSCNN()
model.init_weights()
model.train()
batch_size = 4
imgs = torch.randn(batch_size, 3, 512, 1024)
feat = model(imgs)

assert len(feat) == 3
# higher-res
assert feat[0].shape == torch.Size([batch_size, 64, 64, 128])
# lower-res
assert feat[1].shape == torch.Size([batch_size, 128, 16, 32])
# FFM output
assert feat[2].shape == torch.Size([batch_size, 128, 64, 128])
66 changes: 66 additions & 0 deletions tests/test_models/test_backbones/test_mobilenet_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import pytest
import torch

from mmseg.models.backbones import MobileNetV3


def test_mobilenet_v3():
with pytest.raises(AssertionError):
# check invalid arch
MobileNetV3('big')

with pytest.raises(AssertionError):
# check invalid reduction_factor
MobileNetV3(reduction_factor=0)

with pytest.raises(ValueError):
# check invalid out_indices
MobileNetV3(out_indices=(0, 1, 15))

with pytest.raises(ValueError):
# check invalid frozen_stages
MobileNetV3(frozen_stages=15)

with pytest.raises(TypeError):
# check invalid pretrained
model = MobileNetV3()
model.init_weights(pretrained=8)

# Test MobileNetV3 with default settings
model = MobileNetV3()
model.init_weights()
model.train()

imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == (2, 16, 112, 112)
assert feat[1].shape == (2, 16, 56, 56)
assert feat[2].shape == (2, 576, 28, 28)

# Test MobileNetV3 with arch = 'large'
model = MobileNetV3(arch='large', out_indices=(1, 3, 16))
model.init_weights()
model.train()

imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == (2, 16, 112, 112)
assert feat[1].shape == (2, 24, 56, 56)
assert feat[2].shape == (2, 960, 28, 28)

# Test MobileNetV3 with norm_eval True, with_cp True and frozen_stages=5
model = MobileNetV3(norm_eval=True, with_cp=True, frozen_stages=5)
with pytest.raises(TypeError):
# check invalid pretrained
model.init_weights(pretrained=8)
model.init_weights()
model.train()

imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == (2, 16, 112, 112)
assert feat[1].shape == (2, 16, 56, 56)
assert feat[2].shape == (2, 576, 28, 28)
43 changes: 43 additions & 0 deletions tests/test_models/test_backbones/test_resnest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
import torch

from mmseg.models.backbones import ResNeSt
from mmseg.models.backbones.resnest import Bottleneck as BottleneckS


def test_resnest_bottleneck():
with pytest.raises(AssertionError):
# Style must be in ['pytorch', 'caffe']
BottleneckS(64, 64, radix=2, reduction_factor=4, style='tensorflow')

# Test ResNeSt Bottleneck structure
block = BottleneckS(
64, 256, radix=2, reduction_factor=4, stride=2, style='pytorch')
assert block.avd_layer.stride == 2
assert block.conv2.channels == 256

# Test ResNeSt Bottleneck forward
block = BottleneckS(64, 16, radix=2, reduction_factor=4)
x = torch.randn(2, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([2, 64, 56, 56])


def test_resnest_backbone():
with pytest.raises(KeyError):
# ResNeSt depth should be in [50, 101, 152, 200]
ResNeSt(depth=18)

# Test ResNeSt with radix 2, reduction_factor 4
model = ResNeSt(
depth=50, radix=2, reduction_factor=4, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()

imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([2, 256, 56, 56])
assert feat[1].shape == torch.Size([2, 512, 28, 28])
assert feat[2].shape == torch.Size([2, 1024, 14, 14])
assert feat[3].shape == torch.Size([2, 2048, 7, 7])
Loading