| 
 | 1 | +# Copyright (c) OpenMMLab. All rights reserved.  | 
 | 2 | +import pytest  | 
 | 3 | +import torch  | 
 | 4 | + | 
 | 5 | +from mmseg.models.backbones import TIMMBackbone  | 
 | 6 | +from .utils import check_norm_state  | 
 | 7 | + | 
 | 8 | + | 
 | 9 | +def test_timm_backbone():  | 
 | 10 | +    with pytest.raises(TypeError):  | 
 | 11 | +        # pretrained must be a string path  | 
 | 12 | +        model = TIMMBackbone()  | 
 | 13 | +        model.init_weights(pretrained=0)  | 
 | 14 | + | 
 | 15 | +    # Test different norm_layer, can be: 'SyncBN', 'BN2d', 'GN', 'LN', 'IN'  | 
 | 16 | +    # Test resnet18 from timm, norm_layer='BN2d'  | 
 | 17 | +    model = TIMMBackbone(  | 
 | 18 | +        model_name='resnet18',  | 
 | 19 | +        features_only=True,  | 
 | 20 | +        pretrained=False,  | 
 | 21 | +        output_stride=32,  | 
 | 22 | +        norm_layer='BN2d')  | 
 | 23 | + | 
 | 24 | +    # Test resnet18 from timm, norm_layer='SyncBN'  | 
 | 25 | +    model = TIMMBackbone(  | 
 | 26 | +        model_name='resnet18',  | 
 | 27 | +        features_only=True,  | 
 | 28 | +        pretrained=False,  | 
 | 29 | +        output_stride=32,  | 
 | 30 | +        norm_layer='SyncBN')  | 
 | 31 | + | 
 | 32 | +    # Test resnet18 from timm, features_only=True, output_stride=32  | 
 | 33 | +    model = TIMMBackbone(  | 
 | 34 | +        model_name='resnet18',  | 
 | 35 | +        features_only=True,  | 
 | 36 | +        pretrained=False,  | 
 | 37 | +        output_stride=32)  | 
 | 38 | +    model.init_weights()  | 
 | 39 | +    model.train()  | 
 | 40 | +    assert check_norm_state(model.modules(), True)  | 
 | 41 | + | 
 | 42 | +    imgs = torch.randn(1, 3, 224, 224)  | 
 | 43 | +    feats = model(imgs)  | 
 | 44 | +    feats = [feat.shape for feat in feats]  | 
 | 45 | +    assert len(feats) == 5  | 
 | 46 | +    assert feats[0] == torch.Size((1, 64, 112, 112))  | 
 | 47 | +    assert feats[1] == torch.Size((1, 64, 56, 56))  | 
 | 48 | +    assert feats[2] == torch.Size((1, 128, 28, 28))  | 
 | 49 | +    assert feats[3] == torch.Size((1, 256, 14, 14))  | 
 | 50 | +    assert feats[4] == torch.Size((1, 512, 7, 7))  | 
 | 51 | + | 
 | 52 | +    # Test resnet18 from timm, features_only=True, output_stride=16  | 
 | 53 | +    model = TIMMBackbone(  | 
 | 54 | +        model_name='resnet18',  | 
 | 55 | +        features_only=True,  | 
 | 56 | +        pretrained=False,  | 
 | 57 | +        output_stride=16)  | 
 | 58 | +    imgs = torch.randn(1, 3, 224, 224)  | 
 | 59 | +    feats = model(imgs)  | 
 | 60 | +    feats = [feat.shape for feat in feats]  | 
 | 61 | +    assert len(feats) == 5  | 
 | 62 | +    assert feats[0] == torch.Size((1, 64, 112, 112))  | 
 | 63 | +    assert feats[1] == torch.Size((1, 64, 56, 56))  | 
 | 64 | +    assert feats[2] == torch.Size((1, 128, 28, 28))  | 
 | 65 | +    assert feats[3] == torch.Size((1, 256, 14, 14))  | 
 | 66 | +    assert feats[4] == torch.Size((1, 512, 14, 14))  | 
 | 67 | + | 
 | 68 | +    # Test resnet18 from timm, features_only=True, output_stride=8  | 
 | 69 | +    model = TIMMBackbone(  | 
 | 70 | +        model_name='resnet18',  | 
 | 71 | +        features_only=True,  | 
 | 72 | +        pretrained=False,  | 
 | 73 | +        output_stride=8)  | 
 | 74 | +    imgs = torch.randn(1, 3, 224, 224)  | 
 | 75 | +    feats = model(imgs)  | 
 | 76 | +    feats = [feat.shape for feat in feats]  | 
 | 77 | +    assert len(feats) == 5  | 
 | 78 | +    assert feats[0] == torch.Size((1, 64, 112, 112))  | 
 | 79 | +    assert feats[1] == torch.Size((1, 64, 56, 56))  | 
 | 80 | +    assert feats[2] == torch.Size((1, 128, 28, 28))  | 
 | 81 | +    assert feats[3] == torch.Size((1, 256, 28, 28))  | 
 | 82 | +    assert feats[4] == torch.Size((1, 512, 28, 28))  | 
 | 83 | + | 
 | 84 | +    # Test efficientnet_b1 with pretrained weights  | 
 | 85 | +    model = TIMMBackbone(model_name='efficientnet_b1', pretrained=True)  | 
 | 86 | + | 
 | 87 | +    # Test resnetv2_50x1_bitm from timm, features_only=True, output_stride=8  | 
 | 88 | +    model = TIMMBackbone(  | 
 | 89 | +        model_name='resnetv2_50x1_bitm',  | 
 | 90 | +        features_only=True,  | 
 | 91 | +        pretrained=False,  | 
 | 92 | +        output_stride=8)  | 
 | 93 | +    imgs = torch.randn(1, 3, 8, 8)  | 
 | 94 | +    feats = model(imgs)  | 
 | 95 | +    feats = [feat.shape for feat in feats]  | 
 | 96 | +    assert len(feats) == 5  | 
 | 97 | +    assert feats[0] == torch.Size((1, 64, 4, 4))  | 
 | 98 | +    assert feats[1] == torch.Size((1, 256, 2, 2))  | 
 | 99 | +    assert feats[2] == torch.Size((1, 512, 1, 1))  | 
 | 100 | +    assert feats[3] == torch.Size((1, 1024, 1, 1))  | 
 | 101 | +    assert feats[4] == torch.Size((1, 2048, 1, 1))  | 
 | 102 | + | 
 | 103 | +    # Test resnetv2_50x3_bitm from timm, features_only=True, output_stride=8  | 
 | 104 | +    model = TIMMBackbone(  | 
 | 105 | +        model_name='resnetv2_50x3_bitm',  | 
 | 106 | +        features_only=True,  | 
 | 107 | +        pretrained=False,  | 
 | 108 | +        output_stride=8)  | 
 | 109 | +    imgs = torch.randn(1, 3, 8, 8)  | 
 | 110 | +    feats = model(imgs)  | 
 | 111 | +    feats = [feat.shape for feat in feats]  | 
 | 112 | +    assert len(feats) == 5  | 
 | 113 | +    assert feats[0] == torch.Size((1, 192, 4, 4))  | 
 | 114 | +    assert feats[1] == torch.Size((1, 768, 2, 2))  | 
 | 115 | +    assert feats[2] == torch.Size((1, 1536, 1, 1))  | 
 | 116 | +    assert feats[3] == torch.Size((1, 3072, 1, 1))  | 
 | 117 | +    assert feats[4] == torch.Size((1, 6144, 1, 1))  | 
 | 118 | + | 
 | 119 | +    # Test resnetv2_101x1_bitm from timm, features_only=True, output_stride=8  | 
 | 120 | +    model = TIMMBackbone(  | 
 | 121 | +        model_name='resnetv2_101x1_bitm',  | 
 | 122 | +        features_only=True,  | 
 | 123 | +        pretrained=False,  | 
 | 124 | +        output_stride=8)  | 
 | 125 | +    imgs = torch.randn(1, 3, 8, 8)  | 
 | 126 | +    feats = model(imgs)  | 
 | 127 | +    feats = [feat.shape for feat in feats]  | 
 | 128 | +    assert len(feats) == 5  | 
 | 129 | +    assert feats[0] == torch.Size((1, 64, 4, 4))  | 
 | 130 | +    assert feats[1] == torch.Size((1, 256, 2, 2))  | 
 | 131 | +    assert feats[2] == torch.Size((1, 512, 1, 1))  | 
 | 132 | +    assert feats[3] == torch.Size((1, 1024, 1, 1))  | 
 | 133 | +    assert feats[4] == torch.Size((1, 2048, 1, 1))  | 
0 commit comments