@@ -19,15 +19,16 @@ def __init__(self,
1919 groups = 1 ,
2020 deformable_groups = 1 ,
2121 bias = False ):
22- assert not bias
2322 super (DeformConv , self ).__init__ ()
2423
24+ assert not bias
2525 assert in_channels % groups == 0 , \
2626 'in_channels {} cannot be divisible by groups {}' .format (
2727 in_channels , groups )
2828 assert out_channels % groups == 0 , \
2929 'out_channels {} cannot be divisible by groups {}' .format (
3030 out_channels , groups )
31+
3132 self .in_channels = in_channels
3233 self .out_channels = out_channels
3334 self .kernel_size = _pair (kernel_size )
@@ -50,10 +51,34 @@ def reset_parameters(self):
5051 stdv = 1. / math .sqrt (n )
5152 self .weight .data .uniform_ (- stdv , stdv )
5253
53- def forward (self , input , offset ):
54- return deform_conv (input , offset , self .weight , self .stride ,
55- self .padding , self .dilation , self .groups ,
56- self .deformable_groups )
54+ def forward (self , x , offset ):
55+ return deform_conv (x , offset , self .weight , self .stride , self .padding ,
56+ self .dilation , self .groups , self .deformable_groups )
57+
58+
59+ class DeformConvPack (DeformConv ):
60+
61+ def __init__ (self , * args , ** kwargs ):
62+ super (ModulatedDeformConvPack , self ).__init__ (* args , ** kwargs )
63+
64+ self .conv_offset = nn .Conv2d (
65+ self .in_channels ,
66+ self .deformable_groups * 2 * self .kernel_size [0 ] *
67+ self .kernel_size [1 ],
68+ kernel_size = self .kernel_size ,
69+ stride = _pair (self .stride ),
70+ padding = _pair (self .padding ),
71+ bias = True )
72+ self .init_offset ()
73+
74+ def init_offset (self ):
75+ self .conv_offset .weight .data .zero_ ()
76+ self .conv_offset .bias .data .zero_ ()
77+
78+ def forward (self , x ):
79+ offset = self .conv_offset (x )
80+ return deform_conv (x , offset , self .weight , self .stride , self .padding ,
81+ self .dilation , self .groups , self .deformable_groups )
5782
5883
5984class ModulatedDeformConv (nn .Module ):
@@ -97,30 +122,19 @@ def reset_parameters(self):
97122 if self .bias is not None :
98123 self .bias .data .zero_ ()
99124
100- def forward (self , input , offset , mask ):
101- return modulated_deform_conv (
102- input , offset , mask , self .weight , self .bias , self .stride ,
103- self . padding , self . dilation , self .groups , self .deformable_groups )
125+ def forward (self , x , offset , mask ):
126+ return modulated_deform_conv (x , offset , mask , self . weight , self . bias ,
127+ self .stride , self .padding , self .dilation ,
128+ self .groups , self .deformable_groups )
104129
105130
106131class ModulatedDeformConvPack (ModulatedDeformConv ):
107132
108- def __init__ (self ,
109- in_channels ,
110- out_channels ,
111- kernel_size ,
112- stride = 1 ,
113- padding = 0 ,
114- dilation = 1 ,
115- groups = 1 ,
116- deformable_groups = 1 ,
117- bias = True ):
118- super (ModulatedDeformConvPack , self ).__init__ (
119- in_channels , out_channels , kernel_size , stride , padding , dilation ,
120- groups , deformable_groups , bias )
133+ def __init__ (self , * args , ** kwargs ):
134+ super (ModulatedDeformConvPack , self ).__init__ (* args , ** kwargs )
121135
122136 self .conv_offset_mask = nn .Conv2d (
123- self .in_channels // self . groups ,
137+ self .in_channels ,
124138 self .deformable_groups * 3 * self .kernel_size [0 ] *
125139 self .kernel_size [1 ],
126140 kernel_size = self .kernel_size ,
@@ -133,11 +147,11 @@ def init_offset(self):
133147 self .conv_offset_mask .weight .data .zero_ ()
134148 self .conv_offset_mask .bias .data .zero_ ()
135149
136- def forward (self , input ):
137- out = self .conv_offset_mask (input )
150+ def forward (self , x ):
151+ out = self .conv_offset_mask (x )
138152 o1 , o2 , mask = torch .chunk (out , 3 , dim = 1 )
139153 offset = torch .cat ((o1 , o2 ), dim = 1 )
140154 mask = torch .sigmoid (mask )
141- return modulated_deform_conv (
142- input , offset , mask , self .weight , self .bias , self .stride ,
143- self . padding , self . dilation , self .groups , self .deformable_groups )
155+ return modulated_deform_conv (x , offset , mask , self . weight , self . bias ,
156+ self .stride , self .padding , self .dilation ,
157+ self .groups , self .deformable_groups )
0 commit comments