diff --git a/ngraph/core/include/ngraph/descriptor/tensor.hpp b/ngraph/core/include/ngraph/descriptor/tensor.hpp index 381e528e531858..4dc14e57068e38 100644 --- a/ngraph/core/include/ngraph/descriptor/tensor.hpp +++ b/ngraph/core/include/ngraph/descriptor/tensor.hpp @@ -4,7 +4,9 @@ #pragma once +#include #include +#include #include #include @@ -74,16 +76,24 @@ namespace ngraph protected: element::Type m_element_type; - // TODO(amprocte): For now we are maintaining both m_shape and m_partial_shape fields, - // with m_shape possibly being invalid (get_shape will throw an exception if it - // is). This is because get_shape() returns a const reference. I think ideally we - // should refactor so that get_shape returns by value. - Shape m_shape; + // TODO: remove along with get_shape + // Initially there was ngraph::Shape m_shape only available to keep shape information. + // Support for dynamic shapes required transition to ngraph::PartialShape. + // To smoothly transition to ngraph::PartialShape we introduced m_partial_shape + // and kept m_shape in sync with m_partial_shape. Synchronization point was placed + // in set_partial_shape which dramatically affected performance of ngraph::Function + // validation. Since we have started the transition to ngraph::PartialShape and reduced + // ngraph::Shape usage the only user of m_shape was get_shape method with signature: + // const Shape& descriptor::Tensor::get_shape() const + // It was decided to move m_shape and m_partial_shape synchronization point there and + // to keep methods signature backward compatible. + mutable std::mutex shape_mutex; + mutable std::atomic_bool m_shape_changed; + mutable Shape m_shape; + // TODO: end + PartialShape m_partial_shape; - Node* m_node{nullptr}; HostTensorPtr m_lower_value, m_upper_value; - size_t m_node_output_number{0}; - std::string m_name; std::unordered_set m_names; }; diff --git a/ngraph/core/src/descriptor/tensor.cpp b/ngraph/core/src/descriptor/tensor.cpp index 1d8335fee080dc..f1da2fbdd52c79 100644 --- a/ngraph/core/src/descriptor/tensor.cpp +++ b/ngraph/core/src/descriptor/tensor.cpp @@ -4,7 +4,6 @@ #include "ngraph/descriptor/tensor.hpp" #include "ngraph/node.hpp" -#include "ngraph/runtime/host_tensor.hpp" using namespace ngraph; using namespace std; @@ -13,9 +12,9 @@ descriptor::Tensor::Tensor(const element::Type& element_type, const PartialShape& pshape, const std::string& name) : m_element_type(element_type) - , m_shape(pshape.is_static() ? pshape.to_shape() : Shape{}) , m_partial_shape(pshape) , m_name(name) + , m_shape_changed(true) { } @@ -24,10 +23,8 @@ descriptor::Tensor::Tensor(const element::Type& element_type, Node* node, size_t node_output_number) : m_element_type(element_type) - , m_shape(pshape.is_static() ? pshape.to_shape() : Shape{}) , m_partial_shape(pshape) - , m_node(node) - , m_node_output_number(node_output_number) + , m_shape_changed(true) { } @@ -46,14 +43,7 @@ void descriptor::Tensor::set_element_type(const element::Type& element_type) void descriptor::Tensor::set_partial_shape(const PartialShape& partial_shape) { m_partial_shape = partial_shape; - if (m_partial_shape.is_static()) - { - m_shape = m_partial_shape.to_shape(); - } - else - { - m_shape = Shape{}; - } + m_shape_changed = true; } void descriptor::Tensor::invalidate_values() @@ -82,6 +72,15 @@ const Shape& descriptor::Tensor::get_shape() const { if (m_partial_shape.is_static()) { + if (m_shape_changed.load(std::memory_order_relaxed)) + { + std::lock_guard guard(shape_mutex); + if (m_shape_changed) // double check after mutex lock + { + m_shape = m_partial_shape.to_shape(); + m_shape_changed = false; + } + } return m_shape; } else diff --git a/ngraph/core/src/node.cpp b/ngraph/core/src/node.cpp index 8d23c8f65bbb7e..d600333e900c71 100644 --- a/ngraph/core/src/node.cpp +++ b/ngraph/core/src/node.cpp @@ -210,7 +210,7 @@ descriptor::Output& Node::get_output_descriptor(size_t position) make_shared(element::dynamic, PartialShape::dynamic(), this, i); m_outputs.emplace_back(this, i, tensor_descriptor); } - return m_outputs.at(position); + return m_outputs[position]; } void Node::set_argument(size_t position, const Output& argument)