Skip to content

Commit 4d66785

Browse files
committed
filter pruning normalization
1 parent df7ae09 commit 4d66785

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

pruning/methods.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def prune_one_filter(model, masks):
3232
kernel weights
3333
arXiv:1611.06440
3434
'''
35-
3635
NO_MASKS = False
3736
# construct masks if there is not yet
3837
if not masks:
@@ -52,6 +51,9 @@ def prune_one_filter(model, masks):
5251
# find the scaled l2 norm for each filter this layer
5352
value_this_layer = np.square(p_np).sum(axis=1).sum(axis=1)\
5453
.sum(axis=1)/(p_np.shape[1]*p_np.shape[2]*p_np.shape[3])
54+
# normalization (important)
55+
value_this_layer = value_this_layer / \
56+
np.sqrt(np.square(value_this_layer).sum())
5557
min_value, min_ind = arg_nonzero_min(list(value_this_layer))
5658
values.append([min_value, min_ind])
5759

0 commit comments

Comments
 (0)