From d508a63465b57f5b12ab2674a91ebc717ebfc3d8 Mon Sep 17 00:00:00 2001 From: Maksymilian Graczyk Date: Wed, 7 Apr 2021 18:03:37 +0200 Subject: [PATCH 01/11] Produce separate profiling plots for original and HLS models except for activations --- hls4ml/model/profiling.py | 43 ++++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index 1cb5888ed..b8148d49f 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -349,32 +349,43 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): The pair of produced figures. First weights and biases, then activations """ - wp, ap = None, None + wp, wph, ap = None, None, None - print("Profiling weights") + print("Profiling weights (the original model)") data = None - if hls_model is not None and isinstance(hls_model, HLSModel): - data = weights_hlsmodel(hls_model, fmt='summary', plot=plot) - elif model is not None: + + if model is not None: if __tf_profiling_enabled__ and isinstance(model, keras.Model): data = weights_keras(model, fmt='summary', plot=plot) elif __torch_profiling_enabled__ and \ isinstance(model, torch.nn.Sequential): data = weights_torch(model, fmt='summary', plot=plot) + if data is None: print("Only keras, PyTorch (Sequential) and HLSModel models " + "can currently be profiled") - return wp, ap + return wp, wph, ap wp = plots[plot](data, fmt='summary') # weight plot - if isinstance(hls_model, HLSModel) and plot in types_plots: - t_data = types_hlsmodel(hls_model) - types_plots[plot](t_data, fmt='summary') - plt.title("Distribution of (non-zero) weights") + plt.title("Distribution of (non-zero) weights (the original model)") plt.tight_layout() - print("Profiling activations") + data = None + if hls_model is not None and isinstance(hls_model, HLSModel): + data = weights_hlsmodel(hls_model, fmt='summary', plot=plot) + + if data is not None: + print("Profiling weights (the HLS model)") + wph = plots[plot](data, fmt='summary') # weight plot + if isinstance(hls_model, HLSModel) and plot in types_plots: + t_data = types_hlsmodel(hls_model) + types_plots[plot](t_data, fmt='summary') + + plt.title("Distribution of (non-zero) weights (the HLS model)") + plt.tight_layout() + + print("Profiling activations (the original model)") data = None if X is not None: if __tf_profiling_enabled__ and isinstance(model, keras.Model): @@ -384,14 +395,14 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): data = activations_torch(model, X, fmt='summary', plot=plot) if data is not None: ap = plots[plot](data, fmt='summary') # activation plot - plt.title("Distribution of (non-zero) activations") + plt.title("Distribution of (non-zero) activations (the original model)") plt.tight_layout() - if X is not None and isinstance(hls_model, HLSModel): - t_data = activation_types_hlsmodel(hls_model) - types_plots[plot](t_data, fmt='summary') + # if X is not None and isinstance(hls_model, HLSModel): + # t_data = activation_types_hlsmodel(hls_model) + # types_plots[plot](t_data, fmt='summary') - return wp, ap + return wp, wph, ap ########COMPARE OUTPUT IMPLEMENTATION######## From a69f9b8ecac93cd6037efd7c2365c28c52a575f4 Mon Sep 17 00:00:00 2001 From: Maksymilian Graczyk Date: Thu, 8 Apr 2021 17:35:38 +0200 Subject: [PATCH 02/11] Add activations_hlsmodel() to profiling.py --- hls4ml/model/profiling.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index b8148d49f..94ea35aab 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -208,6 +208,31 @@ def weights_hlsmodel(model, fmt='longform', plot='boxplot'): data = pandas.DataFrame(data) return data + +def activations_hlsmodel(model, X, fmt='summary', plot='boxplot'): + if fmt == 'longform': + raise NotImplemented + elif fmt == 'summary': + data = [] + + _, trace = model.trace(np.ascontiguousarray(X)) + + if len(trace) == 0: + raise RuntimeError("HLSModel must have tracing on for at least 1 layer (this can be set in its config)") + + for layer in trace.keys(): + print(" {}".format(layer)) + + if fmt == 'summary': + y = trace[layer].flatten() + y = abs(y[y != 0]) + + data.append(array_to_summary(y, fmt=plot)) + data[-1]['weight'] = layer + + return data + + def weights_keras(model, fmt='longform', plot='boxplot'): suffix = ['w', 'b'] if fmt == 'longform': From ab838ad911a8d40989e5dbb5e43c6e8bf865aece Mon Sep 17 00:00:00 2001 From: Maksymilian Graczyk Date: Thu, 8 Apr 2021 17:39:07 +0200 Subject: [PATCH 03/11] Add producing a HLS model activations plot to profiling --- hls4ml/model/profiling.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index 94ea35aab..f4758922a 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -371,10 +371,12 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): Returns ------- tuple - The pair of produced figures. First weights and biases, - then activations + The quadruple of produced figures. First weights and biases + for the original model and HLSModel respectively, + then activations for the original model and HLSModel + respectively. """ - wp, wph, ap = None, None, None + wp, wph, ap, aph = None, None, None, None print("Profiling weights (the original model)") data = None @@ -389,19 +391,19 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): if data is None: print("Only keras, PyTorch (Sequential) and HLSModel models " + "can currently be profiled") - return wp, wph, ap + return wp, wph, ap, aph wp = plots[plot](data, fmt='summary') # weight plot plt.title("Distribution of (non-zero) weights (the original model)") plt.tight_layout() + print("Profiling weights (the HLS model)") data = None if hls_model is not None and isinstance(hls_model, HLSModel): data = weights_hlsmodel(hls_model, fmt='summary', plot=plot) if data is not None: - print("Profiling weights (the HLS model)") wph = plots[plot](data, fmt='summary') # weight plot if isinstance(hls_model, HLSModel) and plot in types_plots: t_data = types_hlsmodel(hls_model) @@ -423,11 +425,20 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): plt.title("Distribution of (non-zero) activations (the original model)") plt.tight_layout() - # if X is not None and isinstance(hls_model, HLSModel): - # t_data = activation_types_hlsmodel(hls_model) - # types_plots[plot](t_data, fmt='summary') + print("Profiling activations (the HLS model)") + data = None + if X is not None and hls_model is not None and isinstance(hls_model, HLSModel): + data = activations_hlsmodel(hls_model, X, fmt='summary', plot=plot) + + if data is not None: + aph = plots[plot](data, fmt='summary') + if X is not None and isinstance(hls_model, HLSModel): + t_data = activation_types_hlsmodel(hls_model) + types_plots[plot](t_data, fmt='summary') + plt.title("Distribution of (non-zero) activations (the HLS model)") + plt.tight_layout() - return wp, wph, ap + return wp, wph, ap, aph ########COMPARE OUTPUT IMPLEMENTATION######## From 51286d46ebe861576b1067d68cd11d9e36f1bdc5 Mon Sep 17 00:00:00 2001 From: Maksymilian Graczyk Date: Tue, 13 Apr 2021 14:20:58 +0200 Subject: [PATCH 04/11] Make weights_* print different suffices based on weight types in a layer --- hls4ml/model/profiling.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index f4758922a..4915d597c 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -243,7 +243,12 @@ def weights_keras(model, fmt='longform', plot='boxplot'): name = layer.name weights = layer.get_weights() for i, w in enumerate(weights): - l = '{}/{}'.format(name, suffix[i]) + if len(weights) != 2: + suf = i + else: + suf = suffix[i] + + l = '{}/{}'.format(name, suf) w = w.flatten() w = abs(w[w != 0]) n = len(w) @@ -301,7 +306,12 @@ def weights_torch(model, fmt='longform', plot='boxplot'): name = layer.__class__.__name__ weights = list(layer.parameters()) for i, w in enumerate(weights): - l = '{}/{}'.format(name, suffix[i]) + if len(weights) != 2: + suf = i + else: + suf = suffix[i] + + l = '{}/{}'.format(name, suf) w = weights[i].detach().numpy() w = w.flatten() w = abs(w[w != 0]) From d71f98942389ef6a3d9a895f8e282adb40937346 Mon Sep 17 00:00:00 2001 From: Maksymilian Graczyk Date: Fri, 16 Apr 2021 16:39:53 +0200 Subject: [PATCH 05/11] Add pre-optimization graph copy to HLSModel --- hls4ml/model/hls_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/hls4ml/model/hls_model.py b/hls4ml/model/hls_model.py index 5e65ff44e..fdf14a078 100644 --- a/hls4ml/model/hls_model.py +++ b/hls4ml/model/hls_model.py @@ -263,6 +263,7 @@ def __init__(self, config, data_reader, layer_list, inputs=None, outputs=None): self._top_function_lib = None self._make_graph(layer_list) + self.original_graph = self.graph.copy() self._optimize_model(self.config.optimizers) @@ -418,8 +419,11 @@ def next_layer(self): self.index += 1 return self.index - def get_layers(self): - return self.graph.values() + def get_layers(self, before_optimization=False): + if before_optimization: + return self.original_graph.values() + else: + return self.graph.values() def get_input_variables(self): variables = [] From 33ebd758dcf9eeebfd3444d83370e83878b75f0e Mon Sep 17 00:00:00 2001 From: Maksymilian Graczyk Date: Fri, 16 Apr 2021 17:25:28 +0200 Subject: [PATCH 06/11] Make profiling return pre- and post-optimization plots --- hls4ml/model/profiling.py | 87 ++++++++++++++++++++++++--------------- 1 file changed, 53 insertions(+), 34 deletions(-) diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index 4915d597c..251f6f8a6 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -140,7 +140,7 @@ def ap_fixed_WIF(dtype): W, I, F = dtype.width, dtype.integer, dtype.fractional return W, I, F -def types_hlsmodel(model): +def types_hlsmodel(model, before_optimization=False): suffix = ['w', 'b'] data = {'layer' : [], 'low' : [], 'high' : []} # Plot the default precision @@ -151,7 +151,7 @@ def types_hlsmodel(model): data['low'].append(-F) data['high'].append(I-1) - for layer in model.get_layers(): + for layer in model.get_layers(before_optimization=before_optimization): for iw, weight in enumerate(layer.get_weights()): wname = '{}/{}'.format(layer.name, suffix[iw]) T = weight.type @@ -163,7 +163,7 @@ def types_hlsmodel(model): data = pandas.DataFrame(data) return data -def activation_types_hlsmodel(model): +def activation_types_hlsmodel(model, before_optimization=False): data = {'layer' : [], 'low' : [], 'high' : []} # Get the default precision default_precision = model.config.model_precision['default'] @@ -171,7 +171,7 @@ def activation_types_hlsmodel(model): data['layer'].append('model') data['low'].append(-F) data['high'].append(I-1) - for layer in model.get_layers(): + for layer in model.get_layers(before_optimization=before_optimization): T = layer.get_output_variable().type.precision W, I, F = ap_fixed_WIF(T) data['layer'].append(layer.name) @@ -180,13 +180,14 @@ def activation_types_hlsmodel(model): data = pandas.DataFrame(data) return data -def weights_hlsmodel(model, fmt='longform', plot='boxplot'): +def weights_hlsmodel(model, before_optimization=False, fmt='longform', plot='boxplot'): suffix = ['w', 'b'] if fmt == 'longform': data = {'x' : [], 'layer' : [], 'weight' : []} elif fmt == 'summary': data = [] - for layer in model.get_layers(): + + for layer in model.get_layers(before_optimization=before_optimization): name = layer.name for iw, weight in enumerate(layer.get_weights()): l = '{}/{}'.format(name, suffix[iw]) @@ -382,16 +383,29 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): ------- tuple The quadruple of produced figures. First weights and biases - for the original model and HLSModel respectively, - then activations for the original model and HLSModel - respectively. + for the pre- and post-optimization models respectively, + then activations for the pre- and post-optimization models + respectively. (Optimizations are applied to an HLSModel by hls4ml, + a post-optimization HLSModel is a final model) """ wp, wph, ap, aph = None, None, None, None - print("Profiling weights (the original model)") + hls_model_present = hls_model is not None and isinstance(hls_model, HLSModel) + model_present = model is not None + + if hls_model_present: + before = " (before optimization)" + after = " (final / after optimization)" + else: + before = "" + after = "" + + print("Profiling weights" + before) data = None - if model is not None: + if hls_model_present: + data = weights_hlsmodel(hls_model, before_optimization=True, fmt='summary', plot=plot) + elif model_present: if __tf_profiling_enabled__ and isinstance(model, keras.Model): data = weights_keras(model, fmt='summary', plot=plot) elif __torch_profiling_enabled__ and \ @@ -405,48 +419,53 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): wp = plots[plot](data, fmt='summary') # weight plot - plt.title("Distribution of (non-zero) weights (the original model)") + if hls_model_present and plot in types_plots: + t_data = types_hlsmodel(hls_model, before_optimization=True) + types_plots[plot](t_data, fmt='summary') + + plt.title("Distribution of (non-zero) weights" + before) plt.tight_layout() - print("Profiling weights (the HLS model)") - data = None - if hls_model is not None and isinstance(hls_model, HLSModel): - data = weights_hlsmodel(hls_model, fmt='summary', plot=plot) + if hls_model_present: + print("Profiling weights" + after) - if data is not None: + data = weights_hlsmodel(hls_model, fmt='summary', plot=plot) wph = plots[plot](data, fmt='summary') # weight plot - if isinstance(hls_model, HLSModel) and plot in types_plots: + + if plot in types_plots: t_data = types_hlsmodel(hls_model) types_plots[plot](t_data, fmt='summary') - plt.title("Distribution of (non-zero) weights (the HLS model)") + plt.title("Distribution of (non-zero) weights" + after) plt.tight_layout() - print("Profiling activations (the original model)") - data = None if X is not None: + print("Profiling activations" + before) + data = None if __tf_profiling_enabled__ and isinstance(model, keras.Model): data = activations_keras(model, X, fmt='summary', plot=plot) elif __torch_profiling_enabled__ and \ isinstance(model, torch.nn.Sequential): data = activations_torch(model, X, fmt='summary', plot=plot) - if data is not None: - ap = plots[plot](data, fmt='summary') # activation plot - plt.title("Distribution of (non-zero) activations (the original model)") - plt.tight_layout() - print("Profiling activations (the HLS model)") - data = None - if X is not None and hls_model is not None and isinstance(hls_model, HLSModel): - data = activations_hlsmodel(hls_model, X, fmt='summary', plot=plot) + if data is not None: + ap = plots[plot](data, fmt='summary') # activation plot + if hls_model_present and plot in types_plots: + t_data = activation_types_hlsmodel(hls_model, before_optimization=True) + types_plots[plot](t_data, fmt='summary') + plt.title("Distribution of (non-zero) activations" + before) + plt.tight_layout() + + if hls_model_present: + print("Profiling activations" + after) + data = activations_hlsmodel(hls_model, X, fmt='summary', plot=plot) + aph = plots[plot](data, fmt='summary') - if data is not None: - aph = plots[plot](data, fmt='summary') - if X is not None and isinstance(hls_model, HLSModel): t_data = activation_types_hlsmodel(hls_model) types_plots[plot](t_data, fmt='summary') - plt.title("Distribution of (non-zero) activations (the HLS model)") - plt.tight_layout() + + plt.title("Distribution of (non-zero) activations (final / after optimization)") + plt.tight_layout() return wp, wph, ap, aph From 3cb38e7d3fb5fb7319b26eb01793e6f7138586f6 Mon Sep 17 00:00:00 2001 From: Maksymilian Graczyk Date: Tue, 20 Apr 2021 14:09:47 +0200 Subject: [PATCH 07/11] Revert "Add pre-optimization graph copy to HLSModel" This reverts commit d71f9894 --- hls4ml/model/hls_model.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/hls4ml/model/hls_model.py b/hls4ml/model/hls_model.py index fdf14a078..5e65ff44e 100644 --- a/hls4ml/model/hls_model.py +++ b/hls4ml/model/hls_model.py @@ -263,7 +263,6 @@ def __init__(self, config, data_reader, layer_list, inputs=None, outputs=None): self._top_function_lib = None self._make_graph(layer_list) - self.original_graph = self.graph.copy() self._optimize_model(self.config.optimizers) @@ -419,11 +418,8 @@ def next_layer(self): self.index += 1 return self.index - def get_layers(self, before_optimization=False): - if before_optimization: - return self.original_graph.values() - else: - return self.graph.values() + def get_layers(self): + return self.graph.values() def get_input_variables(self): variables = [] From 2c86a8066b62e95729cb8b1bfc8e788f03b5e1ac Mon Sep 17 00:00:00 2001 From: Maksymilian Graczyk Date: Tue, 20 Apr 2021 14:24:29 +0200 Subject: [PATCH 08/11] Make pre-optimisation HLSModel be generated in profiling instead of being obtained from a pre-optimisation graph copy --- hls4ml/model/profiling.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index 251f6f8a6..d19db6f4f 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -6,6 +6,7 @@ import seaborn as sb from hls4ml.model.hls_model import HLSModel +from hls4ml.converters import convert_from_config try: from tensorflow import keras @@ -21,6 +22,12 @@ __torch_profiling_enabled__ = False +def get_unoptimized_hlsmodel(model): + new_config = model.config.config.copy() + new_config['HLSConfig']['Optimizers'] = [] + return convert_from_config(new_config) + + def array_to_summary(x, fmt='boxplot'): if fmt == 'boxplot': y = {'med' : np.median(x), @@ -140,7 +147,7 @@ def ap_fixed_WIF(dtype): W, I, F = dtype.width, dtype.integer, dtype.fractional return W, I, F -def types_hlsmodel(model, before_optimization=False): +def types_hlsmodel(model): suffix = ['w', 'b'] data = {'layer' : [], 'low' : [], 'high' : []} # Plot the default precision @@ -151,7 +158,7 @@ def types_hlsmodel(model, before_optimization=False): data['low'].append(-F) data['high'].append(I-1) - for layer in model.get_layers(before_optimization=before_optimization): + for layer in model.get_layers(): for iw, weight in enumerate(layer.get_weights()): wname = '{}/{}'.format(layer.name, suffix[iw]) T = weight.type @@ -163,7 +170,7 @@ def types_hlsmodel(model, before_optimization=False): data = pandas.DataFrame(data) return data -def activation_types_hlsmodel(model, before_optimization=False): +def activation_types_hlsmodel(model): data = {'layer' : [], 'low' : [], 'high' : []} # Get the default precision default_precision = model.config.model_precision['default'] @@ -171,7 +178,7 @@ def activation_types_hlsmodel(model, before_optimization=False): data['layer'].append('model') data['low'].append(-F) data['high'].append(I-1) - for layer in model.get_layers(before_optimization=before_optimization): + for layer in model.get_layers(): T = layer.get_output_variable().type.precision W, I, F = ap_fixed_WIF(T) data['layer'].append(layer.name) @@ -180,14 +187,14 @@ def activation_types_hlsmodel(model, before_optimization=False): data = pandas.DataFrame(data) return data -def weights_hlsmodel(model, before_optimization=False, fmt='longform', plot='boxplot'): +def weights_hlsmodel(model, fmt='longform', plot='boxplot'): suffix = ['w', 'b'] if fmt == 'longform': data = {'x' : [], 'layer' : [], 'weight' : []} elif fmt == 'summary': data = [] - for layer in model.get_layers(before_optimization=before_optimization): + for layer in model.get_layers(): name = layer.name for iw, weight in enumerate(layer.get_weights()): l = '{}/{}'.format(name, suffix[iw]) @@ -396,15 +403,17 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): if hls_model_present: before = " (before optimization)" after = " (final / after optimization)" + hls_model_unoptimized = get_unoptimized_hlsmodel(hls_model) else: before = "" after = "" + hls_model_unoptimized = None print("Profiling weights" + before) data = None if hls_model_present: - data = weights_hlsmodel(hls_model, before_optimization=True, fmt='summary', plot=plot) + data = weights_hlsmodel(hls_model_unoptimized, fmt='summary', plot=plot) elif model_present: if __tf_profiling_enabled__ and isinstance(model, keras.Model): data = weights_keras(model, fmt='summary', plot=plot) @@ -420,7 +429,7 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): wp = plots[plot](data, fmt='summary') # weight plot if hls_model_present and plot in types_plots: - t_data = types_hlsmodel(hls_model, before_optimization=True) + t_data = types_hlsmodel(hls_model_unoptimized) types_plots[plot](t_data, fmt='summary') plt.title("Distribution of (non-zero) weights" + before) @@ -451,7 +460,7 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): if data is not None: ap = plots[plot](data, fmt='summary') # activation plot if hls_model_present and plot in types_plots: - t_data = activation_types_hlsmodel(hls_model, before_optimization=True) + t_data = activation_types_hlsmodel(hls_model_unoptimized) types_plots[plot](t_data, fmt='summary') plt.title("Distribution of (non-zero) activations" + before) plt.tight_layout() From 16b76d141a72b4f18205ac9a2daa82588450f6e9 Mon Sep 17 00:00:00 2001 From: Maksymilian Graczyk Date: Sat, 17 Apr 2021 12:06:04 +0200 Subject: [PATCH 09/11] Base weights_keras()'s weight and suffix processing on a map --- hls4ml/model/profiling.py | 45 +++++++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index d19db6f4f..69506cdd5 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -4,6 +4,7 @@ import numpy as np import pandas import seaborn as sb +from collections import defaultdict from hls4ml.model.hls_model import HLSModel from hls4ml.converters import convert_from_config @@ -217,6 +218,32 @@ def weights_hlsmodel(model, fmt='longform', plot='boxplot'): return data +def _keras_batchnorm(layer): + weights = layer.get_weights() + epsilon = layer.epsilon + + gamma = weights[0] + beta = weights[1] + mean = weights[2] + var = weights[3] + + scale = gamma / np.sqrt(var + epsilon) + bias = beta - gamma * mean / np.sqrt(var + epsilon) + + return [scale, bias], ['s', 'b'] + + +def _keras_layer(layer): + return layer.get_weights(), ['w', 'b'] + + +keras_process_layer_map = defaultdict(lambda: _keras_layer, + { + 'BatchNormalization': _keras_batchnorm, + 'QBatchNormalization': _keras_batchnorm + }) + + def activations_hlsmodel(model, X, fmt='summary', plot='boxplot'): if fmt == 'longform': raise NotImplemented @@ -242,21 +269,16 @@ def activations_hlsmodel(model, X, fmt='summary', plot='boxplot'): def weights_keras(model, fmt='longform', plot='boxplot'): - suffix = ['w', 'b'] if fmt == 'longform': data = {'x' : [], 'layer' : [], 'weight' : []} elif fmt == 'summary': data = [] for layer in model.layers: name = layer.name - weights = layer.get_weights() - for i, w in enumerate(weights): - if len(weights) != 2: - suf = i - else: - suf = suffix[i] + weights, suffix = keras_process_layer_map[type(layer).__name__](layer) - l = '{}/{}'.format(name, suf) + for i, w in enumerate(weights): + l = '{}/{}'.format(name, suffix[i]) w = w.flatten() w = abs(w[w != 0]) n = len(w) @@ -314,12 +336,7 @@ def weights_torch(model, fmt='longform', plot='boxplot'): name = layer.__class__.__name__ weights = list(layer.parameters()) for i, w in enumerate(weights): - if len(weights) != 2: - suf = i - else: - suf = suffix[i] - - l = '{}/{}'.format(name, suf) + l = '{}/{}'.format(name, suffix[i]) w = weights[i].detach().numpy() w = w.flatten() w = abs(w[w != 0]) From 9d85b9e1601520957df6ecacb2c3a3e7e77c7ed7 Mon Sep 17 00:00:00 2001 From: Maksymilian Graczyk Date: Tue, 20 Apr 2021 16:46:01 +0200 Subject: [PATCH 10/11] Make output dir of generated unoptimized HLSModel random --- hls4ml/model/profiling.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index 69506cdd5..c4c41e352 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -4,6 +4,9 @@ import numpy as np import pandas import seaborn as sb +import uuid +import os +import shutil from collections import defaultdict from hls4ml.model.hls_model import HLSModel @@ -25,8 +28,15 @@ def get_unoptimized_hlsmodel(model): new_config = model.config.config.copy() + new_output_dir = uuid.uuid4().hex + + while os.path.exists(new_output_dir): + new_output_dir = uuid.uuid4().hex + new_config['HLSConfig']['Optimizers'] = [] - return convert_from_config(new_config) + new_config['OutputDir'] = new_output_dir + + return convert_from_config(new_config), new_output_dir def array_to_summary(x, fmt='boxplot'): @@ -420,11 +430,11 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): if hls_model_present: before = " (before optimization)" after = " (final / after optimization)" - hls_model_unoptimized = get_unoptimized_hlsmodel(hls_model) + hls_model_unoptimized, tmp_output_dir = get_unoptimized_hlsmodel(hls_model) else: before = "" after = "" - hls_model_unoptimized = None + hls_model_unoptimized, tmp_output_dir = None, None print("Profiling weights" + before) data = None @@ -441,6 +451,10 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): if data is None: print("Only keras, PyTorch (Sequential) and HLSModel models " + "can currently be profiled") + + if hls_model_present and os.path.exists(tmp_output_dir): + shutil.rmtree(tmp_output_dir) + return wp, wph, ap, aph wp = plots[plot](data, fmt='summary') # weight plot @@ -493,6 +507,9 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'): plt.title("Distribution of (non-zero) activations (final / after optimization)") plt.tight_layout() + if hls_model_present and os.path.exists(tmp_output_dir): + shutil.rmtree(tmp_output_dir) + return wp, wph, ap, aph From ac7e72cdda055464c653a5720ec92347c2b45cb5 Mon Sep 17 00:00:00 2001 From: Maksymilian Graczyk Date: Fri, 23 Apr 2021 12:14:20 +0200 Subject: [PATCH 11/11] Move convert_from_config import to inside get_unoptimized_hlsmodel() --- hls4ml/model/profiling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hls4ml/model/profiling.py b/hls4ml/model/profiling.py index c4c41e352..3a37dbf62 100644 --- a/hls4ml/model/profiling.py +++ b/hls4ml/model/profiling.py @@ -10,7 +10,6 @@ from collections import defaultdict from hls4ml.model.hls_model import HLSModel -from hls4ml.converters import convert_from_config try: from tensorflow import keras @@ -27,6 +26,8 @@ def get_unoptimized_hlsmodel(model): + from hls4ml.converters import convert_from_config + new_config = model.config.config.copy() new_output_dir = uuid.uuid4().hex