Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend][Onnx] SequenceAt and SplitToSequence Operators #13602

Merged
merged 4 commits into from
Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4008,6 +4008,23 @@ def _impl_v1(cls, inputs, attr, params):
for var in else_free_vars:
graph_scope._nodes.update({var.name_hint: var})

# Sometimes pytorch to onnx will insert silly if statements that produce dynamic ranks.
# Often these dont contribute anything. If we see a dynamic rank output, try to unify
# them so we can continue without breaking.
if not isinstance(then_expr, _expr.Tuple) and not isinstance(else_expr, _expr.Tuple):
then_shape = infer_shape(then_expr)
else_shape = infer_shape(else_expr)
if len(then_shape) != len(else_shape):
warning_msg = (
"If statement produced outputs with different rank. "
"Attempting to unify ranks but this may produce incorrect results."
)
warnings.warn(warning_msg)
if len(then_shape) < len(else_shape):
then_expr = _op.broadcast_to_like(then_expr, else_expr)
else:
else_expr = _op.broadcast_to_like(else_expr, then_expr)

# Now we can construct the relay if statement and return.
ret = _expr.If(cond, then_expr, else_expr)
if len(then_branch.output) > 1:
Expand Down Expand Up @@ -5565,6 +5582,66 @@ def _impl_v11(cls, inputs, attr, params):
return _op.concatenate(inputs[0], axis=axis)


class SplitToSequence(OnnxOpConverter):
"""Operator converter for split to sequence op."""

@classmethod
def _impl_v11(cls, inputs, attr, params):
axis = attr.get("axis", 0)
keepdims = attr.get("keepdims", 1)

input_tensor = inputs[0]
input_shape = infer_shape(input_tensor)
split = inputs[1]

# If split is not provided, we split all values along axis.
if split is None:
output = _op.split(input_tensor, input_shape[axis], axis=axis)
# If keepdims is 0, then we need to squeeze off the axis.
if not keepdims:
output = [_op.squeeze(tensor_slice, axis=[axis]) for tensor_slice in output]
return _expr.Tuple(list(output))

# Otherwise, split based on provided split value.
else:
# For now we only support constant valued split.
assert isinstance(
split, _expr.Constant
), "Only constant split supported for SplitToSequence"
split = split.data.numpy()
if len(split.shape) == 1 and split.shape[0] > 1:
# If split is a 1D tensor, it must be converted to indices for relay compatibility.
split = np.cumsum(split)
# Remove final invalid index.
split = split[:-1]
else:
# Otherwise get split as an integer.
split = int(split)

output = _op.split(input_tensor, split, axis=axis)

# If keepdims is set to 0 remove split axis. Note that this is
# an inconsistency with the onnx spec but is needed for pytorch compatibility.
if not keepdims:
output = [_op.squeeze(tensor_slice, axis=[axis]) for tensor_slice in output]
return _expr.Tuple(list(output))


class SequenceAt(OnnxOpConverter):
"""Operator converter for sequence at op."""

@classmethod
def _impl_v11(cls, inputs, attr, params):
input_sequence = inputs[0]
position = inputs[1]
assert isinstance(
position, _expr.Constant
), "Only constant position supported for SequenceAt"
# Convert position to integer.
position = int(position.data.numpy())
return input_sequence[position]


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -5793,6 +5870,8 @@ def _get_convert_map(opset):
"SequenceConstruct": SequenceConstruct.get_converter(opset),
"SequenceInsert": SequenceInsert.get_converter(opset),
"ConcatFromSequence": ConcatFromSequence.get_converter(opset),
"SplitToSequence": SplitToSequence.get_converter(opset),
"SequenceAt": SequenceAt.get_converter(opset),
}


Expand Down
2 changes: 0 additions & 2 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,6 @@ def _concatenate_shape_func(inputs, axis):
for i in const_range(ndim):
if i != axis:
out[i] = inputs[0][i]
for j in const_range(1, len(inputs)):
assert out[i] == inputs[j][i], "Dims mismatch in the inputs of concatenate."
else:
out[i] = int64(0)
for j in const_range(len(inputs)):
Expand Down
31 changes: 19 additions & 12 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7043,7 +7043,7 @@ def verify_linear_regressor(a_shape, c_shape, i_shape, targets=1, batch=1):
def test_sequence(target, dev):
"""test_sequence"""

def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=None, new_axis=None):
def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=None):
tensor_shape = list(tensor_shape)
tensor_values = []
for i in range(num_tensors):
Expand All @@ -7062,20 +7062,30 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=None, new_ax
outputs=["sequence"],
)

insert_inputs = ["sequence", input_tensor_names[0]]
position_node = None
if position is not None:
insert_inputs.append("position")
position_node = make_constant_node("position", TensorProto.INT32, (), [position])
position_node = make_constant_node("position", TensorProto.INT32, (), [position])

# Test sequence insertion.
insert_node = helper.make_node(
"SequenceInsert", inputs=insert_inputs, outputs=["inserted_sequence"]
"SequenceInsert",
inputs=["sequence", input_tensor_names[0], "position"],
outputs=["inserted_sequence"],
)

# Test sequence concatenation.
concat_node = helper.make_node(
"ConcatFromSequence", inputs=["inserted_sequence"], outputs=["output"], axis=axis
"ConcatFromSequence",
inputs=["inserted_sequence"],
outputs=["concat_sequence"],
axis=axis,
)

# Test splitting a tensor into a sequence.
split_node = helper.make_node(
"SplitToSequence", inputs=["concat_sequence"], outputs=["split_sequence"], axis=axis
)

at_node = helper.make_node(
"SequenceAt", inputs=["split_sequence", "position"], outputs=["output"]
)

if new_axis is not None:
Expand All @@ -7097,10 +7107,7 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=None, new_ax
output_shape[axis] = (num_tensors + 1) * output_shape[axis]
graph_outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape)]

graph_nodes = []
if position_node is not None:
graph_nodes.append(position_node)
graph_nodes += [construct_node, insert_node, concat_node]
graph_nodes = [position_node, construct_node, insert_node, concat_node, split_node, at_node]

graph = helper.make_graph(
graph_nodes,
Expand Down