Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ngraph: add methods for removing parameters from Function #3854

Merged
merged 13 commits into from
Feb 4, 2021
12 changes: 11 additions & 1 deletion ngraph/core/include/ngraph/function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ namespace ngraph
/// \param result Result node to delete
void remove_result(const std::shared_ptr<op::Result>& result);

/// \brief Add new Parameter nodes to the list. Method doesn't validate graph, it should be
/// done manually after all changes.
/// \param params new Parameter nodes
void add_parameters(const ParameterVector& params);
ilyachur marked this conversation as resolved.
Show resolved Hide resolved

/// \brief Delete Parameter node from the list of parameters. Method will not delete node
/// from graph. Attention: Indexing of parameters will be changed. \param param Parameter
/// node to delete
void remove_parameter(const std::shared_ptr<op::Parameter>& param);

private:
Function(const Function&) = delete;
Function(const Function&&) = delete;
Expand Down Expand Up @@ -203,4 +213,4 @@ namespace ngraph
0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
}
} // namespace ngraph
14 changes: 14 additions & 0 deletions ngraph/core/src/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,4 +405,18 @@ void Function::remove_result(const std::shared_ptr<op::Result>& result)
m_results.end());
}

void Function::add_parameters(const ParameterVector& params)
{
m_parameters.insert(m_parameters.end(), params.begin(), params.end());
sadolini marked this conversation as resolved.
Show resolved Hide resolved
}

void Function::remove_parameter(const std::shared_ptr<op::Parameter>& param)
sadolini marked this conversation as resolved.
Show resolved Hide resolved
{
m_parameters.erase(
std::remove_if(m_parameters.begin(),
m_parameters.end(),
[&param](std::shared_ptr<op::v0::Parameter>& r) { return r == param; }),
GlebKazantaev marked this conversation as resolved.
Show resolved Hide resolved
m_parameters.end());
}

constexpr DiscreteTypeInfo AttributeAdapter<shared_ptr<Function>>::type_info;
55 changes: 55 additions & 0 deletions ngraph/test/build_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,58 @@ TEST(build_graph, build_graph_with_remove_result)
nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 5);
}

TEST(build_graph, build_graph_with_add_parameter)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto arg2 = make_shared<op::Parameter>(element::f32, Shape{2, 8});
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
auto read = make_shared<op::ReadValue>(init_const, "v0");
std::vector<shared_ptr<Node>> args = {arg, read};
auto pattern = make_shared<op::Concat>(args, 1);
auto res = make_shared<op::Result>(pattern);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
auto crop = make_shared<op::v1::Split>(pattern, axis, 3);
auto res2 = make_shared<op::Result>(crop, "v0");

auto f = make_shared<Function>(ResultVector({res, res2}), ParameterVector{arg});

NodeVector nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 8);
ParameterVector params = f->get_parameters();
EXPECT_EQ(params.size(), 1);

f->add_parameters(ParameterVector({arg2}));
params = f->get_parameters();
sadolini marked this conversation as resolved.
Show resolved Hide resolved
EXPECT_EQ(params.size(), 2);
EXPECT_EQ(params[1], arg2);
nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 9);
}

TEST(build_graph, build_graph_with_remove_parameter)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto arg2 = make_shared<op::Parameter>(element::f32, Shape{2, 8});
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
auto read = make_shared<op::ReadValue>(init_const, "v0");
std::vector<shared_ptr<Node>> args = {arg, read};
auto pattern = make_shared<op::Concat>(args, 1);
auto res = make_shared<op::Result>(pattern);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
auto crop = make_shared<op::v1::Split>(pattern, axis, 3);
auto res2 = make_shared<op::Result>(crop, "v0");

auto f = make_shared<Function>(ResultVector({res, res2}), ParameterVector{arg, arg2});

NodeVector nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 9);
ParameterVector params = f->get_parameters();
EXPECT_EQ(params.size(), 2);

f->remove_parameter(arg2);
params = f->get_parameters();
EXPECT_EQ(params.size(), 1);
nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 8);
}