@@ -25,20 +25,20 @@ def test_bisenetv1_backbone():
2525 model .init_weights ()
2626 model .train ()
2727 batch_size = 2
28- imgs = torch .randn (batch_size , 3 , 256 , 512 )
28+ imgs = torch .randn (batch_size , 3 , 64 , 128 )
2929 feat = model (imgs )
3030
3131 assert len (feat ) == 3
3232 # output for segment Head
33- assert feat [0 ].shape == torch .Size ([batch_size , 256 , 32 , 64 ])
33+ assert feat [0 ].shape == torch .Size ([batch_size , 256 , 8 , 16 ])
3434 # for auxiliary head 1
35- assert feat [1 ].shape == torch .Size ([batch_size , 128 , 32 , 64 ])
35+ assert feat [1 ].shape == torch .Size ([batch_size , 128 , 8 , 16 ])
3636 # for auxiliary head 2
37- assert feat [2 ].shape == torch .Size ([batch_size , 128 , 16 , 32 ])
37+ assert feat [2 ].shape == torch .Size ([batch_size , 128 , 4 , 8 ])
3838
3939 # Test input with rare shape
4040 batch_size = 2
41- imgs = torch .randn (batch_size , 3 , 527 , 279 )
41+ imgs = torch .randn (batch_size , 3 , 95 , 27 )
4242 feat = model (imgs )
4343 assert len (feat ) == 3
4444
@@ -47,20 +47,20 @@ def test_bisenetv1_backbone():
4747 BiSeNetV1 (
4848 backbone_cfg = backbone_cfg ,
4949 in_channels = 3 ,
50- spatial_channels = (64 , 64 , 64 ))
50+ spatial_channels = (16 , 16 , 16 ))
5151
5252 with pytest .raises (AssertionError ):
5353 # BiSeNetV1 context path constraints.
5454 BiSeNetV1 (
5555 backbone_cfg = backbone_cfg ,
5656 in_channels = 3 ,
57- context_channels = (128 , 256 , 512 , 1024 ))
57+ context_channels = (16 , 32 , 64 , 128 ))
5858
5959
6060def test_bisenetv1_spatial_path ():
6161 with pytest .raises (AssertionError ):
6262 # BiSeNetV1 spatial path channel constraints.
63- SpatialPath (num_channels = (64 , 64 , 64 ), in_channels = 3 )
63+ SpatialPath (num_channels = (16 , 16 , 16 ), in_channels = 3 )
6464
6565
6666def test_bisenetv1_context_path ():
@@ -79,31 +79,31 @@ def test_bisenetv1_context_path():
7979 with pytest .raises (AssertionError ):
8080 # BiSeNetV1 context path constraints.
8181 ContextPath (
82- backbone_cfg = backbone_cfg , context_channels = (128 , 256 , 512 , 1024 ))
82+ backbone_cfg = backbone_cfg , context_channels = (16 , 32 , 64 , 128 ))
8383
8484
8585def test_bisenetv1_attention_refinement_module ():
86- x_arm = AttentionRefinementModule (256 , 64 )
87- assert x_arm .conv_layer .in_channels == 256
88- assert x_arm .conv_layer .out_channels == 64
86+ x_arm = AttentionRefinementModule (32 , 8 )
87+ assert x_arm .conv_layer .in_channels == 32
88+ assert x_arm .conv_layer .out_channels == 8
8989 assert x_arm .conv_layer .kernel_size == (3 , 3 )
90- x = torch .randn (2 , 256 , 32 , 64 )
90+ x = torch .randn (2 , 32 , 8 , 16 )
9191 x_out = x_arm (x )
92- assert x_out .shape == torch .Size ([2 , 64 , 32 , 64 ])
92+ assert x_out .shape == torch .Size ([2 , 8 , 8 , 16 ])
9393
9494
9595def test_bisenetv1_feature_fusion_module ():
96- ffm = FeatureFusionModule (128 , 256 )
97- assert ffm .conv1 .in_channels == 128
98- assert ffm .conv1 .out_channels == 256
96+ ffm = FeatureFusionModule (16 , 32 )
97+ assert ffm .conv1 .in_channels == 16
98+ assert ffm .conv1 .out_channels == 32
9999 assert ffm .conv1 .kernel_size == (1 , 1 )
100100 assert ffm .gap .output_size == (1 , 1 )
101- assert ffm .conv_atten [0 ].in_channels == 256
102- assert ffm .conv_atten [0 ].out_channels == 256
101+ assert ffm .conv_atten [0 ].in_channels == 32
102+ assert ffm .conv_atten [0 ].out_channels == 32
103103 assert ffm .conv_atten [0 ].kernel_size == (1 , 1 )
104104
105- ffm = FeatureFusionModule (128 , 128 )
106- x1 = torch .randn (2 , 64 , 64 , 128 )
107- x2 = torch .randn (2 , 64 , 64 , 128 )
105+ ffm = FeatureFusionModule (16 , 16 )
106+ x1 = torch .randn (2 , 8 , 8 , 16 )
107+ x2 = torch .randn (2 , 8 , 8 , 16 )
108108 x_out = ffm (x1 , x2 )
109- assert x_out .shape == torch .Size ([2 , 128 , 64 , 128 ])
109+ assert x_out .shape == torch .Size ([2 , 16 , 8 , 16 ])
0 commit comments