Skip to content

Commit 7f262e1

Browse files
author
叶修强
authored
Update model2.py
1 parent cea8969 commit 7f262e1

File tree

1 file changed

+369
-0
lines changed

1 file changed

+369
-0
lines changed

model2.py

Lines changed: 369 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,5 +773,374 @@ def forward(self, x):
773773
return output, p8
774774

775775

776+
# HRNetV2
777+
import numpy as np
778+
import torch.nn.functional as F
779+
780+
class BasicBlock(nn.Module):
781+
expansion = 1
782+
783+
def __init__(self, inplanes, planes, stride=1, downsample=None):
784+
super(BasicBlock, self).__init__()
785+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
786+
self.bn1 = nn.BatchNorm2d(planes)
787+
self.relu = nn.ReLU(inplace=False)
788+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
789+
self.bn2 = nn.BatchNorm2d(planes)
790+
self.downsample = downsample
791+
792+
def forward(self, x):
793+
residual = x
794+
out = self.conv1(x)
795+
out = self.bn1(out)
796+
out = self.relu(out)
797+
798+
out = self.conv2(out)
799+
out = self.bn2(out)
800+
801+
if self.downsample is not None:
802+
residual = self.downsample(x)
803+
out = out + residual
804+
out = self.relu(out)
805+
return out
806+
807+
808+
class BottleNeck(nn.Module):
809+
810+
expansion = 4
811+
812+
def __init__(self, inplanes, planes, stride=1, downsample=None):
813+
super(BottleNeck, self).__init__()
814+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
815+
self.bn1 = nn.BatchNorm2d(planes)
816+
817+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
818+
self.bn2 = nn.BatchNorm2d(planes)
819+
820+
self.conv3 = nn.Conv2d(planes, planes*self.expansion, kernel_size=1, bias=False)
821+
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
822+
823+
self.relu = nn.ReLU(inplace=False)
824+
self.downsample = downsample
825+
826+
def forward(self, x):
827+
residual = x
828+
829+
out = self.conv1(x)
830+
out = self.bn1(out)
831+
out = self.relu(out)
832+
833+
out = self.conv2(out)
834+
out = self.bn2(out)
835+
out = self.relu(out)
836+
837+
out = self.conv3(out)
838+
out = self.bn3(out)
839+
840+
if self.downsample is not None:
841+
residual = self.downsample(x)
842+
843+
out = out + residual
844+
out = self.relu(out)
845+
return out
846+
847+
848+
blocks_dict = {
849+
'BASIC':BasicBlock,
850+
'BOTTLENECK':BottleNeck
851+
}
852+
853+
854+
class HighResolutionModule(nn.Module):
855+
def __init__(self, num_branches, blocks, num_blocks, num_inchannels, num_channels, fuse_method, multi_scale_output=True):
856+
super(HighResolutionModule, self).__init__()
857+
self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels)
858+
859+
self.num_inchannels = num_inchannels
860+
self.fuse_method = fuse_method
861+
self.num_branches = num_branches
862+
863+
self.multi_scale_output = multi_scale_output
864+
865+
self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels)
866+
self.fuse_layers = self._make_fuse_layers()
867+
self.relu = nn.ReLU(inplace=False)
868+
869+
def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels):
870+
if num_branches != len(num_blocks):
871+
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(num_branches, len(num_blocks))
872+
raise ValueError(error_msg)
873+
874+
if num_branches != len(num_channels):
875+
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(num_branches, len(num_channels))
876+
raise ValueError(error_msg)
877+
878+
if num_branches != len(num_inchannels):
879+
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(num_branches, len(num_inchannels))
880+
raise ValueError(error_msg)
881+
882+
def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
883+
downsample = None
884+
if stride != 1 or \
885+
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
886+
downsample = nn.Sequential(
887+
nn.Conv2d(self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion,
888+
kernel_size=1, stride=stride, bias=False),
889+
nn.BatchNorm2d(num_channels[branch_index] * block.expansion),
890+
)
891+
892+
layers = []
893+
layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample))
894+
self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
895+
for i in range(1, num_blocks[branch_index]):
896+
layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))
897+
898+
return nn.Sequential(*layers)
899+
900+
def _make_branches(self, num_branches, block, num_blocks, num_channels):
901+
branches = []
902+
for i in range(num_branches):
903+
branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
904+
return nn.ModuleList(branches)
905+
906+
def _make_fuse_layers(self):
907+
if self.num_branches == 1:
908+
return None
909+
910+
num_branches = self.num_branches
911+
num_inchannels = self.num_inchannels
912+
fuse_layers = []
913+
for i in range(num_branches if self.multi_scale_output else 1):
914+
fuse_layer = []
915+
for j in range(num_branches):
916+
if j > i:
917+
fuse_layer.append(nn.Sequential(
918+
nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False),
919+
nn.BatchNorm2d(num_inchannels[i])))
920+
elif j == i:
921+
fuse_layer.append(None)
922+
else:
923+
conv3x3s = []
924+
for k in range(i-j):
925+
if k == i - j - 1:
926+
num_outchannels_conv3x3 = num_inchannels[i]
927+
conv3x3s.append(nn.Sequential(
928+
nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
929+
nn.BatchNorm2d(num_outchannels_conv3x3)))
930+
else:
931+
num_outchannels_conv3x3 = num_inchannels[j]
932+
conv3x3s.append(nn.Sequential(
933+
nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
934+
nn.BatchNorm2d(num_outchannels_conv3x3),
935+
nn.ReLU(inplace=False)))
936+
fuse_layer.append(nn.Sequential(*conv3x3s))
937+
fuse_layers.append(nn.ModuleList(fuse_layer))
938+
939+
return nn.ModuleList(fuse_layers)
940+
941+
def get_num_inchannels(self):
942+
return self.num_inchannels
943+
944+
def forward(self, x):
945+
if self.num_branches == 1:
946+
return [self.branches[0](x[0])]
947+
948+
for i in range(self.num_branches):
949+
x[i] = self.branches[i](x[i])
950+
951+
x_fuse = []
952+
for i in range(len(self.fuse_layers)):
953+
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
954+
for j in range(1, self.num_branches):
955+
if i == j:
956+
y = y + x[j]
957+
elif j > i:
958+
width_output = x[i].shape[-1]
959+
height_output = x[i].shape[-2]
960+
y = y + F.interpolate(self.fuse_layers[i][j](x[j]), size=[height_output, width_output], mode='bilinear', align_corners=True)
961+
else:
962+
y = y + self.fuse_layers[i][j](x[j])
963+
x_fuse.append(self.relu(y))
964+
965+
return x_fuse
966+
967+
968+
969+
class HighResolutionNet(nn.Module):
970+
971+
def __init__(self, nums_class=136):
972+
super(HighResolutionNet, self).__init__()
973+
974+
self.num_branches1 = 2
975+
self.num_branches2 = 3
976+
self.num_branches3 = 4
977+
978+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
979+
self.bn1 = nn.BatchNorm2d(64)
980+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
981+
self.bn2 = nn.BatchNorm2d(64)
982+
self.relu = nn.ReLU(inplace=False)
983+
984+
985+
self.layer1 = self._make_layer(BottleNeck, 64, 64, 4)
986+
layer1_out_channel = BottleNeck.expansion*64
987+
988+
num_channels1 = [48, 96]
989+
num_channels_expansion1 = [num_channels1[i] * BasicBlock.expansion for i in range(len(num_channels1))]
990+
self.transition1 = self._make_transition_layer([layer1_out_channel], num_channels_expansion1)
991+
# layer_config['NUM_MODULES'] layer_config['NUM_BRANCHES'] layer_config['NUM_BLOCKS']
992+
# layer_config['NUM_CHANNELS'] blocks_dict[layer_config['BLOCK']] layer_config['FUSE_METHOD']
993+
self.stage2, pre_stage_channels = self._make_stage(1, self.num_branches1, [4, 4], num_channels1, BasicBlock, 'SUM', num_channels_expansion1)
994+
995+
num_channels2 = [48, 96, 192]
996+
num_channels_expansion2 = [num_channels2[i] * BasicBlock.expansion for i in range(len(num_channels2))]
997+
self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels_expansion2)
998+
self.stage3, pre_stage_channels = self._make_stage(4, self.num_branches2, [4, 4, 4], num_channels2, BasicBlock, 'SUM', num_channels_expansion2)
999+
1000+
num_channels3 = [48, 96, 192, 384]
1001+
num_channels_expansion3 = [num_channels3[i] * BasicBlock.expansion for i in range(len(num_channels3))]
1002+
self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels_expansion3)
1003+
self.stage4, pre_stage_channels = self._make_stage(3, self.num_branches3, [4, 4, 4, 4], num_channels3, BasicBlock, 'SUM', num_channels_expansion3, multi_scale_output=True)
1004+
1005+
last_inp_channels = np.int(np.sum(pre_stage_channels))
1006+
1007+
self.FINAL_CONV_KERNEL = 1
1008+
self.last_layer = nn.Sequential(
1009+
nn.Conv2d(in_channels=last_inp_channels, out_channels=last_inp_channels, kernel_size=1, stride=1, padding=0),
1010+
nn.BatchNorm2d(last_inp_channels),
1011+
nn.ReLU(inplace=False),
1012+
# nn.Conv2d(in_channels=last_inp_channels, out_channels=nums_class, kernel_size=self.FINAL_CONV_KERNEL,
1013+
# stride=1, padding=1 if self.FINAL_CONV_KERNEL == 3 else 0)
1014+
)
1015+
1016+
self.in_features = last_inp_channels * 28 * 28
1017+
self.fc = nn.Linear(in_features=self.in_features, out_features=nums_class)
1018+
1019+
self.init_weights()
1020+
1021+
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
1022+
downsample = None
1023+
if stride != 1 or inplanes != planes * block.expansion:
1024+
downsample = nn.Sequential(
1025+
nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
1026+
nn.BatchNorm2d(planes * block.expansion),
1027+
)
1028+
1029+
layers = []
1030+
layers.append(block(inplanes, planes, stride, downsample))
1031+
inplanes = planes * block.expansion
1032+
for i in range(1, blocks):
1033+
layers.append(block(inplanes, planes))
1034+
1035+
return nn.Sequential(*layers)
1036+
1037+
def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
1038+
num_branches_cur = len(num_channels_cur_layer)
1039+
num_branches_pre = len(num_channels_pre_layer)
1040+
1041+
transition_layers = []
1042+
for i in range(num_branches_cur):
1043+
if i < num_branches_pre:
1044+
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
1045+
transition_layers.append(nn.Sequential(
1046+
nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False),
1047+
nn.BatchNorm2d(num_channels_cur_layer[i]),
1048+
nn.ReLU(inplace=False)))
1049+
else:
1050+
transition_layers.append(None)
1051+
else:
1052+
conv3x3s = []
1053+
for j in range(i+1-num_branches_pre):
1054+
inchannels = num_channels_pre_layer[-1]
1055+
outchannels = num_channels_cur_layer[i] \
1056+
if j == i-num_branches_pre else inchannels
1057+
conv3x3s.append(nn.Sequential(
1058+
nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
1059+
nn.BatchNorm2d(outchannels),
1060+
nn.ReLU(inplace=False)))
1061+
transition_layers.append(nn.Sequential(*conv3x3s))
1062+
1063+
return nn.ModuleList(transition_layers)
1064+
1065+
def _make_stage(self, num_modules, num_branches, num_blocks, num_channels, block, fuse_method, num_inchannels, multi_scale_output=True):
1066+
modules = []
1067+
for i in range(num_modules):
1068+
# multi_scale_output is only used last module
1069+
if not multi_scale_output and i == num_modules - 1:
1070+
reset_multi_scale_output = False
1071+
else:
1072+
reset_multi_scale_output = True
1073+
modules.append(
1074+
HighResolutionModule(num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method, reset_multi_scale_output)
1075+
)
1076+
num_inchannels = modules[-1].get_num_inchannels()
1077+
1078+
return nn.Sequential(*modules), num_inchannels
1079+
1080+
def init_weights(self):
1081+
# print('=> init weights from normal distribution')
1082+
for m in self.modules():
1083+
if isinstance(m, nn.Conv2d):
1084+
nn.init.kaiming_normal_(m.weight, mode='fan_out')
1085+
if m.bias is not None:
1086+
nn.init.constant_(m.bias, 0)
1087+
elif isinstance(m, nn.BatchNorm2d):
1088+
nn.init.constant_(m.weight, 1)
1089+
nn.init.constant_(m.bias, 0)
1090+
elif isinstance(m, nn.Linear):
1091+
nn.init.normal_(m.weight, std=0.01)
1092+
if m.bias is not None:
1093+
nn.init.constant_(m.bias, 0)
1094+
1095+
def forward(self, x):
1096+
x = self.relu(self.bn1(self.conv1(x)))
1097+
axn_input = self.relu(self.bn2(self.conv2(x)))
1098+
out = self.layer1(axn_input)
1099+
1100+
x_list = []
1101+
for i in range(self.num_branches1):
1102+
if self.transition1[i] is not None:
1103+
x_list.append(self.transition1[i](out))
1104+
else:
1105+
x_list.append(out)
1106+
y_list = self.stage2(x_list)
1107+
1108+
x_list = []
1109+
for i in range(self.num_branches2):
1110+
if self.transition2[i] is not None:
1111+
if i < self.num_branches1:
1112+
x_list.append(self.transition2[i](y_list[i]))
1113+
else:
1114+
x_list.append(self.transition2[i](y_list[-1]))
1115+
else:
1116+
x_list.append(y_list[i])
1117+
y_list = self.stage3(x_list)
1118+
1119+
x_list = []
1120+
for i in range(self.num_branches3):
1121+
if self.transition3[i] is not None:
1122+
if i < self.num_branches2:
1123+
x_list.append(self.transition3[i](y_list[i]))
1124+
else:
1125+
x_list.append(self.transition3[i](y_list[-1]))
1126+
else:
1127+
x_list.append(y_list[i])
1128+
out = self.stage4(x_list)
1129+
# print(out[0].size(), out[1].size(), out[2].size(), out[3].size())
1130+
# Upsampling
1131+
x0_h, x0_w = out[0].size(2), out[0].size(3)
1132+
x1 = F.interpolate(out[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
1133+
x2 = F.interpolate(out[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
1134+
x3 = F.interpolate(out[3], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
1135+
1136+
out = torch.cat([out[0], x1, x2, x3], 1)
1137+
1138+
out = self.last_layer(out)
1139+
out = out.view(out.size(0), -1)
1140+
out = self.fc(out)
1141+
1142+
# print(out.size(), axn_input.size())
1143+
return out, axn_input
1144+
7761145

7771146

0 commit comments

Comments
 (0)