Skip to content

Commit

Permalink
Avoid Constant data copy inside Reshape constant folding
Browse files Browse the repository at this point in the history
  • Loading branch information
Gleb Kazantaev committed Jun 28, 2021
1 parent 7762505 commit e2e6347
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 13 deletions.
3 changes: 3 additions & 0 deletions ngraph/core/include/ngraph/op/constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ namespace ngraph
}

Constant(const Constant& other);
Constant(const Constant& other, const Shape & new_shape);
Constant& operator=(const Constant&) = delete;

virtual ~Constant() override;
Expand Down Expand Up @@ -307,6 +308,8 @@ namespace ngraph
return rc;
}

std::shared_ptr<runtime::AlignedBuffer> get_aligned_buffer() const { return m_data; }

const void* get_data_ptr() const { return (m_data ? m_data->get_ptr() : nullptr); }
template <typename T>
const T* get_data_ptr() const
Expand Down
11 changes: 11 additions & 0 deletions ngraph/core/src/op/constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,17 @@ op::Constant::Constant(const Constant& other)
constructor_validate_and_infer_types();
}

op::Constant::Constant(const Constant& other, const Shape & new_shape)
{
NGRAPH_CHECK(shape_size(other.m_shape) == shape_size(new_shape),
"Shape size " + std::to_string(shape_size(new_shape)) + " is not equal to " + std::to_string(shape_size(other.m_shape)));
m_element_type = other.m_element_type;
m_shape = new_shape;
m_data = other.m_data;
m_all_elements_bitwise_identical = other.m_all_elements_bitwise_identical;
constructor_validate_and_infer_types();
}

op::Constant::~Constant() {}

string op::Constant::convert_value_to_string(size_t index) const
Expand Down
14 changes: 1 addition & 13 deletions ngraph/core/src/op/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,7 @@ bool op::v1::Reshape::constant_fold(OutputVector& output_values, const OutputVec
if (auto data_const =
std::dynamic_pointer_cast<op::Constant>(inputs_values[0].get_node_shared_ptr()))
{
// In case if data constant has single consumer we can change it shape without making a copy
// Otherwise we create Constant copy with shape from reshape node
if (data_const->output(0).get_target_inputs().size() == 1)
{
data_const->set_data_shape(shape);
data_const->validate_and_infer_types();
output_values[0] = data_const;
}
else
{
output_values[0] = std::make_shared<op::Constant>(
data_const->get_element_type(), shape, data_const->get_data_ptr());
}
output_values[0] = std::make_shared<op::Constant>(*data_const, shape);
return true;
}
return false;
Expand Down
23 changes: 23 additions & 0 deletions ngraph/test/constant_folding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2301,6 +2301,29 @@ TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
}

TEST(constant_folding, const_reshape_no_data_copy)
{
auto const_data = op::Constant::create(element::f32, Shape{1, 64}, {1});
auto const_reshape = op::Constant::create(element::i64, Shape{2}, {2, 32});
auto reshape = std::make_shared<op::v1::Reshape>(const_data, const_reshape, false);
auto consumer1 = std::make_shared<op::Relu>(reshape);
auto consumer2 = std::make_shared<op::Relu>(reshape);

auto f = std::make_shared<Function>(NodeVector{consumer1, consumer2}, ParameterVector{});

pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);

auto const1 = std::dynamic_pointer_cast<op::Constant>(consumer1->input_value(0).get_node_shared_ptr());
auto const2 = std::dynamic_pointer_cast<op::Constant>(consumer2->input_value(0).get_node_shared_ptr());

ASSERT_TRUE(const1);
ASSERT_TRUE(const2);
ASSERT_EQ(const1, const2);
ASSERT_EQ(const1->get_aligned_buffer(), const2->get_aligned_buffer());
}

TEST(constant_folding, constant_transpose)
{
Shape shape_in{2, 4};
Expand Down

0 comments on commit e2e6347

Please sign in to comment.