Skip to content

Commit b3879e8

Browse files
author
Luyu Wang
committed
bug fixed
1 parent d8f7491 commit b3879e8

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

mlp_pretrained.pkl

779 KB
Binary file not shown.

models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
class MLP(nn.Module):
66
def __init__(self):
7-
super(LeNet5, self).__init__()
7+
super(MLP, self).__init__()
88
self.linear1 = MaskedLinear(28*28, 200)
99
self.relu1 = nn.ReLU(inplace=True)
1010
self.linear2 = MaskedLinear(200, 200)
@@ -55,4 +55,4 @@ def set_masks(self, masks):
5555
# Should be a less manual way to set masks
5656
# Leave it for the future
5757
self.conv1.set_mask(masks[0])
58-
self.conv2.set_mask(masks[1])
58+
self.conv2.set_mask(masks[1])

pruning/methods.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def weight_prune(model, pruning_perc):
2525
return masks
2626

2727

28-
def filter_prune(model masks):
28+
def filter_prune(model, masks):
2929
'''
3030
Pruning a single feature map by the l1 norm of kernel weights
3131
arXiv: 1608.08710
@@ -66,4 +66,4 @@ def filter_prune(model masks):
6666
to_prune_filter_ind,
6767
to_prune_layer_ind))
6868

69-
return masks
69+
return masks

weight_pruning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torchvision.transforms as transforms
99

1010
from pruning.methods import weight_prune
11-
from pruning.utils import to_var, train, test, check_nonzero
11+
from pruning.utils import to_var, train, test, prune_rate
1212
from models import MLP
1313

1414

0 commit comments

Comments
 (0)