Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for flattening to the pytorch parser #852

Merged
merged 5 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions hls4ml/converters/pytorch/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,27 @@ def parse_reshape_layer(operation, layer_name, input_names, input_shapes, node,
output_shape = input_shapes[0][:1] + layer['target_shape']

return layer, output_shape


@pytorch_handler('Flatten')
def parse_flatten_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
assert operation == 'Flatten'

layer = {}
layer['class_name'] = 'Reshape'
layer['name'] = layer_name
layer['inputs'] = input_names

start_dim = class_object.start_dim
end_dim = class_object.end_dim
if end_dim + 1 == 0 or end_dim + 1 > len(input_shapes[0]):
end_dim = len(input_shapes[0])
else:
end_dim = end_dim + 1

layer['target_shape'] = (
input_shapes[0][0:start_dim] + [np.prod(input_shapes[0][start_dim:end_dim])] + input_shapes[0][end_dim:]
)
output_shape = layer['target_shape']

return layer, output_shape
10 changes: 4 additions & 6 deletions hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
import torch

from hls4ml.model import ModelGraph
Expand Down Expand Up @@ -105,6 +104,7 @@ def decorator(function):
'max_pool2d': 'MaxPool2d',
'avg_pool1d': 'AvgPool1d',
'avg_pool2d': 'AvgPool2d',
'flatten': 'Flatten',
}


Expand Down Expand Up @@ -144,7 +144,7 @@ def pytorch_to_hls(config):

traced_model = symbolic_trace(model)
# Define layers to skip for conversion to HLS
skip_layers = ['Dropout', 'Flatten', 'Sequential']
skip_layers = ['Dropout', 'Sequential']

# All supported layers
supported_layers = get_supported_pytorch_layers() + skip_layers
Expand Down Expand Up @@ -189,10 +189,8 @@ def pytorch_to_hls(config):
if pytorch_class == 'Sequential': # Ignore the mother module's class name
continue

if pytorch_class == 'Flatten':
output_shapes[layer_name] = [input_shapes[0][0], np.prod(input_shapes[0][1:])]
else:
output_shapes[layer_name] = input_shapes[0]
output_shapes[layer_name] = input_shapes[0]

continue

# Increment the layer counter after initial screenings
Expand Down
26 changes: 22 additions & 4 deletions hls4ml/model/optimizer/passes/convert_to_channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Based on https://github.com/fastmachinelearning/qonnx/blob/
# 12c96a3ded06beacab08e0f554e4ed014476c0aa/src/qonnx/transformation/channels_last.py

from hls4ml.model.layers import Concatenate, Input
from hls4ml.model.layers import Concatenate, Input, Reshape
from hls4ml.model.optimizer import OptimizerPass


Expand Down Expand Up @@ -92,15 +92,33 @@ def transform(self, model, node):
dims = [outdims[1], outdims[2], outdims[0]]
node.add_output_variable(shape, dims)

# Have to transpose back before flattening to get correct order of elements in the flattened tensor
if isinstance(node, Reshape) and len(node.attributes['target_shape']) == 1:
previous_node = node.get_input_node(node.inputs[0])
input = previous_node.name
outshape = previous_node.get_output_variable().shape

if len(outshape) == 2:
attributes = {'perm': [1, 0]}
else:
attributes = {'perm': [2, 0, 1]}

transpose_node = model.make_node(
'Transpose', f'transpose_input_for_{node.get_attr("name")}', attributes, [input]
)
transpose_node.channels_last_converted = True

model.insert_node(transpose_node)

# Add transpose for output layer
if (
node.get_attr("name") in model.outputs
elif (
node.get_attr('name') in model.outputs
and len(outshape) > 1
and model.config.config['HLSConfig']['Model']['TransposeOutputs']
):
input = node.name
outshape = node.get_output_variable().shape
print(outshape)

if len(outshape) == 2:
attributes = {'perm': [1, 0]}
else:
Expand Down
20 changes: 19 additions & 1 deletion test/pytest/test_pytorch_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def test_activation_functionals(activation_function, backend, io_type):

hls_prediction = hls_model.predict(X_input)

np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01)
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=0.05)

from torch.fx import symbolic_trace

Expand Down Expand Up @@ -572,3 +572,21 @@ def test_pooling(pooling, padds, backend):
assert hls_pool.attributes['pool_width'] == class_object_pool.kernel_size[0]
assert hls_pool.attributes['stride_width'] == class_object_pool.stride[0]
assert hls_pool.attributes['padding'] == 'same' if class_object_pool.padding == 0 else 'valid'


@pytest.mark.parametrize('backend', ['Vivado', 'Quartus'])
def test_flatten(backend):
input = torch.randn(1, 1, 5, 5)
model = nn.Sequential(nn.Conv2d(1, 32, 5, 1, 1), nn.Flatten(), nn.ReLU())
pytorch_prediction = model(input).detach().numpy()
input_shape = (None, 1, 5, 5)

config = config_from_pytorch_model(model)
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_flatten_backend_{backend}')
hls_model = convert_from_pytorch_model(model, input_shape, hls_config=config, output_dir=output_dir, backend=backend)
hls_model.compile()

pred = hls_model.predict(input.detach().numpy())
hls_prediction = np.reshape(pred, (1, 288))

np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=5e-2)
Loading