Skip to content

Commit

Permalink
added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sadolini committed Jan 12, 2021
1 parent 2c43725 commit 64fea17
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions ngraph/test/build_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,58 @@ TEST(build_graph, build_graph_with_remove_result)
nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 5);
}

TEST(build_graph, build_graph_with_add_parameter)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto arg2 = make_shared<op::Parameter>(element::f32, Shape{2, 8});
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
auto read = make_shared<op::ReadValue>(init_const, "v0");
std::vector<shared_ptr<Node>> args = {arg, read};
auto pattern = make_shared<op::Concat>(args, 1);
auto res = make_shared<op::Result>(pattern);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
auto crop = make_shared<op::v1::Split>(pattern, axis, 3);
auto res2 = make_shared<op::Result>(crop, "v0");

auto f = make_shared<Function>(ResultVector({res, res2}), ParameterVector{arg});

NodeVector nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 8);
ParameterVector params = f->get_parameters();
EXPECT_EQ(params.size(), 1);

f->add_parameters(ParameterVector({arg2}));
params = f->get_parameters();
EXPECT_EQ(params.size(), 2);
EXPECT_EQ(params[1], arg2);
nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 9);
}

TEST(build_graph, build_graph_with_remove_parameter)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto arg2 = make_shared<op::Parameter>(element::f32, Shape{2, 8});
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
auto read = make_shared<op::ReadValue>(init_const, "v0");
std::vector<shared_ptr<Node>> args = {arg, read};
auto pattern = make_shared<op::Concat>(args, 1);
auto res = make_shared<op::Result>(pattern);
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
auto crop = make_shared<op::v1::Split>(pattern, axis, 3);
auto res2 = make_shared<op::Result>(crop, "v0");

auto f = make_shared<Function>(ResultVector({res, res2}), ParameterVector{arg, arg2});

NodeVector nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 9);
ParameterVector params = f->get_parameters();
EXPECT_EQ(params.size(), 2);

f->remove_parameter(arg2);
params = f->get_parameters();
EXPECT_EQ(params.size(), 1);
nodes = f->get_ops();
EXPECT_EQ(nodes.size(), 8);
}

0 comments on commit 64fea17

Please sign in to comment.