Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ngraph: add methods for removing parameters from Function #3854

Merged
merged 13 commits into from
Feb 4, 2021
27 changes: 25 additions & 2 deletions ngraph/core/include/ngraph/function.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -170,6 +170,29 @@ namespace ngraph
/// \param result Result node to delete
void remove_result(const std::shared_ptr<op::Result>& result);

/// \brief Add new Parameter nodes to the list.
///
/// Method doesn't change or validate graph, it should be done manually.
/// For example, if you want to replace `ReadValue` node by `Parameter`, you should do the following steps:
/// * replace node `ReadValue` by `Parameter` in graph
/// * call add_parameter() to add new input to the list
/// * call graph validation to check correctness of changes
///
/// \param params new Parameter nodes
void add_parameters(const ParameterVector& params);
ilyachur marked this conversation as resolved.
Show resolved Hide resolved

/// \brief Delete Parameter node from the list of parameters. Method will not delete node
/// from graph. Attention: Indexing of parameters can be changed.
sadolini marked this conversation as resolved.
Show resolved Hide resolved
///
/// Possible use of method is to replace input by variable. For it the following steps should be done:
/// * `Parameter` node should be replaced by `ReadValue`
/// * call remove_parameter(param) to remove input from the list
/// * check if any parameter indexes are saved/used somewhere, update it for all inputs because indexes can be changed
/// * call graph validation to check all changes
///
/// \param param Parameter node to delete
void remove_parameter(const std::shared_ptr<op::Parameter>& param);

private:
Function(const Function&) = delete;
Function(const Function&&) = delete;
Expand Down Expand Up @@ -203,4 +226,4 @@ namespace ngraph
0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
}
} // namespace ngraph
27 changes: 26 additions & 1 deletion ngraph/core/src/function.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -405,4 +405,29 @@ void Function::remove_result(const std::shared_ptr<op::Result>& result)
m_results.end());
}

void Function::add_parameters(const ParameterVector& params)
{
for (int i = 0; i < params.size(); i++)
{
for (int j = 0; j < m_parameters.size(); j++)
{
NGRAPH_CHECK(params[i] != m_parameters[j],
"add_parameters(): Tried to add parameter (index in array ",
i,
") but function already have the same parameter with index ",
j);
}
}
m_parameters.insert(m_parameters.end(), params.begin(), params.end());
sadolini marked this conversation as resolved.
Show resolved Hide resolved
}

void Function::remove_parameter(const std::shared_ptr<op::Parameter>& param)
sadolini marked this conversation as resolved.
Show resolved Hide resolved
{
m_parameters.erase(
std::remove_if(m_parameters.begin(),
m_parameters.end(),
[&param](std::shared_ptr<op::v0::Parameter>& r) { return r == param; }),
GlebKazantaev marked this conversation as resolved.
Show resolved Hide resolved
m_parameters.end());
}

constexpr DiscreteTypeInfo AttributeAdapter<shared_ptr<Function>>::type_info;
55 changes: 55 additions & 0 deletions ngraph/test/build_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,58 @@ TEST(build_graph, build_graph_with_remove_result)
nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 5);
}

TEST(build_graph, build_graph_with_add_parameter)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto arg2 = make_shared<op::Parameter>(element::f32, Shape{2, 8});
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
auto read = make_shared<op::ReadValue>(init_const, "v0");
std::vector<shared_ptr<Node>> args = {arg, read};
auto pattern = make_shared<op::Concat>(args, 1);
auto res = make_shared<op::Result>(pattern);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
auto crop = make_shared<op::v1::Split>(pattern, axis, 3);
auto res2 = make_shared<op::Result>(crop, "v0");

auto f = make_shared<Function>(ResultVector({res, res2}), ParameterVector{arg});

NodeVector nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 8);
ParameterVector params = f->get_parameters();
EXPECT_EQ(params.size(), 1);

f->add_parameters(ParameterVector({arg2}));
params = f->get_parameters();
sadolini marked this conversation as resolved.
Show resolved Hide resolved
EXPECT_EQ(params.size(), 2);
EXPECT_EQ(params[1], arg2);
nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 9);
}

TEST(build_graph, build_graph_with_remove_parameter)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto arg2 = make_shared<op::Parameter>(element::f32, Shape{2, 8});
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
auto read = make_shared<op::ReadValue>(init_const, "v0");
std::vector<shared_ptr<Node>> args = {arg, read};
auto pattern = make_shared<op::Concat>(args, 1);
auto res = make_shared<op::Result>(pattern);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
auto crop = make_shared<op::v1::Split>(pattern, axis, 3);
auto res2 = make_shared<op::Result>(crop, "v0");

auto f = make_shared<Function>(ResultVector({res, res2}), ParameterVector{arg, arg2});

NodeVector nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 9);
ParameterVector params = f->get_parameters();
EXPECT_EQ(params.size(), 2);

f->remove_parameter(arg2);
params = f->get_parameters();
EXPECT_EQ(params.size(), 1);
nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 8);
}