Skip to content

Commit

Permalink
[Relay][Frontend][Onnx] SequenceAt and SplitToSequence Operators (apa…
Browse files Browse the repository at this point in the history
…che#13602)

* Add support for SequenceAt and SplitToSequence to onnx importer

* Formatting

* Change keepdims comparison

* Only unify non-tuples in If
  • Loading branch information
Josh Fromm authored and fzi-peccia committed Mar 27, 2023
1 parent 0d4a2cd commit c48c063
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 14 deletions.
79 changes: 79 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4008,6 +4008,23 @@ def _impl_v1(cls, inputs, attr, params):
for var in else_free_vars:
graph_scope._nodes.update({var.name_hint: var})

# Sometimes pytorch to onnx will insert silly if statements that produce dynamic ranks.
# Often these dont contribute anything. If we see a dynamic rank output, try to unify
# them so we can continue without breaking.
if not isinstance(then_expr, _expr.Tuple) and not isinstance(else_expr, _expr.Tuple):
then_shape = infer_shape(then_expr)
else_shape = infer_shape(else_expr)
if len(then_shape) != len(else_shape):
warning_msg = (
"If statement produced outputs with different rank. "
"Attempting to unify ranks but this may produce incorrect results."
)
warnings.warn(warning_msg)
if len(then_shape) < len(else_shape):
then_expr = _op.broadcast_to_like(then_expr, else_expr)
else:
else_expr = _op.broadcast_to_like(else_expr, then_expr)

# Now we can construct the relay if statement and return.
ret = _expr.If(cond, then_expr, else_expr)
if len(then_branch.output) > 1:
Expand Down Expand Up @@ -5565,6 +5582,66 @@ def _impl_v11(cls, inputs, attr, params):
return _op.concatenate(inputs[0], axis=axis)


class SplitToSequence(OnnxOpConverter):
"""Operator converter for split to sequence op."""

@classmethod
def _impl_v11(cls, inputs, attr, params):
axis = attr.get("axis", 0)
keepdims = attr.get("keepdims", 1)

input_tensor = inputs[0]
input_shape = infer_shape(input_tensor)
split = inputs[1]

# If split is not provided, we split all values along axis.
if split is None:
output = _op.split(input_tensor, input_shape[axis], axis=axis)
# If keepdims is 0, then we need to squeeze off the axis.
if not keepdims:
output = [_op.squeeze(tensor_slice, axis=[axis]) for tensor_slice in output]
return _expr.Tuple(list(output))

# Otherwise, split based on provided split value.
else:
# For now we only support constant valued split.
assert isinstance(
split, _expr.Constant
), "Only constant split supported for SplitToSequence"
split = split.data.numpy()
if len(split.shape) == 1 and split.shape[0] > 1:
# If split is a 1D tensor, it must be converted to indices for relay compatibility.
split = np.cumsum(split)
# Remove final invalid index.
split = split[:-1]
else:
# Otherwise get split as an integer.
split = int(split)

output = _op.split(input_tensor, split, axis=axis)

# If keepdims is set to 0 remove split axis. Note that this is
# an inconsistency with the onnx spec but is needed for pytorch compatibility.
if not keepdims:
output = [_op.squeeze(tensor_slice, axis=[axis]) for tensor_slice in output]
return _expr.Tuple(list(output))


class SequenceAt(OnnxOpConverter):
"""Operator converter for sequence at op."""

@classmethod
def _impl_v11(cls, inputs, attr, params):
input_sequence = inputs[0]
position = inputs[1]
assert isinstance(
position, _expr.Constant
), "Only constant position supported for SequenceAt"
# Convert position to integer.
position = int(position.data.numpy())
return input_sequence[position]


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -5793,6 +5870,8 @@ def _get_convert_map(opset):
"SequenceConstruct": SequenceConstruct.get_converter(opset),
"SequenceInsert": SequenceInsert.get_converter(opset),
"ConcatFromSequence": ConcatFromSequence.get_converter(opset),
"SplitToSequence": SplitToSequence.get_converter(opset),
"SequenceAt": SequenceAt.get_converter(opset),
}


Expand Down
2 changes: 0 additions & 2 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,6 @@ def _concatenate_shape_func(inputs, axis):
for i in const_range(ndim):
if i != axis:
out[i] = inputs[0][i]
for j in const_range(1, len(inputs)):
assert out[i] == inputs[j][i], "Dims mismatch in the inputs of concatenate."
else:
out[i] = int64(0)
for j in const_range(len(inputs)):
Expand Down
31 changes: 19 additions & 12 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7043,7 +7043,7 @@ def verify_linear_regressor(a_shape, c_shape, i_shape, targets=1, batch=1):
def test_sequence(target, dev):
"""test_sequence"""

def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=None, new_axis=None):
def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=None):
tensor_shape = list(tensor_shape)
tensor_values = []
for i in range(num_tensors):
Expand All @@ -7062,20 +7062,30 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=None, new_ax
outputs=["sequence"],
)

insert_inputs = ["sequence", input_tensor_names[0]]
position_node = None
if position is not None:
insert_inputs.append("position")
position_node = make_constant_node("position", TensorProto.INT32, (), [position])
position_node = make_constant_node("position", TensorProto.INT32, (), [position])

# Test sequence insertion.
insert_node = helper.make_node(
"SequenceInsert", inputs=insert_inputs, outputs=["inserted_sequence"]
"SequenceInsert",
inputs=["sequence", input_tensor_names[0], "position"],
outputs=["inserted_sequence"],
)

# Test sequence concatenation.
concat_node = helper.make_node(
"ConcatFromSequence", inputs=["inserted_sequence"], outputs=["output"], axis=axis
"ConcatFromSequence",
inputs=["inserted_sequence"],
outputs=["concat_sequence"],
axis=axis,
)

# Test splitting a tensor into a sequence.
split_node = helper.make_node(
"SplitToSequence", inputs=["concat_sequence"], outputs=["split_sequence"], axis=axis
)

at_node = helper.make_node(
"SequenceAt", inputs=["split_sequence", "position"], outputs=["output"]
)

if new_axis is not None:
Expand All @@ -7097,10 +7107,7 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=None, new_ax
output_shape[axis] = (num_tensors + 1) * output_shape[axis]
graph_outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape)]

graph_nodes = []
if position_node is not None:
graph_nodes.append(position_node)
graph_nodes += [construct_node, insert_node, concat_node]
graph_nodes = [position_node, construct_node, insert_node, concat_node, split_node, at_node]

graph = helper.make_graph(
graph_nodes,
Expand Down

0 comments on commit c48c063

Please sign in to comment.