Skip to content

Commit

Permalink
Validate speedup (openvinotoolkit#6779)
Browse files Browse the repository at this point in the history
* Add minor speedup changes.

* inline clip

* reduce clip calls

* more Interval::size - move to header

* terminate instead of throwing exception

* back to throw exception when element type was not found

* rename variable
  • Loading branch information
pelszkow authored and akuporos committed Sep 29, 2021
1 parent d5d101e commit 983a85b
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 119 deletions.
4 changes: 2 additions & 2 deletions ngraph/core/include/ngraph/enum_names.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace ngraph
});
return rc;
};
for (auto p : get().m_string_enums)
for (const auto& p : get().m_string_enums)
{
if (to_lower(p.first) == to_lower(name))
{
Expand All @@ -41,7 +41,7 @@ namespace ngraph
/// Converts enum values to strings
static const std::string& as_string(EnumType e)
{
for (auto& p : get().m_string_enums)
for (const auto& p : get().m_string_enums)
{
if (p.second == e)
{
Expand Down
17 changes: 10 additions & 7 deletions ngraph/core/include/ngraph/interval.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,16 @@ namespace ngraph
Interval& operator=(const Interval& interval) = default;

/// \brief The number of elements in the interval. Zero if max < min.
size_type size() const;
size_type size() const
{
if (m_max_val == s_max)
{
return m_min_val == s_max ? 0 : s_max;
}
return m_max_val - m_min_val + 1;
}
/// \brief Returns true if the interval has no elements
bool empty() const;
bool empty() const { return m_min_val == s_max; }
/// \brief the inclusive lower bound of the interval
value_type get_min_val() const { return m_min_val; }
/// \brief Set the inclusive lower bound of the interval
Expand Down Expand Up @@ -84,7 +91,7 @@ namespace ngraph
Interval& operator&=(const Interval& interval);

/// \brief True if this interval includes value
bool contains(value_type value) const;
bool contains(value_type value) const { return m_min_val <= value && value <= m_max_val; }
/// \brief True if this interval includes all the values in interval
bool contains(const Interval& interval) const;

Expand All @@ -93,10 +100,6 @@ namespace ngraph

protected:
void canonicalize();
static value_type clip(value_type value);
static value_type clip_times(value_type a, value_type b);
static value_type clip_add(value_type a, value_type b);
static value_type clip_minus(value_type a, value_type b);

value_type m_min_val{0};
value_type m_max_val{s_max};
Expand Down
4 changes: 2 additions & 2 deletions ngraph/core/include/ngraph/partial_shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ namespace ngraph

/// \brief Constructs a PartialShape with static rank from a vector of Dimension.
/// \param dimensions The Dimension values for the constructed shape.
PartialShape(const std::vector<Dimension>& dimensions);
PartialShape(std::vector<Dimension> dimensions);

/// \brief Constructs a PartialShape with static rank from a vector of dimensions values.
/// \param dimensions The Dimension values for the constructed shape.
Expand Down Expand Up @@ -269,7 +269,7 @@ namespace ngraph

private:
// Private constructor for PartialShape::dynamic().
PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions);
PartialShape(bool rank_is_static, std::vector<Dimension> dimensions);

// True if the shape's rank is static.
bool m_rank_is_static;
Expand Down
4 changes: 2 additions & 2 deletions ngraph/core/src/dimension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)

bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimension d2)
{
if (d1.m_dimension.size() == 1 && d1.m_dimension.get_min_val() == 1)
if (d1.m_dimension.get_min_val() == 1 && d1.m_dimension.size() == 1)
{
dst = d2;
return true;
}
if (d2.m_dimension.size() == 1 && d2.m_dimension.get_min_val() == 1)
if (d2.m_dimension.get_min_val() == 1 && d2.m_dimension.size() == 1)
{
dst = d1;
return true;
Expand Down
101 changes: 42 additions & 59 deletions ngraph/core/src/interval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,46 @@

using namespace ngraph;

namespace
{
Interval::value_type clip(Interval::value_type value)
{
return std::max(Interval::value_type(0), std::min(Interval::s_max, value));
}

Interval::value_type clip_times(Interval::value_type a, Interval::value_type b)
{
if (a == 0 || b == 0)
{
return 0;
}
else if (a == Interval::s_max || b == Interval::s_max)
{
return Interval::s_max;
}
else
{
return a * b;
}
}
Interval::value_type clip_add(Interval::value_type a, Interval::value_type b)
{
return (a == Interval::s_max || b == Interval::s_max) ? Interval::s_max : a + b;
}
Interval::value_type clip_minus(Interval::value_type a, Interval::value_type b)
{
if (a <= b)
{
return 0;
}
if (a == Interval::s_max)
{
return Interval::s_max;
}
return a - b;
}
} // namespace

void Interval::canonicalize()
{
if (m_max_val < m_min_val)
Expand All @@ -28,22 +68,9 @@ Interval::Interval(value_type min_val, value_type max_val)
}

Interval::Interval(value_type val)
: Interval(val, val)
{
}

Interval::size_type Interval::size() const
{
if (m_max_val == s_max)
{
return m_min_val == s_max ? 0 : s_max;
}
return m_max_val - m_min_val + 1;
}

bool Interval::empty() const
{
return m_min_val == s_max;
m_min_val = clip(val);
m_max_val = m_min_val;
}

bool Interval::operator==(const Interval& interval) const
Expand Down Expand Up @@ -116,55 +143,11 @@ Interval& Interval::operator&=(const Interval& interval)
return *this = *this & interval;
}

bool Interval::contains(value_type value) const
{
return m_min_val <= value && value <= m_max_val;
}

bool Interval::contains(const Interval& interval) const
{
return contains(interval.m_min_val) && contains(interval.m_max_val);
}

Interval::value_type Interval::clip(value_type value)
{
return std::max(value_type(0), std::min(s_max, value));
}

Interval::value_type Interval::clip_add(value_type a, value_type b)
{
return (a == s_max || b == s_max) ? s_max : a + b;
}

Interval::value_type Interval::clip_minus(value_type a, value_type b)
{
if (a <= b)
{
return 0;
}
if (a == s_max)
{
return s_max;
}
return a - b;
}

Interval::value_type Interval::clip_times(value_type a, value_type b)
{
if (a == 0 || b == 0)
{
return 0;
}
else if (a == s_max || b == s_max)
{
return s_max;
}
else
{
return a * b;
}
}

constexpr Interval::value_type Interval::s_max;

namespace ngraph
Expand Down
10 changes: 5 additions & 5 deletions ngraph/core/src/partial_shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ PartialShape::PartialShape(const Shape& shape)
{
}

PartialShape::PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions)
PartialShape::PartialShape(bool rank_is_static, std::vector<Dimension> dimensions)
: m_rank_is_static(rank_is_static)
, m_dimensions(dimensions)
, m_dimensions(std::move(dimensions))
{
}

PartialShape::PartialShape(const std::vector<Dimension>& dimensions)
PartialShape::PartialShape(std::vector<Dimension> dimensions)
: m_rank_is_static(true)
, m_dimensions(dimensions)
, m_dimensions(std::move(dimensions))
{
}

Expand Down Expand Up @@ -387,7 +387,7 @@ bool PartialShape::broadcast_merge_into(PartialShape& dst,
i < (new_rank - src_rank) ? Dimension(1) : src[i - (new_rank - src_rank)];
success &= Dimension::broadcast_merge(dims[i], dsti, srci);
}
dst = PartialShape(dims);
dst = PartialShape(std::move(dims));
return success;
}
}
Expand Down
Loading

0 comments on commit 983a85b

Please sign in to comment.