diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 62f0f4b2dd25..3470099100d4 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -4008,6 +4008,23 @@ def _impl_v1(cls, inputs, attr, params): for var in else_free_vars: graph_scope._nodes.update({var.name_hint: var}) + # Sometimes pytorch to onnx will insert silly if statements that produce dynamic ranks. + # Often these dont contribute anything. If we see a dynamic rank output, try to unify + # them so we can continue without breaking. + if not isinstance(then_expr, _expr.Tuple) and not isinstance(else_expr, _expr.Tuple): + then_shape = infer_shape(then_expr) + else_shape = infer_shape(else_expr) + if len(then_shape) != len(else_shape): + warning_msg = ( + "If statement produced outputs with different rank. " + "Attempting to unify ranks but this may produce incorrect results." + ) + warnings.warn(warning_msg) + if len(then_shape) < len(else_shape): + then_expr = _op.broadcast_to_like(then_expr, else_expr) + else: + else_expr = _op.broadcast_to_like(else_expr, then_expr) + # Now we can construct the relay if statement and return. ret = _expr.If(cond, then_expr, else_expr) if len(then_branch.output) > 1: @@ -5565,6 +5582,66 @@ def _impl_v11(cls, inputs, attr, params): return _op.concatenate(inputs[0], axis=axis) +class SplitToSequence(OnnxOpConverter): + """Operator converter for split to sequence op.""" + + @classmethod + def _impl_v11(cls, inputs, attr, params): + axis = attr.get("axis", 0) + keepdims = attr.get("keepdims", 1) + + input_tensor = inputs[0] + input_shape = infer_shape(input_tensor) + split = inputs[1] + + # If split is not provided, we split all values along axis. + if split is None: + output = _op.split(input_tensor, input_shape[axis], axis=axis) + # If keepdims is 0, then we need to squeeze off the axis. + if not keepdims: + output = [_op.squeeze(tensor_slice, axis=[axis]) for tensor_slice in output] + return _expr.Tuple(list(output)) + + # Otherwise, split based on provided split value. + else: + # For now we only support constant valued split. + assert isinstance( + split, _expr.Constant + ), "Only constant split supported for SplitToSequence" + split = split.data.numpy() + if len(split.shape) == 1 and split.shape[0] > 1: + # If split is a 1D tensor, it must be converted to indices for relay compatibility. + split = np.cumsum(split) + # Remove final invalid index. + split = split[:-1] + else: + # Otherwise get split as an integer. + split = int(split) + + output = _op.split(input_tensor, split, axis=axis) + + # If keepdims is set to 0 remove split axis. Note that this is + # an inconsistency with the onnx spec but is needed for pytorch compatibility. + if not keepdims: + output = [_op.squeeze(tensor_slice, axis=[axis]) for tensor_slice in output] + return _expr.Tuple(list(output)) + + +class SequenceAt(OnnxOpConverter): + """Operator converter for sequence at op.""" + + @classmethod + def _impl_v11(cls, inputs, attr, params): + input_sequence = inputs[0] + position = inputs[1] + assert isinstance( + position, _expr.Constant + ), "Only constant position supported for SequenceAt" + # Convert position to integer. + position = int(position.data.numpy()) + return input_sequence[position] + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -5793,6 +5870,8 @@ def _get_convert_map(opset): "SequenceConstruct": SequenceConstruct.get_converter(opset), "SequenceInsert": SequenceInsert.get_converter(opset), "ConcatFromSequence": ConcatFromSequence.get_converter(opset), + "SplitToSequence": SplitToSequence.get_converter(opset), + "SequenceAt": SequenceAt.get_converter(opset), } diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 5b7e342c4b4e..d4e4a527835a 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -432,8 +432,6 @@ def _concatenate_shape_func(inputs, axis): for i in const_range(ndim): if i != axis: out[i] = inputs[0][i] - for j in const_range(1, len(inputs)): - assert out[i] == inputs[j][i], "Dims mismatch in the inputs of concatenate." else: out[i] = int64(0) for j in const_range(len(inputs)): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 211d7f798aba..dcd4f2defbe8 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -7043,7 +7043,7 @@ def verify_linear_regressor(a_shape, c_shape, i_shape, targets=1, batch=1): def test_sequence(target, dev): """test_sequence""" - def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=None, new_axis=None): + def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=None): tensor_shape = list(tensor_shape) tensor_values = [] for i in range(num_tensors): @@ -7062,20 +7062,30 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=None, new_ax outputs=["sequence"], ) - insert_inputs = ["sequence", input_tensor_names[0]] - position_node = None - if position is not None: - insert_inputs.append("position") - position_node = make_constant_node("position", TensorProto.INT32, (), [position]) + position_node = make_constant_node("position", TensorProto.INT32, (), [position]) # Test sequence insertion. insert_node = helper.make_node( - "SequenceInsert", inputs=insert_inputs, outputs=["inserted_sequence"] + "SequenceInsert", + inputs=["sequence", input_tensor_names[0], "position"], + outputs=["inserted_sequence"], ) # Test sequence concatenation. concat_node = helper.make_node( - "ConcatFromSequence", inputs=["inserted_sequence"], outputs=["output"], axis=axis + "ConcatFromSequence", + inputs=["inserted_sequence"], + outputs=["concat_sequence"], + axis=axis, + ) + + # Test splitting a tensor into a sequence. + split_node = helper.make_node( + "SplitToSequence", inputs=["concat_sequence"], outputs=["split_sequence"], axis=axis + ) + + at_node = helper.make_node( + "SequenceAt", inputs=["split_sequence", "position"], outputs=["output"] ) if new_axis is not None: @@ -7097,10 +7107,7 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=None, new_ax output_shape[axis] = (num_tensors + 1) * output_shape[axis] graph_outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape)] - graph_nodes = [] - if position_node is not None: - graph_nodes.append(position_node) - graph_nodes += [construct_node, insert_node, concat_node] + graph_nodes = [position_node, construct_node, insert_node, concat_node, split_node, at_node] graph = helper.make_graph( graph_nodes,