Skip to content

Commit 1b4583b

Browse files
committed
检测backbone添加out_channels属性,方便后续module获取
1 parent f3f84d7 commit 1b4583b

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

torchocr/networks/backbones/DetMobilenetV3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,10 @@ def __init__(self, in_channels, **kwargs):
184184
inplanes = self.make_divisible(inplanes * scale)
185185
self.stages = nn.ModuleList()
186186
block_list = []
187+
self.out_channels = []
187188
for layer_cfg in cfg:
188189
if layer_cfg[5] == 2 and i > 2:
190+
self.out_channels.append(inplanes)
189191
self.stages.append(nn.Sequential(*block_list))
190192
block_list = []
191193
block = ResidualUnit(num_in_filter=inplanes,
@@ -207,6 +209,7 @@ def __init__(self, in_channels, **kwargs):
207209
padding=0,
208210
groups=1,
209211
act='hard_swish')
212+
self.out_channels.append(self.make_divisible(scale * cls_ch_squeeze))
210213

211214
def make_divisible(self, v, divisor=8, min_value=None):
212215
if min_value is None:
@@ -235,4 +238,4 @@ def forward(self, x):
235238
x = stage(x)
236239
out.append(x)
237240
out[-1] = self.conv2(out[-1])
238-
return out
241+
return out

torchocr/networks/backbones/DetResNetvd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def __init__(self, in_channels, layers, **kwargs):
209209
self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
210210

211211
self.stages = nn.ModuleList()
212+
self.out_channels = []
212213
in_ch = 64
213214
for block_index in range(len(depth)):
214215
block_list = []
@@ -227,8 +228,8 @@ def __init__(self, in_channels, layers, **kwargs):
227228
stride=2 if i == 0 and block_index != 0 else 1,
228229
if_first=block_index == i == 0, name=conv_name))
229230
in_ch = block_list[-1].output_channels
231+
self.out_channels.append(in_ch)
230232
self.stages.append(nn.Sequential(*block_list))
231-
self.out_channels = in_ch
232233

233234
def load_3rd_state_dict(self, _3rd_name, _state):
234235
if _3rd_name == 'paddle':

0 commit comments

Comments
 (0)