Skip to content

Commit 8538145

Browse files
committed
Simplify resnet definition
1 parent 44c07f0 commit 8538145

File tree

1 file changed

+13
-127
lines changed

1 file changed

+13
-127
lines changed

lib/nets/resnet_v1.py

Lines changed: 13 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -19,137 +19,23 @@
1919
import math
2020
import torch.utils.model_zoo as model_zoo
2121

22+
import torchvision
23+
from torchvision.models.resnet import BasicBlock, Bottleneck
2224

23-
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
24-
'resnet152']
25-
26-
27-
model_urls = {
28-
'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth',
29-
'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth',
30-
'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth',
31-
'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth',
32-
'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth',
33-
}
34-
35-
36-
def conv3x3(in_planes, out_planes, stride=1):
37-
"3x3 convolution with padding"
38-
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
39-
padding=1, bias=False)
40-
41-
42-
class BasicBlock(nn.Module):
43-
expansion = 1
44-
45-
def __init__(self, inplanes, planes, stride=1, downsample=None):
46-
super(BasicBlock, self).__init__()
47-
self.conv1 = conv3x3(inplanes, planes, stride)
48-
self.bn1 = nn.BatchNorm2d(planes)
49-
self.relu = nn.ReLU(inplace=True)
50-
self.conv2 = conv3x3(planes, planes)
51-
self.bn2 = nn.BatchNorm2d(planes)
52-
self.downsample = downsample
53-
self.stride = stride
54-
55-
def forward(self, x):
56-
residual = x
57-
58-
out = self.conv1(x)
59-
out = self.bn1(out)
60-
out = self.relu(out)
61-
62-
out = self.conv2(out)
63-
out = self.bn2(out)
64-
65-
if self.downsample is not None:
66-
residual = self.downsample(x)
67-
68-
out += residual
69-
out = self.relu(out)
70-
71-
return out
72-
73-
74-
class Bottleneck(nn.Module):
75-
expansion = 4
76-
77-
def __init__(self, inplanes, planes, stride=1, downsample=None):
78-
super(Bottleneck, self).__init__()
79-
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change
80-
self.bn1 = nn.BatchNorm2d(planes)
81-
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change
82-
padding=1, bias=False)
83-
self.bn2 = nn.BatchNorm2d(planes)
84-
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
85-
self.bn3 = nn.BatchNorm2d(planes * 4)
86-
self.relu = nn.ReLU(inplace=True)
87-
self.downsample = downsample
88-
self.stride = stride
89-
90-
def forward(self, x):
91-
residual = x
92-
93-
out = self.conv1(x)
94-
out = self.bn1(out)
95-
out = self.relu(out)
96-
97-
out = self.conv2(out)
98-
out = self.bn2(out)
99-
out = self.relu(out)
100-
101-
out = self.conv3(out)
102-
out = self.bn3(out)
103-
104-
if self.downsample is not None:
105-
residual = self.downsample(x)
106-
107-
out += residual
108-
out = self.relu(out)
109-
110-
return out
111-
112-
113-
class ResNet(nn.Module):
25+
class ResNet(torchvision.models.resnet.ResNet):
11426
def __init__(self, block, layers, num_classes=1000):
11527
self.inplanes = 64
116-
super(ResNet, self).__init__()
117-
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
118-
bias=False)
119-
self.bn1 = nn.BatchNorm2d(64)
120-
self.relu = nn.ReLU(inplace=True)
121-
# maxpool different from pytorch-resnet, to match tf-faster-rcnn
122-
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
123-
self.layer1 = self._make_layer(block, 64, layers[0])
124-
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
125-
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
28+
super(ResNet, self).__init__(block, layers, num_classes)
29+
# change to match the caffe resnet
30+
for i in range(2, 4):
31+
getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2)
32+
getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1)
12633
# use stride 1 for the last conv4 layer (same as tf-faster-rcnn)
127-
self.layer4 = self._make_layer(block, 512, layers[3], stride=1)
128-
129-
for m in self.modules():
130-
if isinstance(m, nn.Conv2d):
131-
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
132-
m.weight.data.normal_(0, math.sqrt(2. / n))
133-
elif isinstance(m, nn.BatchNorm2d):
134-
m.weight.data.fill_(1)
135-
m.bias.data.zero_()
136-
137-
def _make_layer(self, block, planes, blocks, stride=1):
138-
downsample = None
139-
if stride != 1 or self.inplanes != planes * block.expansion:
140-
downsample = nn.Sequential(
141-
nn.Conv2d(self.inplanes, planes * block.expansion,
142-
kernel_size=1, stride=stride, bias=False),
143-
nn.BatchNorm2d(planes * block.expansion),
144-
)
145-
146-
layers = []
147-
layers.append(block(self.inplanes, planes, stride, downsample))
148-
self.inplanes = planes * block.expansion
149-
for i in range(1, blocks):
150-
layers.append(block(self.inplanes, planes))
151-
152-
return nn.Sequential(*layers)
34+
self.layer4[0].conv2.stride = (1,1)
35+
self.layer4[0].downsample[0].stride = (1,1)
36+
37+
del self.avgpool, self.fc
38+
15339

15440
def resnet18(pretrained=False):
15541
"""Constructs a ResNet-18 model.

0 commit comments

Comments
 (0)