Skip to content

Commit

Permalink
Fix clone_function: clone ngraph::Variables (#6804)
Browse files Browse the repository at this point in the history
* fix clone function

* ngraph codestyle

* fix copy function for assign/read value v3

* add unit test
  • Loading branch information
itikhono authored Jul 30, 2021
1 parent 5920cf8 commit bc06279
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 7 deletions.
3 changes: 2 additions & 1 deletion ngraph/core/include/ngraph/graph_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ namespace ngraph
bool is_post_dominated(Node* X, Node* Y);

NGRAPH_API
bool is_equal_to_const_value(std::string const_value, const Output<Node>& reduce_constant);
bool is_equal_to_const_value(const std::string& const_value,
const Output<Node>& reduce_constant);

// input nodes are cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
Expand Down
39 changes: 33 additions & 6 deletions ngraph/core/src/graph_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "ngraph/op/tensor_iterator.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/opsets/opset5.hpp"
#include "ngraph/opsets/opset8.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/provenance.hpp"
Expand Down Expand Up @@ -405,6 +406,32 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
// clone function operations
clone_nodes(func.get_ops(), node_map);

// clone variables
auto variables = func.get_variables();
VariableVector cloned_vars;
std::map<std::string, std::shared_ptr<Variable>> var_map;
for (const auto& var : variables)
{
auto cloned_var = std::make_shared<Variable>(
VariableInfo{PartialShape::dynamic(), element::dynamic, var->get_info().variable_id});
cloned_vars.push_back(cloned_var);
var_map[cloned_var->get_info().variable_id] = cloned_var;
}
if (!variables.empty())
{
for (const auto& op : node_map)
{
if (auto read_val = std::dynamic_pointer_cast<VariableExtension>(op.second))
{
read_val->set_variable(var_map.at(read_val->get_variable_id()));
}
else if (auto assign = std::dynamic_pointer_cast<VariableExtension>(op.second))
{
assign->set_variable(var_map.at(assign->get_variable_id()));
}
}
}

// get cloned function results and sinks and parameters
ResultVector cloned_results;
for (shared_ptr<Node> node : func.get_results())
Expand All @@ -417,25 +444,25 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
cloned_results.push_back(result);
}
SinkVector cloned_sinks;
for (auto node : func.get_sinks())
for (const auto& node : func.get_sinks())
{
cloned_sinks.push_back(static_pointer_cast<op::Sink>(node_map.at(node.get())));
}

std::vector<std::shared_ptr<op::Parameter>> cloned_params;
for (auto param : func.get_parameters())
for (const auto& param : func.get_parameters())
{
cloned_params.push_back(as_type_ptr<op::Parameter>(node_map.at(param.get())));
}

// create and return cloned function
auto result = std::make_shared<ngraph::Function>(cloned_results, cloned_params);
result->set_friendly_name(func.get_friendly_name());
result->add_sinks(cloned_sinks);
auto result = std::make_shared<ngraph::Function>(
cloned_results, cloned_sinks, cloned_params, cloned_vars, func.get_friendly_name());
return result;
}

bool ngraph::is_equal_to_const_value(std::string const_value, const Output<Node>& reduce_constant)
bool ngraph::is_equal_to_const_value(const std::string& const_value,
const Output<Node>& reduce_constant)
{
if (auto rc = as_type_ptr<ngraph::op::Constant>(reduce_constant.get_node_shared_ptr()))
{
Expand Down
19 changes: 19 additions & 0 deletions ngraph/test/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/op_annotations.hpp"
#include "ngraph/opsets/opset6.hpp"
#include "ngraph/opsets/opset8.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "util/all_close.hpp"
Expand Down Expand Up @@ -251,6 +252,24 @@ TEST(graph_util, clone_multiple_results)
auto copy = clone_function(*f);
}

TEST(graph_util, clone_function_variables)
{
auto c_fp16 = make_shared<opset8::Constant>(element::f16, Shape{3}, std::vector<float>{0});
auto variable = make_shared<Variable>(VariableInfo{PartialShape::dynamic(), element::dynamic, "var_1"});
auto read_value = make_shared<opset8::ReadValue>(c_fp16, variable);
auto assign = make_shared<opset8::Assign>(read_value, variable);
auto f = make_shared<Function>(OutputVector{assign}, ParameterVector{}, VariableVector{variable});
auto copy = clone_function(*f);
auto c_fp32 = make_shared<opset8::Constant>(element::f32, Shape{3}, std::vector<float>{0});
for (const auto& op : copy->get_ops()) {
if (auto constant = std::dynamic_pointer_cast<opset8::Constant>(op)) {
ngraph::replace_node(constant, c_fp32);
}
}
copy->validate_nodes_and_infer_types();
copy = clone_function(*f);
}

TEST(graph_util, clone_rt_info)
{
const std::string testAffinity = "CPU";
Expand Down

0 comments on commit bc06279

Please sign in to comment.