@@ -773,5 +773,374 @@ def forward(self, x):
773773 return output , p8
774774
775775
776+ # HRNetV2
777+ import numpy as np
778+ import torch .nn .functional as F
779+
780+ class BasicBlock (nn .Module ):
781+ expansion = 1
782+
783+ def __init__ (self , inplanes , planes , stride = 1 , downsample = None ):
784+ super (BasicBlock , self ).__init__ ()
785+ self .conv1 = nn .Conv2d (inplanes , planes , kernel_size = 3 , stride = stride , padding = 1 , bias = False )
786+ self .bn1 = nn .BatchNorm2d (planes )
787+ self .relu = nn .ReLU (inplace = False )
788+ self .conv2 = nn .Conv2d (planes , planes , kernel_size = 3 , stride = 1 , padding = 1 , bias = False )
789+ self .bn2 = nn .BatchNorm2d (planes )
790+ self .downsample = downsample
791+
792+ def forward (self , x ):
793+ residual = x
794+ out = self .conv1 (x )
795+ out = self .bn1 (out )
796+ out = self .relu (out )
797+
798+ out = self .conv2 (out )
799+ out = self .bn2 (out )
800+
801+ if self .downsample is not None :
802+ residual = self .downsample (x )
803+ out = out + residual
804+ out = self .relu (out )
805+ return out
806+
807+
808+ class BottleNeck (nn .Module ):
809+
810+ expansion = 4
811+
812+ def __init__ (self , inplanes , planes , stride = 1 , downsample = None ):
813+ super (BottleNeck , self ).__init__ ()
814+ self .conv1 = nn .Conv2d (inplanes , planes , kernel_size = 1 , bias = False )
815+ self .bn1 = nn .BatchNorm2d (planes )
816+
817+ self .conv2 = nn .Conv2d (planes , planes , kernel_size = 3 , stride = stride , padding = 1 , bias = False )
818+ self .bn2 = nn .BatchNorm2d (planes )
819+
820+ self .conv3 = nn .Conv2d (planes , planes * self .expansion , kernel_size = 1 , bias = False )
821+ self .bn3 = nn .BatchNorm2d (planes * self .expansion )
822+
823+ self .relu = nn .ReLU (inplace = False )
824+ self .downsample = downsample
825+
826+ def forward (self , x ):
827+ residual = x
828+
829+ out = self .conv1 (x )
830+ out = self .bn1 (out )
831+ out = self .relu (out )
832+
833+ out = self .conv2 (out )
834+ out = self .bn2 (out )
835+ out = self .relu (out )
836+
837+ out = self .conv3 (out )
838+ out = self .bn3 (out )
839+
840+ if self .downsample is not None :
841+ residual = self .downsample (x )
842+
843+ out = out + residual
844+ out = self .relu (out )
845+ return out
846+
847+
848+ blocks_dict = {
849+ 'BASIC' :BasicBlock ,
850+ 'BOTTLENECK' :BottleNeck
851+ }
852+
853+
854+ class HighResolutionModule (nn .Module ):
855+ def __init__ (self , num_branches , blocks , num_blocks , num_inchannels , num_channels , fuse_method , multi_scale_output = True ):
856+ super (HighResolutionModule , self ).__init__ ()
857+ self ._check_branches (num_branches , blocks , num_blocks , num_inchannels , num_channels )
858+
859+ self .num_inchannels = num_inchannels
860+ self .fuse_method = fuse_method
861+ self .num_branches = num_branches
862+
863+ self .multi_scale_output = multi_scale_output
864+
865+ self .branches = self ._make_branches (num_branches , blocks , num_blocks , num_channels )
866+ self .fuse_layers = self ._make_fuse_layers ()
867+ self .relu = nn .ReLU (inplace = False )
868+
869+ def _check_branches (self , num_branches , blocks , num_blocks , num_inchannels , num_channels ):
870+ if num_branches != len (num_blocks ):
871+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})' .format (num_branches , len (num_blocks ))
872+ raise ValueError (error_msg )
873+
874+ if num_branches != len (num_channels ):
875+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})' .format (num_branches , len (num_channels ))
876+ raise ValueError (error_msg )
877+
878+ if num_branches != len (num_inchannels ):
879+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})' .format (num_branches , len (num_inchannels ))
880+ raise ValueError (error_msg )
881+
882+ def _make_one_branch (self , branch_index , block , num_blocks , num_channels , stride = 1 ):
883+ downsample = None
884+ if stride != 1 or \
885+ self .num_inchannels [branch_index ] != num_channels [branch_index ] * block .expansion :
886+ downsample = nn .Sequential (
887+ nn .Conv2d (self .num_inchannels [branch_index ], num_channels [branch_index ] * block .expansion ,
888+ kernel_size = 1 , stride = stride , bias = False ),
889+ nn .BatchNorm2d (num_channels [branch_index ] * block .expansion ),
890+ )
891+
892+ layers = []
893+ layers .append (block (self .num_inchannels [branch_index ], num_channels [branch_index ], stride , downsample ))
894+ self .num_inchannels [branch_index ] = num_channels [branch_index ] * block .expansion
895+ for i in range (1 , num_blocks [branch_index ]):
896+ layers .append (block (self .num_inchannels [branch_index ], num_channels [branch_index ]))
897+
898+ return nn .Sequential (* layers )
899+
900+ def _make_branches (self , num_branches , block , num_blocks , num_channels ):
901+ branches = []
902+ for i in range (num_branches ):
903+ branches .append (self ._make_one_branch (i , block , num_blocks , num_channels ))
904+ return nn .ModuleList (branches )
905+
906+ def _make_fuse_layers (self ):
907+ if self .num_branches == 1 :
908+ return None
909+
910+ num_branches = self .num_branches
911+ num_inchannels = self .num_inchannels
912+ fuse_layers = []
913+ for i in range (num_branches if self .multi_scale_output else 1 ):
914+ fuse_layer = []
915+ for j in range (num_branches ):
916+ if j > i :
917+ fuse_layer .append (nn .Sequential (
918+ nn .Conv2d (num_inchannels [j ], num_inchannels [i ], 1 , 1 , 0 , bias = False ),
919+ nn .BatchNorm2d (num_inchannels [i ])))
920+ elif j == i :
921+ fuse_layer .append (None )
922+ else :
923+ conv3x3s = []
924+ for k in range (i - j ):
925+ if k == i - j - 1 :
926+ num_outchannels_conv3x3 = num_inchannels [i ]
927+ conv3x3s .append (nn .Sequential (
928+ nn .Conv2d (num_inchannels [j ], num_outchannels_conv3x3 , 3 , 2 , 1 , bias = False ),
929+ nn .BatchNorm2d (num_outchannels_conv3x3 )))
930+ else :
931+ num_outchannels_conv3x3 = num_inchannels [j ]
932+ conv3x3s .append (nn .Sequential (
933+ nn .Conv2d (num_inchannels [j ], num_outchannels_conv3x3 , 3 , 2 , 1 , bias = False ),
934+ nn .BatchNorm2d (num_outchannels_conv3x3 ),
935+ nn .ReLU (inplace = False )))
936+ fuse_layer .append (nn .Sequential (* conv3x3s ))
937+ fuse_layers .append (nn .ModuleList (fuse_layer ))
938+
939+ return nn .ModuleList (fuse_layers )
940+
941+ def get_num_inchannels (self ):
942+ return self .num_inchannels
943+
944+ def forward (self , x ):
945+ if self .num_branches == 1 :
946+ return [self .branches [0 ](x [0 ])]
947+
948+ for i in range (self .num_branches ):
949+ x [i ] = self .branches [i ](x [i ])
950+
951+ x_fuse = []
952+ for i in range (len (self .fuse_layers )):
953+ y = x [0 ] if i == 0 else self .fuse_layers [i ][0 ](x [0 ])
954+ for j in range (1 , self .num_branches ):
955+ if i == j :
956+ y = y + x [j ]
957+ elif j > i :
958+ width_output = x [i ].shape [- 1 ]
959+ height_output = x [i ].shape [- 2 ]
960+ y = y + F .interpolate (self .fuse_layers [i ][j ](x [j ]), size = [height_output , width_output ], mode = 'bilinear' , align_corners = True )
961+ else :
962+ y = y + self .fuse_layers [i ][j ](x [j ])
963+ x_fuse .append (self .relu (y ))
964+
965+ return x_fuse
966+
967+
968+
969+ class HighResolutionNet (nn .Module ):
970+
971+ def __init__ (self , nums_class = 136 ):
972+ super (HighResolutionNet , self ).__init__ ()
973+
974+ self .num_branches1 = 2
975+ self .num_branches2 = 3
976+ self .num_branches3 = 4
977+
978+ self .conv1 = nn .Conv2d (3 , 64 , kernel_size = 3 , stride = 2 , padding = 1 , bias = False )
979+ self .bn1 = nn .BatchNorm2d (64 )
980+ self .conv2 = nn .Conv2d (64 , 64 , kernel_size = 3 , stride = 2 , padding = 1 , bias = False )
981+ self .bn2 = nn .BatchNorm2d (64 )
982+ self .relu = nn .ReLU (inplace = False )
983+
984+
985+ self .layer1 = self ._make_layer (BottleNeck , 64 , 64 , 4 )
986+ layer1_out_channel = BottleNeck .expansion * 64
987+
988+ num_channels1 = [48 , 96 ]
989+ num_channels_expansion1 = [num_channels1 [i ] * BasicBlock .expansion for i in range (len (num_channels1 ))]
990+ self .transition1 = self ._make_transition_layer ([layer1_out_channel ], num_channels_expansion1 )
991+ # layer_config['NUM_MODULES'] layer_config['NUM_BRANCHES'] layer_config['NUM_BLOCKS']
992+ # layer_config['NUM_CHANNELS'] blocks_dict[layer_config['BLOCK']] layer_config['FUSE_METHOD']
993+ self .stage2 , pre_stage_channels = self ._make_stage (1 , self .num_branches1 , [4 , 4 ], num_channels1 , BasicBlock , 'SUM' , num_channels_expansion1 )
994+
995+ num_channels2 = [48 , 96 , 192 ]
996+ num_channels_expansion2 = [num_channels2 [i ] * BasicBlock .expansion for i in range (len (num_channels2 ))]
997+ self .transition2 = self ._make_transition_layer (pre_stage_channels , num_channels_expansion2 )
998+ self .stage3 , pre_stage_channels = self ._make_stage (4 , self .num_branches2 , [4 , 4 , 4 ], num_channels2 , BasicBlock , 'SUM' , num_channels_expansion2 )
999+
1000+ num_channels3 = [48 , 96 , 192 , 384 ]
1001+ num_channels_expansion3 = [num_channels3 [i ] * BasicBlock .expansion for i in range (len (num_channels3 ))]
1002+ self .transition3 = self ._make_transition_layer (pre_stage_channels , num_channels_expansion3 )
1003+ self .stage4 , pre_stage_channels = self ._make_stage (3 , self .num_branches3 , [4 , 4 , 4 , 4 ], num_channels3 , BasicBlock , 'SUM' , num_channels_expansion3 , multi_scale_output = True )
1004+
1005+ last_inp_channels = np .int (np .sum (pre_stage_channels ))
1006+
1007+ self .FINAL_CONV_KERNEL = 1
1008+ self .last_layer = nn .Sequential (
1009+ nn .Conv2d (in_channels = last_inp_channels , out_channels = last_inp_channels , kernel_size = 1 , stride = 1 , padding = 0 ),
1010+ nn .BatchNorm2d (last_inp_channels ),
1011+ nn .ReLU (inplace = False ),
1012+ # nn.Conv2d(in_channels=last_inp_channels, out_channels=nums_class, kernel_size=self.FINAL_CONV_KERNEL,
1013+ # stride=1, padding=1 if self.FINAL_CONV_KERNEL == 3 else 0)
1014+ )
1015+
1016+ self .in_features = last_inp_channels * 28 * 28
1017+ self .fc = nn .Linear (in_features = self .in_features , out_features = nums_class )
1018+
1019+ self .init_weights ()
1020+
1021+ def _make_layer (self , block , inplanes , planes , blocks , stride = 1 ):
1022+ downsample = None
1023+ if stride != 1 or inplanes != planes * block .expansion :
1024+ downsample = nn .Sequential (
1025+ nn .Conv2d (inplanes , planes * block .expansion , kernel_size = 1 , stride = stride , bias = False ),
1026+ nn .BatchNorm2d (planes * block .expansion ),
1027+ )
1028+
1029+ layers = []
1030+ layers .append (block (inplanes , planes , stride , downsample ))
1031+ inplanes = planes * block .expansion
1032+ for i in range (1 , blocks ):
1033+ layers .append (block (inplanes , planes ))
1034+
1035+ return nn .Sequential (* layers )
1036+
1037+ def _make_transition_layer (self , num_channels_pre_layer , num_channels_cur_layer ):
1038+ num_branches_cur = len (num_channels_cur_layer )
1039+ num_branches_pre = len (num_channels_pre_layer )
1040+
1041+ transition_layers = []
1042+ for i in range (num_branches_cur ):
1043+ if i < num_branches_pre :
1044+ if num_channels_cur_layer [i ] != num_channels_pre_layer [i ]:
1045+ transition_layers .append (nn .Sequential (
1046+ nn .Conv2d (num_channels_pre_layer [i ], num_channels_cur_layer [i ], 3 , 1 , 1 , bias = False ),
1047+ nn .BatchNorm2d (num_channels_cur_layer [i ]),
1048+ nn .ReLU (inplace = False )))
1049+ else :
1050+ transition_layers .append (None )
1051+ else :
1052+ conv3x3s = []
1053+ for j in range (i + 1 - num_branches_pre ):
1054+ inchannels = num_channels_pre_layer [- 1 ]
1055+ outchannels = num_channels_cur_layer [i ] \
1056+ if j == i - num_branches_pre else inchannels
1057+ conv3x3s .append (nn .Sequential (
1058+ nn .Conv2d (inchannels , outchannels , 3 , 2 , 1 , bias = False ),
1059+ nn .BatchNorm2d (outchannels ),
1060+ nn .ReLU (inplace = False )))
1061+ transition_layers .append (nn .Sequential (* conv3x3s ))
1062+
1063+ return nn .ModuleList (transition_layers )
1064+
1065+ def _make_stage (self , num_modules , num_branches , num_blocks , num_channels , block , fuse_method , num_inchannels , multi_scale_output = True ):
1066+ modules = []
1067+ for i in range (num_modules ):
1068+ # multi_scale_output is only used last module
1069+ if not multi_scale_output and i == num_modules - 1 :
1070+ reset_multi_scale_output = False
1071+ else :
1072+ reset_multi_scale_output = True
1073+ modules .append (
1074+ HighResolutionModule (num_branches , block , num_blocks , num_inchannels , num_channels , fuse_method , reset_multi_scale_output )
1075+ )
1076+ num_inchannels = modules [- 1 ].get_num_inchannels ()
1077+
1078+ return nn .Sequential (* modules ), num_inchannels
1079+
1080+ def init_weights (self ):
1081+ # print('=> init weights from normal distribution')
1082+ for m in self .modules ():
1083+ if isinstance (m , nn .Conv2d ):
1084+ nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' )
1085+ if m .bias is not None :
1086+ nn .init .constant_ (m .bias , 0 )
1087+ elif isinstance (m , nn .BatchNorm2d ):
1088+ nn .init .constant_ (m .weight , 1 )
1089+ nn .init .constant_ (m .bias , 0 )
1090+ elif isinstance (m , nn .Linear ):
1091+ nn .init .normal_ (m .weight , std = 0.01 )
1092+ if m .bias is not None :
1093+ nn .init .constant_ (m .bias , 0 )
1094+
1095+ def forward (self , x ):
1096+ x = self .relu (self .bn1 (self .conv1 (x )))
1097+ axn_input = self .relu (self .bn2 (self .conv2 (x )))
1098+ out = self .layer1 (axn_input )
1099+
1100+ x_list = []
1101+ for i in range (self .num_branches1 ):
1102+ if self .transition1 [i ] is not None :
1103+ x_list .append (self .transition1 [i ](out ))
1104+ else :
1105+ x_list .append (out )
1106+ y_list = self .stage2 (x_list )
1107+
1108+ x_list = []
1109+ for i in range (self .num_branches2 ):
1110+ if self .transition2 [i ] is not None :
1111+ if i < self .num_branches1 :
1112+ x_list .append (self .transition2 [i ](y_list [i ]))
1113+ else :
1114+ x_list .append (self .transition2 [i ](y_list [- 1 ]))
1115+ else :
1116+ x_list .append (y_list [i ])
1117+ y_list = self .stage3 (x_list )
1118+
1119+ x_list = []
1120+ for i in range (self .num_branches3 ):
1121+ if self .transition3 [i ] is not None :
1122+ if i < self .num_branches2 :
1123+ x_list .append (self .transition3 [i ](y_list [i ]))
1124+ else :
1125+ x_list .append (self .transition3 [i ](y_list [- 1 ]))
1126+ else :
1127+ x_list .append (y_list [i ])
1128+ out = self .stage4 (x_list )
1129+ # print(out[0].size(), out[1].size(), out[2].size(), out[3].size())
1130+ # Upsampling
1131+ x0_h , x0_w = out [0 ].size (2 ), out [0 ].size (3 )
1132+ x1 = F .interpolate (out [1 ], size = (x0_h , x0_w ), mode = 'bilinear' , align_corners = True )
1133+ x2 = F .interpolate (out [2 ], size = (x0_h , x0_w ), mode = 'bilinear' , align_corners = True )
1134+ x3 = F .interpolate (out [3 ], size = (x0_h , x0_w ), mode = 'bilinear' , align_corners = True )
1135+
1136+ out = torch .cat ([out [0 ], x1 , x2 , x3 ], 1 )
1137+
1138+ out = self .last_layer (out )
1139+ out = out .view (out .size (0 ), - 1 )
1140+ out = self .fc (out )
1141+
1142+ # print(out.size(), axn_input.size())
1143+ return out , axn_input
1144+
7761145
7771146
0 commit comments