Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make profiling return pre- and post-optimization plots (up to 4 plots in total) #323

Merged
merged 11 commits into from
Apr 23, 2021
165 changes: 142 additions & 23 deletions hls4ml/model/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@
import numpy as np
import pandas
import seaborn as sb
import uuid
import os
import shutil
from collections import defaultdict

from hls4ml.model.hls_model import HLSModel
from hls4ml.converters import convert_from_config
thesps marked this conversation as resolved.
Show resolved Hide resolved

try:
from tensorflow import keras
Expand All @@ -21,6 +26,19 @@
__torch_profiling_enabled__ = False


def get_unoptimized_hlsmodel(model):
new_config = model.config.config.copy()
new_output_dir = uuid.uuid4().hex

while os.path.exists(new_output_dir):
new_output_dir = uuid.uuid4().hex

new_config['HLSConfig']['Optimizers'] = []
new_config['OutputDir'] = new_output_dir

return convert_from_config(new_config), new_output_dir


def array_to_summary(x, fmt='boxplot'):
if fmt == 'boxplot':
y = {'med' : np.median(x),
Expand Down Expand Up @@ -186,6 +204,7 @@ def weights_hlsmodel(model, fmt='longform', plot='boxplot'):
data = {'x' : [], 'layer' : [], 'weight' : []}
elif fmt == 'summary':
data = []

for layer in model.get_layers():
name = layer.name
for iw, weight in enumerate(layer.get_weights()):
Expand All @@ -208,15 +227,66 @@ def weights_hlsmodel(model, fmt='longform', plot='boxplot'):
data = pandas.DataFrame(data)
return data


def _keras_batchnorm(layer):
weights = layer.get_weights()
epsilon = layer.epsilon

gamma = weights[0]
beta = weights[1]
mean = weights[2]
var = weights[3]

scale = gamma / np.sqrt(var + epsilon)
bias = beta - gamma * mean / np.sqrt(var + epsilon)

return [scale, bias], ['s', 'b']


def _keras_layer(layer):
return layer.get_weights(), ['w', 'b']


keras_process_layer_map = defaultdict(lambda: _keras_layer,
{
'BatchNormalization': _keras_batchnorm,
'QBatchNormalization': _keras_batchnorm
})
thesps marked this conversation as resolved.
Show resolved Hide resolved


def activations_hlsmodel(model, X, fmt='summary', plot='boxplot'):
if fmt == 'longform':
raise NotImplemented
elif fmt == 'summary':
data = []

_, trace = model.trace(np.ascontiguousarray(X))

if len(trace) == 0:
raise RuntimeError("HLSModel must have tracing on for at least 1 layer (this can be set in its config)")

for layer in trace.keys():
print(" {}".format(layer))

if fmt == 'summary':
y = trace[layer].flatten()
y = abs(y[y != 0])

data.append(array_to_summary(y, fmt=plot))
data[-1]['weight'] = layer

return data


def weights_keras(model, fmt='longform', plot='boxplot'):
suffix = ['w', 'b']
if fmt == 'longform':
data = {'x' : [], 'layer' : [], 'weight' : []}
elif fmt == 'summary':
data = []
for layer in model.layers:
name = layer.name
weights = layer.get_weights()
weights, suffix = keras_process_layer_map[type(layer).__name__](layer)

for i, w in enumerate(weights):
l = '{}/{}'.format(name, suffix[i])
w = w.flatten()
Expand Down Expand Up @@ -346,52 +416,101 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'):
Returns
-------
tuple
The pair of produced figures. First weights and biases,
then activations
The quadruple of produced figures. First weights and biases
for the pre- and post-optimization models respectively,
then activations for the pre- and post-optimization models
respectively. (Optimizations are applied to an HLSModel by hls4ml,
a post-optimization HLSModel is a final model)
"""
wp, ap = None, None
wp, wph, ap, aph = None, None, None, None

hls_model_present = hls_model is not None and isinstance(hls_model, HLSModel)
model_present = model is not None

if hls_model_present:
before = " (before optimization)"
after = " (final / after optimization)"
hls_model_unoptimized, tmp_output_dir = get_unoptimized_hlsmodel(hls_model)
else:
before = ""
after = ""
hls_model_unoptimized, tmp_output_dir = None, None

print("Profiling weights")
print("Profiling weights" + before)
data = None
if hls_model is not None and isinstance(hls_model, HLSModel):
data = weights_hlsmodel(hls_model, fmt='summary', plot=plot)
elif model is not None:

if hls_model_present:
data = weights_hlsmodel(hls_model_unoptimized, fmt='summary', plot=plot)
elif model_present:
if __tf_profiling_enabled__ and isinstance(model, keras.Model):
data = weights_keras(model, fmt='summary', plot=plot)
elif __torch_profiling_enabled__ and \
isinstance(model, torch.nn.Sequential):
data = weights_torch(model, fmt='summary', plot=plot)

if data is None:
print("Only keras, PyTorch (Sequential) and HLSModel models " +
"can currently be profiled")
return wp, ap

if hls_model_present and os.path.exists(tmp_output_dir):
shutil.rmtree(tmp_output_dir)

return wp, wph, ap, aph

wp = plots[plot](data, fmt='summary') # weight plot
if isinstance(hls_model, HLSModel) and plot in types_plots:
t_data = types_hlsmodel(hls_model)

if hls_model_present and plot in types_plots:
t_data = types_hlsmodel(hls_model_unoptimized)
types_plots[plot](t_data, fmt='summary')

plt.title("Distribution of (non-zero) weights")
plt.title("Distribution of (non-zero) weights" + before)
plt.tight_layout()

print("Profiling activations")
data = None
if hls_model_present:
print("Profiling weights" + after)

data = weights_hlsmodel(hls_model, fmt='summary', plot=plot)
wph = plots[plot](data, fmt='summary') # weight plot

if plot in types_plots:
t_data = types_hlsmodel(hls_model)
types_plots[plot](t_data, fmt='summary')

plt.title("Distribution of (non-zero) weights" + after)
plt.tight_layout()

if X is not None:
print("Profiling activations" + before)
data = None
if __tf_profiling_enabled__ and isinstance(model, keras.Model):
data = activations_keras(model, X, fmt='summary', plot=plot)
elif __torch_profiling_enabled__ and \
isinstance(model, torch.nn.Sequential):
data = activations_torch(model, X, fmt='summary', plot=plot)
if data is not None:
ap = plots[plot](data, fmt='summary') # activation plot
plt.title("Distribution of (non-zero) activations")
plt.tight_layout()

if X is not None and isinstance(hls_model, HLSModel):
t_data = activation_types_hlsmodel(hls_model)
types_plots[plot](t_data, fmt='summary')
if data is not None:
ap = plots[plot](data, fmt='summary') # activation plot
if hls_model_present and plot in types_plots:
t_data = activation_types_hlsmodel(hls_model_unoptimized)
types_plots[plot](t_data, fmt='summary')
plt.title("Distribution of (non-zero) activations" + before)
plt.tight_layout()

if hls_model_present:
print("Profiling activations" + after)
data = activations_hlsmodel(hls_model, X, fmt='summary', plot=plot)
aph = plots[plot](data, fmt='summary')

t_data = activation_types_hlsmodel(hls_model)
types_plots[plot](t_data, fmt='summary')

plt.title("Distribution of (non-zero) activations (final / after optimization)")
plt.tight_layout()

if hls_model_present and os.path.exists(tmp_output_dir):
shutil.rmtree(tmp_output_dir)

return wp, ap
return wp, wph, ap, aph


########COMPARE OUTPUT IMPLEMENTATION########
Expand Down