From 2e5df8ca8276c380dae0de2f7ada2d4bbbb3000b Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Tue, 5 Jul 2022 21:22:26 +0200 Subject: [PATCH 01/10] Support generated source code as an attribute --- hls4ml/model/attributes.py | 10 +++++++++- hls4ml/model/layers.py | 3 ++- hls4ml/model/types.py | 7 +++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/hls4ml/model/attributes.py b/hls4ml/model/attributes.py index 40bdec338..926ff7884 100644 --- a/hls4ml/model/attributes.py +++ b/hls4ml/model/attributes.py @@ -1,6 +1,6 @@ from collections.abc import MutableMapping -from hls4ml.model.types import InplaceVariable, NamedType, TensorVariable, WeightVariable +from hls4ml.model.types import InplaceVariable, NamedType, TensorVariable, WeightVariable, Source class Attribute(object): def __init__(self, name, value_type=int, default=None, configurable=False): @@ -40,6 +40,10 @@ class WeightAttribute(Attribute): def __init__(self, name): super(WeightAttribute, self).__init__(name, value_type=WeightVariable, default=None, configurable=False) +class CodeAttrubute(Attribute): + def __init__(self, name): + super(WeightAttribute, self).__init__(name, value_type=Source, default=None, configurable=False) + class AttributeDict(MutableMapping): def __init__(self, layer): self.layer = layer @@ -118,3 +122,7 @@ def __iter__(self): class TypeMapping(AttributeMapping): def __init__(self, attributes): super().__init__(attributes, NamedType) + +class CodeMapping(AttributeMapping): + def __init__(self, attributes): + super().__init__(attributes, Source) diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index f821d08e9..b050368b4 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -6,7 +6,7 @@ from hls4ml.model.types import IntegerPrecisionType, FixedPrecisionType, ExponentPrecisionType from hls4ml.model.types import find_minimum_width -from hls4ml.model.attributes import Attribute, TypeMapping, VariableMapping, WeightAttribute, TypeAttribute, ChoiceAttribute, WeightMapping +from hls4ml.model.attributes import Attribute, CodeMapping, TypeMapping, VariableMapping, WeightAttribute, TypeAttribute, ChoiceAttribute, WeightMapping from hls4ml.model.attributes import AttributeDict, AttributeMapping # TODO move this to some utility module @@ -57,6 +57,7 @@ def __init__(self, model, name, attributes, inputs, outputs=None): self.weights = WeightMapping(self.attributes) self.variables = VariableMapping(self.attributes) self.types = TypeMapping(self.attributes) + self.code = CodeMapping(self.attributes) accum_t = NamedType(*reversed(self.model.config.get_precision(self, 'accum'))) self.set_attr('accum_t', accum_t) diff --git a/hls4ml/model/types.py b/hls4ml/model/types.py index afa579a65..3e982524b 100644 --- a/hls4ml/model/types.py +++ b/hls4ml/model/types.py @@ -351,3 +351,10 @@ def __next__(self): return '{%d, %s}' % (value[0], value_fmt) next = __next__ + +class Source(object): + def __init__(self, code): + self.code = code + + def __str__(self): + return str(self.code) From d3d41fd077b40c6413aab4664b5efbdb92735a87 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Tue, 5 Jul 2022 21:23:32 +0200 Subject: [PATCH 02/10] Code-generated im2col 1D/2D cnn implementation --- hls4ml/backends/fpga/fpga_backend.py | 170 ++++++++++++ hls4ml/backends/fpga/passes/codegen.py | 45 ++++ .../vivado/passes/convolution_templates.py | 16 ++ hls4ml/backends/vivado/vivado_backend.py | 28 +- hls4ml/templates/vivado/firmware/parameters.h | 1 + .../vivado/nnet_utils/nnet_code_gen.h | 37 +++ .../templates/vivado/nnet_utils/nnet_conv1d.h | 4 + .../vivado/nnet_utils/nnet_conv1d_resource.h | 232 +++++----------- .../templates/vivado/nnet_utils/nnet_conv2d.h | 18 +- .../vivado/nnet_utils/nnet_conv2d_latency.h | 131 +++------ .../vivado/nnet_utils/nnet_conv2d_resource.h | 253 +++++------------- hls4ml/writer/vivado_writer.py | 24 ++ 12 files changed, 506 insertions(+), 453 deletions(-) create mode 100644 hls4ml/backends/fpga/passes/codegen.py create mode 100644 hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index 4cb38888a..cca1a2af3 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -181,6 +181,27 @@ def set_target_reuse_factor(self, layer): layer.set_attr('reuse_factor', float(rf) / kernel_multiplies) + def get_valid_conv_partition_splits(self, out_height, out_width): + """Generate valid partition splits of a Conv1D/2D layer. + + Essentially a list of divisors of the number of pixels of the output image. + + Args: + out_height (int): The height of the output image + out_width (int): The width of the output image + + Returns: + list: List of valid partition splits + """ + n_pixels = out_height * out_width + valid_n_partitions = [] + for i in range(1, int(n_pixels / 2) + 1): + if n_pixels % i == 0: + valid_n_partitions.append(i) + valid_n_partitions.append(n_pixels) + + return valid_n_partitions + @classmethod def convert_precision_string(cls, precision): if isinstance(precision, IntegerPrecisionType) or isinstance(precision, FixedPrecisionType): @@ -384,6 +405,155 @@ def compute_conv2d_instructions(self, in_H, in_W, in_C, kernel_size=3, stride=1, return (min_H, min_W, windows_int) + def _compute_conv1d_im2col(self, input_shape, kernel=3, stride=1, pad=(0,0)): + W, C = input_shape + pad_l, pad_r = pad + + out_w = (W + pad_l + pad_r - kernel) // stride + 1 + + input_img = np.arange(1, W * C + 1).reshape(W, C) + + img = np.pad(input_img, [(pad_l, pad_r), (0,0)], 'constant') + col = np.zeros((out_w, kernel, C)) + + for x in range(kernel): + x_max = x + stride * out_w + col[:, x, :] = img[x:x_max:stride, :] + + col = col.reshape(out_w, -1) + return col + + def generate_conv1d_line_buffer_fn(self, layer_idx, n_partitions, in_W, in_C, kernel=3, stride=1, pad=0): + if isinstance(pad, Iterable): + pad_left = pad[0] + pad_right = pad[1] + else: + pad_left = pad + pad_right = pad + + im2col_matrix = self._compute_conv1d_im2col( + (in_W, in_C), + kernel, + stride, + (pad_left, pad_right) + ) + + generated_code = ( + "template\n" + "class fill_buffer_{index} : public FillConv1DBuffer {{\n" + " public:\n" + " static void fill_buffer(\n" + " data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],\n" + " data_T buffer[CONFIG_T::n_pixels][CONFIG_T::filt_width * CONFIG_T::n_chan],\n" + " const unsigned partition\n" + " ) {{\n" + ).format(index=layer_idx) + indent = ' ' + + for partition_idx, partition in enumerate(np.split(im2col_matrix, n_partitions)): + generated_code += indent * 2 + 'if (partition == {:>3}) {{\n'.format(partition_idx) + for pixel_idx, arr in enumerate(partition): + buffer_stmts = [] + for j, v in enumerate(arr): + if v == 0: + val = '0' + else: + val = 'data[{}]'.format(int(v-1)) + buffer_stmts.append('buffer[{}][{}] = {:>10};'.format(pixel_idx, j, val)) + generated_code += indent * 3 + ' '.join(buffer_stmts) + '\n' + generated_code += '\n' + indent * 2 + '}\n' + + generated_code += indent + '}\n' + generated_code += '};\n' + + return generated_code + + def _compute_conv2d_im2col(self, input_shape, kernel=(3,3), stride=(1,1), pad=(0,0,0,0)): + H, W, C = input_shape + kernel_h, kernel_w = kernel + stride_h, stride_w = stride + pad_t, pad_b, pad_l, pad_r = pad + + out_h = (H + pad_t + pad_b - kernel_h) // stride_h + 1 + out_w = (W + pad_l + pad_r - kernel_w) // stride_w + 1 + + input_img = np.arange(1, C * H * W + 1).reshape(C, H, W) + + img = np.pad(input_img, [(0,0), (pad_t, pad_b), (pad_l, pad_r)], 'constant') + col = np.zeros((C, kernel_h, kernel_w, out_h, out_w)) + + for y in range(kernel_h): + y_max = y + stride_h * out_h + for x in range(kernel_w): + x_max = x + stride_w * out_w + col[:, y, x, :, :] = img[:, y:y_max:stride_h, x:x_max:stride_w] + + col = col.transpose(3, 4, 0, 1, 2).reshape(out_h * out_w, -1) + return col + + def generate_conv2d_line_buffer_fn(self, layer_idx, n_partitions, in_H, in_W, in_C, kernel=(3, 3), stride=(1, 1), pad=(0, 0, 0, 0)): + if isinstance(kernel, Iterable): + kernel_height = kernel[0] + kernel_width = kernel[1] + else: + kernel_height = kernel + kernel_width = kernel + + if isinstance(stride, Iterable): + stride_height = stride[0] + stride_width = stride[1] + else: + stride_height = stride + stride_width = stride + + if isinstance(pad, Iterable): + pad_top = pad[0] + pad_bottom = pad[1] + pad_left = pad[2] + pad_right = pad[3] + else: + pad_top = pad + pad_bottom = pad + pad_left = pad + pad_right = pad + + im2col_matrix = self._compute_conv2d_im2col( + (in_H, in_W, in_C), + (kernel_height, kernel_width), + (stride_height, stride_width), + (pad_top, pad_bottom, pad_left, pad_right) + ) + + generated_code = ( + "template\n" + "class fill_buffer_{index} : public FillConv2DBuffer {{\n" + " public:\n" + " static void fill_buffer(\n" + " data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan],\n" + " data_T buffer[CONFIG_T::n_pixels][CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan],\n" + " const unsigned partition\n" + " ) {{\n" + ).format(index=layer_idx) + indent = ' ' + + for partition_idx, partition in enumerate(np.split(im2col_matrix, n_partitions)): + generated_code += indent * 2 + 'if (partition == {:>3}) {{\n'.format(partition_idx) + for pixel_idx, arr in enumerate(partition): + buffer_stmts = [] + for j, v in enumerate(arr): + if v == 0: + val = '0' + else: + val = 'data[{}]'.format(int(v-1)) + buffer_stmts.append('buffer[{}][{}] = {:>10};'.format(pixel_idx, j, val)) + generated_code += indent * 3 + ' '.join(buffer_stmts) + '\n' + generated_code += '\n' + indent * 2 + '}\n' + + generated_code += indent + '}\n' + generated_code += '};\n' + + return generated_code + @model_optimizer() def write_hls(self, model): self.writer.write_hls(model) diff --git a/hls4ml/backends/fpga/passes/codegen.py b/hls4ml/backends/fpga/passes/codegen.py new file mode 100644 index 000000000..f03164509 --- /dev/null +++ b/hls4ml/backends/fpga/passes/codegen.py @@ -0,0 +1,45 @@ +from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.layers import Conv1D, Conv2D +from hls4ml.model.types import Source + +class GenerateConvIm2col(OptimizerPass): + ''' Generates tcode for im2col step of 1D/2d convolution ''' + def match(self, node): + return isinstance(node, (Conv1D, Conv2D)) and \ + node.model.config.get_config_value('IOType') == 'io_parallel' + + def transform(self, model, node): + node_class = node.__class__.__name__ + if '1D' in node_class: + self._generate_im2col_1d(node) + elif '2D' in node_class: + self._generate_im2col_2d(node) + else: + raise Exception('Cannot generate instructions for node {} ({})'.format(node.name, node_class)) + + def _generate_im2col_1d(self, node): + code_str = node.model.config.backend.generate_conv1d_line_buffer_fn( + node.get_attr('index'), + node.get_attr('n_partitions'), + node.get_input_variable().shape[0], + node.get_input_variable().shape[1], + kernel=node.get_attr('filt_width'), + stride=node.get_attr('stride_width'), + pad=(node.get_attr('pad_left'), node.get_attr('pad_right')) + ) + + node.set_attr('line_buffer_codegen', Source(code_str)) + + def _generate_im2col_2d(self, node): + code_str = node.model.config.backend.generate_conv2d_line_buffer_fn( + node.get_attr('index'), + node.get_attr('n_partitions'), + node.get_input_variable().shape[0], + node.get_input_variable().shape[1], + node.get_input_variable().shape[2], + kernel=(node.get_attr('filt_height'), node.get_attr('filt_width')), + stride=(node.get_attr('stride_height'), node.get_attr('stride_width')), + pad=(node.get_attr('pad_top'), node.get_attr('pad_bottom'), node.get_attr('pad_left'), node.get_attr('pad_right')) + ) + + node.set_attr('line_buffer_codegen', Source(code_str)) diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index d4ac2d5b0..46477c528 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -37,6 +37,10 @@ static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation}; static const unsigned min_width = {min_width}; static const ap_uint pixels[min_width]; + static const unsigned n_partitions = {n_partitions}; + static const unsigned n_pixels = out_width / n_partitions; + template + using fill_buffer = nnet::{fill_fn}; typedef {accum_t.name} accum_t; typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; @@ -60,6 +64,10 @@ def format(self, node): params['nzeros'] = node.get_weights('weight').nzeros params['config_t'] = 'config{}_mult'.format(node.index) + if node.model.config.get_config_value('IOType') == 'io_parallel': + params['fill_fn'] = 'fill_buffer_{}'.format(node.index) + else: + params['fill_fn'] = 'FillConv1DBuffer' conv_config = self.template.format(**params) mult_params = self._default_config_params(node) @@ -109,6 +117,10 @@ def format(self, node): static const unsigned min_height = {min_height}; static const unsigned min_width = {min_width}; static const ap_uint pixels[min_height * min_width]; + static const unsigned n_partitions = {n_partitions}; + static const unsigned n_pixels = out_height * out_width / n_partitions; + template + using fill_buffer = nnet::{fill_fn}; typedef {accum_t.name} accum_t; typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; @@ -133,6 +145,10 @@ def format(self, node): params['nzeros'] = node.get_weights('weight').nzeros params['config_t'] = 'config{}_mult'.format(node.index) + if node.model.config.get_config_value('IOType') == 'io_parallel': + params['fill_fn'] = 'fill_buffer_{}'.format(node.index) + else: + params['fill_fn'] = 'FillConv2DBuffer' conv_config = self.template.format(**params) mult_params = self._default_config_params(node) diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 4c02f6f37..bb8a6a521 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -60,6 +60,7 @@ def _register_flows(self): 'vivado:transform_types', 'vivado:generate_conv_streaming_instructions', 'vivado:apply_resource_strategy', + 'vivado:generate_conv_im2col', ] vivado_types_flow = register_flow('specific_types', vivado_types, requires=[init_flow], backend=self.name) @@ -157,7 +158,18 @@ def init_conv1d(self, layer): self.set_closest_reuse_factor(layer, n_in, n_out) else: layer.set_attr('strategy', 'latency') - + + out_width = layer.get_output_variable().shape[0] + chosen_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1) + valid_pf = self.get_valid_conv_partition_splits(1, out_width) + if chosen_pf not in valid_pf: + closest_pf = self.get_closest_reuse_factor(valid_pf, chosen_pf) + print('WARNING: Invalid ParallelizationFactor={} in layer "{}". Using ParallelizationFactor={} instead. Valid ParallelizationFactor(s): {}.' + .format(chosen_pf, layer.name, closest_pf, ','.join(map(str, valid_pf)))) + else: + closest_pf = chosen_pf + layer.set_attr('n_partitions', out_width // closest_pf) + layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) @layer_optimizer(SeparableConv1D) @@ -183,7 +195,19 @@ def init_conv2d(self, layer): self.set_closest_reuse_factor(layer, n_in, n_out) else: layer.set_attr('strategy', 'latency') - + + out_height = layer.get_output_variable().shape[0] + out_width = layer.get_output_variable().shape[1] + chosen_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1) + valid_pf = self.get_valid_conv_partition_splits(out_height, out_width) + if chosen_pf not in valid_pf: + closest_pf = self.get_closest_reuse_factor(valid_pf, chosen_pf) + print('WARNING: Invalid ParallelizationFactor={} in layer "{}". Using ParallelizationFactor={} instead. Valid ParallelizationFactor(s): {}.' + .format(chosen_pf, layer.name, closest_pf, ','.join(map(str, valid_pf)))) + else: + closest_pf = chosen_pf + layer.set_attr('n_partitions', out_height * out_width // closest_pf) + layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) @layer_optimizer(SeparableConv2D) diff --git a/hls4ml/templates/vivado/firmware/parameters.h b/hls4ml/templates/vivado/firmware/parameters.h index bb6413535..addee4ef2 100644 --- a/hls4ml/templates/vivado/firmware/parameters.h +++ b/hls4ml/templates/vivado/firmware/parameters.h @@ -5,6 +5,7 @@ #include "ap_fixed.h" #include "nnet_utils/nnet_helpers.h" +#include "nnet_utils/nnet_code_gen.h" //hls-fpga-machine-learning insert includes //hls-fpga-machine-learning insert weights diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h b/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h new file mode 100644 index 000000000..d170eb667 --- /dev/null +++ b/hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h @@ -0,0 +1,37 @@ +#ifndef NNET_INSTR_GEN_H_ +#define NNET_INSTR_GEN_H_ + +#include +#include "nnet_helpers.h" + +namespace nnet { + +template +class FillConv1DBuffer{ + public: + static void fill_buffer( + data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + data_T buffer[CONFIG_T::n_pixels][CONFIG_T::filt_width * CONFIG_T::n_chan], + const unsigned partition + ) { + // To be implemented in subclasses + } +}; + +template +class FillConv2DBuffer{ + public: + static void fill_buffer( + data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + data_T buffer[CONFIG_T::n_pixels][CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan], + const unsigned partition + ) { + // To be implemented in subclasses + } +}; + +//hls4ml insert code + +} + +#endif \ No newline at end of file diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h index 9846bd216..bde62161e 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h @@ -58,6 +58,8 @@ void conv_1d_cl( typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + #pragma HLS INLINE region + if (CONFIG_T::strategy == nnet::latency) { conv_1d_latency_cl(data, res, weights, biases); } else { @@ -74,6 +76,8 @@ void pointwise_conv_1d_cl( { assert(CONFIG_T::filt_width == 1); + #pragma HLS INLINE region + if (CONFIG_T::strategy == nnet::latency) { pointwise_conv_1d_latency_cl(data, res, weights, biases); } else { diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_resource.h index 142c7973a..2ce906be1 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_resource.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_resource.h @@ -6,156 +6,94 @@ namespace nnet { -template -void im2col_1d(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], data_T data_col[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::out_width]) { - //int index = 0; - for (int channel = CONFIG_T::n_chan; channel--; data += CONFIG_T::in_width) { - #pragma HLS PIPELINE II=1 rewind - for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) { - #pragma HLS UNROLL - int input_col = -CONFIG_T::pad_left + kernel_col * CONFIG_T::dilation; - for (int output_col = CONFIG_T::out_width; output_col; output_col--) { - #pragma HLS UNROLL - if (input_col >= 0 && input_col < CONFIG_T::in_width) { - *(data_col++) = data[input_col]; - //data_col[index] = data[input_col]; - } else { - *(data_col++) = 0; - //data_col[index] = 0; - } - //index++; - input_col += CONFIG_T::stride_width; - } - } - } -} - template -void conv_1d_full( +void conv_1d_resource_cl( data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], - typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] -) + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { - data_T data_conv[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::out_width]; - data_T data_col[CONFIG_T::filt_width * CONFIG_T::n_chan]; - res_T res_col[CONFIG_T::n_filt]; + constexpr unsigned mult_n_in = CONFIG_T::filt_width * CONFIG_T::n_chan; + constexpr unsigned mult_n_out = CONFIG_T::n_filt; + constexpr unsigned block_factor = DIV_ROUNDUP(mult_n_in * mult_n_out, CONFIG_T::reuse_factor); + constexpr unsigned multscale = block_factor / mult_n_out; - //#pragma HLS ARRAY_PARTITION variable=data_conv complete - #pragma HLS ARRAY_PARTITION variable=data_col complete - #pragma HLS ARRAY_PARTITION variable=res_col complete + assert((block_factor % mult_n_out == 0 || CONFIG_T::reuse_factor >= mult_n_in) && "The current Reuse Factor is not allowed"); + assert((CONFIG_T::reuse_factor <= CONFIG_T::filt_width * CONFIG_T::n_chan) && "This function is correct only for RF <= FILT_WIDTH * N_CHAN"); - im2col_1d(data, data_conv); + data_T data_buf[CONFIG_T::n_pixels][mult_n_in]; + #pragma HLS ARRAY_PARTITION variable=data_buf complete dim=0 - for (int i = 0; i < CONFIG_T::out_width; i++) { - #pragma HLS UNROLL - for (int j = 0; j < CONFIG_T::filt_width * CONFIG_T::n_chan; j++) { - data_col[j] = data_conv[j * CONFIG_T::out_width + i]; - } - dense_resource(data_col, res_col, weights, biases); - for (int j = 0; j < CONFIG_T::n_filt; j++) { - //res[i * CONFIG_T::n_filt + j] = res_col[j]; - res[j * CONFIG_T::out_width + i] = res_col[j]; // Transposed order - } - } -} + #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor + #pragma HLS ARRAY_PARTITION variable=biases complete -template -void im2col_1d_cf_idx(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], data_T data_col[CONFIG_T::filt_width * CONFIG_T::n_chan], const int col) { - ChannelLoop: - for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { - //#pragma HLS UNROLL - #pragma HLS PIPELINE II=1 rewind - KernelLoop: - for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) { + typename CONFIG_T::accum_t acc[CONFIG_T::n_pixels][mult_n_out]; + #pragma HLS ARRAY_PARTITION variable=acc complete dim=0 + + PartitionLoop: + for (unsigned i_part = 0; i_part < CONFIG_T::n_partitions; i_part++) { + //#pragma HLS UNROLL // We don't want this loop unrolled + + CONFIG_T::template fill_buffer::fill_buffer(data, data_buf, i_part); + + PixelInitAccumLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { #pragma HLS UNROLL - int input_col = -CONFIG_T::pad_left + kernel_col * CONFIG_T::dilation + col * CONFIG_T::stride_width; - if (input_col >= 0 && input_col < CONFIG_T::in_width) { - //*(data_col++) = data[input_col]; - data_col[channel * CONFIG_T::filt_width + kernel_col] = data[channel * CONFIG_T::in_width + input_col]; - } else { - //*(data_col++) = 0; - data_col[channel * CONFIG_T::filt_width + kernel_col] = 0; - } - } - } -} -template -void im2col_1d_cf(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], data_T data_col[CONFIG_T::n_chan * CONFIG_T::filt_width], const int col) { - int index = 0; - ChannelLoop: - for (int channel = CONFIG_T::n_chan; channel--; data += CONFIG_T::in_width) { - #pragma HLS UNROLL - KernelLoop: - for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) { - int input_col = -CONFIG_T::pad_left + kernel_col * CONFIG_T::dilation + col * CONFIG_T::stride_width; - if (input_col >= 0 && input_col < CONFIG_T::in_width) { - //*(data_col++) = data[input_col]; - data_col[index] = data[input_col]; - } else { - //*(data_col++) = 0; - data_col[index] = 0; + InitAccumLoop: + for (unsigned i_acc = 0; i_acc < mult_n_out; i_acc++) { + #pragma HLS UNROLL + acc[i_pxl][i_acc] = (typename CONFIG_T::accum_t) biases[i_acc]; } - index++; } - } -} -template -void conv_1d_resource_cf( - data_T data[CONFIG_T::n_chan * CONFIG_T::in_width], - res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], - typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], - typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] -) -{ - const int nin = CONFIG_T::n_chan * CONFIG_T::filt_width; - const int nout = CONFIG_T::n_filt; - const int rufactor = CONFIG_T::reuse_factor; - const int block_factor = DIV_ROUNDUP(nin*nout, rufactor); + ReuseLoop: + for (unsigned i_rf = 0; i_rf < CONFIG_T::reuse_factor; i_rf++) { + #pragma HLS PIPELINE II=1 rewind - //#pragma HLS function_instantiate variable=weights,biases - //#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose correctly - //#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor - //#pragma HLS ARRAY_PARTITION variable=biases complete - - data_T data_col[CONFIG_T::filt_width * CONFIG_T::n_chan]; - res_T res_col[CONFIG_T::n_filt]; + unsigned i_w = i_rf; + unsigned i_in = i_rf; + unsigned i_out = 0; + unsigned i_acc = 0; - #pragma HLS ARRAY_PARTITION variable=data_col complete - #pragma HLS ARRAY_PARTITION variable=res_col complete + MultLoop: + for (unsigned i_blk = 0; i_blk < block_factor; i_blk++) { + #pragma HLS UNROLL - ColLoop: - for (int i = 0; i < CONFIG_T::out_width; i++) { - #pragma HLS PIPELINE - im2col_1d_cf(data, data_col, i); - dense_resource(data_col, res_col, weights, biases); - for (int j = 0; j < CONFIG_T::n_filt; j++) { - //res[i * CONFIG_T::n_filt + j] = res_col[j]; - res[j * CONFIG_T::out_width + i] = res_col[j]; // Transposed order - } - } -} + PixelMultLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL -template -void im2col_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], data_T data_col[CONFIG_T::filt_width * CONFIG_T::n_chan], const int col) { - int index = 0; - KernelLoop: - for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) { - #pragma HLS UNROLL + acc[i_pxl][i_out] += static_cast( + CONFIG_T::mult_config::template product::product(data_buf[i_pxl][i_in], weights[i_w])); + } - ChannelLoop: - for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { - int index_data = (col*CONFIG_T::stride_width+kernel_col-CONFIG_T::pad_left) * CONFIG_T::n_chan + channel; + // Increment i_w + i_w += CONFIG_T::reuse_factor; + // Increment i_in + i_in += CONFIG_T::reuse_factor; + if (i_in >= mult_n_in) { + i_in = i_rf; + } + // Increment i_out + if (i_acc + 1 >= multscale) { + i_acc = 0; + i_out++; + } else { + i_acc++; + } + } + } - if (index_data >= 0 && index_data < CONFIG_T::in_width*CONFIG_T::n_chan) { - data_col[index] = data[index_data]; - } else { - data_col[index] = 0; + PixelResultLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + // Cast to "res_t" type + ResultLoop: + for (unsigned i_res = 0; i_res < mult_n_out; i_res++) { + #pragma HLS UNROLL + *(res++) = cast(acc[i_pxl][i_res]); } - index++; } } } @@ -178,42 +116,6 @@ void im2col_1d_pointwise_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], } } -template -void conv_1d_resource_cl( - data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], - res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], - typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], - typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] -) -{ - const int nin = CONFIG_T::n_chan * CONFIG_T::filt_width; - const int nout = CONFIG_T::n_filt; - const int rufactor = CONFIG_T::reuse_factor; - const int block_factor = DIV_ROUNDUP(nin*nout, rufactor); - - //#pragma HLS function_instantiate variable=weights,biases - //#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose correctly - //#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor - //#pragma HLS ARRAY_PARTITION variable=biases complete - - data_T data_col[CONFIG_T::filt_width * CONFIG_T::n_chan]; - res_T res_col[CONFIG_T::n_filt]; - - #pragma HLS ARRAY_PARTITION variable=data_col complete - #pragma HLS ARRAY_PARTITION variable=res_col complete - - ColLoop: - for (int i = 0; i < CONFIG_T::out_width; i++) { - #pragma HLS PIPELINE - im2col_1d_cl(data, data_col, i); - dense_resource(data_col, res_col, weights, biases); - for (int j = 0; j < CONFIG_T::n_filt; j++) { - res[i * CONFIG_T::n_filt + j] = res_col[j]; - } - } -} - - template void pointwise_conv_1d_resource_cl( data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d.h index 0dc75e9e1..e27d45545 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d.h @@ -58,20 +58,6 @@ struct conv2d_config static const unsigned n_zeros = 0; // not used yet }; -template -void conv_2d_cf( - data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], - res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], - typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], - typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) -{ - if (CONFIG_T::strategy == nnet::latency) { - conv_2d_latency_cf(data, res, weights, biases); - } else { - conv_2d_resource_cf(data, res, weights, biases); - } -} - template void conv_2d_cl( data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], @@ -79,6 +65,8 @@ void conv_2d_cl( typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + #pragma HLS INLINE region + if (CONFIG_T::strategy == nnet::latency) { conv_2d_latency_cl(data, res, weights, biases); } else { @@ -95,6 +83,8 @@ void pointwise_conv_2d_cl( { assert(CONFIG_T::filt_width == 1); + #pragma HLS INLINE region + if (CONFIG_T::strategy == nnet::latency) { pointwise_conv_2d_latency_cl(data, res, weights, biases); } else { diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_latency.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_latency.h index 24132e5c6..7ad3dc821 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_latency.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_latency.h @@ -2,6 +2,7 @@ #define NNET_CONV2D_LATENCY_H_ #include "nnet_common.h" +#include "nnet_mult.h" #include namespace nnet { @@ -173,113 +174,65 @@ void conv_2d_latency_cl( typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { + constexpr unsigned mult_n_in = CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan; + constexpr unsigned mult_n_out = CONFIG_T::n_filt; - typename CONFIG_T::accum_t mult[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan * CONFIG_T::filt_height * CONFIG_T::filt_width]; - typename CONFIG_T::accum_t acc[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt]; + data_T data_buf[CONFIG_T::n_pixels][mult_n_in]; + #pragma HLS ARRAY_PARTITION variable=data_buf complete dim=0 - #pragma HLS ARRAY_PARTITION variable=mult complete dim=0 - #pragma HLS ARRAY_PARTITION variable=acc complete dim=0 + typename CONFIG_T::accum_t mult[mult_n_in * mult_n_out]; + #pragma HLS ARRAY_PARTITION variable=mult complete - // Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases - #pragma HLS function_instantiate variable=weights,biases + typename CONFIG_T::accum_t acc[mult_n_out]; + #pragma HLS ARRAY_PARTITION variable=acc complete - // Parallel mode - #pragma HLS PIPELINE - #pragma HLS ARRAY_PARTITION variable=biases complete dim=0 + #pragma HLS ARRAY_PARTITION variable=weights complete + #pragma HLS ARRAY_PARTITION variable=biases complete - // Limit multipliers to control parallelization - const int multiplier_limit = compute_multiplier_limit_conv2d(weights); - #pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation - - // Convolve, saving all multiplication results to accumulate later - ConvOutHeight: for(int oh = 0; oh < CONFIG_T::out_height; oh++) { - ConvOutWidth: for(int ow = 0; ow < CONFIG_T::out_width; ow++) { - ConvFilt: for(int ff = 0; ff < CONFIG_T::n_filt; ff++){ - ConvChan: for(int cc = 0; cc < CONFIG_T::n_chan; cc++){ - ConvFiltHeight: for(int fh = 0; fh < CONFIG_T::filt_height; fh++){ - ConvFiltWidth: for(int fw = 0; fw < CONFIG_T::filt_width; fw++){ - - int index_mult = oh*CONFIG_T::out_width*CONFIG_T::n_filt*CONFIG_T::n_chan*CONFIG_T::filt_height*CONFIG_T::filt_width - + ow*CONFIG_T::n_filt*CONFIG_T::n_chan*CONFIG_T::filt_height*CONFIG_T::filt_width - + ff*CONFIG_T::n_chan*CONFIG_T::filt_height*CONFIG_T::filt_width - + cc*CONFIG_T::filt_height*CONFIG_T::filt_width - + fh*CONFIG_T::filt_width - + fw; + int multiplier_limit = CONFIG_T::n_pixels * (float(mult_n_in * mult_n_out) - float(CONFIG_T::mult_config::n_zeros)); + CONFIG_T::mult_config::template product::limit(multiplier_limit); - int index_weight = fh*CONFIG_T::filt_width*CONFIG_T::n_chan*CONFIG_T::n_filt - + fw*CONFIG_T::n_chan*CONFIG_T::n_filt - + cc*CONFIG_T::n_filt - + ff; + PartitionLoop: + for (int i_part = 0; i_part < CONFIG_T::n_partitions; i_part++) { + #pragma HLS PIPELINE II=CONFIG_T::reuse_factor - if ((oh*CONFIG_T::stride_height+fh) < CONFIG_T::pad_top - || (oh*CONFIG_T::stride_height+fh) >= (CONFIG_T::pad_top+CONFIG_T::in_height) - || (ow*CONFIG_T::stride_width+fw) < CONFIG_T::pad_left - || (ow*CONFIG_T::stride_width+fw) >= (CONFIG_T::pad_left+CONFIG_T::in_width)) { - mult[index_mult] = 0; - } else { - int index_data = (oh*CONFIG_T::stride_height+fh-CONFIG_T::pad_top)*CONFIG_T::in_width*CONFIG_T::n_chan - + (ow*CONFIG_T::stride_width+fw-CONFIG_T::pad_left)*CONFIG_T::n_chan - + cc; - mult[index_mult] = data[index_data] * weights[index_weight]; - } + CONFIG_T::template fill_buffer::fill_buffer(data, data_buf, i_part); - }//end mult loop - }//end channel loop - }//end filter width loop - }//end filter height loop - }//end output width loop - }//end output height loop + PixelLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + data_T cache; - // Initialize accumulator with input biases - for(int oh = 0; oh < CONFIG_T::out_height; oh++) { - for(int ow = 0; ow < CONFIG_T::out_width; ow++) { - for(int ff = 0; ff < CONFIG_T::n_filt; ff++) { - acc[oh*CONFIG_T::out_width*CONFIG_T::n_filt + ow*CONFIG_T::n_filt + ff]=biases[ff]; + // Do the matrix-multiply + Product1: for(int i_in = 0; i_in < mult_n_in; i_in++) { + cache = data_buf[i_pxl][i_in]; + Product2: for(int i_out = 0; i_out < mult_n_out; i_out++) { + mult[i_in * mult_n_out + i_out] = CONFIG_T::mult_config::template product::product(cache, weights[i_out * mult_n_in + i_in/*i_in * mult_n_out + i_out*/]); + } } - } - } - - - // Accumulate multiplication result - AccumOutHeight: for(int oh = 0; oh < CONFIG_T::out_height; oh++) { - AccumOutWidth: for(int ow = 0; ow < CONFIG_T::out_width; ow++) { - AccumFilt: for(int ff = 0; ff < CONFIG_T::n_filt; ff++) { - //Do "dot product" sum within filter and sum over channels - AccumChan: for(int cc = 0; cc < CONFIG_T::n_chan; cc++){ - AccumDotHeight: for(int fh = 0; fh < CONFIG_T::filt_height; fh++){ - AccumDotWidth: for(int fw = 0; fw < CONFIG_T::filt_width; fw++){ - - int index_mult = oh*CONFIG_T::out_width*CONFIG_T::n_filt*CONFIG_T::n_chan*CONFIG_T::filt_height*CONFIG_T::filt_width - + ow*CONFIG_T::n_filt*CONFIG_T::n_chan*CONFIG_T::filt_height*CONFIG_T::filt_width - + ff*CONFIG_T::n_chan*CONFIG_T::filt_height*CONFIG_T::filt_width - + cc*CONFIG_T::filt_height*CONFIG_T::filt_width - + fh*CONFIG_T::filt_width - + fw; - int index_acc = oh*CONFIG_T::out_width*CONFIG_T::n_filt - + ow*CONFIG_T::n_filt - + ff; - acc[index_acc] += mult[index_mult]; + // Initialize accumulator with input biases + ResetAccum: for(int i_acc = 0; i_acc < mult_n_out; i_acc++) { + acc[i_acc] = (typename CONFIG_T::accum_t) biases[i_acc]; + } - }//end dot product filter width loop - }//end dot product filter height loop - }//end n channel loop - }//end n filter loop - }//end output width loop - }//end output height loop + // Accumulate multiplication result + Accum1: for(int i_in = 0; i_in < mult_n_in; i_in++) { + Accum2: for(int i_out = 0; i_out < mult_n_out; i_out++) { + acc[i_out] += mult[i_in * mult_n_out + i_out]; + } + } - // Cast to "res_t" type - for(int oh = 0; oh < CONFIG_T::out_height; oh++) { - for(int ow = 0; ow < CONFIG_T::out_width; ow++) { - for(int ff = 0; ff < CONFIG_T::n_filt; ff++) { - int index = oh*CONFIG_T::out_width*CONFIG_T::n_filt + ow*CONFIG_T::n_filt + ff; - res[index] = (res_T)(acc[index]); + // Cast to "res_t" type + Result: for(int i_res = 0; i_res < mult_n_out; i_res++){ + *(res++) = cast(acc[i_res]); } + } } -}//end conv2d +} template void pointwise_conv_2d_latency_cl( diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_resource.h index 6f6cc0d62..fa7a30b6b 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_resource.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_resource.h @@ -6,171 +6,97 @@ namespace nnet { -template -void im2col_2d( - data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], - data_T data_col[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::out_height * CONFIG_T::out_width]) -{ - const int output_h = (CONFIG_T::in_height + CONFIG_T::pad_top + CONFIG_T::pad_bottom - - (CONFIG_T::dilation_height * (CONFIG_T::filt_height - 1) + 1)) / CONFIG_T::stride_height + 1; - const int output_w = (CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right - - (CONFIG_T::dilation_width * (CONFIG_T::filt_width - 1) + 1)) / CONFIG_T::stride_width + 1; - const int channel_size = CONFIG_T::in_height * CONFIG_T::in_width; - - for (int channel = CONFIG_T::n_chan; channel--; data += channel_size) { - for (int kernel_row = 0; kernel_row < CONFIG_T::filt_height; kernel_row++) { - for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) { - int input_row = -CONFIG_T::pad_top + kernel_row * CONFIG_T::dilation_height; - for (int output_rows = output_h; output_rows; output_rows--) { - if (input_row < 0 || input_row > CONFIG_T::in_height) { - for (int output_cols = output_w; output_cols; output_cols--) { - *(data_col++) = 0; - } - } else { - int input_col = -CONFIG_T::pad_left + kernel_col * CONFIG_T::dilation_width; - for (int output_col = output_w; output_col; output_col--) { - if (input_col >= 0 && input_col < CONFIG_T::in_width) { - *(data_col++) = data[input_row * CONFIG_T::in_width + input_col]; - } else { - *(data_col++) = 0; - } - input_col += CONFIG_T::stride_width; - } - } - input_row += CONFIG_T::stride_height; - } - } - } - } -} - - template -void conv_2d_full( +void conv_2d_resource_cl( data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { - data_T data_conv[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::out_height * CONFIG_T::out_width]; - data_T data_col[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan]; - res_T res_col[CONFIG_T::n_filt]; + constexpr unsigned mult_n_in = CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan; + constexpr unsigned mult_n_out = CONFIG_T::n_filt; + constexpr unsigned block_factor = DIV_ROUNDUP(mult_n_in * mult_n_out, CONFIG_T::reuse_factor); - //#pragma HLS ARRAY_PARTITION variable=data_conv complete - #pragma HLS ARRAY_PARTITION variable=data_col complete - #pragma HLS ARRAY_PARTITION variable=res_col complete + constexpr unsigned multiplier_limit = DIV_ROUNDUP(mult_n_in * mult_n_out, CONFIG_T::reuse_factor); + constexpr unsigned multscale = multiplier_limit / mult_n_out; - im2col_2d(data, data_conv); + assert((multiplier_limit % mult_n_out == 0 || CONFIG_T::reuse_factor >= mult_n_in) && "The current Reuse Factor is not allowed"); + assert((multiplier_limit == block_factor) && "This function is correct only for RF <= FILT_HEIGHT * FILT_WIDTH * N_CHAN"); - for (int i = 0; i < CONFIG_T::out_height * CONFIG_T::out_width; i++) { - for (int j = 0; j < CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan; j++) { - data_col[j] = data[j * CONFIG_T::out_height * CONFIG_T::out_width + i]; - } - dense(data_col, res_col, weights, biases); - for (int j = 0; j < CONFIG_T::n_filt; j++) { - //res[i * CONFIG_T::n_filt + j] = res_col[j]; - res[j * CONFIG_T::out_height * CONFIG_T::out_width + i] = res_col[j]; // Transposed order - } - } -} + data_T data_buf[CONFIG_T::n_pixels][mult_n_in]; + #pragma HLS ARRAY_PARTITION variable=data_buf complete dim=0 -template -void im2col_2d_cf( - data_T data[CONFIG_T::n_chan * CONFIG_T::in_height * CONFIG_T::in_width], - data_T data_col[CONFIG_T::n_chan * CONFIG_T::filt_height * CONFIG_T::filt_width], - const int row, - const int col) -{ - const int channel_size = CONFIG_T::in_height * CONFIG_T::in_width; - int index = 0; - for (int channel = CONFIG_T::n_chan; channel--; data += channel_size) { - #pragma HLS UNROLL - for (int kernel_row = 0; kernel_row < CONFIG_T::filt_height; kernel_row++) { - int input_row = -CONFIG_T::pad_top + kernel_row * CONFIG_T::dilation_height + row * CONFIG_T::stride_height; - for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) { - if (input_row < 0 || input_row > CONFIG_T::in_height) { - data_col[index++] = 0; - } else { - int input_col = -CONFIG_T::pad_left + kernel_col * CONFIG_T::dilation_width + col * CONFIG_T::stride_width; - if (input_col >= 0 && input_col < CONFIG_T::in_width) { - //*(data_col++) = data[input_row * CONFIG_T::in_width + input_col]; - data_col[index++] = data[input_row * CONFIG_T::in_width + input_col]; - } else { - //*(data_col++) = 0; - data_col[index++] = 0; - } - input_col += CONFIG_T::stride_width; - } - } - input_row += CONFIG_T::stride_height; - } - } -} + #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor + #pragma HLS ARRAY_PARTITION variable=biases complete -template -void conv_2d_resource_cf( - data_T data[CONFIG_T::n_chan * CONFIG_T::in_height * CONFIG_T::in_width], - res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], - typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], - typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) -{ - const int nin = CONFIG_T::n_chan * CONFIG_T::filt_width; - const int nout = CONFIG_T::n_filt; - const int rufactor = CONFIG_T::reuse_factor; - const int block_factor = DIV_ROUNDUP(nin*nout, rufactor); + typename CONFIG_T::accum_t acc[CONFIG_T::n_pixels][mult_n_out]; + #pragma HLS ARRAY_PARTITION variable=acc complete dim=0 - //#pragma HLS function_instantiate variable=weights,biases - //#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose correctly - //#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor - //#pragma HLS ARRAY_PARTITION variable=biases complete + PartitionLoop: + for (unsigned i_part = 0; i_part < CONFIG_T::n_partitions; i_part++) { + //#pragma HLS UNROLL // We don't want this loop unrolled - data_T data_col[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan]; - res_T res_col[CONFIG_T::n_filt]; + CONFIG_T::template fill_buffer::fill_buffer(data, data_buf, i_part); - #pragma HLS ARRAY_PARTITION variable=data_col complete - #pragma HLS ARRAY_PARTITION variable=res_col complete - - HeightLoop: - for (int i = 0; i < CONFIG_T::out_height; i++) { - WidthLoop: - for (int j = 0; j < CONFIG_T::out_width; j++) { - #pragma HLS PIPELINE - im2col_2d_cf(data, data_col, i, j); - dense(data_col, res_col, weights, biases); - FiltLoop: - for (int k = 0; k < CONFIG_T::n_filt; k++) { - //res[i * CONFIG_T::out_width * CONFIG_T::n_filt + j * CONFIG_T::n_filt + k] = res_col[k]; - res[k * CONFIG_T::out_height * CONFIG_T::out_width + i * CONFIG_T::out_width + j] = res_col[k]; // Transposed order + PixelInitAccumLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + InitAccumLoop: + for (unsigned i_acc = 0; i_acc < mult_n_out; i_acc++) { + #pragma HLS UNROLL + acc[i_pxl][i_acc] = (typename CONFIG_T::accum_t) biases[i_acc]; } } - } -} -template -void im2col_2d_cl( - data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], - data_T data_col[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan], - const int row, - const int col) -{ - int index = 0; - for (int kernel_row = 0; kernel_row < CONFIG_T::filt_height; kernel_row++) { - #pragma HLS UNROLL - int input_row = -CONFIG_T::pad_top + kernel_row * CONFIG_T::dilation_height + row * CONFIG_T::stride_height; - for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) { - for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { - if (input_row < 0 || input_row >= CONFIG_T::in_height) { - data_col[index++] = 0; + ReuseLoop: + for (unsigned i_rf = 0; i_rf < CONFIG_T::reuse_factor; i_rf++) { + #pragma HLS PIPELINE II=1 rewind + + unsigned i_w = i_rf; + unsigned i_in = i_rf; + unsigned i_out = 0; + unsigned i_acc = 0; + + MultLoop: + for (unsigned i_blk = 0; i_blk < block_factor; i_blk++) { + #pragma HLS UNROLL + + PixelMultLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + acc[i_pxl][i_out] += static_cast( + CONFIG_T::mult_config::template product::product(data_buf[i_pxl][i_in], weights[i_w])); + } + + // Increment i_w + i_w += CONFIG_T::reuse_factor; + // Increment i_in + i_in += CONFIG_T::reuse_factor; + if (i_in >= mult_n_in) { + i_in = i_rf; + } + // Increment i_out + if (i_acc + 1 >= multscale) { + i_acc = 0; + i_out++; } else { - int input_col = -CONFIG_T::pad_left + kernel_col * CONFIG_T::dilation_width + col * CONFIG_T::stride_width; - if (input_col >= 0 && input_col < CONFIG_T::in_width) { - data_col[index++] = data[input_row * CONFIG_T::in_width * CONFIG_T::n_chan + input_col * CONFIG_T::n_chan + channel]; - } else { - data_col[index++] = 0; - } + i_acc++; } } } + + PixelResultLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + // Cast to "res_t" type + ResultLoop: + for (unsigned i_res = 0; i_res < mult_n_out; i_res++) { + #pragma HLS UNROLL + *(res++) = cast(acc[i_pxl][i_res]); + } + } } } @@ -200,45 +126,6 @@ void im2col_2d_pointwise_cl( } } -template -void conv_2d_resource_cl( - data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], - res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], - typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], - typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) -{ - const int nin = CONFIG_T::n_chan * CONFIG_T::filt_width; - const int nout = CONFIG_T::n_filt; - const int rufactor = CONFIG_T::reuse_factor; - const int block_factor = DIV_ROUNDUP(nin*nout, rufactor); - - //#pragma HLS function_instantiate variable=weights,biases - //#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose correctly - //#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor - //#pragma HLS ARRAY_PARTITION variable=biases complete - - data_T data_col[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan]; - res_T res_col[CONFIG_T::n_filt]; - - #pragma HLS ARRAY_PARTITION variable=data_col complete - #pragma HLS ARRAY_PARTITION variable=res_col complete - - HeightLoop: - for (int i = 0; i < CONFIG_T::out_height; i++) { - WidthLoop: - for (int j = 0; j < CONFIG_T::out_width; j++) { - #pragma HLS PIPELINE - im2col_2d_cl(data, data_col, i, j); - dense(data_col, res_col, weights, biases); - FiltLoop: - for (int k = 0; k < CONFIG_T::n_filt; k++) { - res[i * CONFIG_T::out_width * CONFIG_T::n_filt + j * CONFIG_T::n_filt + k] = res_col[k]; - } - } - } -} - - template void pointwise_conv_2d_resource_cl( data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], diff --git a/hls4ml/writer/vivado_writer.py b/hls4ml/writer/vivado_writer.py index 202f0c5b0..24418a993 100644 --- a/hls4ml/writer/vivado_writer.py +++ b/hls4ml/writer/vivado_writer.py @@ -596,6 +596,29 @@ def write_nnet_utils(self, model): dstpath = '{}/firmware/{}'.format(model.config.get_output_dir(), dst) copyfile(srcpath, dstpath) + def write_generated_code(self, model): + ################### + ## nnet_code_gen.h + ################### + + path = '{}/firmware/nnet_utils/nnet_code_gen.h'.format(model.config.get_output_dir()) + f = open(path,'r') + contents = f.readlines() + f.close() + f = open(path,'w') + + for line in contents: + if '//hls4ml insert code' in line: + newline = line + for layer in model.get_layers(): + for generated_code in layer.code.values(): + newline += str(generated_code) + else: + newline = line + f.write(newline) + f.close() + + def write_yml(self, model): ################### # YAML config file @@ -635,6 +658,7 @@ def write_hls(self, model): self.write_bridge(model) self.write_build_script(model) self.write_nnet_utils(model) + self.write_generated_code(model) self.write_yml(model) self.write_tar(model) print('Done') From 1894e00d9fea383880515126ad58969cc80b6e62 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Fri, 8 Jul 2022 19:19:25 +0200 Subject: [PATCH 03/10] Use the same implementaion of im2col in python and C++ --- hls4ml/backends/fpga/fpga_backend.py | 90 ++++++++++++++++++---------- 1 file changed, 58 insertions(+), 32 deletions(-) diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index cca1a2af3..cf7e3c6a0 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -405,25 +405,31 @@ def compute_conv2d_instructions(self, in_H, in_W, in_C, kernel_size=3, stride=1, return (min_H, min_W, windows_int) - def _compute_conv1d_im2col(self, input_shape, kernel=3, stride=1, pad=(0,0)): + def _compute_conv1d_im2col(self, input_shape, kernel=3, stride=1, pad=(0,0), dilation=1): W, C = input_shape pad_l, pad_r = pad - out_w = (W + pad_l + pad_r - kernel) // stride + 1 + out_w = (W + pad_l + pad_r - (dilation * (kernel - 1) + 1)) // stride + 1 - input_img = np.arange(1, W * C + 1).reshape(W, C) + input_img = np.arange(1, W * C + 1) + im_matrix = np.zeros((kernel * C * out_w, )) - img = np.pad(input_img, [(pad_l, pad_r), (0,0)], 'constant') - col = np.zeros((out_w, kernel, C)) - - for x in range(kernel): - x_max = x + stride * out_w - col[:, x, :] = img[x:x_max:stride, :] + index = 0 + for i_ow in range(out_w): + for i_kw in range(kernel): + for i_c in range(C): + input_col = -pad_l + i_kw * dilation + i_ow * stride + if (input_col >= 0 and input_col < W): + im_matrix[index] = input_img[input_col * C + i_c] + else: + im_matrix[index] = 0 + index += 1 + + im_matrix = im_matrix.reshape(out_w, -1) + return im_matrix - col = col.reshape(out_w, -1) - return col - def generate_conv1d_line_buffer_fn(self, layer_idx, n_partitions, in_W, in_C, kernel=3, stride=1, pad=0): + def generate_conv1d_line_buffer_fn(self, layer_idx, n_partitions, in_W, in_C, kernel=3, stride=1, pad=0, dilation=1): if isinstance(pad, Iterable): pad_left = pad[0] pad_right = pad[1] @@ -435,7 +441,8 @@ def generate_conv1d_line_buffer_fn(self, layer_idx, n_partitions, in_W, in_C, ke (in_W, in_C), kernel, stride, - (pad_left, pad_right) + (pad_left, pad_right), + dilation ) generated_code = ( @@ -468,30 +475,41 @@ def generate_conv1d_line_buffer_fn(self, layer_idx, n_partitions, in_W, in_C, ke return generated_code - def _compute_conv2d_im2col(self, input_shape, kernel=(3,3), stride=(1,1), pad=(0,0,0,0)): + def _compute_conv2d_im2col(self, input_shape, kernel=(3, 3), stride=(1, 1), pad=(0, 0, 0, 0), dilation=(1,1)): H, W, C = input_shape kernel_h, kernel_w = kernel stride_h, stride_w = stride pad_t, pad_b, pad_l, pad_r = pad + dilation_h, dilation_w = dilation + + out_h = (H + pad_t + pad_b - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 + out_w = (W + pad_l + pad_r - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 + + input_img = np.arange(1, H * W * C + 1) + im_matrix = np.zeros((kernel_h * kernel_w * C * out_h * out_w, )) + + index = 0 + for i_oh in range(out_h): + for i_ow in range(out_w): + for i_kh in range(kernel_h): + input_row = -pad_t + i_kh * dilation_h + i_oh * stride_h + for i_kw in range(kernel_w): + for i_c in range(C): + if (input_row < 0 or input_row >= H): + im_matrix[index] = 0 + else: + input_col = -pad_l + i_kw * dilation_w + i_ow * stride_w + if (input_col >= 0 and input_col < W): + im_matrix[index] = input_img[input_row * W * C + input_col * C + i_c] + else: + im_matrix[index] = 0 + index += 1 + + im_matrix = im_matrix.reshape(out_h * out_w, -1) + return im_matrix - out_h = (H + pad_t + pad_b - kernel_h) // stride_h + 1 - out_w = (W + pad_l + pad_r - kernel_w) // stride_w + 1 - - input_img = np.arange(1, C * H * W + 1).reshape(C, H, W) - - img = np.pad(input_img, [(0,0), (pad_t, pad_b), (pad_l, pad_r)], 'constant') - col = np.zeros((C, kernel_h, kernel_w, out_h, out_w)) - - for y in range(kernel_h): - y_max = y + stride_h * out_h - for x in range(kernel_w): - x_max = x + stride_w * out_w - col[:, y, x, :, :] = img[:, y:y_max:stride_h, x:x_max:stride_w] - - col = col.transpose(3, 4, 0, 1, 2).reshape(out_h * out_w, -1) - return col - def generate_conv2d_line_buffer_fn(self, layer_idx, n_partitions, in_H, in_W, in_C, kernel=(3, 3), stride=(1, 1), pad=(0, 0, 0, 0)): + def generate_conv2d_line_buffer_fn(self, layer_idx, n_partitions, in_H, in_W, in_C, kernel=(3, 3), stride=(1, 1), pad=(0, 0, 0, 0), dilation=(1, 1)): if isinstance(kernel, Iterable): kernel_height = kernel[0] kernel_width = kernel[1] @@ -517,11 +535,19 @@ def generate_conv2d_line_buffer_fn(self, layer_idx, n_partitions, in_H, in_W, in pad_left = pad pad_right = pad + if isinstance(dilation, Iterable): + dilation_height = dilation[0] + dilation_width = dilation[1] + else: + dilation_height = dilation + dilation_width = dilation + im2col_matrix = self._compute_conv2d_im2col( (in_H, in_W, in_C), (kernel_height, kernel_width), (stride_height, stride_width), - (pad_top, pad_bottom, pad_left, pad_right) + (pad_top, pad_bottom, pad_left, pad_right), + (dilation_height, dilation_width) ) generated_code = ( From bb9d9af921ade7bdd32ad38c6962037f236b4614 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Wed, 13 Jul 2022 19:08:54 +0200 Subject: [PATCH 04/10] Ensure 'Resource' strategy is used for Conv1D/2D --- hls4ml/backends/vivado/vivado_backend.py | 9 +++++++++ hls4ml/templates/vivado/nnet_utils/nnet_conv2d_latency.h | 8 ++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 1d2312554..eb0b89f27 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -120,6 +120,11 @@ def build(self, model, reset=False, csim=True, synth=True, cosim=False, validati return parse_vivado_report(model.config.get_output_dir()) + def _validate_conv_strategy(self, layer): + if layer.model.config.model_strategy.lower() != 'resource': + print('WARNING: Cannot use "Latency" model strategy for {} layer. Switching to "Resource" strategy.') + layer.model.config.model_strategy = 'Resource' + @layer_optimizer(Layer) def init_base_layer(self, layer): reuse_factor = layer.model.config.get_reuse_factor(layer) @@ -172,6 +177,8 @@ def init_conv1d(self, layer): layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + self._validate_conv_strategy(layer) + @layer_optimizer(SeparableConv1D) def init_sepconv1d(self, layer): if layer.model.config.is_resource_strategy(layer): @@ -210,6 +217,8 @@ def init_conv2d(self, layer): layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + self._validate_conv_strategy(layer) + @layer_optimizer(SeparableConv2D) def init_sepconv2d(self, layer): if layer.model.config.is_resource_strategy(layer): diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_latency.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_latency.h index 7ad3dc821..bf177928a 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_latency.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_latency.h @@ -169,8 +169,8 @@ void conv_2d_latency_cf( template void conv_2d_latency_cl( - data_T data[CONFIG_T::in_height*CONFIG_T::in_width*CONFIG_T::n_chan], - res_T res[CONFIG_T::out_height*CONFIG_T::out_width*CONFIG_T::n_filt], + data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { @@ -189,7 +189,7 @@ void conv_2d_latency_cl( #pragma HLS ARRAY_PARTITION variable=weights complete #pragma HLS ARRAY_PARTITION variable=biases complete - int multiplier_limit = CONFIG_T::n_pixels * (float(mult_n_in * mult_n_out) - float(CONFIG_T::mult_config::n_zeros)); + int multiplier_limit = CONFIG_T::n_pixels * (ceil(float(mult_n_in * mult_n_out) / float(CONFIG_T::reuse_factor)) - floor(float(CONFIG_T::mult_config::n_zeros) / float(CONFIG_T::reuse_factor))); CONFIG_T::mult_config::template product::limit(multiplier_limit); PartitionLoop: @@ -208,7 +208,7 @@ void conv_2d_latency_cl( Product1: for(int i_in = 0; i_in < mult_n_in; i_in++) { cache = data_buf[i_pxl][i_in]; Product2: for(int i_out = 0; i_out < mult_n_out; i_out++) { - mult[i_in * mult_n_out + i_out] = CONFIG_T::mult_config::template product::product(cache, weights[i_out * mult_n_in + i_in/*i_in * mult_n_out + i_out*/]); + mult[i_in * mult_n_out + i_out] = CONFIG_T::mult_config::template product::product(cache, weights[i_in * mult_n_out + i_out]); } } From 41d9cd68608cd23229bd2a7dfb943c9a950d9382 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Wed, 13 Jul 2022 19:18:40 +0200 Subject: [PATCH 05/10] Add 'Latency' implementation of Conv1D --- .../vivado/nnet_utils/nnet_conv1d_latency.h | 147 ++++++++---------- 1 file changed, 68 insertions(+), 79 deletions(-) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_latency.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_latency.h index f79903ee2..c48565bf5 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_latency.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_latency.h @@ -2,10 +2,78 @@ #define NNET_CONV1D_LATENCY_H_ #include "nnet_common.h" +#include "nnet_mult.h" #include namespace nnet { +template +void conv_1d_latency_cl( + data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], + res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], + typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) +{ + constexpr unsigned mult_n_in = CONFIG_T::filt_width * CONFIG_T::n_chan; + constexpr unsigned mult_n_out = CONFIG_T::n_filt; + + data_T data_buf[CONFIG_T::n_pixels][mult_n_in]; + #pragma HLS ARRAY_PARTITION variable=data_buf complete dim=0 + + typename CONFIG_T::accum_t mult[mult_n_in * mult_n_out]; + #pragma HLS ARRAY_PARTITION variable=mult complete + + typename CONFIG_T::accum_t acc[mult_n_out]; + #pragma HLS ARRAY_PARTITION variable=acc complete + + #pragma HLS ARRAY_PARTITION variable=weights complete + #pragma HLS ARRAY_PARTITION variable=biases complete + + int multiplier_limit = CONFIG_T::n_pixels * (ceil(float(mult_n_in * mult_n_out) / float(CONFIG_T::reuse_factor)) - floor(float(CONFIG_T::mult_config::n_zeros) / float(CONFIG_T::reuse_factor))); + CONFIG_T::mult_config::template product::limit(multiplier_limit); + + PartitionLoop: + for (int i_part = 0; i_part < CONFIG_T::n_partitions; i_part++) { + #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + + CONFIG_T::template fill_buffer::fill_buffer(data, data_buf, i_part); + + PixelLoop: + for (unsigned i_pxl = 0; i_pxl < CONFIG_T::n_pixels; i_pxl++) { + #pragma HLS UNROLL + + data_T cache; + + // Do the matrix-multiply + Product1: for(int i_in = 0; i_in < mult_n_in; i_in++) { + cache = data_buf[i_pxl][i_in]; + Product2: for(int i_out = 0; i_out < mult_n_out; i_out++) { + mult[i_in * mult_n_out + i_out] = CONFIG_T::mult_config::template product::product(cache, weights[i_in * mult_n_out + i_out]); + } + } + + // Initialize accumulator with input biases + ResetAccum: for(int i_acc = 0; i_acc < mult_n_out; i_acc++) { + acc[i_acc] = (typename CONFIG_T::accum_t) biases[i_acc]; + } + + // Accumulate multiplication result + Accum1: for(int i_in = 0; i_in < mult_n_in; i_in++) { + Accum2: for(int i_out = 0; i_out < mult_n_out; i_out++) { + acc[i_out] += mult[i_in * mult_n_out + i_out]; + } + } + + // Cast to "res_t" type + Result: for(int i_res = 0; i_res < mult_n_out; i_res++){ + *(res++) = cast(acc[i_res]); + } + + } + } + +} + //Computes multiplier limit //This function should not be synthesized into firmware template @@ -39,85 +107,6 @@ int compute_multiplier_limit( }//end compute_n_mult - -template -void conv_1d_latency_cl( - data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], - res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], - typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt], - typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] -) -{ - - typename CONFIG_T::accum_t mult[CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan * CONFIG_T::filt_width]; - typename CONFIG_T::accum_t acc[CONFIG_T::out_width][CONFIG_T::n_filt]; - - #pragma HLS ARRAY_PARTITION variable=mult complete dim=0 - #pragma HLS ARRAY_PARTITION variable=acc complete dim=0 - - // Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases - #pragma HLS function_instantiate variable=weights,biases - - // Parallel mode - #pragma HLS PIPELINE - #pragma HLS ARRAY_PARTITION variable=biases complete dim=0 - - // Limit multipliers to control parallelization - const int multiplier_limit = compute_multiplier_limit(weights); - #pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation - - // Convolve, saving all multiplication results to accumulate later - ConvOut: for(int ii = 0; ii < CONFIG_T::out_width; ii++) { - ConvFilt: for(int ff = 0; ff < CONFIG_T::n_filt; ff++){ - ConvChan: for(int cc = 0; cc < CONFIG_T::n_chan; cc++){ - ConvMult: for(int jj = 0; jj < CONFIG_T::filt_width; jj++){ - - int index_mult = ii*CONFIG_T::n_filt*CONFIG_T::n_chan*CONFIG_T::filt_width + ff*CONFIG_T::n_chan*CONFIG_T::filt_width + cc*CONFIG_T::filt_width + jj; - int index_weight = jj*CONFIG_T::n_chan*CONFIG_T::n_filt + cc*CONFIG_T::n_filt + ff; - int index_data = (ii*CONFIG_T::stride_width+jj-CONFIG_T::pad_left) * CONFIG_T::n_chan + cc; - - if((ii*CONFIG_T::stride_width+jj) < CONFIG_T::pad_left || (ii*CONFIG_T::stride_width+jj) >= (CONFIG_T::pad_left + CONFIG_T::in_width)){ - mult[index_mult] = 0; - } - else { - mult[index_mult] = data[index_data] * weights[index_weight]; - } - } - }//end channel loop - }//end filter loop - }//end output loop - - - // Initialize accumulator with input biases - for(int ii = 0; ii < CONFIG_T::out_width; ii++) { - for(int ff = 0; ff < CONFIG_T::n_filt; ff++) { - acc[ii][ff]=biases[ff]; - } - } - - - // Accumulate multiplication result - AccumOut: for(int ii = 0; ii < CONFIG_T::out_width; ii++) { - AccumFilt: for(int ff = 0; ff < CONFIG_T::n_filt; ff++) { - //Do "dot product" sum within filter and sum over channels - AccumChan: for(int cc = 0; cc < CONFIG_T::n_chan; cc++){ - AccumDot: for(int jj = 0; jj < CONFIG_T::filt_width; jj++){ - int index_mult = ii*CONFIG_T::n_filt*CONFIG_T::n_chan*CONFIG_T::filt_width + ff*CONFIG_T::n_chan*CONFIG_T::filt_width + cc*CONFIG_T::filt_width + jj; - acc[ii][ff] += mult[index_mult]; - }//end dot product loop - }//end channel loop - }//end filter loop - }//end output loop - - - // Cast to "res_t" type - for(int ii = 0; ii < CONFIG_T::out_width; ii++) { - for(int ff = 0; ff < CONFIG_T::n_filt; ff++) { - res[ii * CONFIG_T::n_filt + ff] = (res_T)(acc[ii][ff]); - } - } -} - template void pointwise_conv_1d_latency_cl( data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], From 8e539fc85cacdbbeec53ac9da6e0f2fd6d0fc300 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Fri, 5 Aug 2022 18:27:50 +0200 Subject: [PATCH 06/10] Explicitly partition the pool array --- hls4ml/templates/vivado/nnet_utils/nnet_pooling.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_pooling.h b/hls4ml/templates/vivado/nnet_utils/nnet_pooling.h index 5267a58fc..a2887c5df 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_pooling.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_pooling.h @@ -120,6 +120,7 @@ void pooling1d_cl( // Loop over input image x in steps of stride for(int ii = 0; ii < padded_width; ii += CONFIG_T::stride_width) { data_T pool[CONFIG_T::pool_width]; + #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 // Keep track of number of pixels in image vs padding region unsigned img_overlap = 0; // Loop over pool window x @@ -162,6 +163,7 @@ void global_pooling1d_cl( for(int ff = 0; ff < CONFIG_T::n_filt; ff++) { data_T pool[CONFIG_T::n_in]; + #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 for(int jj = 0; jj < CONFIG_T::n_in; jj++) { pool[jj] = data[jj * CONFIG_T::n_filt + ff]; } @@ -224,6 +226,7 @@ void pooling2d_cl( // Loop over input image x in steps of stride for(int jj = 0; jj < padded_width; jj += CONFIG_T::stride_width){ data_T pool[CONFIG_T::pool_height * CONFIG_T::pool_width]; + #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 // Keep track of number of pixels in image vs padding region unsigned img_overlap = 0; // Loop over pool window y @@ -278,6 +281,7 @@ void pooling2d_cf( // Loop over input image x in steps of stride for(int jj = 0; jj < padded_width; jj += CONFIG_T::stride_width){ data_T pool[CONFIG_T::pool_height * CONFIG_T::pool_width]; + #pragma HLS ARRAY_PARTITION variable=pool complete dim=0 // Keep track of number of pixels in image vs padding region unsigned img_overlap = 0; // Loop over pool window y From 96d059c90db37152790336d606e774af7c357f08 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Tue, 4 Oct 2022 03:00:35 +0200 Subject: [PATCH 07/10] Remove old pointwise implementations --- .../templates/vivado/nnet_utils/nnet_conv1d.h | 5 +- .../vivado/nnet_utils/nnet_conv1d_latency.h | 106 ------------------ .../vivado/nnet_utils/nnet_conv1d_resource.h | 55 --------- .../templates/vivado/nnet_utils/nnet_conv2d.h | 5 +- .../vivado/nnet_utils/nnet_conv2d_latency.h | 100 ----------------- .../vivado/nnet_utils/nnet_conv2d_resource.h | 67 ----------- 6 files changed, 6 insertions(+), 332 deletions(-) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h index bde62161e..323cb5b81 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h @@ -78,10 +78,11 @@ void pointwise_conv_1d_cl( #pragma HLS INLINE region + // Nothing special to be done for io_parallel implementation if (CONFIG_T::strategy == nnet::latency) { - pointwise_conv_1d_latency_cl(data, res, weights, biases); + conv_1d_latency_cl(data, res, weights, biases); } else { - pointwise_conv_1d_resource_cl(data, res, weights, biases); + conv_1d_resource_cl(data, res, weights, biases); } } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_latency.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_latency.h index c48565bf5..65f71d080 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_latency.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_latency.h @@ -74,111 +74,5 @@ void conv_1d_latency_cl( } -//Computes multiplier limit -//This function should not be synthesized into firmware -template -int compute_multiplier_limit( - typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt] -) -{ - int n_mult = 0; - for(int ii = 0; ii < CONFIG_T::out_width; ii++) { - for(int ff = 0; ff < CONFIG_T::n_filt; ff++){ - for(int cc = 0; cc < CONFIG_T::n_chan; cc++){ - for(int jj = 0; jj < CONFIG_T::filt_width; jj++){ - - int index_weight = jj*CONFIG_T::n_chan*CONFIG_T::n_filt + cc*CONFIG_T::n_filt + ff; - - if((ii*CONFIG_T::stride_width+jj) < CONFIG_T::pad_left || (ii*CONFIG_T::stride_width+jj) >= (CONFIG_T::pad_left + CONFIG_T::in_width)){ - //padded -- do nothing - continue; - } else { - //need to tune this cut? - if( weights[index_weight] > 1e-20 || weights[index_weight] < -1e-20 ){ - n_mult++; - }//end if nonzero weight - }//end not padding - }//end loop accross filter - }//end channel loop - }//end filter loop - }//end output loop - - return ceil( float(n_mult) / float(CONFIG_T::reuse_factor) ); - -}//end compute_n_mult - -template -void pointwise_conv_1d_latency_cl( - data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], - res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], - typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], - typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) -{ - assert(CONFIG_T::filt_width == 1); - - typename CONFIG_T::accum_t mult[CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan]; - typename CONFIG_T::accum_t acc[CONFIG_T::out_width][CONFIG_T::n_filt]; - - #pragma HLS ARRAY_PARTITION variable=mult complete dim=0 - #pragma HLS ARRAY_PARTITION variable=acc complete dim=0 - - // Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases - #pragma HLS function_instantiate variable=weights,biases - - // Parallel mode - #pragma HLS PIPELINE - #pragma HLS ARRAY_PARTITION variable=biases complete dim=0 - - // Limit multipliers to control parallelization - const int multiplier_limit = compute_multiplier_limit(weights); - #pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation - - // Convolve, saving all multiplication results to accumulate later - ConvOut: for(int ii = 0; ii < CONFIG_T::out_width; ii++) { - ConvFilt: for(int ff = 0; ff < CONFIG_T::n_filt; ff++) { - ConvChan: for(int cc = 0; cc < CONFIG_T::n_chan; cc++) { - int index_mult = ii*CONFIG_T::n_filt*CONFIG_T::n_chan + ff*CONFIG_T::n_chan + cc; - int index_weight = cc*CONFIG_T::n_filt + ff; - int index_data = (ii*CONFIG_T::stride_width-CONFIG_T::pad_left) * CONFIG_T::n_chan + cc; - - if((ii*CONFIG_T::stride_width) < CONFIG_T::pad_left || (ii*CONFIG_T::stride_width) >= (CONFIG_T::pad_left + CONFIG_T::in_width)){ - mult[index_mult] = 0; - } - else { - mult[index_mult] = data[index_data] * weights[index_weight]; - } - }//end channel loop - }//end filter loop - }//end output loop - - - // Initialize accumulator with input biases - for(int ii = 0; ii < CONFIG_T::out_width; ii++) { - for(int ff = 0; ff < CONFIG_T::n_filt; ff++) { - acc[ii][ff]=biases[ff]; - } - } - - - // Accumulate multiplication result - AccumOut: for(int ii = 0; ii < CONFIG_T::out_width; ii++) { - AccumFilt: for(int ff = 0; ff < CONFIG_T::n_filt; ff++) { - //Do "dot product" sum within filter and sum over channels - AccumChan: for(int cc = 0; cc < CONFIG_T::n_chan; cc++) { - int index_mult = ii*CONFIG_T::n_filt*CONFIG_T::n_chan + ff*CONFIG_T::n_chan + cc; - acc[ii][ff] += mult[index_mult]; - }//end channel loop - }//end filter loop - }//end output loop - - - // Cast to "res_t" type - for(int ii = 0; ii < CONFIG_T::out_width; ii++) { - for(int ff = 0; ff < CONFIG_T::n_filt; ff++) { - res[ii * CONFIG_T::n_filt + ff] = (res_T)(acc[ii][ff]); - } - } -} - } #endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_resource.h index 2ce906be1..0467c7111 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_resource.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv1d_resource.h @@ -98,60 +98,5 @@ void conv_1d_resource_cl( } } -template -void im2col_1d_pointwise_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], data_T data_col[CONFIG_T::n_chan], const int col) { - int index = 0; - ChannelLoop: - for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { - #pragma HLS UNROLL - - int index_data = (col*CONFIG_T::stride_width-CONFIG_T::pad_left) * CONFIG_T::n_chan + channel; - - if (index_data >= 0 && index_data < CONFIG_T::in_width*CONFIG_T::n_chan) { - data_col[index] = data[index_data]; - } else { - data_col[index] = 0; - } - index++; - } -} - -template -void pointwise_conv_1d_resource_cl( - data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], - res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], - typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], - typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] -) -{ - assert(CONFIG_T::filt_width == 1); - - const int nin = CONFIG_T::n_chan; - const int nout = CONFIG_T::n_filt; - const int rufactor = CONFIG_T::reuse_factor; - const int block_factor = DIV_ROUNDUP(nin*nout, rufactor); - - //#pragma HLS function_instantiate variable=weights,biases - //#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose correctly - //#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor - //#pragma HLS ARRAY_PARTITION variable=biases complete - - data_T data_col[CONFIG_T::n_chan]; - res_T res_col[CONFIG_T::n_filt]; - - #pragma HLS ARRAY_PARTITION variable=data_col complete - #pragma HLS ARRAY_PARTITION variable=res_col complete - - ColLoop: - for (int i = 0; i < CONFIG_T::out_width; i++) { - #pragma HLS PIPELINE - im2col_1d_pointwise_cl(data, data_col, i); - dense_resource(data_col, res_col, weights, biases); - for (int j = 0; j < CONFIG_T::n_filt; j++) { - res[i * CONFIG_T::n_filt + j] = res_col[j]; - } - } -} - } #endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d.h index e27d45545..babf4a940 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d.h @@ -85,10 +85,11 @@ void pointwise_conv_2d_cl( #pragma HLS INLINE region + // Nothing special to be done for io_parallel implementation if (CONFIG_T::strategy == nnet::latency) { - pointwise_conv_2d_latency_cl(data, res, weights, biases); + conv_2d_latency_cl(data, res, weights, biases); } else { - pointwise_conv_2d_resource_cl(data, res, weights, biases); + conv_2d_resource_cl(data, res, weights, biases); } } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_latency.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_latency.h index bf177928a..ff2fb181c 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_latency.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_latency.h @@ -234,105 +234,5 @@ void conv_2d_latency_cl( } -template -void pointwise_conv_2d_latency_cl( - data_T data[CONFIG_T::in_height*CONFIG_T::in_width*CONFIG_T::n_chan], - res_T res[CONFIG_T::out_height*CONFIG_T::out_width*CONFIG_T::n_filt], - typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], - typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) -{ - - typename CONFIG_T::accum_t mult[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan]; - typename CONFIG_T::accum_t acc[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt]; - - #pragma HLS ARRAY_PARTITION variable=mult complete dim=0 - #pragma HLS ARRAY_PARTITION variable=acc complete dim=0 - - // Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases - #pragma HLS function_instantiate variable=weights,biases - - // Parallel mode - #pragma HLS PIPELINE - #pragma HLS ARRAY_PARTITION variable=biases complete dim=0 - - // Limit multipliers to control parallelization - const int multiplier_limit = compute_multiplier_limit_conv2d(weights); - #pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation - - // Convolve, saving all multiplication results to accumulate later - ConvOutHeight: for(int oh = 0; oh < CONFIG_T::out_height; oh++) { - ConvOutWidth: for(int ow = 0; ow < CONFIG_T::out_width; ow++) { - ConvFilt: for(int ff = 0; ff < CONFIG_T::n_filt; ff++) { - ConvChan: for(int cc = 0; cc < CONFIG_T::n_chan; cc++) { - - int index_mult = oh*CONFIG_T::out_width*CONFIG_T::n_filt*CONFIG_T::n_chan - + ow*CONFIG_T::n_filt*CONFIG_T::n_chan - + ff*CONFIG_T::n_chan - + cc; - - int index_weight = cc*CONFIG_T::n_filt + ff; - - if ((oh*CONFIG_T::stride_height) < CONFIG_T::pad_top - || (oh*CONFIG_T::stride_height) >= (CONFIG_T::pad_top+CONFIG_T::in_height) - || (ow*CONFIG_T::stride_width) < CONFIG_T::pad_left - || (ow*CONFIG_T::stride_width) >= (CONFIG_T::pad_left+CONFIG_T::in_width)) { - mult[index_mult] = 0; - } else { - int index_data = (oh*CONFIG_T::stride_height-CONFIG_T::pad_top)*CONFIG_T::in_width*CONFIG_T::n_chan - + (ow*CONFIG_T::stride_width-CONFIG_T::pad_left)*CONFIG_T::n_chan - + cc; - mult[index_mult] = data[index_data] * weights[index_weight]; - } - - } - } - } - } - - - // Initialize accumulator with input biases - for(int oh = 0; oh < CONFIG_T::out_height; oh++) { - for(int ow = 0; ow < CONFIG_T::out_width; ow++) { - for(int ff = 0; ff < CONFIG_T::n_filt; ff++) { - acc[oh*CONFIG_T::out_width*CONFIG_T::n_filt + ow*CONFIG_T::n_filt + ff]=biases[ff]; - } - } - } - - - // Accumulate multiplication result - AccumOutHeight: for(int oh = 0; oh < CONFIG_T::out_height; oh++) { - AccumOutWidth: for(int ow = 0; ow < CONFIG_T::out_width; ow++) { - AccumFilt: for(int ff = 0; ff < CONFIG_T::n_filt; ff++) { - //Do "dot product" sum within filter and sum over channels - AccumChan: for(int cc = 0; cc < CONFIG_T::n_chan; cc++) { - - int index_mult = oh*CONFIG_T::out_width*CONFIG_T::n_filt*CONFIG_T::n_chan - + ow*CONFIG_T::n_filt*CONFIG_T::n_chan - + ff*CONFIG_T::n_chan - + cc; - int index_acc = oh*CONFIG_T::out_width*CONFIG_T::n_filt - + ow*CONFIG_T::n_filt - + ff; - - acc[index_acc] += mult[index_mult]; - - } - } - } - } - - // Cast to "res_t" type - for(int oh = 0; oh < CONFIG_T::out_height; oh++) { - for(int ow = 0; ow < CONFIG_T::out_width; ow++) { - for(int ff = 0; ff < CONFIG_T::n_filt; ff++) { - int index = oh*CONFIG_T::out_width*CONFIG_T::n_filt + ow*CONFIG_T::n_filt + ff; - res[index] = (res_T)(acc[index]); - } - } - } - -}//end conv2d - } #endif \ No newline at end of file diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_resource.h index fa7a30b6b..1db618c61 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_resource.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_conv2d_resource.h @@ -100,72 +100,5 @@ void conv_2d_resource_cl( } } -template -void im2col_2d_pointwise_cl( - data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], - data_T data_col[CONFIG_T::n_chan], - const int row, - const int col) -{ - int index = 0; - int input_row = -CONFIG_T::pad_top + row * CONFIG_T::stride_height; - - ChannelLoop: - for (int channel = 0; channel < CONFIG_T::n_chan; channel++) { - #pragma HLS UNROLL - if (input_row < 0 || input_row >= CONFIG_T::in_height) { - data_col[index++] = 0; - } else { - int input_col = -CONFIG_T::pad_left + col * CONFIG_T::stride_width; - if (input_col >= 0 && input_col < CONFIG_T::in_width) { - data_col[index++] = data[input_row * CONFIG_T::in_width * CONFIG_T::n_chan + input_col * CONFIG_T::n_chan + channel]; - } else { - data_col[index++] = 0; - } - } - } -} - -template -void pointwise_conv_2d_resource_cl( - data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], - res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], - typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt], - typename CONFIG_T::bias_t biases[CONFIG_T::n_filt] -) -{ - assert(CONFIG_T::filt_height == 1 && CONFIG_T::filt_width == 1); - - const int nin = CONFIG_T::n_chan; - const int nout = CONFIG_T::n_filt; - const int rufactor = CONFIG_T::reuse_factor; - const int block_factor = DIV_ROUNDUP(nin*nout, rufactor); - - //#pragma HLS function_instantiate variable=weights,biases - //#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose correctly - //#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor - //#pragma HLS ARRAY_PARTITION variable=biases complete - - data_T data_col[CONFIG_T::n_chan]; - res_T res_col[CONFIG_T::n_filt]; - - #pragma HLS ARRAY_PARTITION variable=data_col complete - #pragma HLS ARRAY_PARTITION variable=res_col complete - - HeightLoop: - for (int i = 0; i < CONFIG_T::out_height; i++) { - WidthLoop: - for (int j = 0; j < CONFIG_T::out_width; j++) { - #pragma HLS PIPELINE - im2col_2d_pointwise_cl(data, data_col, i, j); - dense(data_col, res_col, weights, biases); - FiltLoop: - for (int k = 0; k < CONFIG_T::n_filt; k++) { - res[i * CONFIG_T::out_width * CONFIG_T::n_filt + j * CONFIG_T::n_filt + k] = res_col[k]; - } - } - } -} - } #endif From c7f87f01ce4f90bc6a2d6868ef5b35320f92e48e Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Tue, 4 Oct 2022 03:34:55 +0200 Subject: [PATCH 08/10] Fix separable conv failing due to missing info about partitions --- hls4ml/backends/vivado/passes/convolution_templates.py | 4 ++++ hls4ml/backends/vivado/vivado_backend.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index 46477c528..45a7c46bf 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -214,6 +214,7 @@ def format(self, node): params['nzeros'] = node.get_weights('depthwise').nzeros params['index'] = str(node.index) + '_depthwise' params['weight_t'] = node.get_weights('depthwise').type + params['fill_fn'] = 'FillConv1DBuffer' params['config_t'] = 'config{}_depthwise_mult'.format(node.index) depthwise_config = self.depthwise_template.format(**params) @@ -245,6 +246,7 @@ def format(self, node): params['weight_t'] = node.get_weights('pointwise').type params['min_width'] = params['in_width'] params['instructions'] = '0' + params['fill_fn'] = 'FillConv1DBuffer' params['config_t'] = 'config{}_pointwise_mult'.format(node.index) pointwise_config = self.pointwise_template.format(**params) @@ -299,6 +301,7 @@ def format(self, node): params['nzeros'] = node.get_weights('depthwise').nzeros params['index'] = str(node.index) + '_depthwise' params['weight_t'] = node.get_weights('depthwise').type + params['fill_fn'] = 'FillConv2DBuffer' params['config_t'] = 'config{}_depthwise_mult'.format(node.index) depthwise_config = self.depthwise_template.format(**params) @@ -330,6 +333,7 @@ def format(self, node): params['min_height'] = params['in_height'] params['min_width'] = params['in_width'] params['instructions'] = '0' + params['fill_fn'] = 'FillConv2DBuffer' params['config_t'] = 'config{}_pointwise_mult'.format(node.index) pointwise_config = self.pointwise_template.format(**params) diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 4ebf3dc9b..f3b05519d 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -194,6 +194,7 @@ def init_sepconv1d(self, layer): else: layer.set_attr('strategy', 'latency') + layer.set_attr('n_partitions', 1) #TODO Once we have SeparableConv implementation for io_parallel this should be set properly layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) @layer_optimizer(Conv2D) @@ -234,6 +235,7 @@ def init_sepconv2d(self, layer): else: layer.set_attr('strategy', 'latency') + layer.set_attr('n_partitions', 1) #TODO Once we have SeparableConv implementation for io_parallel this should be set properly layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) @layer_optimizer(DepthwiseConv2D) @@ -245,6 +247,7 @@ def init_depconv2d(self, layer): else: layer.set_attr('strategy', 'latency') + layer.set_attr('n_partitions', 1) #TODO Once we have SeparableConv implementation for io_parallel this should be set properly layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) @layer_optimizer(Activation) From e9bb7ffc56aca537afab6ca2a37fed5963fc5604 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Tue, 4 Oct 2022 03:35:13 +0200 Subject: [PATCH 09/10] Docstrings for codegeneration functions --- hls4ml/backends/fpga/fpga_backend.py | 42 ++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index cf7e3c6a0..59ba22d5a 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -430,6 +430,26 @@ def _compute_conv1d_im2col(self, input_shape, kernel=3, stride=1, pad=(0,0), dil def generate_conv1d_line_buffer_fn(self, layer_idx, n_partitions, in_W, in_C, kernel=3, stride=1, pad=0, dilation=1): + """Generate a C++ function that mimics the im2col algorithm. This function works for 1D convolution. + + The HLS compiler produces suboptimal designs for a im2col algorithm implementation, so a trick we use is + to generate a resulting a result of im2col transformation explicitly, instead of relying on loops. Since + the result depends on the paraleters of the convolution layer (the input size, the kernel size, stride etc), + we need to do this for every convolution layer. + + Args: + layer_idx (int): Index of layer ('index' attribute). + n_partitions (int): Number of partitions to divide the input into. The pixels in each partition will be processed in parallel. + in_W (int): Width of input. + in_C (int): Number of channels. + kernel (int, optional): Size of the kernel. Defaults to 3. + stride (int, optional): Stride length. Defaults to 1. + pad (int or Iterable, optional): Padding to apply. Specified as either a number or a list [left_pad, right_pad]. Defaults to 0. + dilation (int, optional): Dilation rate. Defaults to 1. + + Returns: + str: Generated C++ function + """ if isinstance(pad, Iterable): pad_left = pad[0] pad_right = pad[1] @@ -510,6 +530,28 @@ def _compute_conv2d_im2col(self, input_shape, kernel=(3, 3), stride=(1, 1), pad= def generate_conv2d_line_buffer_fn(self, layer_idx, n_partitions, in_H, in_W, in_C, kernel=(3, 3), stride=(1, 1), pad=(0, 0, 0, 0), dilation=(1, 1)): + """Generate a C++ function that mimics the im2col algorithm. This function works for 2D convolution. + + The HLS compiler produces suboptimal designs for a im2col algorithm implementation, so a trick we use is + to generate a resulting a result of im2col transformation explicitly, instead of relying on loops. Since + the result depends on the paraleters of the convolution layer (the input size, the kernel size, stride etc), + we need to do this for every convolution layer. + + Args: + layer_idx (int): Index of layer ('index' attribute). + n_partitions (int): Number of partitions to divide the input into. The pixels in each partition will be processed in parallel. + in_H (int): Height of input. + in_W (int): Width of input. + in_C (int): Number of channels. + kernel (int or Iterable, optional): Size of the kernel. Defaults to (3,3). + stride (int or Iterable, optional): Stride length. Defaults to (1,1). + pad (int or Iterable, optional): Padding to apply. Specified as either a number or a list [top_pad, bottom_pad, left_pad, right_pad]. Defaults to 0. + dilation (int or Iterable, optional): Dilation rate. Defaults to (1,1). + + Returns: + str: Generated C++ function + """ + if isinstance(kernel, Iterable): kernel_height = kernel[0] kernel_width = kernel[1] From cd915eb021d73d11336c26d869dc161808a4b8b5 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Tue, 4 Oct 2022 16:58:08 +0200 Subject: [PATCH 10/10] Use smaller model in Conv1D test --- hls4ml/backends/quartus/quartus_backend.py | 4 ++++ test/pytest/test_conv1d.py | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/hls4ml/backends/quartus/quartus_backend.py b/hls4ml/backends/quartus/quartus_backend.py index a73c95606..5ac415272 100644 --- a/hls4ml/backends/quartus/quartus_backend.py +++ b/hls4ml/backends/quartus/quartus_backend.py @@ -254,6 +254,8 @@ def init_conv1d(self, layer): # - Winograd - use Winograd, if possible layer.set_attr('implementation', layer.model.config.get_layer_config_value(layer, 'Implementation', 'combination')) + layer.set_attr('n_partitions', 1) #TODO Not used yet as there is no codegen implementation of CNNs for Quartus backend + @layer_optimizer(Conv2D) def init_conv2d(self, layer): # This can happen if we assign weights of Dense layer to 1x1 Conv2D @@ -280,6 +282,8 @@ def init_conv2d(self, layer): # - im2col - specifically use im2col # - Winograd - use Winograd, if possible layer.set_attr('implementation', layer.model.config.get_layer_config_value(layer, 'Implementation', 'combination')) + + layer.set_attr('n_partitions', 1) #TODO Not used yet as there is no codegen implementation of CNNs for Quartus backend @layer_optimizer(LSTM) def init_lstm(self, layer): diff --git a/test/pytest/test_conv1d.py b/test/pytest/test_conv1d.py index 291d1b447..bef486cda 100644 --- a/test/pytest/test_conv1d.py +++ b/test/pytest/test_conv1d.py @@ -10,16 +10,16 @@ @pytest.fixture(scope='module') def data(): - X = np.random.rand(100,100,7) + X = np.random.rand(100,10,4) return X @pytest.fixture(scope='module') def keras_model(): - model_path = example_model_path / 'keras/KERAS_conv1d.json' + model_path = example_model_path / 'keras/KERAS_conv1d_small.json' with model_path.open('r') as f: jsons = f.read() model = model_from_json(jsons) - model.load_weights(example_model_path / 'keras/KERAS_conv1d_weights.h5') + model.load_weights(example_model_path / 'keras/KERAS_conv1d_small_weights.h5') return model @pytest.fixture