Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate speedup #6779

Merged
merged 7 commits into from
Jul 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ngraph/core/include/ngraph/enum_names.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace ngraph
});
return rc;
};
for (auto p : get().m_string_enums)
for (const auto& p : get().m_string_enums)
{
if (to_lower(p.first) == to_lower(name))
{
Expand All @@ -41,7 +41,7 @@ namespace ngraph
/// Converts enum values to strings
static const std::string& as_string(EnumType e)
{
for (auto& p : get().m_string_enums)
for (const auto& p : get().m_string_enums)
{
if (p.second == e)
{
Expand Down
17 changes: 10 additions & 7 deletions ngraph/core/include/ngraph/interval.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,16 @@ namespace ngraph
Interval& operator=(const Interval& interval) = default;

/// \brief The number of elements in the interval. Zero if max < min.
size_type size() const;
size_type size() const
{
if (m_max_val == s_max)
{
return m_min_val == s_max ? 0 : s_max;
}
return m_max_val - m_min_val + 1;
}
/// \brief Returns true if the interval has no elements
bool empty() const;
bool empty() const { return m_min_val == s_max; }
/// \brief the inclusive lower bound of the interval
value_type get_min_val() const { return m_min_val; }
/// \brief Set the inclusive lower bound of the interval
Expand Down Expand Up @@ -84,7 +91,7 @@ namespace ngraph
Interval& operator&=(const Interval& interval);

/// \brief True if this interval includes value
bool contains(value_type value) const;
bool contains(value_type value) const { return m_min_val <= value && value <= m_max_val; }
/// \brief True if this interval includes all the values in interval
bool contains(const Interval& interval) const;

Expand All @@ -93,10 +100,6 @@ namespace ngraph

protected:
void canonicalize();
static value_type clip(value_type value);
static value_type clip_times(value_type a, value_type b);
static value_type clip_add(value_type a, value_type b);
static value_type clip_minus(value_type a, value_type b);

value_type m_min_val{0};
value_type m_max_val{s_max};
Expand Down
4 changes: 2 additions & 2 deletions ngraph/core/include/ngraph/partial_shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ namespace ngraph

/// \brief Constructs a PartialShape with static rank from a vector of Dimension.
/// \param dimensions The Dimension values for the constructed shape.
PartialShape(const std::vector<Dimension>& dimensions);
PartialShape(std::vector<Dimension> dimensions);

/// \brief Constructs a PartialShape with static rank from a vector of dimensions values.
/// \param dimensions The Dimension values for the constructed shape.
Expand Down Expand Up @@ -269,7 +269,7 @@ namespace ngraph

private:
// Private constructor for PartialShape::dynamic().
PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions);
PartialShape(bool rank_is_static, std::vector<Dimension> dimensions);

// True if the shape's rank is static.
bool m_rank_is_static;
Expand Down
4 changes: 2 additions & 2 deletions ngraph/core/src/dimension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)

bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimension d2)
{
if (d1.m_dimension.size() == 1 && d1.m_dimension.get_min_val() == 1)
if (d1.m_dimension.get_min_val() == 1 && d1.m_dimension.size() == 1)
{
dst = d2;
return true;
}
if (d2.m_dimension.size() == 1 && d2.m_dimension.get_min_val() == 1)
if (d2.m_dimension.get_min_val() == 1 && d2.m_dimension.size() == 1)
{
dst = d1;
return true;
Expand Down
101 changes: 42 additions & 59 deletions ngraph/core/src/interval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,46 @@

using namespace ngraph;

namespace
{
Interval::value_type clip(Interval::value_type value)
{
return std::max(Interval::value_type(0), std::min(Interval::s_max, value));
}

Interval::value_type clip_times(Interval::value_type a, Interval::value_type b)
{
if (a == 0 || b == 0)
{
return 0;
}
else if (a == Interval::s_max || b == Interval::s_max)
{
return Interval::s_max;
}
else
{
return a * b;
}
}
Interval::value_type clip_add(Interval::value_type a, Interval::value_type b)
{
return (a == Interval::s_max || b == Interval::s_max) ? Interval::s_max : a + b;
}
Interval::value_type clip_minus(Interval::value_type a, Interval::value_type b)
{
if (a <= b)
{
return 0;
}
if (a == Interval::s_max)
{
return Interval::s_max;
}
return a - b;
}
} // namespace

void Interval::canonicalize()
{
if (m_max_val < m_min_val)
Expand All @@ -28,22 +68,9 @@ Interval::Interval(value_type min_val, value_type max_val)
}

Interval::Interval(value_type val)
: Interval(val, val)
{
}

Interval::size_type Interval::size() const
{
if (m_max_val == s_max)
{
return m_min_val == s_max ? 0 : s_max;
}
return m_max_val - m_min_val + 1;
}

bool Interval::empty() const
{
return m_min_val == s_max;
m_min_val = clip(val);
m_max_val = m_min_val;
}

bool Interval::operator==(const Interval& interval) const
Expand Down Expand Up @@ -116,55 +143,11 @@ Interval& Interval::operator&=(const Interval& interval)
return *this = *this & interval;
}

bool Interval::contains(value_type value) const
{
return m_min_val <= value && value <= m_max_val;
}

bool Interval::contains(const Interval& interval) const
{
return contains(interval.m_min_val) && contains(interval.m_max_val);
}

Interval::value_type Interval::clip(value_type value)
{
return std::max(value_type(0), std::min(s_max, value));
}

Interval::value_type Interval::clip_add(value_type a, value_type b)
{
return (a == s_max || b == s_max) ? s_max : a + b;
}

Interval::value_type Interval::clip_minus(value_type a, value_type b)
{
if (a <= b)
{
return 0;
}
if (a == s_max)
{
return s_max;
}
return a - b;
}

Interval::value_type Interval::clip_times(value_type a, value_type b)
{
if (a == 0 || b == 0)
{
return 0;
}
else if (a == s_max || b == s_max)
{
return s_max;
}
else
{
return a * b;
}
}

constexpr Interval::value_type Interval::s_max;

namespace ngraph
Expand Down
10 changes: 5 additions & 5 deletions ngraph/core/src/partial_shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ PartialShape::PartialShape(const Shape& shape)
{
}

PartialShape::PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions)
PartialShape::PartialShape(bool rank_is_static, std::vector<Dimension> dimensions)
: m_rank_is_static(rank_is_static)
, m_dimensions(dimensions)
, m_dimensions(std::move(dimensions))
tomdol marked this conversation as resolved.
Show resolved Hide resolved
{
}

PartialShape::PartialShape(const std::vector<Dimension>& dimensions)
PartialShape::PartialShape(std::vector<Dimension> dimensions)
: m_rank_is_static(true)
, m_dimensions(dimensions)
, m_dimensions(std::move(dimensions))
{
}

Expand Down Expand Up @@ -387,7 +387,7 @@ bool PartialShape::broadcast_merge_into(PartialShape& dst,
i < (new_rank - src_rank) ? Dimension(1) : src[i - (new_rank - src_rank)];
success &= Dimension::broadcast_merge(dims[i], dsti, srci);
}
dst = PartialShape(dims);
dst = PartialShape(std::move(dims));
return success;
}
}
Expand Down
Loading