diff --git a/ngraph/test/build_graph.cpp b/ngraph/test/build_graph.cpp index f43fa79bd12423..e3401d3125c48a 100644 --- a/ngraph/test/build_graph.cpp +++ b/ngraph/test/build_graph.cpp @@ -363,3 +363,58 @@ TEST(build_graph, build_graph_with_remove_result) nodes = f->get_ops(); EXPECT_EQ(nodes.size(), 5); } + +TEST(build_graph, build_graph_with_add_parameter) +{ + auto arg = make_shared(element::f32, Shape{2, 4}); + auto arg2 = make_shared(element::f32, Shape{2, 8}); + auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0}); + auto read = make_shared(init_const, "v0"); + std::vector> args = {arg, read}; + auto pattern = make_shared(args, 1); + auto res = make_shared(pattern); + const auto axis = op::Constant::create(element::i64, Shape{}, {1}); + auto crop = make_shared(pattern, axis, 3); + auto res2 = make_shared(crop, "v0"); + + auto f = make_shared(ResultVector({res, res2}), ParameterVector{arg}); + + NodeVector nodes = f->get_ops(); + EXPECT_EQ(nodes.size(), 8); + ParameterVector params = f->get_parameters(); + EXPECT_EQ(params.size(), 1); + + f->add_parameters(ParameterVector({arg2})); + params = f->get_parameters(); + EXPECT_EQ(params.size(), 2); + EXPECT_EQ(params[1], arg2); + nodes = f->get_ops(); + EXPECT_EQ(nodes.size(), 9); +} + +TEST(build_graph, build_graph_with_remove_parameter) +{ + auto arg = make_shared(element::f32, Shape{2, 4}); + auto arg2 = make_shared(element::f32, Shape{2, 8}); + auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0}); + auto read = make_shared(init_const, "v0"); + std::vector> args = {arg, read}; + auto pattern = make_shared(args, 1); + auto res = make_shared(pattern); + const auto axis = op::Constant::create(element::i64, Shape{}, {1}); + auto crop = make_shared(pattern, axis, 3); + auto res2 = make_shared(crop, "v0"); + + auto f = make_shared(ResultVector({res, res2}), ParameterVector{arg, arg2}); + + NodeVector nodes = f->get_ops(); + EXPECT_EQ(nodes.size(), 9); + ParameterVector params = f->get_parameters(); + EXPECT_EQ(params.size(), 2); + + f->remove_parameter(arg2); + params = f->get_parameters(); + EXPECT_EQ(params.size(), 1); + nodes = f->get_ops(); + EXPECT_EQ(nodes.size(), 8); +}