diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/execution_graph_tests/remove_parameter.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/execution_graph_tests/remove_parameter.cpp new file mode 100644 index 00000000000000..c11d876ee38e01 --- /dev/null +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/execution_graph_tests/remove_parameter.cpp @@ -0,0 +1,16 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "execution_graph_tests/remove_parameter.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace ExecutionGraphTests; + +namespace { + +INSTANTIATE_TEST_CASE_P(smoke_removeParameter, ExecGraphRemoveParameterNode, + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ExecGraphRemoveParameterNode::getTestCaseName); + +} // namespace diff --git a/inference-engine/tests/functional/plugin/shared/include/execution_graph_tests/remove_parameter.hpp b/inference-engine/tests/functional/plugin/shared/include/execution_graph_tests/remove_parameter.hpp new file mode 100644 index 00000000000000..e2b19c5ef07611 --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/include/execution_graph_tests/remove_parameter.hpp @@ -0,0 +1,15 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "gtest/gtest.h" + +namespace ExecutionGraphTests { + +class ExecGraphRemoveParameterNode + : public testing::TestWithParam { +public: + static std::string getTestCaseName(testing::TestParamInfo obj); +}; + +} // namespace ExecutionGraphTests diff --git a/inference-engine/tests/functional/plugin/shared/src/execution_graph_tests/remove_parameter.cpp b/inference-engine/tests/functional/plugin/shared/src/execution_graph_tests/remove_parameter.cpp new file mode 100644 index 00000000000000..3f23df27a1833e --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/src/execution_graph_tests/remove_parameter.cpp @@ -0,0 +1,110 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "execution_graph_tests/remove_parameter.hpp" +#include "functional_test_utils/skip_tests_config.hpp" + +#include +#include + +namespace ExecutionGraphTests { + +std::string ExecGraphRemoveParameterNode::getTestCaseName( + testing::TestParamInfo obj) { + std::string targetDevice = obj.param; + return "Dev=" + targetDevice; +} + +/** + * Replacing parameter by another node change indexing for other parameters, + * check that we still can correctly process changed network. + */ +TEST_P(ExecGraphRemoveParameterNode, RemoveParameterNode) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + + auto device_name = this->GetParam(); + ngraph::Shape shape = {3, 2}; + float in_data_2[6] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + float in_data[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + ngraph::element::Type type = ngraph::element::f32; + + using std::make_shared; + using namespace ngraph::op; + + // Some simple graph with 2 Parameters + // in2 in1 // + // \ / | // + // mul | // + // \ | // + // sum // + // | // + // out // + auto input = make_shared(type, shape); + auto input2 = make_shared(type, shape); + auto mul = make_shared(input2, input); + auto sum = make_shared(mul, input); + + auto function = std::make_shared( + ngraph::NodeVector{sum}, ngraph::ParameterVector{input2, input}, + "SimpleNet"); + + // Load into plugin and get exec graph + auto ie = InferenceEngine::Core(); + auto net = InferenceEngine::CNNNetwork(function); + auto exec_net = ie.LoadNetwork(net, device_name); + auto exec_graph = exec_net.GetExecGraphInfo(); + auto infer_req = exec_net.CreateInferRequest(); + InferenceEngine::TensorDesc tDesc(InferenceEngine::Precision::FP32, shape, + InferenceEngine::Layout::NC); + InferenceEngine::Blob::Ptr inBlob2 = + InferenceEngine::make_shared_blob(tDesc, in_data_2); + infer_req.SetBlob(input2->get_name(), inBlob2); + + InferenceEngine::Blob::Ptr inBlob = + InferenceEngine::make_shared_blob(tDesc, in_data); + infer_req.SetBlob(input->get_name(), inBlob); + + infer_req.Infer(); + + auto outBlob = infer_req.GetBlob(sum->get_name()); + InferenceEngine::MemoryBlob::CPtr output = + InferenceEngine::as(outBlob); + auto outputHolder = output->rmap(); + const auto ref_result = outputHolder.as(); + + ASSERT_EQ(function->get_parameter_index(input2), 0); + ASSERT_EQ(function->get_parameter_index(input), 1); + + // Replace input2 by constant + auto const_in = + make_shared(type, shape, std::vector(6, 1.0)); + mul->input(0).replace_source_output(const_in->output(0)); + function->remove_parameter(input2); + + ASSERT_EQ(function->get_parameters().size(), 1); + ASSERT_EQ(function->get_parameter_index(input), 0); + + // Load new function into plugin and get exec graph + auto new_net = InferenceEngine::CNNNetwork(function); + auto new_exec_net = ie.LoadNetwork(new_net, device_name); + auto new_exec_graph = new_exec_net.GetExecGraphInfo(); + + // infer new graph + auto new_infer_req = new_exec_net.CreateInferRequest(); + new_infer_req.SetBlob(input->get_name(), inBlob); + + new_infer_req.Infer(); + + auto new_outBlob = new_infer_req.GetBlob(sum->get_name()); + InferenceEngine::MemoryBlob::CPtr new_output = + InferenceEngine::as(new_outBlob); + auto new_outputHolder = new_output->rmap(); + const auto result = new_outputHolder.as(); + + for (int i = 0; i < 6; i++) { + ASSERT_NEAR(result[i], ref_result[i], 1e-5); + } +} + +} // namespace ExecutionGraphTests diff --git a/ngraph/core/include/ngraph/function.hpp b/ngraph/core/include/ngraph/function.hpp index 922cfcd7966162..937cafc3913ad8 100644 --- a/ngraph/core/include/ngraph/function.hpp +++ b/ngraph/core/include/ngraph/function.hpp @@ -170,6 +170,33 @@ namespace ngraph /// \param result Result node to delete void remove_result(const std::shared_ptr& result); + /// \brief Add new Parameter nodes to the list. + /// + /// Method doesn't change or validate graph, it should be done manually. + /// For example, if you want to replace `ReadValue` node by `Parameter`, you should do the + /// following steps: + /// * replace node `ReadValue` by `Parameter` in graph + /// * call add_parameter() to add new input to the list + /// * call graph validation to check correctness of changes + /// + /// \param params new Parameter nodes + void add_parameters(const ParameterVector& params); + + /// \brief Delete Parameter node from the list of parameters. Method will not delete node + /// from graph. You need to replace Parameter with other operation manually. + /// Attention: Indexing of parameters can be changed. + /// + /// Possible use of method is to replace input by variable. For it the following steps + /// should be done: + /// * `Parameter` node should be replaced by `ReadValue` + /// * call remove_parameter(param) to remove input from the list + /// * check if any parameter indexes are saved/used somewhere, update it for all inputs + /// because indexes can be changed + /// * call graph validation to check all changes + /// + /// \param param Parameter node to delete + void remove_parameter(const std::shared_ptr& param); + private: Function(const Function&) = delete; Function(const Function&&) = delete; @@ -203,4 +230,4 @@ namespace ngraph 0}; const DiscreteTypeInfo& get_type_info() const override { return type_info; } }; -} +} // namespace ngraph diff --git a/ngraph/core/src/function.cpp b/ngraph/core/src/function.cpp index ea696fa30499b2..19595e3ecb809c 100644 --- a/ngraph/core/src/function.cpp +++ b/ngraph/core/src/function.cpp @@ -405,4 +405,29 @@ void Function::remove_result(const std::shared_ptr& result) m_results.end()); } +void Function::add_parameters(const ParameterVector& params) +{ + for (int i = 0; i < params.size(); i++) + { + for (int j = 0; j < m_parameters.size(); j++) + { + NGRAPH_CHECK(params[i] != m_parameters[j], + "add_parameters(): Tried to add parameter (index in array ", + i, + ") but function already have the same parameter with index ", + j); + } + } + m_parameters.insert(m_parameters.end(), params.begin(), params.end()); +} + +void Function::remove_parameter(const std::shared_ptr& param) +{ + m_parameters.erase( + std::remove_if(m_parameters.begin(), + m_parameters.end(), + [¶m](std::shared_ptr& r) { return r == param; }), + m_parameters.end()); +} + constexpr DiscreteTypeInfo AttributeAdapter>::type_info; diff --git a/ngraph/test/build_graph.cpp b/ngraph/test/build_graph.cpp index 1279735c806d63..a33230d4b2c14e 100644 --- a/ngraph/test/build_graph.cpp +++ b/ngraph/test/build_graph.cpp @@ -363,3 +363,91 @@ TEST(build_graph, build_graph_with_remove_result) nodes = f->get_ops(); EXPECT_EQ(nodes.size(), 5); } + +TEST(build_graph, build_graph_with_add_parameter) +{ + auto arg = make_shared(element::f32, Shape{2, 4}); + auto arg2 = make_shared(element::f32, Shape{2, 2}); + auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0}); + auto read = make_shared(init_const, "v0"); + std::vector> args = {arg, read}; + auto pattern = make_shared(args, 1); + auto res = make_shared(pattern); + const auto axis = op::Constant::create(element::i64, Shape{}, {1}); + auto crop = make_shared(pattern, axis, 3); + auto res2 = make_shared(crop, "v0"); + + auto f = make_shared(ResultVector({res, res2}), ParameterVector{arg}); + + NodeVector nodes = f->get_ops(); + EXPECT_EQ(nodes.size(), 8); + ParameterVector params = f->get_parameters(); + EXPECT_EQ(params.size(), 1); + + pattern->input(1).replace_source_output(arg2->output(0)); + + f->add_parameters(ParameterVector({arg2})); + params = f->get_parameters(); + EXPECT_EQ(params.size(), 2); + EXPECT_EQ(params[1], arg2); + nodes = f->get_ops(); + EXPECT_EQ(nodes.size(), 7); +} + +TEST(build_graph, build_graph_with_remove_parameter) +{ + auto arg = make_shared(element::f32, Shape{2, 4}); + auto arg2 = make_shared(element::f32, Shape{2, 2}); + auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0}); + auto read = make_shared(init_const, "v0"); + std::vector> args = {arg, arg2}; + auto pattern = make_shared(args, 1); + auto res = make_shared(pattern); + const auto axis = op::Constant::create(element::i64, Shape{}, {1}); + auto crop = make_shared(pattern, axis, 3); + auto res2 = make_shared(crop, "v0"); + + auto f = make_shared(ResultVector({res, res2}), ParameterVector{arg, arg2}); + + NodeVector nodes = f->get_ops(); + EXPECT_EQ(nodes.size(), 7); + ParameterVector params = f->get_parameters(); + EXPECT_EQ(params.size(), 2); + + pattern->input(1).replace_source_output(read->output(0)); + f->remove_parameter(arg2); + params = f->get_parameters(); + EXPECT_EQ(params.size(), 1); + nodes = f->get_ops(); + EXPECT_EQ(nodes.size(), 8); +} + +TEST(build_graph, build_graph_with_remove_parameter_indexing) +{ + auto arg = make_shared(element::f32, Shape{2, 4}); + auto arg2 = make_shared(element::f32, Shape{2, 2}); + auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0}); + auto read = make_shared(init_const, "v0"); + std::vector> args = {arg2, arg}; + auto pattern = make_shared(args, 1); + auto res = make_shared(pattern); + const auto axis = op::Constant::create(element::i64, Shape{}, {1}); + auto crop = make_shared(pattern, axis, 3); + auto res2 = make_shared(crop, "v0"); + + auto f = make_shared(ResultVector({res, res2}), ParameterVector{arg2, arg}); + + NodeVector nodes = f->get_ops(); + EXPECT_EQ(nodes.size(), 7); + ParameterVector params = f->get_parameters(); + EXPECT_EQ(params.size(), 2); + + pattern->input(0).replace_source_output(read->output(0)); + f->remove_parameter(arg2); + params = f->get_parameters(); + EXPECT_EQ(params.size(), 1); + nodes = f->get_ops(); + EXPECT_EQ(nodes.size(), 8); + + f->validate_nodes_and_infer_types(); +} \ No newline at end of file