Skip to content

Commit

Permalink
Validation time for reshape and transpose speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
Stepyreva, Evgenya committed Jul 27, 2021
1 parent 01b8ff8 commit 92dfd5c
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 143 deletions.
4 changes: 2 additions & 2 deletions ngraph/core/include/ngraph/op/reshape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ namespace ngraph
const HostTensorVector& inputs) const;

private:
void calculate_output_shape(std::vector<Dimension>& reshape_pattern,
void calculate_output_shape(const std::vector<Dimension>& reshape_pattern,
const int64_t& minus_one_idx,
const PartialShape& input_pshape,
std::vector<Dimension>& output_shape) const;
PartialShape& output_shape) const;
};
} // namespace v1
} // namespace op
Expand Down
292 changes: 168 additions & 124 deletions ngraph/core/src/op/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <algorithm>
#include <ngraph/validation_util.hpp>
#include <numeric>

#include "itt.hpp"
#include "ngraph/op/constant.hpp"
Expand Down Expand Up @@ -61,10 +62,9 @@ bool op::v1::Reshape::visit_attributes(AttributeVisitor& visitor)
void op::v1::Reshape::validate_and_infer_types()
{
NGRAPH_OP_SCOPE(v1_Reshape_validate_and_infer_types);
auto shape_pattern_et = get_input_element_type(1);
// check data types
NODE_VALIDATION_CHECK(
this, shape_pattern_et.is_integral_number(), "Shape pattern must be an integral number.");
this, get_input_element_type(1).is_integral_number(), "Shape pattern must be an integral number.");

// check shapes
const PartialShape& input_pshape = get_input_partial_shape(0);
Expand All @@ -76,7 +76,7 @@ void op::v1::Reshape::validate_and_infer_types()
"Pattern shape must have rank 1 or be empty, got ",
shape_pattern_shape.rank(),
".");
Rank output_rank =
const Rank output_rank =
shape_pattern_shape.rank().is_dynamic()
? Rank::dynamic()
: shape_pattern_shape.rank().get_length() == 0 ? 0 : shape_pattern_shape[0];
Expand All @@ -87,43 +87,73 @@ void op::v1::Reshape::validate_and_infer_types()
bool shape_can_be_calculated = false;
int64_t minus_one_idx = -1;

HostTensorPtr lb, ub;
std::tie(lb, ub) = evaluate_both_bounds(get_input_source_output(1));
if (lb && ub)
if (const auto& constant = get_constant_from_source(input_value(1)))
{
const auto lower_bound = std::make_shared<op::Constant>(lb)->cast_vector<int64_t>();
const auto upper_bound = std::make_shared<op::Constant>(ub)->cast_vector<int64_t>();
shape_can_be_calculated = true;
NGRAPH_CHECK(lower_bound.size() == upper_bound.size());
for (size_t i = 0; i < lower_bound.size(); ++i)
reshape_pattern = constant->cast_vector<Dimension>();
for (size_t i = 0; i < reshape_pattern.size(); ++i)
{
NODE_VALIDATION_CHECK(this,
lower_bound[i] >= -1 && upper_bound[i] >= -1,
reshape_pattern[i].is_dynamic() || reshape_pattern[i].get_length() >= 0,
"Dim size cannot be less than -1");

if (lower_bound[i] == -1 && upper_bound[i] == -1)
if (reshape_pattern[i].is_dynamic())
{ // ctor of Dimension(-1) would turn input Dimension(0, max_int)
NODE_VALIDATION_CHECK(
this, minus_one_idx == -1, "More than one dimension has size of -1");
this, minus_one_idx == -1, "More than one dimension has size of -1");
minus_one_idx = static_cast<int64_t>(i);
}
reshape_pattern.emplace_back(lower_bound[i], upper_bound[i]);
}
// For scalar case reshape_patter should be empty but scalar reshape pattern should be empty
// For scalar case reshape_pattern should be empty but scalar reshape pattern should be empty
// or equal to 1
if (output_rank.is_static() && output_rank.get_length() == 0 && !lower_bound.empty())
if (output_rank == 0 && !reshape_pattern.empty())
{
reshape_pattern.clear();
NGRAPH_CHECK(lower_bound.size() == 1);
NGRAPH_CHECK(reshape_pattern.size() == 1);
NODE_VALIDATION_CHECK(this,
lower_bound[0] == 1 && upper_bound[0] == 1,
reshape_pattern[0] == 1,
"The value of scalar shape pattern should be equal to 1!");
reshape_pattern.clear();
}
}
else
{
HostTensorPtr lb, ub;
std::tie(lb, ub) = evaluate_both_bounds(get_input_source_output(1));
if (lb && ub)
{
const auto& lower_bound = std::make_shared<op::Constant>(lb)->cast_vector<int64_t>();
const auto& upper_bound = std::make_shared<op::Constant>(ub)->cast_vector<int64_t>();
shape_can_be_calculated = true;
NGRAPH_CHECK(lower_bound.size() == upper_bound.size());
for (size_t i = 0; i < lower_bound.size(); ++i)
{
NODE_VALIDATION_CHECK(this,
lower_bound[i] >= -1 && upper_bound[i] >= -1,
"Dim size cannot be less than -1");

if (lower_bound[i] == -1 && upper_bound[i] == -1)
{ // ctor of Dimension(-1) would turn input Dimension(0, max_int)
NODE_VALIDATION_CHECK(
this, minus_one_idx == -1, "More than one dimension has size of -1");
minus_one_idx = static_cast<int64_t>(i);
}
reshape_pattern.emplace_back(lower_bound[i], upper_bound[i]);
}
// For scalar case reshape_patter should be empty but scalar reshape pattern should be empty
// or equal to 1
if (output_rank.is_static() && output_rank.get_length() == 0 && !lower_bound.empty())
{
reshape_pattern.clear();
NGRAPH_CHECK(lower_bound.size() == 1);
NODE_VALIDATION_CHECK(this,
lower_bound[0] == 1 && upper_bound[0] == 1,
"The value of scalar shape pattern should be equal to 1!");
}
}
}

if (shape_can_be_calculated)
{
std::vector<Dimension> output_shape(output_rank.get_length());
auto output_shape = PartialShape::dynamic(output_rank);
calculate_output_shape(reshape_pattern, minus_one_idx, input_pshape, output_shape);
set_output_type(0, get_input_element_type(0), output_shape);
}
Expand Down Expand Up @@ -178,11 +208,11 @@ bool op::v1::Reshape::evaluate_reshape(const HostTensorVector& outputs,
reshape_pattern.emplace_back(out_shape_val[i]);
}

std::vector<Dimension> output_shape(out_shape_val.size());
auto output_shape = PartialShape::dynamic(out_shape_val.size());
calculate_output_shape(
reshape_pattern, minus_one_idx, inputs[0]->get_partial_shape(), output_shape);
NGRAPH_CHECK(PartialShape(output_shape).is_static());
outputs[0]->set_shape(PartialShape(output_shape).to_shape());
NGRAPH_CHECK(output_shape.is_static());
outputs[0]->set_shape(output_shape.to_shape());

const AxisVector order = get_default_order(inputs[0]->get_shape());
return reshapeop::evaluate_reshape(inputs[0], outputs[0], order);
Expand Down Expand Up @@ -259,133 +289,147 @@ bool op::v1::Reshape::constant_fold(OutputVector& output_values, const OutputVec
return false;
}

void op::v1::Reshape::calculate_output_shape(vector<Dimension>& reshape_pattern,
void op::v1::Reshape::calculate_output_shape(const vector<Dimension>& reshape_pattern,
const int64_t& minus_one_idx,
const PartialShape& input_pshape,
vector<Dimension>& output_shape) const
PartialShape& output_shape) const
{
Dimension output_product(1);
for (int64_t i = 0; i < static_cast<int64_t>(reshape_pattern.size()); ++i)
if (minus_one_idx == -1)
{
if (i == minus_one_idx) // resolving everything except -1
continue;

auto pattern_dim = reshape_pattern[i];
if (pattern_dim.get_min_length() == 0 && pattern_dim.get_max_length() == 0 &&
get_special_zero())
for (int64_t i = 0; i < static_cast<int64_t>(reshape_pattern.size()); ++i)
{
if (input_pshape.rank().is_dynamic())
const auto& pattern_dim = reshape_pattern[i];
if (pattern_dim == 0 && get_special_zero())
{
output_shape[i] = Dimension::dynamic();
output_product *= Dimension::dynamic();
if (input_pshape.rank().is_dynamic())
{
output_shape[i] = Dimension::dynamic();
}
else
{
NODE_VALIDATION_CHECK(
this, i < input_pshape.rank().get_length(), "'0' dimension is out of range");
output_shape[i] = input_pshape[i];
}
}
else
{
NODE_VALIDATION_CHECK(
this, i < input_pshape.rank().get_length(), "'0' dimension is out of range");
output_shape[i] = input_pshape[i];
// we do not include dimension to output product here and won't include in input
// product later because we will divide output_product by input_product. This
// dimension contributes to both products equally, but in case this dimension
// is dynamic and others are not we could fully define output dimension that
// is masked by -1
output_shape[i] = pattern_dim;
}
}
else
{
output_shape[i] = pattern_dim;
output_product *= pattern_dim;
}
}
Dimension input_product(1);
if (input_pshape.rank().is_static())
for (int64_t i = 0; i < input_pshape.rank().get_length(); ++i)
{
if (i < static_cast<int64_t>(reshape_pattern.size()) &&
reshape_pattern[i].get_min_length() == 0 &&
reshape_pattern[i].get_max_length() == 0)
continue;
input_product *= input_pshape[i];
}
else if (reshape_pattern.size() == 1 && reshape_pattern[0] == -1 && minus_one_idx == 0)
{
output_shape = {std::accumulate(input_pshape.begin(), input_pshape.end(), Dimension(1), std::multiplies<Dimension>())};
return;
}
else
input_product = Dimension::dynamic();

if (minus_one_idx != -1) // resolving -1 masked dimension
{
if (output_product.get_min_length() == 0 && output_product.get_max_length() == 0)
{
// TODO: Decide if this is desired behavior here. (NumPy seems
// to fail.)
NODE_VALIDATION_CHECK(this,
input_product.get_min_length() == 0 &&
input_product.get_max_length() == 0,
"Cannot infer '-1' dimension with zero-size output "
"dimension unless at least one input dimension is "
"also zero-size");
output_shape[minus_one_idx] = Dimension(0);
}
else
{
if (input_product.is_static() && output_product.is_static())
{
NODE_VALIDATION_CHECK(
this,
input_product.get_length() % output_product.get_length() == 0,
"Non-'-1' output dimensions do not evenly divide the input dimensions");
Dimension output_product(1);
for (int64_t i = 0; i < static_cast<int64_t>(reshape_pattern.size()); ++i) {
if (i == minus_one_idx) // resolving everything except -1
continue;

const auto &pattern_dim = reshape_pattern[i];
if (pattern_dim == 0 && get_special_zero()) {
if (input_pshape.rank().is_dynamic()) {
output_shape[i] = Dimension::dynamic();
output_product *= Dimension::dynamic();
} else {
NODE_VALIDATION_CHECK(
this, i < input_pshape.rank().get_length(), "'0' dimension is out of range");
output_shape[i] = input_pshape[i];
// we do not include dimension to output product here and won't include in input
// product later because we will divide output_product by input_product. This
// dimension contributes to both products equally, but in case this dimension
// is dynamic and others are not we could fully define output dimension that
// is masked by -1
}
} else {
output_shape[i] = pattern_dim;
output_product *= pattern_dim;
}
if (output_product.get_min_length() == 0 || output_product == Dimension() ||
input_product == Dimension())
{
output_shape[minus_one_idx] = Dimension::dynamic();
}
const auto &pattern_length = static_cast<int64_t>(reshape_pattern.size());
Dimension input_product(1);
if (input_pshape.rank().is_static())
for (int64_t i = 0; i < input_pshape.rank().get_length(); ++i) {
if (i < pattern_length && reshape_pattern[i] == 0)
continue;
input_product *= input_pshape[i];
}
else
{
Dimension::value_type lower;
if (input_product.get_min_length() == 0)
lower = 0;
else if (input_product.get_min_length() == -1 ||
output_product.get_max_length() == 0 ||
output_product.get_max_length() == -1)
lower = -1; // dynamic
else
lower = static_cast<Dimension::value_type>(
ceil(static_cast<double>(input_product.get_min_length()) /
output_product.get_max_length()));

Dimension::value_type upper;
if (input_product.get_max_length() == 0)
upper = 0;
else if (input_product.get_max_length() == -1 ||
output_product.get_min_length() == 0 ||
output_product.get_min_length() == -1)
upper = -1; // dynamic
else
upper = static_cast<Dimension::value_type>(
floor(static_cast<double>(input_product.get_max_length()) /
output_product.get_min_length()));
else
input_product = Dimension::dynamic();

if (lower == -1)
output_shape[minus_one_idx] = Dimension::dynamic();
else if (upper == -1)
output_shape[minus_one_idx] = Dimension(lower, upper);
else if (lower > upper) // empty intersection
if (minus_one_idx != -1) // resolving -1 masked dimension
{
if (output_product == 0) {
// TODO: Decide if this is desired behavior here. (NumPy seems
// to fail.)
NODE_VALIDATION_CHECK(this,
input_product.get_min_length() == 0 &&
input_product.get_max_length() == 0,
"Cannot infer '-1' dimension with zero-size output "
"dimension unless at least one input dimension is "
"also zero-size");
output_shape[minus_one_idx] = Dimension(0);
} else {
if (input_product.is_static() && output_product.is_static()) {
NODE_VALIDATION_CHECK(
this,
input_product.get_length() % output_product.get_length() == 0,
"Non-'-1' output dimensions do not evenly divide the input dimensions");
}
if (output_product.get_min_length() == 0 || output_product == Dimension() ||
input_product == Dimension()) {
output_shape[minus_one_idx] = Dimension::dynamic();
else
output_shape[minus_one_idx] = Dimension(lower, upper);
} else {
Dimension::value_type lower;
if (input_product.get_min_length() == 0)
lower = 0;
else if (input_product.get_min_length() == -1 ||
output_product.get_max_length() == 0 ||
output_product.get_max_length() == -1)
lower = -1; // dynamic
else
lower = static_cast<Dimension::value_type>(
ceil(static_cast<double>(input_product.get_min_length()) /
output_product.get_max_length()));

Dimension::value_type upper;
if (input_product.get_max_length() == 0)
upper = 0;
else if (input_product.get_max_length() == -1 ||
output_product.get_min_length() == 0 ||
output_product.get_min_length() == -1)
upper = -1; // dynamic
else
upper = static_cast<Dimension::value_type>(
floor(static_cast<double>(input_product.get_max_length()) /
output_product.get_min_length()));

if (lower == -1)
output_shape[minus_one_idx] = Dimension::dynamic();
else if (upper == -1)
output_shape[minus_one_idx] = Dimension(lower, upper);
else if (lower > upper) // empty intersection
output_shape[minus_one_idx] = Dimension::dynamic();
else
output_shape[minus_one_idx] = Dimension(lower, upper);
}
}
}
}
PartialShape output_pshape(output_shape);
if (input_pshape.is_static() && output_pshape.is_static())
if (input_pshape.is_static() && output_shape.is_static())
{
size_t zero_dims =
const auto& zero_dims =
std::count_if(reshape_pattern.begin(), reshape_pattern.end(), [](Dimension dim) {
return dim.get_max_length() == 0 && dim.get_min_length() == 0;
return dim == 0;
});

bool backward_compatible_check = (zero_dims && get_special_zero()) || minus_one_idx != -1;
bool in_out_elements_equal =
shape_size(get_input_shape(0)) == shape_size(output_pshape.to_shape());
shape_size(get_input_shape(0)) == shape_size(output_shape.to_shape());

NODE_VALIDATION_CHECK(this,
backward_compatible_check || in_out_elements_equal,
Expand Down
Loading

0 comments on commit 92dfd5c

Please sign in to comment.