-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcount_operations.py
122 lines (103 loc) · 3.62 KB
/
count_operations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torchvision.models as models
import torch.nn as nn
import pretrainedmodels
from torchsummary import summary
from torch.autograd import Variable
from collections import OrderedDict
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from networkx.drawing.nx_agraph import graphviz_layout, to_agraph
import pygraphviz as pgv
import json
import copy
import numpy as np
import pickle
import os
from modelsummary import summaryX
from utils import Utils
from calc_tam_layer import *
import data_loader_cifar10
from operations import * # Pytorch operations (layers)
import time
from simplecnn import SimpleCNN
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def generateSummary(model, img_shape = (3,244,244), automatedModel=True, input_image=0):
try:
modelSummary = summaryX(img_shape, model, automatedModel, input_image)
except Exception as e:
print (f'Exception generating model summary: {e}')
return (None)
return (modelSummary)
def generateModels():
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
torchModels = {}
for modelName in model_names:
print(f"\n---- {modelName} [...] ----")
print(modelName)
model = models.__dict__[modelName]().to(device)
print("[ Generating Summary ... ]")
# Generate the stateTransition graph of the model and get the graph
graph = generateSummary(model, img_shape=(3,244,244), automatedModel=False)
print("[ Storing Summary ... ]")
if graph is not None:
torchModels[modelName] = OrderedDict()
#stateTransitionFreq = generateDictStateTransitionGraph(modelSummary)
torchModels[modelName]['summary'] = graph
else:
continue
# The model itself is no longer needed
del model
torch.cuda.empty_cache()
return torchModels
def getParamDependingLayer(layerName):
layerName = layerName.lower()
if "conv" in layerName or ("pool" in layerName and "adaptive" not in layerName):
return "kernel_size"
elif "adaptive" in layerName:
return "output_size"
elif "dropout" in layerName:
return "p"
elif "linear" in layerName:
return "out_features"
else:
return None
#Count operations
if __name__ == "__main__":
population = generateModels()
operations = {}
for ind in population.keys():
print (population[ind]['summary'].keys())
for layerName, values in population[ind]['summary'].items():
layerName = layerName.split("-")[0]
#print(values)
param = getParamDependingLayer(layerName)
if layerName not in operations.keys():
operations[layerName] = []
if param is not None:
operations[layerName] += [values[param]]
#else:
# print(f"Layer: {layerName}, had {param}")
else:
if param is not None:
operations[layerName] += [values[param]]
#else:
# print(f"Layer: {layerName}, had {param}")
#operations[]
#exit()
# Remove duplicates
for layer in operations:
operations[layer] = set(operations[layer])
# count operations
count = 0
for layer in operations:
size = len(operations[layer])
if size == 0:
count += 1 #relu, etc
else:
count += size
print (operations)
print(f'{count} operations!')