From 410b65ed613ad6016bf4f213f9da8a2fe4e40d43 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Sat, 1 May 2021 20:52:13 +0300 Subject: [PATCH 1/3] Extend nGraph Python API and test IE IR reader for Einsum Signed-off-by: Roman Kazantsev --- .../ngraph_reader/einsum_tests.cpp | 272 ++++++++++++++++++ ngraph/python/src/ngraph/__init__.py | 1 + ngraph/python/src/ngraph/opset7/__init__.py | 1 + ngraph/python/src/ngraph/opset7/ops.py | 22 +- ngraph/python/tests/__init__.py | 2 + .../python/tests/test_ngraph/test_einsum.py | 100 +++++++ 6 files changed, 396 insertions(+), 2 deletions(-) create mode 100644 inference-engine/tests/functional/inference_engine/ngraph_reader/einsum_tests.cpp create mode 100644 ngraph/python/tests/test_ngraph/test_einsum.py diff --git a/inference-engine/tests/functional/inference_engine/ngraph_reader/einsum_tests.cpp b/inference-engine/tests/functional/inference_engine/ngraph_reader/einsum_tests.cpp new file mode 100644 index 00000000000000..753aff586cbf78 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/ngraph_reader/einsum_tests.cpp @@ -0,0 +1,272 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include "ngraph_reader_tests.hpp" +#include "common_test_utils/xml_net_builder/ir_net.hpp" + +TEST_F(NGraphReaderTests, ReadEinsumNetwork) { + std::string model = R"V0G0N( + + + + + + + 2 + 3 + 4 + + + + + + + + 5 + 3 + 4 + + + + + + + + 2 + 3 + 4 + + + 5 + 3 + 4 + + + + + 2 + 5 + + + + + + + 2 + 5 + + + + + + + + + + +)V0G0N"; + std::string modelV7 = R"V0G0N( + + + + + + 2 + 3 + 4 + + + + + + + 5 + 3 + 4 + + + + + + + + 2 + 3 + 4 + + + 5 + 3 + 4 + + + + + 2 + 5 + + + + + + + + + +)V0G0N"; + compareIRs(model, modelV7); +} + +TEST_F(NGraphReaderTests, ReadEinsumNetwork2) { + std::string model = R"V0G0N( + + + + + + + 2 + 3 + 4 + 5 + + + + + + + + 4 + 5 + 6 + + + + + + + + 7 + 4 + 5 + + + + + + + + 2 + 3 + 4 + 5 + + + 4 + 5 + 6 + + + 7 + 4 + 5 + + + + + 2 + 3 + 6 + + + + + + + 2 + 3 + 6 + + + + + + + + + + + +)V0G0N"; + std::string modelV7 = R"V0G0N( + + + + + + 2 + 3 + 4 + 5 + + + + + + + 4 + 5 + 6 + + + + + + + 7 + 4 + 5 + + + + + + + + 2 + 3 + 4 + 5 + + + 4 + 5 + 6 + + + 7 + 4 + 5 + + + + + 2 + 3 + 6 + + + + + + + + + + +)V0G0N"; + compareIRs(model, modelV7); +} + diff --git a/ngraph/python/src/ngraph/__init__.py b/ngraph/python/src/ngraph/__init__.py index 441392d4f0a7ad..c66e3ee81e08c1 100644 --- a/ngraph/python/src/ngraph/__init__.py +++ b/ngraph/python/src/ngraph/__init__.py @@ -53,6 +53,7 @@ from ngraph.opset7 import depth_to_space from ngraph.opset7 import detection_output from ngraph.opset7 import divide +from ngraph.opset7 import einsum from ngraph.opset7 import elu from ngraph.opset7 import embedding_bag_offsets_sum from ngraph.opset7 import embedding_bag_packed_sum diff --git a/ngraph/python/src/ngraph/opset7/__init__.py b/ngraph/python/src/ngraph/opset7/__init__.py index a7d12fb6f025a2..c1ded5f9ad424c 100644 --- a/ngraph/python/src/ngraph/opset7/__init__.py +++ b/ngraph/python/src/ngraph/opset7/__init__.py @@ -38,6 +38,7 @@ from ngraph.opset1.ops import depth_to_space from ngraph.opset1.ops import detection_output from ngraph.opset1.ops import divide +from ngraph.opset7.ops import einsum from ngraph.opset1.ops import elu from ngraph.opset3.ops import embedding_bag_offsets_sum from ngraph.opset3.ops import embedding_bag_packed_sum diff --git a/ngraph/python/src/ngraph/opset7/ops.py b/ngraph/python/src/ngraph/opset7/ops.py index dee2c5d319275f..419ac419fa68e2 100644 --- a/ngraph/python/src/ngraph/opset7/ops.py +++ b/ngraph/python/src/ngraph/opset7/ops.py @@ -2,11 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 """Factory functions for all ngraph ops.""" +from functools import partial from typing import Callable, Iterable, List, Optional, Set, Union import numpy as np -from functools import partial - from ngraph.impl import Node, Shape from ngraph.impl.op import Constant, Parameter from ngraph.opset_utils import _get_node_factory @@ -42,9 +41,28 @@ _get_node_factory_opset7 = partial(_get_node_factory, "opset7") + # -------------------------------------------- ops ------------------------------------------------ +@nameable_op +def einsum( + inputs: List[Node], + equation: str +) -> Node: + """Return a node which performs Einsum operation. + + @param inputs: The list of input nodes + @param equation: Einsum equation + @return: The new node performing Einsum operation on the inputs + """ + attributes = { + "equation": equation + } + + return _get_node_factory_opset7().create("Einsum", as_nodes(*inputs), attributes) + + @nameable_op def gelu( data: Node, diff --git a/ngraph/python/tests/__init__.py b/ngraph/python/tests/__init__.py index 1a5925e56bf1c0..e579d02dfa094b 100644 --- a/ngraph/python/tests/__init__.py +++ b/ngraph/python/tests/__init__.py @@ -159,3 +159,5 @@ def xfail_test(reason="Mark the test as expected to fail", strict=True): "Not equal to tolerance") xfail_issue_49391 = xfail_test(reason="Roll is not implemented in CPU plugin.") + +xfail_issue_45432 = xfail_test(reason="Einsum is not implemented in CPU plugin.") diff --git a/ngraph/python/tests/test_ngraph/test_einsum.py b/ngraph/python/tests/test_ngraph/test_einsum.py new file mode 100644 index 00000000000000..3b33b273f87956 --- /dev/null +++ b/ngraph/python/tests/test_ngraph/test_einsum.py @@ -0,0 +1,100 @@ +import ngraph as ng +import numpy as np +import pytest + +from ngraph.utils.types import get_element_type +from tests import xfail_issue_45432 +from tests.runtime import get_runtime + + +def einsum_op_exec(input_shapes: list, equation: str, data_type: np.dtype, + with_value=False, seed=202104): + """ + Test Einsum operation for given input shapes, equation, and data type. + It generates input data of given shapes and type, receives reference results using numpy, + and tests IE implementation by matching with reference numpy results. + + :param input_shapes: a list of tuples with shapes + :param equation: Einsum equation + :param data_type: a type of input data + :param with_value: if True - tests output data shape and type along with its value, + otherwise, tests only the output shape and type + :param seed: a seed for random generation of input data + """ + np.random.seed(seed) + num_inputs = len(input_shapes) + runtime = get_runtime() + + # set absolute tolerance based on the data type + atol = 0.0 if np.issubdtype(data_type, np.integer) else 1e-04 + + # generate input tensors + ng_inputs = [] + np_inputs = [] + for i in range(num_inputs): + input_i = np.random.random_integers(10, size=input_shapes[i]).astype(data_type) + print("input_i = ", input_i) + np_inputs.append(input_i) + ng_inputs.append(ng.parameter(input_i.shape, dtype=data_type)) + + expected_result = np.einsum(equation, *np_inputs) + einsum_model = ng.einsum(ng_inputs, equation) + + # check the output shape and type + assert einsum_model.get_type_name() == "Einsum" + assert einsum_model.get_output_size() == 1 + assert list(einsum_model.get_output_shape(0)) == list(expected_result.shape) + assert einsum_model.get_output_element_type(0) == get_element_type(data_type) + + # check inference result + if with_value: + computation = runtime.computation(einsum_model, *ng_inputs) + actual_result = computation(*np_inputs) + np.allclose(actual_result, expected_result, atol=atol) + + +@pytest.mark.parametrize("data_type", [np.float32, np.int32]) +def test_dot_product(data_type): + einsum_op_exec([5, 5], "i,i->", data_type) + + +@pytest.mark.parametrize("data_type", [np.float32, np.int32]) +def test_matrix_multiplication(data_type): + einsum_op_exec([(2, 3), (3, 4)], "ab,bc->ac", data_type) + + +@pytest.mark.parametrize("data_type", [np.float32, np.int32]) +def test_batch_trace(data_type): + einsum_op_exec([(2, 3, 3)], "kii->k", data_type) + + +@pytest.mark.parametrize("data_type", [np.float32, np.int32]) +def test_diagonal_extraction(data_type): + einsum_op_exec([(6, 5, 5)], "kii->ki", data_type) + + +@pytest.mark.parametrize("data_type", [np.float32, np.int32]) +def test_transpose(data_type): + einsum_op_exec([(1, 2, 3)], "ijk->kij", data_type) + + +@pytest.mark.parametrize("data_type", [np.float32, np.int32]) +def test_multiple_multiplication(data_type): + einsum_op_exec([(2, 5), (5, 3, 6), (5, 3)], "ab,bcd,bc->ca", data_type) + + +@pytest.mark.parametrize("data_type", [np.float32, np.int32]) +def test_simple_ellipsis(data_type): + einsum_op_exec([(5, 3, 4)], "a...->...", data_type) + + +@xfail_issue_45432 +@pytest.mark.parametrize("data_type", [np.float32, np.int32]) +def test_multiple_ellipsis(data_type): + einsum_op_exec([(3, 5), 1], "a...,...->a...", data_type, with_value=True) + + +@xfail_issue_45432 +@pytest.mark.parametrize("data_type", [np.float32, np.int32]) +def test_broadcasting_ellipsis(data_type): + einsum_op_exec([(9, 1, 4, 3), (3, 11, 7, 1)], "a...b,b...->a...", data_type, with_value=True) From e43fdd2c621816987ad0ac141beca234151942b1 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Tue, 4 May 2021 07:18:59 +0300 Subject: [PATCH 2/3] Format description for test auxiliary function Signed-off-by: Roman Kazantsev --- ngraph/python/tests/test_ngraph/test_einsum.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ngraph/python/tests/test_ngraph/test_einsum.py b/ngraph/python/tests/test_ngraph/test_einsum.py index 3b33b273f87956..e38130bab23b7d 100644 --- a/ngraph/python/tests/test_ngraph/test_einsum.py +++ b/ngraph/python/tests/test_ngraph/test_einsum.py @@ -9,11 +9,10 @@ def einsum_op_exec(input_shapes: list, equation: str, data_type: np.dtype, with_value=False, seed=202104): - """ - Test Einsum operation for given input shapes, equation, and data type. + """Test Einsum operation for given input shapes, equation, and data type. + It generates input data of given shapes and type, receives reference results using numpy, and tests IE implementation by matching with reference numpy results. - :param input_shapes: a list of tuples with shapes :param equation: Einsum equation :param data_type: a type of input data From 63dbcc1d70f6a29e01ec7011ea12a23a98d9d825 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Tue, 4 May 2021 10:50:52 +0300 Subject: [PATCH 3/3] Remove print from the python test Signed-off-by: Roman Kazantsev --- ngraph/python/tests/test_ngraph/test_einsum.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ngraph/python/tests/test_ngraph/test_einsum.py b/ngraph/python/tests/test_ngraph/test_einsum.py index e38130bab23b7d..a89b6c3ff35d89 100644 --- a/ngraph/python/tests/test_ngraph/test_einsum.py +++ b/ngraph/python/tests/test_ngraph/test_einsum.py @@ -32,7 +32,6 @@ def einsum_op_exec(input_shapes: list, equation: str, data_type: np.dtype, np_inputs = [] for i in range(num_inputs): input_i = np.random.random_integers(10, size=input_shapes[i]).astype(data_type) - print("input_i = ", input_i) np_inputs.append(input_i) ng_inputs.append(ng.parameter(input_i.shape, dtype=data_type))