Skip to content

Commit 943f6d3

Browse files
author
swordli
committed
add dsfdv2
1 parent 95cc5f8 commit 943f6d3

17 files changed

+2468
-0
lines changed

DSFDv2_r18/bi_fpn.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import torch
2+
import torch.nn as nn
3+
from operations import OPS, BN_OPS, NORMAL_OPS
4+
5+
import torch.nn.functional as F
6+
import math
7+
# from genotypes import FPN_Genotype
8+
9+
from dataset.config import widerface_640 as cfg
10+
11+
# FPN_Genotype = namedtuple("FPN_Genotype", "Inter_Layer Out_Layer")
12+
13+
BiFPN_PRIMITIVES = [
14+
# "none",
15+
# "max_pool_3x3",
16+
# "avg_pool_3x3",
17+
"conv_1x1",
18+
"sep_conv_3x3",
19+
"sep_conv_5x5",
20+
"dil_conv_3x3",
21+
"dil_conv_3x3_3",
22+
"dil_conv_5x5",
23+
# 'sep_conv_7x7',
24+
# "conv_1x3_3x1",
25+
# "conv_1x5_5x1",
26+
# 'dconv_3x3',
27+
# 'conv_1x3',
28+
# 'conv_3x1',
29+
# 'conv_1x5',
30+
# 'conv_5x1',
31+
]
32+
33+
34+
# for retraining the network
35+
class BiFPN_From_Genotype(nn.Module):
36+
""" Build a FPN cell accroding to its genotype file """
37+
38+
def __init__(self, genotype, feature_channel=256, weight_node=False, **kwargs):
39+
"""
40+
:param genotype:
41+
The Genotype is formatted as follow:
42+
[
43+
# for a node
44+
[
45+
# for an operation
46+
(prim, number of the front node)
47+
# other ops
48+
...
49+
]
50+
# other nodes
51+
...
52+
]
53+
:param feature_channel:
54+
"""
55+
super(BiFPN_From_Genotype, self).__init__()
56+
57+
bn = True
58+
59+
# ops = NORMAL_OPS
60+
if bn:
61+
ops = BN_OPS
62+
print("Retrain with BN - FPN.")
63+
else:
64+
ops = OPS
65+
print("Retrain without BN - FPN.")
66+
67+
print(ops.keys())
68+
69+
self.feature_channel = feature_channel
70+
71+
self.genotype = genotype
72+
self.node_weights_enable = weight_node
73+
74+
# sharing the same structure of genotype
75+
[
76+
self.conv1_td,
77+
self.conv1,
78+
self.conv2_td,
79+
self.conv2_du,
80+
self.conv2,
81+
self.conv3_td,
82+
self.conv3_du,
83+
self.conv3,
84+
self.conv4_td,
85+
self.conv4_du,
86+
self.conv4,
87+
self.conv5_td,
88+
self.conv5_du,
89+
self.conv5,
90+
self.conv6_du,
91+
self.conv6,
92+
] = [ops[prim](feature_channel, 1, True) for node in self.genotype for prim, _ in node]
93+
94+
[self.w1, self.w2, self.w3, self.w4, self.w5, self.w6] = [
95+
nn.Parameter(1e-3 * torch.randn(len(node))) for node in self.genotype
96+
]
97+
98+
self.out_layers = nn.ModuleList([nn.Conv2d(feature_channel, feature_channel, 1, 1, 0) for _ in range(6)])
99+
100+
def upsample_as(self, x, y):
101+
return F.interpolate(x, size=y.shape[2:], mode="bilinear", align_corners=True)
102+
103+
def max_pool(self, x):
104+
return F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
105+
106+
def forward(self, sources):
107+
"""
108+
forward function
109+
"""
110+
f1, f2, f3, f4, f5, f6 = sources
111+
112+
# top-down path
113+
f6_td = f6
114+
f5_td = self.conv5_td(self.upsample_as(f6_td, f5) * f5)
115+
f4_td = self.conv4_td(self.upsample_as(f5_td, f4) * f4)
116+
f3_td = self.conv3_td(self.upsample_as(f4_td, f3) * f3)
117+
f2_td = self.conv2_td(self.upsample_as(f3_td, f2) * f2)
118+
f1_td = self.conv1_td(self.upsample_as(f2_td, f1) * f1)
119+
120+
# bottom-up path
121+
f1_du = f1
122+
f2_du = self.conv2_du(self.max_pool(f1_du) * f2)
123+
f3_du = self.conv3_du(self.max_pool(f2_du) * f3)
124+
f4_du = self.conv4_du(self.max_pool(f3_du) * f4)
125+
f5_du = self.conv5_du(self.max_pool(f4_du) * f5)
126+
f6_du = self.conv6_du(self.max_pool(f5_du) * f6)
127+
128+
# output
129+
f1_out = self.conv1(f1)
130+
f2_out = self.conv2(f2)
131+
f3_out = self.conv3(f3)
132+
f4_out = self.conv4(f4)
133+
f5_out = self.conv5(f5)
134+
f6_out = self.conv6(f6)
135+
136+
return [
137+
self.out_layers[0](
138+
(torch.stack([f1_td, f1_out], dim=-1) * F.softmax(self.w1, dim=0)).sum(dim=-1, keepdim=False)
139+
),
140+
self.out_layers[1](
141+
(torch.stack([f2_td, f2_du, f2_out], dim=-1) * F.softmax(self.w2, dim=0)).sum(dim=-1, keepdim=False)
142+
),
143+
self.out_layers[2](
144+
(torch.stack([f3_td, f3_du, f3_out], dim=-1) * F.softmax(self.w3, dim=0)).sum(dim=-1, keepdim=False)
145+
),
146+
self.out_layers[3](
147+
(torch.stack([f4_td, f4_du, f4_out], dim=-1) * F.softmax(self.w4, dim=0)).sum(dim=-1, keepdim=False)
148+
),
149+
self.out_layers[4](
150+
(torch.stack([f5_td, f5_du, f5_out], dim=-1) * F.softmax(self.w5, dim=0)).sum(dim=-1, keepdim=False)
151+
),
152+
self.out_layers[5](
153+
(torch.stack([f6_du, f6_out], dim=-1) * F.softmax(self.w6, dim=0)).sum(dim=-1, keepdim=False)
154+
),
155+
]
156+
157+
158+
class BiFPN_Neck_From_Genotype(nn.Module):
159+
""" FPN_Neck from genotype file """
160+
161+
def __init__(
162+
self, genotype, in_channels=256, feature_size=256, weight_node=False, fpn_layers=1,
163+
):
164+
super(BiFPN_Neck_From_Genotype, self).__init__()
165+
166+
genotype = genotype.Inter_Layer
167+
168+
self.fpn_layers = fpn_layers
169+
if fpn_layers == 1:
170+
self.layers = BiFPN_From_Genotype(genotype, feature_channel=feature_size, weight_node=weight_node,)
171+
else:
172+
self.layers = nn.ModuleList()
173+
174+
for i in range(fpn_layers):
175+
self.layers.append(BiFPN_From_Genotype(genotype, feature_channel=feature_size, weight_node=weight_node))
176+
177+
def forward(self, source):
178+
""" forward function """
179+
180+
if self.fpn_layers == 1:
181+
return self.layers(source)
182+
183+
else:
184+
for layer in self.layers:
185+
source = layer(source)
186+
return source
187+

DSFDv2_r18/dataset/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
'''
2+
@Author: aalenzhang
3+
@Date: 2020-02-20 15:08:32
4+
@LastEditors: Please set LastEditors
5+
@LastEditTime: 2020-03-06 11:34:59
6+
@Description:
7+
@FilePath: \DSFDv2_r18\dataset\__init__.py
8+
'''
9+
from .config import *

DSFDv2_r18/dataset/config.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# config.py
2+
import os.path
3+
4+
# gets home dir cross platform
5+
HOME = os.path.expanduser("~")
6+
7+
# for making bounding boxes pretty
8+
COLORS = (
9+
(255, 0, 0, 128),
10+
(0, 255, 0, 128),
11+
(0, 0, 255, 128),
12+
(0, 255, 255, 128),
13+
(255, 0, 255, 128),
14+
(255, 255, 0, 128),
15+
)
16+
17+
MEANS = (104, 117, 123)
18+
19+
widerface_640 = {
20+
"num_classes": 2,
21+
#'lr_steps': (80000, 100000, 120000),
22+
#'max_iter': 120000,
23+
"feature_maps": [160, 80, 40, 20, 10, 5],
24+
"min_dim": 640,
25+
# "feature_maps": [256, 128, 64, 32, 16, 8],
26+
# "min_dim": 1024,
27+
"steps": [4, 8, 16, 32, 64, 128], # stride
28+
"variance": [0.1, 0.2],
29+
"clip": True, # make default box in [0,1]
30+
"name": "WIDERFace",
31+
"l2norm_scale": [10, 8, 5],
32+
"base": [64, 64, "M", 128, 128, "M", 256, 256, 256, "C", 512, 512, 512, "M", 512, 512, 512,],
33+
"extras": [256, "S", 512, 128, "S", 256],
34+
"mbox": [1, 1, 1, 1, 1, 1],
35+
#'mbox': [2, 2, 2, 2, 2, 2],
36+
#'mbox': [4, 4, 4, 4, 4, 4],
37+
"min_sizes": [16, 32, 64, 128, 256, 512],
38+
"max_sizes": [],
39+
#'max_sizes': [8, 16, 32, 64, 128, 256],
40+
#'aspect_ratios': [ [],[],[],[],[],[] ], # [1,2] default 1
41+
"aspect_ratios": [[1.5], [1.5], [1.5], [1.5], [1.5], [1.5]], # [1,2] default 1
42+
"backbone": "resnet18", # vgg, resnet, detnet, resnet50
43+
# about PC-DARTS
44+
'edge_normalization': True,
45+
'groups': 4,
46+
"lr_steps": (50, 80, 100, 121),
47+
"max_epochs": 121,
48+
49+
'STC_STR': False,
50+
'auxiliary_classify': False,
51+
'retrain_with_bn': True,
52+
'residual_learning': True,
53+
'syncBN': False,
54+
55+
'GN': False,
56+
57+
# BiFPN
58+
"bidirectional_feature_pyramid_network": True,
59+
# FPN
60+
"feature_pyramid_network": False,
61+
# whether to search FPN
62+
"search_feature_pyramid_network": True,
63+
#
64+
"use_searched_feature_pyramid_network": False,
65+
# which layer to fed into the FPN cell
66+
"inter_input_nums": [2, 3, 3, 3, 3, 2],
67+
"out_skip_input_nums": [0, 0, 0, 0, 0, 0],
68+
"inter_start_layer": [0, 0, 1, 2, 3, 4],
69+
"out_skip_start_layer": [0, 1, 2, 3, 4, 5],
70+
"bottom_up_path": False,
71+
72+
# CPM
73+
"cpm_simple": False,
74+
"cpm_simple_v2": True,
75+
"context_predict_module": False,
76+
"search_context_predict_module": False,
77+
78+
"cross_stack": False,
79+
"fpn_cpm_channel": 128, #256
80+
"stack_convs": 1, #3
81+
82+
"max_in_out": True,
83+
"improved_max_in_out": False,
84+
85+
"FreeAnchor": False,
86+
"GHM": False,
87+
88+
"margin_loss_type": "", # arcface, cosface, arcface_scale, cosface_scale
89+
"margin_loss_s": 1,
90+
"margin_loss_m": 0.2,
91+
"focal_loss": False,
92+
"iou_loss": "", # giou, diou, ciou
93+
"ATSS": False,
94+
"ATSS_topk": 9,
95+
"centerness": False,
96+
97+
"pyramid_anchor": True,
98+
"refinedet": False,
99+
"max_out": False,
100+
"anchor_compensation": False,
101+
"data_anchor_sampling": False,
102+
"overlap_thresh": [0.4],
103+
"negpos_ratio": 3,
104+
# test
105+
"nms_thresh": 0.3,
106+
"conf_thresh": 0.05, # 0.01
107+
"num_thresh": 2000, # 5000
108+
}

DSFDv2_r18/layers/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
'''
2+
@Author: aalenzhang
3+
@Date: 2020-02-20 15:08:32
4+
@LastEditors: Please set LastEditors
5+
@LastEditTime: 2020-03-06 11:36:03
6+
@Description:
7+
@FilePath: \DSFDv2_r18\layers\__init__.py
8+
'''
9+
from .functions import *

DSFDv2_r18/layers/__init__.pyc

210 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)