diff --git a/ngraph/core/include/ngraph/enum_names.hpp b/ngraph/core/include/ngraph/enum_names.hpp index 213613d8064883..8dbdf6f5cff607 100644 --- a/ngraph/core/include/ngraph/enum_names.hpp +++ b/ngraph/core/include/ngraph/enum_names.hpp @@ -28,7 +28,7 @@ namespace ngraph }); return rc; }; - for (auto p : get().m_string_enums) + for (const auto& p : get().m_string_enums) { if (to_lower(p.first) == to_lower(name)) { @@ -41,7 +41,7 @@ namespace ngraph /// Converts enum values to strings static const std::string& as_string(EnumType e) { - for (auto& p : get().m_string_enums) + for (const auto& p : get().m_string_enums) { if (p.second == e) { diff --git a/ngraph/core/include/ngraph/interval.hpp b/ngraph/core/include/ngraph/interval.hpp index 08302289f99018..c5cb5453d02a17 100644 --- a/ngraph/core/include/ngraph/interval.hpp +++ b/ngraph/core/include/ngraph/interval.hpp @@ -41,9 +41,16 @@ namespace ngraph Interval& operator=(const Interval& interval) = default; /// \brief The number of elements in the interval. Zero if max < min. - size_type size() const; + size_type size() const + { + if (m_max_val == s_max) + { + return m_min_val == s_max ? 0 : s_max; + } + return m_max_val - m_min_val + 1; + } /// \brief Returns true if the interval has no elements - bool empty() const; + bool empty() const { return m_min_val == s_max; } /// \brief the inclusive lower bound of the interval value_type get_min_val() const { return m_min_val; } /// \brief Set the inclusive lower bound of the interval @@ -84,7 +91,7 @@ namespace ngraph Interval& operator&=(const Interval& interval); /// \brief True if this interval includes value - bool contains(value_type value) const; + bool contains(value_type value) const { return m_min_val <= value && value <= m_max_val; } /// \brief True if this interval includes all the values in interval bool contains(const Interval& interval) const; @@ -93,10 +100,6 @@ namespace ngraph protected: void canonicalize(); - static value_type clip(value_type value); - static value_type clip_times(value_type a, value_type b); - static value_type clip_add(value_type a, value_type b); - static value_type clip_minus(value_type a, value_type b); value_type m_min_val{0}; value_type m_max_val{s_max}; diff --git a/ngraph/core/include/ngraph/partial_shape.hpp b/ngraph/core/include/ngraph/partial_shape.hpp index c100273d765a05..5f4bddf689482d 100644 --- a/ngraph/core/include/ngraph/partial_shape.hpp +++ b/ngraph/core/include/ngraph/partial_shape.hpp @@ -54,7 +54,7 @@ namespace ngraph /// \brief Constructs a PartialShape with static rank from a vector of Dimension. /// \param dimensions The Dimension values for the constructed shape. - PartialShape(const std::vector& dimensions); + PartialShape(std::vector dimensions); /// \brief Constructs a PartialShape with static rank from a vector of dimensions values. /// \param dimensions The Dimension values for the constructed shape. @@ -269,7 +269,7 @@ namespace ngraph private: // Private constructor for PartialShape::dynamic(). - PartialShape(bool rank_is_static, const std::vector& dimensions); + PartialShape(bool rank_is_static, std::vector dimensions); // True if the shape's rank is static. bool m_rank_is_static; diff --git a/ngraph/core/src/dimension.cpp b/ngraph/core/src/dimension.cpp index 2941d1ff083300..6b86316c74080a 100644 --- a/ngraph/core/src/dimension.cpp +++ b/ngraph/core/src/dimension.cpp @@ -99,12 +99,12 @@ bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2) bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimension d2) { - if (d1.m_dimension.size() == 1 && d1.m_dimension.get_min_val() == 1) + if (d1.m_dimension.get_min_val() == 1 && d1.m_dimension.size() == 1) { dst = d2; return true; } - if (d2.m_dimension.size() == 1 && d2.m_dimension.get_min_val() == 1) + if (d2.m_dimension.get_min_val() == 1 && d2.m_dimension.size() == 1) { dst = d1; return true; diff --git a/ngraph/core/src/interval.cpp b/ngraph/core/src/interval.cpp index f02ad332885e0c..ef8a466fa75502 100644 --- a/ngraph/core/src/interval.cpp +++ b/ngraph/core/src/interval.cpp @@ -6,6 +6,46 @@ using namespace ngraph; +namespace +{ + Interval::value_type clip(Interval::value_type value) + { + return std::max(Interval::value_type(0), std::min(Interval::s_max, value)); + } + + Interval::value_type clip_times(Interval::value_type a, Interval::value_type b) + { + if (a == 0 || b == 0) + { + return 0; + } + else if (a == Interval::s_max || b == Interval::s_max) + { + return Interval::s_max; + } + else + { + return a * b; + } + } + Interval::value_type clip_add(Interval::value_type a, Interval::value_type b) + { + return (a == Interval::s_max || b == Interval::s_max) ? Interval::s_max : a + b; + } + Interval::value_type clip_minus(Interval::value_type a, Interval::value_type b) + { + if (a <= b) + { + return 0; + } + if (a == Interval::s_max) + { + return Interval::s_max; + } + return a - b; + } +} // namespace + void Interval::canonicalize() { if (m_max_val < m_min_val) @@ -28,22 +68,9 @@ Interval::Interval(value_type min_val, value_type max_val) } Interval::Interval(value_type val) - : Interval(val, val) -{ -} - -Interval::size_type Interval::size() const -{ - if (m_max_val == s_max) - { - return m_min_val == s_max ? 0 : s_max; - } - return m_max_val - m_min_val + 1; -} - -bool Interval::empty() const { - return m_min_val == s_max; + m_min_val = clip(val); + m_max_val = m_min_val; } bool Interval::operator==(const Interval& interval) const @@ -116,55 +143,11 @@ Interval& Interval::operator&=(const Interval& interval) return *this = *this & interval; } -bool Interval::contains(value_type value) const -{ - return m_min_val <= value && value <= m_max_val; -} - bool Interval::contains(const Interval& interval) const { return contains(interval.m_min_val) && contains(interval.m_max_val); } -Interval::value_type Interval::clip(value_type value) -{ - return std::max(value_type(0), std::min(s_max, value)); -} - -Interval::value_type Interval::clip_add(value_type a, value_type b) -{ - return (a == s_max || b == s_max) ? s_max : a + b; -} - -Interval::value_type Interval::clip_minus(value_type a, value_type b) -{ - if (a <= b) - { - return 0; - } - if (a == s_max) - { - return s_max; - } - return a - b; -} - -Interval::value_type Interval::clip_times(value_type a, value_type b) -{ - if (a == 0 || b == 0) - { - return 0; - } - else if (a == s_max || b == s_max) - { - return s_max; - } - else - { - return a * b; - } -} - constexpr Interval::value_type Interval::s_max; namespace ngraph diff --git a/ngraph/core/src/partial_shape.cpp b/ngraph/core/src/partial_shape.cpp index e02425c4daa7dc..c5222863a1bc1e 100644 --- a/ngraph/core/src/partial_shape.cpp +++ b/ngraph/core/src/partial_shape.cpp @@ -34,15 +34,15 @@ PartialShape::PartialShape(const Shape& shape) { } -PartialShape::PartialShape(bool rank_is_static, const std::vector& dimensions) +PartialShape::PartialShape(bool rank_is_static, std::vector dimensions) : m_rank_is_static(rank_is_static) - , m_dimensions(dimensions) + , m_dimensions(std::move(dimensions)) { } -PartialShape::PartialShape(const std::vector& dimensions) +PartialShape::PartialShape(std::vector dimensions) : m_rank_is_static(true) - , m_dimensions(dimensions) + , m_dimensions(std::move(dimensions)) { } @@ -387,7 +387,7 @@ bool PartialShape::broadcast_merge_into(PartialShape& dst, i < (new_rank - src_rank) ? Dimension(1) : src[i - (new_rank - src_rank)]; success &= Dimension::broadcast_merge(dims[i], dsti, srci); } - dst = PartialShape(dims); + dst = PartialShape(std::move(dims)); return success; } } diff --git a/ngraph/core/src/type/element_type.cpp b/ngraph/core/src/type/element_type.cpp index 8d688fbf995104..fd91450cbc722f 100644 --- a/ngraph/core/src/type/element_type.cpp +++ b/ngraph/core/src/type/element_type.cpp @@ -12,45 +12,47 @@ #include "ngraph/type/element_type_traits.hpp" using namespace ngraph; -using namespace std; constexpr DiscreteTypeInfo AttributeAdapter::type_info; - -class TypeInfo +namespace { -public: - TypeInfo(size_t bitwidth, - bool is_real, - bool is_signed, - bool is_quantized, - const std::string& cname, - const std::string& type_name) - : m_bitwidth{bitwidth} - , m_is_real{is_real} - , m_is_signed{is_signed} - , m_is_quantized{is_quantized} - , m_cname{cname} - , m_type_name{type_name} + class TypeInfo { - } - size_t m_bitwidth; - bool m_is_real; - bool m_is_signed; - bool m_is_quantized; - std::string m_cname; - std::string m_type_name; -}; + public: + TypeInfo(size_t bitwidth, + bool is_real, + bool is_signed, + bool is_quantized, + const std::string& cname, + const std::string& type_name) + : m_bitwidth{bitwidth} + , m_is_real{is_real} + , m_is_signed{is_signed} + , m_is_quantized{is_quantized} + , m_cname{cname} + , m_type_name{type_name} + { + } + size_t m_bitwidth; + bool m_is_real; + bool m_is_signed; + bool m_is_quantized; + std::string m_cname; + std::string m_type_name; + }; -struct element_type_hash -{ - size_t operator()(element::Type_t t) const { return static_cast(t); } -}; + struct ElementTypes + { + struct TypeHash + { + size_t operator()(element::Type_t t) const { return static_cast(t); } + }; -typedef unordered_map element_types_map_t; + using ElementsMap = std::unordered_map; + static const ElementsMap elements_map; + }; -static const element_types_map_t& get_type_info_map() -{ - static element_types_map_t s_type_info_map{ + const ElementTypes::ElementsMap ElementTypes::elements_map{ {element::Type_t::undefined, TypeInfo( std::numeric_limits::max(), false, false, false, "undefined", "undefined")}, @@ -72,8 +74,20 @@ static const element_types_map_t& get_type_info_map() {element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t", "u32")}, {element::Type_t::u64, TypeInfo(64, false, false, false, "uint64_t", "u64")}, }; - return s_type_info_map; -}; + + const ElementTypes::ElementsMap& get_type_info_map() { return ElementTypes::elements_map; }; + + const TypeInfo& get_type_info(element::Type_t type) + { + const auto& tim = get_type_info_map(); + const auto& found = tim.find(type); + if (found == tim.end()) + { + throw std::out_of_range{"element::Type_t not supported"}; + } + return found->second; + }; +} // namespace std::vector element::Type::get_known_types() { @@ -103,7 +117,7 @@ element::Type::Type(size_t bitwidth, bool is_quantized, const std::string& /* cname */) { - for (auto& t : get_type_info_map()) + for (const auto& t : get_type_info_map()) { const TypeInfo& info = t.second; if (bitwidth == info.m_bitwidth && is_real == info.m_is_real && @@ -117,7 +131,7 @@ element::Type::Type(size_t bitwidth, const std::string& element::Type::c_type_string() const { - return get_type_info_map().at(m_type).m_cname; + return get_type_info(m_type).m_cname; } size_t element::Type::size() const @@ -132,7 +146,7 @@ size_t element::Type::hash() const const std::string& element::Type::get_type_name() const { - return get_type_info_map().at(m_type).m_type_name; + return get_type_info(m_type).m_type_name; } namespace ngraph @@ -247,12 +261,12 @@ bool element::Type::merge(element::Type& dst, const element::Type& t1, const ele bool element::Type::is_static() const { - return get_type_info_map().at(m_type).m_bitwidth != 0; + return get_type_info(m_type).m_bitwidth != 0; } bool element::Type::is_real() const { - return get_type_info_map().at(m_type).m_is_real; + return get_type_info(m_type).m_is_real; } bool element::Type::is_integral_number() const @@ -262,17 +276,17 @@ bool element::Type::is_integral_number() const bool element::Type::is_signed() const { - return get_type_info_map().at(m_type).m_is_signed; + return get_type_info(m_type).m_is_signed; } bool element::Type::is_quantized() const { - return get_type_info_map().at(m_type).m_is_quantized; + return get_type_info(m_type).m_is_quantized; } size_t element::Type::bitwidth() const { - return get_type_info_map().at(m_type).m_bitwidth; + return get_type_info(m_type).m_bitwidth; } size_t ngraph::compiler_byte_size(element::Type_t et)