From c987c75bcd8dce1cc2698f67fbd61967ef854b6e Mon Sep 17 00:00:00 2001 From: eguchi1904 Date: Tue, 18 Jul 2023 15:29:42 +0900 Subject: [PATCH 1/9] add convTranspose --- runtime/onnion_runtime/__init__.py | 1 + runtime/onnion_runtime/convtranspose.py | 86 +++++++++++++++++++++++++ runtime/onnion_runtime/error.py | 6 +- 3 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 runtime/onnion_runtime/convtranspose.py diff --git a/runtime/onnion_runtime/__init__.py b/runtime/onnion_runtime/__init__.py index 8be21a6..b674622 100644 --- a/runtime/onnion_runtime/__init__.py +++ b/runtime/onnion_runtime/__init__.py @@ -19,6 +19,7 @@ from .concatfromsequence import ConcatFromSequence # noqa: F401 from .constant import Constant # noqa: F401 from .constantofshape import ConstantOfShape # noqa: F401 +from .convtranspose import ConvTranspose # noqa: F401 from .cos import Cos # noqa: F401 from .cosh import Cosh # noqa: F401 from .depthtospace import DepthToSpace # noqa: F401 diff --git a/runtime/onnion_runtime/convtranspose.py b/runtime/onnion_runtime/convtranspose.py new file mode 100644 index 0000000..e23daa9 --- /dev/null +++ b/runtime/onnion_runtime/convtranspose.py @@ -0,0 +1,86 @@ +from typing import Any, List, Optional + +import numpy as np + +from .error import RunError + + +# https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose +class ConvTranspose: + auto_pad: str + group: Optional[int] + dilations: Optional[List[int]] + strides: Optional[List[int]] + kernel_shape: Optional[List[int]] + output_shape: Optional[List[int]] + output_padding: Optional[List[int]] + pads: Optional[List[int]] + + def __init__(self, opset_version: int, **kwargs: Any): + self.version = opset_version + self.auto_pad = kwargs.get("auto_pad", "NOTSET") + self.dilations = kwargs.get("dilations", None) + self.group = kwargs.get("group", None) + self.kernel_shape = kwargs.get("kernel_shape", None) + self.output_padding = kwargs.get("output_padding", None) + self.output_shape = kwargs.get("output_shape", None) + self.pads = kwargs.get("pads", None) + self.strides = kwargs.get("strides", None) + + def run(self, x: np.ndarray, W: np.ndarray, b: Optional[np.ndarray] = None) -> List[np.ndarray]: + # x: [batch, in_ch, in_h, in_w] + # W: [in_ch, out_ch/group, kernel_h, kernel_w] + # b: [out_ch] + + # fix parameters + dim = len(x.shape) - 2 + group = self.group or 1 + batch = x.shape[0] + in_ch = x.shape[1] + out_ch = W.shape[1] + dilations = self.dilations or [1] * dim + strides = self.strides or [1] * dim + pads = self.pads or [0] * (dim * 2) + output_padding = self.output_padding or [0] * dim + kernel_shape = self.kernel_shape or W.shape[2:] + input_shape = x.shape[2:] + + # check parameters + if dim != 2: + raise RunError("ConvTranspose", self.version, "support 2d only") + + if group != 1: + raise RunError("ConvTranspose", self.version, "support group=1 only") + + if self.output_shape is not None: + raise RunError("ConvTranspose", self.version, "do not support ouput_shape") + + output_shape = [ + strides[i] * (input_shape[i] - 1) + + output_padding[i] + + ((kernel_shape[i] - 1) * dilations[i] + 1) + - pads[i] + - pads[i + dim] + for i in range(dim) + ] + + result = np.zeros([batch, out_ch, *output_shape], dtype=x.dtype) + + for n in range(batch): + for och in range(out_ch): + if b is not None: + result[n, och, :, :] += b[och] + for ih in range(input_shape[0]): + for iw in range(input_shape[1]): + for kh in range(kernel_shape[0]): + for kw in range(kernel_shape[1]): + oh = strides[0] * ih + kh * dilations[0] - pads[0] + ow = strides[1] * iw + kw * dilations[1] - pads[1] + if oh < 0 or ow < 0 or oh >= output_shape[0] or ow >= output_shape[1]: + continue + v = np.float32(0) + for ich in range(in_ch): + v += x[n, ich, ih, iw] * W[ich, och, kh, kw] + result[n, och, oh, ow] += v + + return [result] diff --git a/runtime/onnion_runtime/error.py b/runtime/onnion_runtime/error.py index 4bc274b..a639f49 100644 --- a/runtime/onnion_runtime/error.py +++ b/runtime/onnion_runtime/error.py @@ -1,4 +1,8 @@ +from typing import Optional + + class RunError(Exception): - def __init__(self, op, version): + def __init__(self, op: str, version: int, reason: Optional[str] = None): self.op = op self.version = version + self.reason = reason From 93e720b4463fff0939ca12a54908a4bef822d20e Mon Sep 17 00:00:00 2001 From: eguchi1904 Date: Tue, 18 Jul 2023 15:31:25 +0900 Subject: [PATCH 2/9] add test_convtranspose.py --- runtime/tests/test_convtranspose.py | 59 +++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 runtime/tests/test_convtranspose.py diff --git a/runtime/tests/test_convtranspose.py b/runtime/tests/test_convtranspose.py new file mode 100644 index 0000000..762d9ba --- /dev/null +++ b/runtime/tests/test_convtranspose.py @@ -0,0 +1,59 @@ +import numpy as np +from onnion_runtime import ConvTranspose + +from .utils import check + + +def test_convtranspose_00() -> None: + opset_version = 13 + attrs = dict() + x = np.array([[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]]).astype(np.float32) # (1, 1, 3, 3) + + W = np.array( + [ + [ + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], # (1, 2, 3, 3) + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ] + ).astype(np.float32) + + check(ConvTranspose, opset_version, attrs, [x, W]) + + +def test_convtranspose_01() -> None: + opset_version = 13 + attrs = {"strides": [3, 2], "output_padding": [1, 1]} + + x = np.array([[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]]).astype(np.float32) # (1, 1, 3, 3) + + W = np.array( + [ + [ + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], # (1, 2, 3, 3) + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ] + ).astype(np.float32) + + check(ConvTranspose, opset_version, attrs, [x, W]) + + +# test dillation +def test_convtranspose_02() -> None: + opset_version = 13 + attrs = {"dilations": [2, 2]} + x = np.random.randn(1, 1, 3, 3).astype(np.float32) + W = np.random.randn(1, 1, 2, 2).astype(np.float32) + + check(ConvTranspose, opset_version, attrs, [x, W]) + + +# test pads +def test_convtranspose_03() -> None: + opset_version = 13 + attrs = {"strides": [3, 2], "pads": [1, 2, 1, 2]} + x = np.random.randn(1, 1, 3, 3).astype(np.float32) + W = np.random.randn(1, 2, 3, 3).astype(np.float32) + + check(ConvTranspose, opset_version, attrs, [x, W]) From c0d38d822becd480b101d112f7b1132995c25e38 Mon Sep 17 00:00:00 2001 From: eguchi1904 Date: Tue, 18 Jul 2023 15:32:41 +0900 Subject: [PATCH 3/9] fix runtime error of `poetry run pytest` on arm64 env --- runtime/tests/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/runtime/tests/utils.py b/runtime/tests/utils.py index 906837a..399dc1d 100644 --- a/runtime/tests/utils.py +++ b/runtime/tests/utils.py @@ -10,7 +10,7 @@ try: import onnx import onnxruntime - from onnx import checker, helper, mapping, numpy_helper + from onnx import checker, helper, numpy_helper WITHOUT_ONNXRUNTIME = False except Exception: @@ -24,7 +24,7 @@ def on_arm32(): result = bool(int(os.environ["ONNION_TEST_ON_ARM32"])) except Exception: arch = platform.machine() - if arch == "x86_64": + if arch == "x86_64" or arch == "arm64": result = False elif arch == "armv7l": result = True @@ -81,7 +81,7 @@ def check_by_data(expected, result, max_error=1e-4): def _convert_type(dtype): assert not WITHOUT_ONNXRUNTIME - return mapping.NP_TYPE_TO_TENSOR_TYPE[dtype] + return helper.np_dtype_to_tensor_dtype(dtype) def _run_onnx(model, inputs, output_names): From ceac0dd88bcc634896eed57cceff5f5c04d1a87f Mon Sep 17 00:00:00 2001 From: eguchi1904 Date: Tue, 18 Jul 2023 16:45:58 +0900 Subject: [PATCH 4/9] convTranspose: support `output_shape` attribute --- runtime/onnion_runtime/convtranspose.py | 52 +++++++++++++++++-------- runtime/tests/test_convtranspose.py | 40 +++++++++++++++++++ 2 files changed, 76 insertions(+), 16 deletions(-) diff --git a/runtime/onnion_runtime/convtranspose.py b/runtime/onnion_runtime/convtranspose.py index e23daa9..8835284 100644 --- a/runtime/onnion_runtime/convtranspose.py +++ b/runtime/onnion_runtime/convtranspose.py @@ -28,11 +28,17 @@ def __init__(self, opset_version: int, **kwargs: Any): self.strides = kwargs.get("strides", None) def run(self, x: np.ndarray, W: np.ndarray, b: Optional[np.ndarray] = None) -> List[np.ndarray]: - # x: [batch, in_ch, in_h, in_w] - # W: [in_ch, out_ch/group, kernel_h, kernel_w] - # b: [out_ch] + """ + 2D Convolution Transpose + input shapes: + x: [batch, in_ch, in_h, in_w] + W: [in_ch, out_ch/group, kernel_h, kernel_w] + b: [out_ch] + output shape: + [batch, out_ch, out_h, out_w] + """ - # fix parameters + # define parameters dim = len(x.shape) - 2 group = self.group or 1 batch = x.shape[0] @@ -40,30 +46,44 @@ def run(self, x: np.ndarray, W: np.ndarray, b: Optional[np.ndarray] = None) -> L out_ch = W.shape[1] dilations = self.dilations or [1] * dim strides = self.strides or [1] * dim - pads = self.pads or [0] * (dim * 2) output_padding = self.output_padding or [0] * dim kernel_shape = self.kernel_shape or W.shape[2:] input_shape = x.shape[2:] + pads = self.pads or [0] * (dim * 2) - # check parameters if dim != 2: raise RunError("ConvTranspose", self.version, "support 2d only") if group != 1: raise RunError("ConvTranspose", self.version, "support group=1 only") - if self.output_shape is not None: - raise RunError("ConvTranspose", self.version, "do not support ouput_shape") + if self.auto_pad != "NOTSET": + raise RunError("ConvTranspose", self.version, "support auto_pad=NOTSET only") - output_shape = [ - strides[i] * (input_shape[i] - 1) - + output_padding[i] - + ((kernel_shape[i] - 1) * dilations[i] + 1) - - pads[i] - - pads[i + dim] - for i in range(dim) - ] + # calculate pads and output_shape + if self.output_shape is not None: + output_shape = self.output_shape + total_padding = [ + strides[i] * (input_shape[i] - 1) + + output_padding[i] + + ((kernel_shape[i] - 1) * dilations[i] + 1) + - output_shape[i] + for i in range(len(input_shape)) + ] + for i in range(len(input_shape)): + pads[i] = total_padding[i] - (total_padding[i] // 2) + pads[i + dim] = total_padding[i] // 2 + else: + output_shape = [ + strides[i] * (input_shape[i] - 1) + + output_padding[i] + + ((kernel_shape[i] - 1) * dilations[i] + 1) + - pads[i] + - pads[i + dim] + for i in range(dim) + ] + # calculate output result = np.zeros([batch, out_ch, *output_shape], dtype=x.dtype) for n in range(batch): diff --git a/runtime/tests/test_convtranspose.py b/runtime/tests/test_convtranspose.py index 762d9ba..31454ee 100644 --- a/runtime/tests/test_convtranspose.py +++ b/runtime/tests/test_convtranspose.py @@ -57,3 +57,43 @@ def test_convtranspose_03() -> None: W = np.random.randn(1, 2, 3, 3).astype(np.float32) check(ConvTranspose, opset_version, attrs, [x, W]) + + +# specify output shape +def test_convtranspose_04() -> None: + opset_version = 13 + attrs = {"strides": [3, 2], "output_shape": [10, 8]} + x = np.random.randn(1, 1, 3, 3).astype(np.float32) + W = np.random.randn(1, 2, 3, 3).astype(np.float32) + + check(ConvTranspose, opset_version, attrs, [x, W]) + + +# specify output shape and output padding +def test_convtranspose_05() -> None: + opset_version = 13 + attrs = {"strides": [3, 2], "output_shape": [10, 8], "kernel_shape": [3, 3], "output_padding": [1, 1]} + x = np.random.randn(1, 1, 3, 3).astype(np.float32) + W = np.random.randn(1, 2, 3, 3).astype(np.float32) + + check(ConvTranspose, opset_version, attrs, [x, W]) + + +# larger channel number +def test_convtranspose_06() -> None: + opset_version = 13 + attrs = {"strides": [2, 2], "kernel_shape": [2, 2], "pads": [0, 0, 0, 0]} + x = np.random.randn(2, 24, 12, 12).astype(np.float32) + W = np.random.randn(24, 24, 2, 2).astype(np.float32) + + check(ConvTranspose, opset_version, attrs, [x, W]) + + +# opset 1 +def test_convtranspose_07() -> None: + opset_version = 1 + attrs = {"strides": [3, 2], "output_shape": [10, 8], "kernel_shape": [3, 3], "output_padding": [1, 1]} + x = np.random.randn(1, 1, 3, 3).astype(np.float32) + W = np.random.randn(1, 2, 3, 3).astype(np.float32) + + check(ConvTranspose, opset_version, attrs, [x, W]) From 2437d7e57f631f64872029701ef28f4b507e0ec3 Mon Sep 17 00:00:00 2001 From: eguchi1904 Date: Tue, 18 Jul 2023 16:54:17 +0900 Subject: [PATCH 5/9] update readme --- runtime/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/runtime/README.md b/runtime/README.md index 8adedde..a3ca3da 100644 --- a/runtime/README.md +++ b/runtime/README.md @@ -61,6 +61,10 @@ This runtime supports only below operators. - ConcatFromSequence - Constant - ConstantOfShape +- ConvTranspose + - support 2d only + - `group` should be 1 + - `auto_pad` should be `NONE` or `"NOTSET"` - Cos - Cosh - DepthToSpace From ed4dae5f2dd27e9d17b507983a55f55bcae491e4 Mon Sep 17 00:00:00 2001 From: eguchi1904 Date: Wed, 19 Jul 2023 16:05:43 +0900 Subject: [PATCH 6/9] Change default value ConvTranspose.group to 1 --- runtime/onnion_runtime/convtranspose.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runtime/onnion_runtime/convtranspose.py b/runtime/onnion_runtime/convtranspose.py index 8835284..a6e826c 100644 --- a/runtime/onnion_runtime/convtranspose.py +++ b/runtime/onnion_runtime/convtranspose.py @@ -8,7 +8,7 @@ # https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose class ConvTranspose: auto_pad: str - group: Optional[int] + group: int dilations: Optional[List[int]] strides: Optional[List[int]] kernel_shape: Optional[List[int]] @@ -20,7 +20,7 @@ def __init__(self, opset_version: int, **kwargs: Any): self.version = opset_version self.auto_pad = kwargs.get("auto_pad", "NOTSET") self.dilations = kwargs.get("dilations", None) - self.group = kwargs.get("group", None) + self.group = kwargs.get("group", 1) self.kernel_shape = kwargs.get("kernel_shape", None) self.output_padding = kwargs.get("output_padding", None) self.output_shape = kwargs.get("output_shape", None) From 02399db15e8217a015e725783d7129e1d45ebae0 Mon Sep 17 00:00:00 2001 From: eguchi1904 Date: Wed, 19 Jul 2023 16:06:51 +0900 Subject: [PATCH 7/9] Update runtime/README.md Co-authored-by: ishiy1993 --- runtime/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/README.md b/runtime/README.md index a3ca3da..7e2e66a 100644 --- a/runtime/README.md +++ b/runtime/README.md @@ -64,7 +64,7 @@ This runtime supports only below operators. - ConvTranspose - support 2d only - `group` should be 1 - - `auto_pad` should be `NONE` or `"NOTSET"` + - `auto_pad` should be `"NOTSET"` (default value) - Cos - Cosh - DepthToSpace From c86d52f7645d03c73bbd5bff2856b4f6d203a648 Mon Sep 17 00:00:00 2001 From: eguchi1904 Date: Wed, 19 Jul 2023 16:15:53 +0900 Subject: [PATCH 8/9] add test case for ConvTranspose --- runtime/tests/test_convtranspose.py | 62 ++++++++++++++++++++++++----- 1 file changed, 53 insertions(+), 9 deletions(-) diff --git a/runtime/tests/test_convtranspose.py b/runtime/tests/test_convtranspose.py index 31454ee..ab0879a 100644 --- a/runtime/tests/test_convtranspose.py +++ b/runtime/tests/test_convtranspose.py @@ -18,7 +18,9 @@ def test_convtranspose_00() -> None: ] ).astype(np.float32) - check(ConvTranspose, opset_version, attrs, [x, W]) + inputs = [x, W] + + check(ConvTranspose, opset_version, attrs, inputs) def test_convtranspose_01() -> None: @@ -36,7 +38,9 @@ def test_convtranspose_01() -> None: ] ).astype(np.float32) - check(ConvTranspose, opset_version, attrs, [x, W]) + inputs = [x, W] + + check(ConvTranspose, opset_version, attrs, inputs) # test dillation @@ -46,7 +50,9 @@ def test_convtranspose_02() -> None: x = np.random.randn(1, 1, 3, 3).astype(np.float32) W = np.random.randn(1, 1, 2, 2).astype(np.float32) - check(ConvTranspose, opset_version, attrs, [x, W]) + inputs = [x, W] + + check(ConvTranspose, opset_version, attrs, inputs) # test pads @@ -55,8 +61,11 @@ def test_convtranspose_03() -> None: attrs = {"strides": [3, 2], "pads": [1, 2, 1, 2]} x = np.random.randn(1, 1, 3, 3).astype(np.float32) W = np.random.randn(1, 2, 3, 3).astype(np.float32) + b = np.random.randn(2).astype(np.float32) + + inputs = [x, W, b] - check(ConvTranspose, opset_version, attrs, [x, W]) + check(ConvTranspose, opset_version, attrs, inputs) # specify output shape @@ -66,7 +75,9 @@ def test_convtranspose_04() -> None: x = np.random.randn(1, 1, 3, 3).astype(np.float32) W = np.random.randn(1, 2, 3, 3).astype(np.float32) - check(ConvTranspose, opset_version, attrs, [x, W]) + inputs = [x, W] + + check(ConvTranspose, opset_version, attrs, inputs) # specify output shape and output padding @@ -75,8 +86,11 @@ def test_convtranspose_05() -> None: attrs = {"strides": [3, 2], "output_shape": [10, 8], "kernel_shape": [3, 3], "output_padding": [1, 1]} x = np.random.randn(1, 1, 3, 3).astype(np.float32) W = np.random.randn(1, 2, 3, 3).astype(np.float32) + b = np.random.randn(2).astype(np.float32) + + inputs = [x, W, b] - check(ConvTranspose, opset_version, attrs, [x, W]) + check(ConvTranspose, opset_version, attrs, inputs) # larger channel number @@ -86,14 +100,44 @@ def test_convtranspose_06() -> None: x = np.random.randn(2, 24, 12, 12).astype(np.float32) W = np.random.randn(24, 24, 2, 2).astype(np.float32) - check(ConvTranspose, opset_version, attrs, [x, W]) + inputs = [x, W] + check(ConvTranspose, opset_version, attrs, inputs) -# opset 1 + +# larger channel number (with bias) def test_convtranspose_07() -> None: + opset_version = 13 + attrs = {"strides": [2, 2], "kernel_shape": [2, 2], "pads": [0, 0, 0, 0]} + x = np.random.randn(2, 24, 12, 12).astype(np.float32) + W = np.random.randn(24, 24, 2, 2).astype(np.float32) + b = np.random.randn(24).astype(np.float32) + + inputs = [x, W, b] + + check(ConvTranspose, opset_version, attrs, inputs) + + +# opset 1 +def test_convtranspose_08() -> None: opset_version = 1 attrs = {"strides": [3, 2], "output_shape": [10, 8], "kernel_shape": [3, 3], "output_padding": [1, 1]} x = np.random.randn(1, 1, 3, 3).astype(np.float32) W = np.random.randn(1, 2, 3, 3).astype(np.float32) - check(ConvTranspose, opset_version, attrs, [x, W]) + inputs = [x, W] + + check(ConvTranspose, opset_version, attrs, inputs) + + +# opset 1 (larager channel number) +def test_convtranspose_09() -> None: + opset_version = 1 + attrs = {"strides": [2, 2], "kernel_shape": [2, 2], "pads": [0, 0, 0, 0]} + x = np.random.randn(2, 24, 12, 12).astype(np.float32) + W = np.random.randn(24, 24, 2, 2).astype(np.float32) + b = np.random.randn(24).astype(np.float32) + + inputs = [x, W, b] + + check(ConvTranspose, opset_version, attrs, inputs) From 6174110f92c4fcc3de86e9a3ffc6941a610d9c74 Mon Sep 17 00:00:00 2001 From: eguchi1904 Date: Wed, 19 Jul 2023 16:21:15 +0900 Subject: [PATCH 9/9] remove redundant `or ` --- runtime/onnion_runtime/convtranspose.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/onnion_runtime/convtranspose.py b/runtime/onnion_runtime/convtranspose.py index a6e826c..2b71c4e 100644 --- a/runtime/onnion_runtime/convtranspose.py +++ b/runtime/onnion_runtime/convtranspose.py @@ -40,7 +40,7 @@ def run(self, x: np.ndarray, W: np.ndarray, b: Optional[np.ndarray] = None) -> L # define parameters dim = len(x.shape) - 2 - group = self.group or 1 + group = self.group batch = x.shape[0] in_ch = x.shape[1] out_ch = W.shape[1]