A simple NN compression tool using ADMM.
Support weight pruning, weight quantization and custom compression operator.
Just
from admm import YOUR_COMPRESSION_TOOL
Import:
from admm import ADMM_pruning
Instantiating the class:
admm = ADMM_pruning(model, update_interval=args.admm_update_interval, l1=args.admm_l1)
After loss.backward()
, you should:
admm.loss_update(loss)
If you want to mask gradient while finetuning, use:
admm.grad_mask()
Use admm.apply_projW()
and admm.restoreW()
at the beginning and the end of each model evaluation to get evaluation of the pruned model. Like:
admm.apply_projW()
# Evaluate your model here
admm.restoreW()
Want to finished pruning iteration or want to start finetuning, use:
admm.apply_projW()
to prune model thoroughly.
You need to implement a class that inherits from class ADMM. Use your own update()
function to define your compression operator. In brief you need to project the weight parameters (or other parameters you want to compress) into your constraint space.
For example, if you want to do pruning and quantization at the same time, you can simply call both update function one after other, which can project the weights to the intersection space of their constraint space.