Skip to content

Commit 2472320

Browse files
committed
m
1 parent f7ab283 commit 2472320

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

filter_pruning.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from pruning.methods import filter_prune
1111
from pruning.utils import to_var, train, test, prune_rate
12-
from models import LeNet5
12+
from models import ConvNet
1313

1414

1515
# Hyper Parameters
@@ -36,19 +36,19 @@
3636

3737

3838
# Load the pretrained model
39-
net = LeNet5()
40-
net.load_state_dict(torch.load('lenet5_pretrained.pkl'))
39+
net = ConvNet()
40+
# net.load_state_dict(torch.load('convnet_pretrained.pkl'))
4141
if torch.cuda.is_available():
4242
print('CUDA ensabled.')
4343
net.cuda()
4444
print("--- Pretrained network loaded ---")
45-
test(net, loader_test)
45+
# test(net, loader_test)
4646

4747
# prune the weights
48-
masks = filter_prune(net, param['pruning_perc'])
49-
net.set_masks(masks)
50-
print("--- {}% parameters pruned ---".format(param['pruning_perc']))
51-
test(net, loader_test)
48+
# masks = filter_prune(net, param['pruning_perc'])
49+
# net.set_masks(masks)
50+
# print("--- {}% parameters pruned ---".format(param['pruning_perc']))
51+
# test(net, loader_test)
5252

5353

5454
# Retraining
@@ -66,4 +66,4 @@
6666

6767

6868
# Save and load the entire model
69-
torch.save(net.state_dict(), 'lenet5_pruned.pkl')
69+
torch.save(net.state_dict(), 'convnet_pretrained.pkl')

models.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def set_masks(self, masks):
2727
self.linear3.set_mask(masks[2])
2828

2929

30-
class LeNet5(nn.Module):
30+
class ConvNet(nn.Module):
3131
def __init__(self):
32-
super(LeNet5, self).__init__()
32+
super(ConvNet, self).__init__()
3333

3434
self.conv1 = MaskedConv2d(1, 32, kernel_size=3, padding=1, stride=1)
3535
self.relu1 = nn.ReLU(inplace=True)
@@ -39,21 +39,22 @@ def __init__(self):
3939
self.relu2 = nn.ReLU(inplace=True)
4040
self.maxpool2 = nn.MaxPool2d(2)
4141

42-
self.linear1 = nn.Linear(7*7*64, 200)
42+
self.conv3 = MaskedConv2d(64, 64, kernel_size=3, padding=1, stride=1)
4343
self.relu3 = nn.ReLU(inplace=True)
4444

45-
self.linear2 = nn.Linear(200, 10)
45+
self.linear1 = nn.Linear(7*7*64, 10)
4646

4747
def forward(self, x):
4848
out = self.maxpool1(self.relu1(self.conv1(x)))
4949
out = self.maxpool2(self.relu2(self.conv2(out)))
50+
out = self.relu3(self.conv3(out))
5051
out = out.view(out.size(0), -1)
51-
out = self.relu3(self.linear1(out))
52-
out = self.linear2(out)
52+
out = self.linear1(out)
5353
return out
5454

5555
def set_masks(self, masks):
5656
# Should be a less manual way to set masks
5757
# Leave it for the future
5858
self.conv1.set_mask(torch.from_numpy(masks[0]))
5959
self.conv2.set_mask(torch.from_numpy(masks[1]))
60+
self.conv3.set_mask(torch.from_numpy(masks[2]))

0 commit comments

Comments
 (0)