Skip to content

thiswinex/nn-compression-simple

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 

Repository files navigation

nn-compression-simple

A simple NN compression tool using ADMM.

Support weight pruning, weight quantization and custom compression operator.

Usage

Just

from admm import YOUR_COMPRESSION_TOOL

Weight Pruning Example

Import:

from admm import ADMM_pruning

Instantiating the class:

admm = ADMM_pruning(model, update_interval=args.admm_update_interval, l1=args.admm_l1)

After loss.backward() , you should:

admm.loss_update(loss)

If you want to mask gradient while finetuning, use:

admm.grad_mask()

Use admm.apply_projW() and admm.restoreW() at the beginning and the end of each model evaluation to get evaluation of the pruned model. Like:

admm.apply_projW()
# Evaluate your model here
admm.restoreW()

Want to finished pruning iteration or want to start finetuning, use:

admm.apply_projW()

to prune model thoroughly.

Custom compression operator

You need to implement a class that inherits from class ADMM. Use your own update() function to define your compression operator. In brief you need to project the weight parameters (or other parameters you want to compress) into your constraint space.

For example, if you want to do pruning and quantization at the same time, you can simply call both update function one after other, which can project the weights to the intersection space of their constraint space.

How it works

About

a simple NN compression tool using ADMM

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages