Skip to content

Commit

Permalink
Add Constant, Squeeze & Sub (apache#10)
Browse files Browse the repository at this point in the history
* Add squeeze.

* Add Constant.

* Add sub.
  • Loading branch information
zxybazh authored Jan 31, 2023
1 parent 2e9941b commit c8de028
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 7 deletions.
109 changes: 102 additions & 7 deletions python/tvm/relax/frontend/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@
# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines
# pylint: disable=import-outside-toplevel
"""ONNX: Open Neural Network Exchange frontend for Relax."""
import copy
import math
import warnings
from typing import Optional
from typing import Union, Optional

import numpy as np
import numpy as _np

import tvm
from tvm import relax, topi
from tvm.ir import IRModule
from tvm.relax import testing
from tvm._ffi import base as _base
from tvm.runtime import ndarray as _nd


def new_var(var_name, shape, dtype="float32"):
Expand Down Expand Up @@ -284,7 +285,7 @@ def _impl_v13(cls, bb, inputs, attr):
if -1 in new_shape:
breakpoint()
data_shape = [dim.value for dim in data.shape.values]
total_elements = np.prod(data_shape)
total_elements = _np.prod(data_shape)
new_product = 1
for dim in new_shape:
if dim > 0:
Expand Down Expand Up @@ -452,17 +453,108 @@ class CumSum(OnnxOpConverter):

@classmethod
def _impl_v13(cls, bb, inputs, attr):
assert getattr(attr, "reverse", 0) == 0, "reverse is not supported yet"
data = inputs[0]
if len(inputs) > 1:
axis = int(inputs[1].data.numpy())
else:
axis = None
return bb.emit_te(
if getattr(attr, "reverse", 0) != 0:
data = bb.emit_te(topi.flip, data, axis=axis if axis else 0)
data = bb.emit_te(
topi.cumsum,
data=inputs[0],
data=data,
axis=axis,
exclusive=attr.get("exclusive", None),
)
if getattr(attr, "reverse", 0) != 0:
data = bb.emit_te(topi.flip, data, axis=axis if axis else 0)
return data


class Squeeze(OnnxOpConverter):
"""Converts an onnx Squeeze node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
if len(inputs) > 1:
axis = [int(x) for x in inputs[1].data.numpy()]
else:
axis = None
return bb.emit_te(topi.squeeze, inputs[0], axis=axis)


class Constant(OnnxOpConverter):
"""Converts an onnx Constant node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
def const(
value: Union[bool, int, float, _np.ndarray, tvm.nd.NDArray],
dtype: Optional[str] = None,
span: Optional[relax.Span] = None,
):
"""Create a constant value.
Parameters
----------
value: Union[bool, int, float, numpy.ndarray, tvm.nd.NDArray]
The constant value.
dtype: str, optional
The data type of the resulting constant.
span: Optional[relax.Span]
Span that points to original source code.
Note
----
When dtype is None, we use the following rule:
- int maps to "int32"
- float maps to "float32"
- bool maps to "bool"
- other using the same default rule as numpy.
"""
if isinstance(value, (_base.numeric_types, (bool, list))):
value = _np.array(value, dtype=dtype)

if not dtype:
# when dtype is None: int maps to "int32", float maps to "float32"
dtype = {_np.dtype("int64"): _np.int32, _np.dtype("float64"): _np.float32}.get(
value.dtype, None
)

if isinstance(value, (_np.ndarray, _np.generic)):
if dtype is not None:
value = value.astype(dtype)
value = _nd.array(value)

if not isinstance(value, _nd.NDArray):
raise ValueError("value has to be scalar or NDArray")

return relax.Constant(value, span)

if "value" not in attr:
raise ValueError("no value in Constant")
value = attr.pop("value")
# Constants may rarely have string types. These are likely exported
# from other frameworks and not actually used in TVM. We'll just use
# a zero valued constant for compatibility.
if isinstance(value, bytes):
np_value = _np.asarray([0]).astype("int64")
else:
np_value = get_numpy(value)
dtype = np_value.dtype.name
value = const(np_value, dtype)
return value


class Sub(OnnxOpConverter):
"""Converts an onnx Sub node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr):
return bb.emit_te(topi.subtract, inputs[0], inputs[1])


def _get_convert_map(opset):
Expand Down Expand Up @@ -494,6 +586,9 @@ def _get_convert_map(opset):
"Pow": Pow.get_converter(opset),
"Erf": Erf.get_converter(opset),
"CumSum": CumSum.get_converter(opset),
"Squeeze": Squeeze.get_converter(opset),
"Constant": Constant.get_converter(opset),
"Sub": Sub.get_converter(opset),
}


Expand Down
58 changes: 58 additions & 0 deletions tests/python/relax/frontend/test_onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,61 @@ def test_cumsum():
check_correctness(model)


def test_squeeze():
squeeze_node = helper.make_node("Squeeze", ["x", "axis"], ["y"])
shape = [1, 32, 1, 32]
graph = helper.make_graph(
[squeeze_node],
"squeeze_test",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
],
initializer=[helper.make_tensor("axis", TensorProto.INT64, [2], [0, 2])],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 32])],
)

model = helper.make_model(graph, producer_name="squeeze_test")
check_correctness(model)


def test_const():
shape = [32, 32]
const_node = helper.make_node(
"Constant",
[],
["y"],
value=helper.make_tensor(
"value", TensorProto.FLOAT, shape, np.random.rand(*shape).astype(np.float32).flatten()
),
)
graph = helper.make_graph(
[const_node],
"const_test",
inputs=[],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)],
)

model = helper.make_model(graph, producer_name="const_test")
check_correctness(model)


def test_sub():
sub_node = helper.make_node("Sub", ["x", "y"], ["z"])
shape = [32, 16]
graph = helper.make_graph(
[sub_node],
"sub_test",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
helper.make_tensor_value_info("y", TensorProto.FLOAT, shape),
],
outputs=[helper.make_tensor_value_info("z", TensorProto.FLOAT, shape)],
)

model = helper.make_model(graph, producer_name="sub_test")
check_correctness(model)


if __name__ == "__main__":
test_matmul()
test_concat()
Expand All @@ -586,6 +641,9 @@ def test_cumsum():
test_pow()
test_erf()
test_cumsum()
test_squeeze()
test_const()
test_sub()

# TODO, still has issues
# test_reshape()
Expand Down

0 comments on commit c8de028

Please sign in to comment.