Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ngraph: add methods for removing parameters from Function #3854

Merged
merged 13 commits into from
Feb 4, 2021
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "execution_graph_tests/remove_parameter.hpp"
#include "common_test_utils/test_constants.hpp"

using namespace ExecutionGraphTests;

namespace {

INSTANTIATE_TEST_CASE_P(smoke_removeParameter, ExecGraphRemoveParameterNode,
::testing::Values(CommonTestUtils::DEVICE_CPU),
ExecGraphRemoveParameterNode::getTestCaseName);

} // namespace
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "gtest/gtest.h"

namespace ExecutionGraphTests {

class ExecGraphRemoveParameterNode
: public testing::TestWithParam<std::string> {
public:
static std::string getTestCaseName(testing::TestParamInfo<std::string> obj);
};

} // namespace ExecutionGraphTests
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "execution_graph_tests/remove_parameter.hpp"
#include "functional_test_utils/skip_tests_config.hpp"

#include <inference_engine.hpp>
#include <ngraph/ngraph.hpp>

namespace ExecutionGraphTests {

std::string ExecGraphRemoveParameterNode::getTestCaseName(
testing::TestParamInfo<std::string> obj) {
std::string targetDevice = obj.param;
return "Dev=" + targetDevice;
}

/**
* Replacing parameter by another node change indexing for other parameters,
* check that we still can correctly process changed network.
*/
TEST_P(ExecGraphRemoveParameterNode, RemoveParameterNode) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()

auto device_name = this->GetParam();
ngraph::Shape shape = {3, 2};
float in_data_2[6] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
float in_data[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
ngraph::element::Type type = ngraph::element::f32;

using std::make_shared;
using namespace ngraph::op;

// Some simple graph with 2 Parameters
// in2 in1 //
// \ / | //
// mul | //
// \ | //
// sum //
// | //
// out //
auto input = make_shared<Parameter>(type, shape);
auto input2 = make_shared<Parameter>(type, shape);
auto mul = make_shared<ngraph::op::v1::Multiply>(input2, input);
auto sum = make_shared<ngraph::op::v1::Add>(mul, input);

auto function = std::make_shared<ngraph::Function>(
ngraph::NodeVector{sum}, ngraph::ParameterVector{input2, input},
"SimpleNet");

// Load into plugin and get exec graph
auto ie = InferenceEngine::Core();
auto net = InferenceEngine::CNNNetwork(function);
auto exec_net = ie.LoadNetwork(net, device_name);
auto exec_graph = exec_net.GetExecGraphInfo();
auto infer_req = exec_net.CreateInferRequest();
InferenceEngine::TensorDesc tDesc(InferenceEngine::Precision::FP32, shape,
InferenceEngine::Layout::NC);
InferenceEngine::Blob::Ptr inBlob2 =
InferenceEngine::make_shared_blob<float>(tDesc, in_data_2);
infer_req.SetBlob(input2->get_name(), inBlob2);

InferenceEngine::Blob::Ptr inBlob =
InferenceEngine::make_shared_blob<float>(tDesc, in_data);
infer_req.SetBlob(input->get_name(), inBlob);

infer_req.Infer();

auto outBlob = infer_req.GetBlob(sum->get_name());
InferenceEngine::MemoryBlob::CPtr output =
InferenceEngine::as<InferenceEngine::MemoryBlob>(outBlob);
auto outputHolder = output->rmap();
const auto ref_result = outputHolder.as<float *>();

ASSERT_EQ(function->get_parameter_index(input2), 0);
ASSERT_EQ(function->get_parameter_index(input), 1);

// Replace input2 by constant
auto const_in =
make_shared<Constant>(type, shape, std::vector<float>(6, 1.0));
mul->input(0).replace_source_output(const_in->output(0));
function->remove_parameter(input2);

ASSERT_EQ(function->get_parameters().size(), 1);
ASSERT_EQ(function->get_parameter_index(input), 0);

// Load new function into plugin and get exec graph
auto new_net = InferenceEngine::CNNNetwork(function);
auto new_exec_net = ie.LoadNetwork(new_net, device_name);
auto new_exec_graph = new_exec_net.GetExecGraphInfo();

// infer new graph
auto new_infer_req = new_exec_net.CreateInferRequest();
new_infer_req.SetBlob(input->get_name(), inBlob);

new_infer_req.Infer();

auto new_outBlob = new_infer_req.GetBlob(sum->get_name());
InferenceEngine::MemoryBlob::CPtr new_output =
InferenceEngine::as<InferenceEngine::MemoryBlob>(new_outBlob);
auto new_outputHolder = new_output->rmap();
const auto result = new_outputHolder.as<float *>();

for (int i = 0; i < 6; i++) {
ASSERT_NEAR(result[i], ref_result[i], 1e-5);
}
}

} // namespace ExecutionGraphTests
29 changes: 28 additions & 1 deletion ngraph/core/include/ngraph/function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,33 @@ namespace ngraph
/// \param result Result node to delete
void remove_result(const std::shared_ptr<op::Result>& result);

/// \brief Add new Parameter nodes to the list.
///
/// Method doesn't change or validate graph, it should be done manually.
/// For example, if you want to replace `ReadValue` node by `Parameter`, you should do the
/// following steps:
/// * replace node `ReadValue` by `Parameter` in graph
/// * call add_parameter() to add new input to the list
/// * call graph validation to check correctness of changes
///
/// \param params new Parameter nodes
void add_parameters(const ParameterVector& params);
ilyachur marked this conversation as resolved.
Show resolved Hide resolved

/// \brief Delete Parameter node from the list of parameters. Method will not delete node
/// from graph. You need to replace Parameter with other operation manually.
/// Attention: Indexing of parameters can be changed.
///
/// Possible use of method is to replace input by variable. For it the following steps
/// should be done:
/// * `Parameter` node should be replaced by `ReadValue`
/// * call remove_parameter(param) to remove input from the list
/// * check if any parameter indexes are saved/used somewhere, update it for all inputs
/// because indexes can be changed
/// * call graph validation to check all changes
///
/// \param param Parameter node to delete
void remove_parameter(const std::shared_ptr<op::Parameter>& param);

private:
Function(const Function&) = delete;
Function(const Function&&) = delete;
Expand Down Expand Up @@ -203,4 +230,4 @@ namespace ngraph
0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
}
} // namespace ngraph
25 changes: 25 additions & 0 deletions ngraph/core/src/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,4 +405,29 @@ void Function::remove_result(const std::shared_ptr<op::Result>& result)
m_results.end());
}

void Function::add_parameters(const ParameterVector& params)
{
for (int i = 0; i < params.size(); i++)
{
for (int j = 0; j < m_parameters.size(); j++)
{
NGRAPH_CHECK(params[i] != m_parameters[j],
"add_parameters(): Tried to add parameter (index in array ",
i,
") but function already have the same parameter with index ",
j);
}
}
m_parameters.insert(m_parameters.end(), params.begin(), params.end());
sadolini marked this conversation as resolved.
Show resolved Hide resolved
}

void Function::remove_parameter(const std::shared_ptr<op::Parameter>& param)
sadolini marked this conversation as resolved.
Show resolved Hide resolved
{
m_parameters.erase(
std::remove_if(m_parameters.begin(),
m_parameters.end(),
[&param](std::shared_ptr<op::v0::Parameter>& r) { return r == param; }),
GlebKazantaev marked this conversation as resolved.
Show resolved Hide resolved
m_parameters.end());
}

constexpr DiscreteTypeInfo AttributeAdapter<shared_ptr<Function>>::type_info;
88 changes: 88 additions & 0 deletions ngraph/test/build_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,91 @@ TEST(build_graph, build_graph_with_remove_result)
nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 5);
}

TEST(build_graph, build_graph_with_add_parameter)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto arg2 = make_shared<op::Parameter>(element::f32, Shape{2, 2});
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
auto read = make_shared<op::ReadValue>(init_const, "v0");
std::vector<shared_ptr<Node>> args = {arg, read};
auto pattern = make_shared<op::Concat>(args, 1);
auto res = make_shared<op::Result>(pattern);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
auto crop = make_shared<op::v1::Split>(pattern, axis, 3);
auto res2 = make_shared<op::Result>(crop, "v0");

auto f = make_shared<Function>(ResultVector({res, res2}), ParameterVector{arg});

NodeVector nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 8);
ParameterVector params = f->get_parameters();
EXPECT_EQ(params.size(), 1);

pattern->input(1).replace_source_output(arg2->output(0));

f->add_parameters(ParameterVector({arg2}));
params = f->get_parameters();
sadolini marked this conversation as resolved.
Show resolved Hide resolved
EXPECT_EQ(params.size(), 2);
EXPECT_EQ(params[1], arg2);
nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 7);
}

TEST(build_graph, build_graph_with_remove_parameter)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto arg2 = make_shared<op::Parameter>(element::f32, Shape{2, 2});
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
auto read = make_shared<op::ReadValue>(init_const, "v0");
std::vector<shared_ptr<Node>> args = {arg, arg2};
auto pattern = make_shared<op::Concat>(args, 1);
auto res = make_shared<op::Result>(pattern);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
auto crop = make_shared<op::v1::Split>(pattern, axis, 3);
auto res2 = make_shared<op::Result>(crop, "v0");

auto f = make_shared<Function>(ResultVector({res, res2}), ParameterVector{arg, arg2});

NodeVector nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 7);
ParameterVector params = f->get_parameters();
EXPECT_EQ(params.size(), 2);

pattern->input(1).replace_source_output(read->output(0));
f->remove_parameter(arg2);
params = f->get_parameters();
EXPECT_EQ(params.size(), 1);
nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 8);
}

TEST(build_graph, build_graph_with_remove_parameter_indexing)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto arg2 = make_shared<op::Parameter>(element::f32, Shape{2, 2});
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
auto read = make_shared<op::ReadValue>(init_const, "v0");
std::vector<shared_ptr<Node>> args = {arg2, arg};
auto pattern = make_shared<op::Concat>(args, 1);
auto res = make_shared<op::Result>(pattern);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
auto crop = make_shared<op::v1::Split>(pattern, axis, 3);
auto res2 = make_shared<op::Result>(crop, "v0");

auto f = make_shared<Function>(ResultVector({res, res2}), ParameterVector{arg2, arg});

NodeVector nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 7);
ParameterVector params = f->get_parameters();
EXPECT_EQ(params.size(), 2);

pattern->input(0).replace_source_output(read->output(0));
f->remove_parameter(arg2);
params = f->get_parameters();
EXPECT_EQ(params.size(), 1);
nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 8);

f->validate_nodes_and_infer_types();
}