Skip to content

Commit

Permalink
Initial Commit
Browse files Browse the repository at this point in the history
  • Loading branch information
timlee0212 committed Jun 24, 2020
1 parent c9e2e5f commit 55c4f62
Show file tree
Hide file tree
Showing 12 changed files with 2,459 additions and 1 deletion.
462 changes: 462 additions & 0 deletions PENNI_workflow.ipynb

Large diffs are not rendered by default.

20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,20 @@
# PENNI
PENNI: Pruned Kernel Sharing for Efficient CNN Inference
This repo provides the code of [PENNI: Pruned Kernel Sharing for Efficient CNN Inference](https://arxiv.org/abs/2005.07133).

Install Requirement Packages:
pip install -r requirements.txt

If you find this work is helpful, cite with:
@inproceedings{li2020penni,
title={PENNI: Pruned Kernel Sharing for Efficient CNN Inference},
author={Li, Shiyu and Hanson, Edward and Li, Hai and Chen, Yiran},
booktitle={International Conference on Machine Learning},
year={2020}
}

##Acknowledgement
The resnet-56 implementation is from: [pytorch_resnet_cifar10](https://github.com/akamaster/pytorch_resnet_cifar10)

We count the FLOPs number with the modified version of [pytorch-OpCounter](https://github.com/Lyken17/pytorch-OpCounter)

The ImageNet training script is derived from [apex](https://github.com/NVIDIA/apex)
1 change: 1 addition & 0 deletions ckpt/model_checkpoint.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
https://drive.google.com/open?id=1GBB0noqDugn50GV2fx9iXuvw3KvbZAp-
73 changes: 73 additions & 0 deletions decompose/decomConv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
###################################################################
# Defined Decomposed Convolution layer
#
# For ICML 2020 Submission
# UNFINISHED RESEARCH CODE
# DO NOT DISTRIBUTE
#
# Author: XXXXXXXXXXXX
# Date: XXXXXXXXXXXX
##################################################################


import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from collections import namedtuple

class DecomposedConv2D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0,
dilation=1, device=None, bias=True, init='rand', num_basis=2,):
super(DecomposedConv2D, self).__init__()
if isinstance(kernel_size, int): #If input only one kernel size
kernel_size = [kernel_size, kernel_size]
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.num_basis = num_basis
self.device = device

if bias:
self.bias = nn.Parameter(torch.randn((out_channels, )), requires_grad=True)
else:
self.bias = None

self.basis = nn.Parameter(torch.randn((num_basis, kernel_size[0] * kernel_size[1])), requires_grad=False)
self.coefs = nn.Parameter(torch.randn((out_channels * in_channels, num_basis)), requires_grad=True)

def init_decompose_with_pca(self, basis, coefs):
self.basis = nn.Parameter(torch.tensor(basis.reshape(self.num_basis, self.kernel_size[0] * self.kernel_size[1])), requires_grad=False)
self.coefs = nn.Parameter(torch.tensor(coefs.reshape(self.out_channels * self.in_channels, self.num_basis)), requires_grad=True)

def forward(self, x):
true_weight = torch.mm(self.coefs, self.basis).view((self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1]))
out = F.conv2d(x, true_weight, self.bias, self.stride, self.padding, self.dilation)

return out

def extra_repr(self):
return 'in_channels={}, out_channels={}, num_basis = {}, bias={}'.format(
self.in_channels, self.out_channels, self.num_basis, self.bias is not None)

def forward_test(self, x):
#No speedup, due to the inefficient implementation using pytorch
#Efficient Implementation
# 1 1 1 1 1
# 2 ----> 2 2 2 2
# 3 3 3 3 3
basis_kernel = self.basis.repeat((1, self.in_channels)).view((self.num_basis*self.in_channels, 1, self.kernel_size[0], self.kernel_size[1]))
w = ((x.shape[2] + self.padding[0]) - self.kernel_size[0]//2) //self.stride[0]
h = ((x.shape[3] + self.padding[1]) - self.kernel_size[1] // 2) // self.stride[1]
mid_fm = F.conv2d(x.repeat((1,self.num_basis, 1, 1)), basis_kernel, self.bias, self.stride, self.padding, self.dilation, num_groups=self.in_channels)
out = torch.zeros((x.shape[0], self.out_channels, w*h))
for batch_idx in range(x.shape[0]):
out[batch_idx, :, :] = torch.mm(self.coefs.view(self.out_channels, self.in_channels*self.num_basis), \
mid_fm[batch_idx, :, :, :].view(self.in_channels*self.num_basis, -1))
out = out.view(x.shape[0], self.out_channels, w, h)

return out
190 changes: 190 additions & 0 deletions decompose/params_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import numpy as np
import torch.nn as nn
import copy

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE, MDS
from decompose import decomConv

class param_resolver:
def __init__(self, model, quant=False):
self.model = model
self.layer_index = {}
self.num_layers = 0
#self.current_gpu = config.current_gpu
self.params = []
self.params_normed = []
self._store_params(model.named_parameters())

def _store_params(self, named_parameters):
layer_idx = 0
flt_3d = 0
flt_2d = 0

for name, param in named_parameters:

if 'weight' in name and len(param.shape)==4:
if param.shape[2] == 1:
print("Skipping 1x1 Conv...")
continue

self.layer_index[layer_idx] = name

param = param.cpu().detach()
# if self.quant:
# param = param.astype(dtype=int) #Using int to evaluate accurately

self.params.append(param.numpy())
layer_idx += 1
flt_3d += param.shape[0]
flt_2d += param.shape[0] * param.shape[1]

print("Layer:%s ---- \t %d x %d Filters with shape (%d, %d)"
% (name, param.shape[0], param.shape[1], param.shape[2], param.shape[3]))

print("%d Conv Layers Loaded, have %d 3D filters and %d 2D kernels in total." % ( layer_idx+1, flt_3d, flt_2d))
self.num_layers = layer_idx
self.params_normed = np.array(self.params)
#self._normilize_weight()

#Normalized each layer
def _normilize_weight(self, norm='l2'):
self.coef = []
self.params_normed = copy.deepcopy(self.params)
for lidx in range(self.num_layers):
#Skip 1x1 Conv
if self.params_normed[lidx].shape[2]==1:
print("Skipping 1x1 Conv...")
self.coef.append(1)
continue
#Store Normalization Coefficient
self.coef.append(np.zeros(self.params_normed[lidx].shape[:2]))
for i in range(self.params_normed[lidx].shape[1]):
for o in range(self.params_normed[lidx].shape[0]):
filter = self.params_normed[lidx][o, i, :, :]
if norm == 'l2':
coef = np.sqrt(np.sum(filter ** 2))
elif norm == 'l1':
coef = np.sum(np.abs(filter))
elif norm == 'l0':
coef = np.sum(filter!=0)
else:
raise NotImplementedError("Not Supported Norm.")
#Fix Sparse Situation
self.params_normed[lidx][o, i, :, :] = (filter / coef) if coef!=0 else filter
self.coef[lidx][o, i] = coef

def PCA_decomposing(self, basis=2, layers=None):
if layers==None:
layers = np.arange(self.num_layers)

if not isinstance(basis, list):
basis = [basis] * len(layers)

error_list = []

for lidx in layers:

layer_name = self.layer_index[lidx].split('.')

print("Decomposing Layer:", self.layer_index[lidx], " with", basis[lidx], "Basis Filters")
in_channel = self.params_normed[lidx].shape[1]
out_channel = self.params_normed[lidx].shape[0]
num_filters = in_channel * out_channel
filter_size = self.params_normed[lidx].shape[2]

decomposer = PCA(n_components=basis[lidx])
weight = self.params_normed[lidx].reshape(in_channel*out_channel, filter_size**2)
decom_coef = decomposer.fit_transform(weight)
decom_basis = decomposer.components_
decom_bias = decomposer.mean_

c = np.matmul(decom_coef, decom_basis) + decom_bias
c = c.reshape(out_channel, in_channel, filter_size, filter_size)

error_list.append((c - self.params_normed[lidx]).flatten())

error = 0
for o in range(out_channel):
for i in range(in_channel):
error += np.sqrt(np.average((c[o, i, :, :] - self.params_normed[lidx][o, i, :, :])**2))
error = error/(out_channel+in_channel)
print("Decomposing Error:", error)

#NEW VERSION - Recursively Replace Conv Models
parent = self.model
for mkey in layer_name:
n_parent = parent._modules[mkey]
if len(n_parent._modules) == 0 and isinstance(n_parent, nn.Conv2d): #Is a basic operation
print(mkey)
ori_conv = n_parent
parent._modules[mkey] = decomConv.DecomposedConv2D(ori_conv.in_channels, ori_conv.out_channels,
ori_conv.kernel_size, ori_conv.stride, num_basis=basis[lidx] ,
padding=ori_conv.padding, dilation=ori_conv.dilation, bias=ori_conv.bias, device='cuda')
parent._modules[mkey].init_decompose_with_pca(decom_basis, decom_coef)
break
else:
parent = n_parent

return self.model

def plot_params_dist(self, dim=2, method='pca', layer_idx= None):
assert dim in [2, 3], "Must be 2D or 3D."
total_dist = np.empty((0, dim))
total_y = np.array([])
if layer_idx is None:
layer_idx = np.arange(self.num_layers)

for lidx in layer_idx:
if self.params[lidx].shape[2] == 1:
print("Drawing scheme: Layer: %s, Skipping 1x1 Conv...")
continue

filters = self.params[lidx].reshape(-1, self.params[lidx].shape[2] * self.params[lidx].shape[3])
if method == 'pca':
decomposer = PCA(n_components = dim)
elif method == 'tsne':
decomposer = TSNE(n_components = dim, perplexity=10)
elif method=='mds':
decomposer = MDS(n_components = dim)
else:
pass
dist = decomposer.fit_transform(filters)
total_dist = np.vstack((total_dist, dist))
y = np.repeat(lidx, dist.shape[0])
total_y = np.hstack((total_y, y))

if dim==2:
plt.scatter(dist[:, 0], dist[:, 1])
else:
fig = plt.figure()
ax = Axes3D(fig)
ax.scatter(dist[:, 0], dist[:, 1], dist[:, 2])
plt.title(self.layer_index[lidx])
plt.savefig("./%s.jpg"%lidx)
plt.show()

def _cal_mse(self, basis, ori_filters, coefs):
error = 0
error_ori = 0
error_elt = np.zeros(basis.shape[1:])
error_elt_ori = np.zeros(basis.shape)
for idx in range(ori_filters.shape[0]):
error_item = np.square(np.abs(ori_filters[idx] - basis))
error_item_ori = np.square(np.abs(ori_filters[idx] * coefs[idx] - basis * coefs[idx]))
error_elt += error_item
error += np.sum(error_item)
error_elt_ori += error_item_ori
error_ori += np.sum(error_item_ori)

error_elt = np.sqrt(error_elt / ori_filters.shape[0])
error = np.sqrt(error / ori_filters.shape[0])

error_elt_ori = np.sqrt(error_elt_ori / ori_filters.shape[0])
error_ori = np.sqrt(error_ori / ori_filters.shape[0])

return error, error_elt, error_ori, error_elt_ori


Loading

0 comments on commit 55c4f62

Please sign in to comment.