Skip to content

aspharjw/pytorch-weights_pruning

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Weights Pruning PyTorch Implementation

Luyu Wang & Gavin Ding, Borealis AI


Weights pruning

Han et al propose to compress deep learning models via weights pruning Han et al, NIPS 2015. This repo is an implementation in PyTorch. The pruning method is replaced by the "class-blinded" method mentioned in See et al, CoNLL 2016, which is much easier to implement and has better performance as well.


High-level idea

  1. We write wrappers on PyTorch Linear and Conv2d layers
  2. For each layer, once a binary mask tensor is computed, it is multiplied with the actual weights tensor on the forward pass
  3. Multiplying the mask is a differentiable operation and the backward pass is handed by automatic differentiation (no explicit code here)

Notes

This implementation is not aiming at obtaining computational efficiency but to offer convenience for studying properties of pruned networks. Discussions on how to have an efficient implementation is welcome. Thanks!

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%