diff --git a/hls4ml/converters/pytorch/reshape.py b/hls4ml/converters/pytorch/reshape.py index 50b15dc47..15efd7a4f 100644 --- a/hls4ml/converters/pytorch/reshape.py +++ b/hls4ml/converters/pytorch/reshape.py @@ -27,3 +27,27 @@ def parse_reshape_layer(operation, layer_name, input_names, input_shapes, node, output_shape = input_shapes[0][:1] + layer['target_shape'] return layer, output_shape + + +@pytorch_handler('Flatten') +def parse_flatten_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): + assert operation == 'Flatten' + + layer = {} + layer['class_name'] = 'Reshape' + layer['name'] = layer_name + layer['inputs'] = input_names + + start_dim = class_object.start_dim + end_dim = class_object.end_dim + if end_dim + 1 == 0 or end_dim + 1 > len(input_shapes[0]): + end_dim = len(input_shapes[0]) + else: + end_dim = end_dim + 1 + + layer['target_shape'] = ( + input_shapes[0][0:start_dim] + [np.prod(input_shapes[0][start_dim:end_dim])] + input_shapes[0][end_dim:] + ) + output_shape = layer['target_shape'] + + return layer, output_shape diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index 961fb735a..e51e9ffd8 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -1,4 +1,3 @@ -import numpy as np import torch from hls4ml.model import ModelGraph @@ -105,6 +104,7 @@ def decorator(function): 'max_pool2d': 'MaxPool2d', 'avg_pool1d': 'AvgPool1d', 'avg_pool2d': 'AvgPool2d', + 'flatten': 'Flatten', } @@ -144,7 +144,7 @@ def pytorch_to_hls(config): traced_model = symbolic_trace(model) # Define layers to skip for conversion to HLS - skip_layers = ['Dropout', 'Flatten', 'Sequential'] + skip_layers = ['Dropout', 'Sequential'] # All supported layers supported_layers = get_supported_pytorch_layers() + skip_layers @@ -189,10 +189,8 @@ def pytorch_to_hls(config): if pytorch_class == 'Sequential': # Ignore the mother module's class name continue - if pytorch_class == 'Flatten': - output_shapes[layer_name] = [input_shapes[0][0], np.prod(input_shapes[0][1:])] - else: - output_shapes[layer_name] = input_shapes[0] + output_shapes[layer_name] = input_shapes[0] + continue # Increment the layer counter after initial screenings diff --git a/hls4ml/model/optimizer/passes/convert_to_channels_last.py b/hls4ml/model/optimizer/passes/convert_to_channels_last.py index cef4d947d..9c1971156 100644 --- a/hls4ml/model/optimizer/passes/convert_to_channels_last.py +++ b/hls4ml/model/optimizer/passes/convert_to_channels_last.py @@ -2,7 +2,7 @@ # Based on https://github.com/fastmachinelearning/qonnx/blob/ # 12c96a3ded06beacab08e0f554e4ed014476c0aa/src/qonnx/transformation/channels_last.py -from hls4ml.model.layers import Concatenate, Input +from hls4ml.model.layers import Concatenate, Input, Reshape from hls4ml.model.optimizer import OptimizerPass @@ -92,15 +92,33 @@ def transform(self, model, node): dims = [outdims[1], outdims[2], outdims[0]] node.add_output_variable(shape, dims) + # Have to transpose back before flattening to get correct order of elements in the flattened tensor + if isinstance(node, Reshape) and len(node.attributes['target_shape']) == 1: + previous_node = node.get_input_node(node.inputs[0]) + input = previous_node.name + outshape = previous_node.get_output_variable().shape + + if len(outshape) == 2: + attributes = {'perm': [1, 0]} + else: + attributes = {'perm': [2, 0, 1]} + + transpose_node = model.make_node( + 'Transpose', f'transpose_input_for_{node.get_attr("name")}', attributes, [input] + ) + transpose_node.channels_last_converted = True + + model.insert_node(transpose_node) + # Add transpose for output layer - if ( - node.get_attr("name") in model.outputs + elif ( + node.get_attr('name') in model.outputs and len(outshape) > 1 and model.config.config['HLSConfig']['Model']['TransposeOutputs'] ): input = node.name outshape = node.get_output_variable().shape - print(outshape) + if len(outshape) == 2: attributes = {'perm': [1, 0]} else: diff --git a/test/pytest/test_pytorch_api.py b/test/pytest/test_pytorch_api.py index ff2bae2a4..485f40eb3 100644 --- a/test/pytest/test_pytorch_api.py +++ b/test/pytest/test_pytorch_api.py @@ -183,7 +183,7 @@ def test_activation_functionals(activation_function, backend, io_type): hls_prediction = hls_model.predict(X_input) - np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=0.05) from torch.fx import symbolic_trace @@ -572,3 +572,21 @@ def test_pooling(pooling, padds, backend): assert hls_pool.attributes['pool_width'] == class_object_pool.kernel_size[0] assert hls_pool.attributes['stride_width'] == class_object_pool.stride[0] assert hls_pool.attributes['padding'] == 'same' if class_object_pool.padding == 0 else 'valid' + + +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) +def test_flatten(backend): + input = torch.randn(1, 1, 5, 5) + model = nn.Sequential(nn.Conv2d(1, 32, 5, 1, 1), nn.Flatten(), nn.ReLU()) + pytorch_prediction = model(input).detach().numpy() + input_shape = (None, 1, 5, 5) + + config = config_from_pytorch_model(model) + output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_flatten_backend_{backend}') + hls_model = convert_from_pytorch_model(model, input_shape, hls_config=config, output_dir=output_dir, backend=backend) + hls_model.compile() + + pred = hls_model.predict(input.detach().numpy()) + hls_prediction = np.reshape(pred, (1, 288)) + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=5e-2)