Skip to content

Commit 2aa4da0

Browse files
committed
small bugs fixed
1 parent 401e9ff commit 2aa4da0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

LeNet5.py renamed to main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
# Load the pretrained model
3232
net = LeNet5()
33-
net.load_state_dict(torch.load('data/lenet5_pretrained.pkl'))
33+
net.load_state_dict(torch.load('models/lenet5_pretrained.pkl'))
3434
if torch.cuda.is_available():
3535
print('CUDA ensabled.')
3636
net.cuda()
@@ -39,7 +39,7 @@
3939

4040
# prune the weights
4141
masks = class_blinded_prune(net, param)
42-
net.set_mask(masks)
42+
net.set_masks(masks)
4343

4444

4545
# Retraining

0 commit comments

Comments
 (0)