diff --git a/ngraph/core/include/ngraph/graph_util.hpp b/ngraph/core/include/ngraph/graph_util.hpp index fda65edffffd6f..cfc5c295984ea2 100644 --- a/ngraph/core/include/ngraph/graph_util.hpp +++ b/ngraph/core/include/ngraph/graph_util.hpp @@ -365,7 +365,8 @@ namespace ngraph bool is_post_dominated(Node* X, Node* Y); NGRAPH_API - bool is_equal_to_const_value(std::string const_value, const Output& reduce_constant); + bool is_equal_to_const_value(const std::string& const_value, + const Output& reduce_constant); // input nodes are cloned and returned // NodeMap input may contain default node mapping i.e. pre-cloned nodes diff --git a/ngraph/core/src/graph_util.cpp b/ngraph/core/src/graph_util.cpp index 901672723a1799..921f54f25da4c7 100644 --- a/ngraph/core/src/graph_util.cpp +++ b/ngraph/core/src/graph_util.cpp @@ -20,6 +20,7 @@ #include "ngraph/op/tensor_iterator.hpp" #include "ngraph/op/util/op_types.hpp" #include "ngraph/opsets/opset5.hpp" +#include "ngraph/opsets/opset8.hpp" #include "ngraph/pass/manager.hpp" #include "ngraph/pass/visualize_tree.hpp" #include "ngraph/provenance.hpp" @@ -405,6 +406,32 @@ std::shared_ptr ngraph::clone_function(const ngraph::Function& // clone function operations clone_nodes(func.get_ops(), node_map); + // clone variables + auto variables = func.get_variables(); + VariableVector cloned_vars; + std::map> var_map; + for (const auto& var : variables) + { + auto cloned_var = std::make_shared( + VariableInfo{PartialShape::dynamic(), element::dynamic, var->get_info().variable_id}); + cloned_vars.push_back(cloned_var); + var_map[cloned_var->get_info().variable_id] = cloned_var; + } + if (!variables.empty()) + { + for (const auto& op : node_map) + { + if (auto read_val = std::dynamic_pointer_cast(op.second)) + { + read_val->set_variable(var_map.at(read_val->get_variable_id())); + } + else if (auto assign = std::dynamic_pointer_cast(op.second)) + { + assign->set_variable(var_map.at(assign->get_variable_id())); + } + } + } + // get cloned function results and sinks and parameters ResultVector cloned_results; for (shared_ptr node : func.get_results()) @@ -417,25 +444,25 @@ std::shared_ptr ngraph::clone_function(const ngraph::Function& cloned_results.push_back(result); } SinkVector cloned_sinks; - for (auto node : func.get_sinks()) + for (const auto& node : func.get_sinks()) { cloned_sinks.push_back(static_pointer_cast(node_map.at(node.get()))); } std::vector> cloned_params; - for (auto param : func.get_parameters()) + for (const auto& param : func.get_parameters()) { cloned_params.push_back(as_type_ptr(node_map.at(param.get()))); } // create and return cloned function - auto result = std::make_shared(cloned_results, cloned_params); - result->set_friendly_name(func.get_friendly_name()); - result->add_sinks(cloned_sinks); + auto result = std::make_shared( + cloned_results, cloned_sinks, cloned_params, cloned_vars, func.get_friendly_name()); return result; } -bool ngraph::is_equal_to_const_value(std::string const_value, const Output& reduce_constant) +bool ngraph::is_equal_to_const_value(const std::string& const_value, + const Output& reduce_constant) { if (auto rc = as_type_ptr(reduce_constant.get_node_shared_ptr())) { diff --git a/ngraph/test/util.cpp b/ngraph/test/util.cpp index f1384d010722c2..7135b8c46ddfec 100644 --- a/ngraph/test/util.cpp +++ b/ngraph/test/util.cpp @@ -15,6 +15,7 @@ #include "ngraph/ngraph.hpp" #include "ngraph/op/util/op_annotations.hpp" #include "ngraph/opsets/opset6.hpp" +#include "ngraph/opsets/opset8.hpp" #include "ngraph/pass/manager.hpp" #include "ngraph/pass/visualize_tree.hpp" #include "util/all_close.hpp" @@ -251,6 +252,24 @@ TEST(graph_util, clone_multiple_results) auto copy = clone_function(*f); } +TEST(graph_util, clone_function_variables) +{ + auto c_fp16 = make_shared(element::f16, Shape{3}, std::vector{0}); + auto variable = make_shared(VariableInfo{PartialShape::dynamic(), element::dynamic, "var_1"}); + auto read_value = make_shared(c_fp16, variable); + auto assign = make_shared(read_value, variable); + auto f = make_shared(OutputVector{assign}, ParameterVector{}, VariableVector{variable}); + auto copy = clone_function(*f); + auto c_fp32 = make_shared(element::f32, Shape{3}, std::vector{0}); + for (const auto& op : copy->get_ops()) { + if (auto constant = std::dynamic_pointer_cast(op)) { + ngraph::replace_node(constant, c_fp32); + } + } + copy->validate_nodes_and_infer_types(); + copy = clone_function(*f); +} + TEST(graph_util, clone_rt_info) { const std::string testAffinity = "CPU";