Skip to content

Commit

Permalink
set_output_type speedup (#6754)
Browse files Browse the repository at this point in the history
* set_output_type speedup

* style

* Final optimization

* Removed extra include, removed unnecessary lock_guard

* Typo

* Apply suggestions from code review

Co-authored-by: Mikhail Nosov <[email protected]>

* Update ngraph/core/include/ngraph/descriptor/tensor.hpp

Co-authored-by: Mikhail Nosov <[email protected]>

Co-authored-by: Mikhail Nosov <[email protected]>
  • Loading branch information
Evgenya Stepyreva and nosovmik authored Jul 26, 2021
1 parent feb1eae commit b907c3b
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 22 deletions.
26 changes: 18 additions & 8 deletions ngraph/core/include/ngraph/descriptor/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

#pragma once

#include <atomic>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_set>

Expand Down Expand Up @@ -74,16 +76,24 @@ namespace ngraph
protected:
element::Type m_element_type;

// TODO(amprocte): For now we are maintaining both m_shape and m_partial_shape fields,
// with m_shape possibly being invalid (get_shape will throw an exception if it
// is). This is because get_shape() returns a const reference. I think ideally we
// should refactor so that get_shape returns by value.
Shape m_shape;
// TODO: remove along with get_shape
// Initially there was ngraph::Shape m_shape only available to keep shape information.
// Support for dynamic shapes required transition to ngraph::PartialShape.
// To smoothly transition to ngraph::PartialShape we introduced m_partial_shape
// and kept m_shape in sync with m_partial_shape. Synchronization point was placed
// in set_partial_shape which dramatically affected performance of ngraph::Function
// validation. Since we have started the transition to ngraph::PartialShape and reduced
// ngraph::Shape usage the only user of m_shape was get_shape method with signature:
// const Shape& descriptor::Tensor::get_shape() const
// It was decided to move m_shape and m_partial_shape synchronization point there and
// to keep methods signature backward compatible.
mutable std::mutex shape_mutex;
mutable std::atomic_bool m_shape_changed;
mutable Shape m_shape;
// TODO: end

PartialShape m_partial_shape;
Node* m_node{nullptr};
HostTensorPtr m_lower_value, m_upper_value;
size_t m_node_output_number{0};

std::string m_name;
std::unordered_set<std::string> m_names;
};
Expand Down
25 changes: 12 additions & 13 deletions ngraph/core/src/descriptor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/node.hpp"
#include "ngraph/runtime/host_tensor.hpp"

using namespace ngraph;
using namespace std;
Expand All @@ -13,9 +12,9 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
const PartialShape& pshape,
const std::string& name)
: m_element_type(element_type)
, m_shape(pshape.is_static() ? pshape.to_shape() : Shape{})
, m_partial_shape(pshape)
, m_name(name)
, m_shape_changed(true)
{
}

Expand All @@ -24,10 +23,8 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
Node* node,
size_t node_output_number)
: m_element_type(element_type)
, m_shape(pshape.is_static() ? pshape.to_shape() : Shape{})
, m_partial_shape(pshape)
, m_node(node)
, m_node_output_number(node_output_number)
, m_shape_changed(true)
{
}

Expand All @@ -46,14 +43,7 @@ void descriptor::Tensor::set_element_type(const element::Type& element_type)
void descriptor::Tensor::set_partial_shape(const PartialShape& partial_shape)
{
m_partial_shape = partial_shape;
if (m_partial_shape.is_static())
{
m_shape = m_partial_shape.to_shape();
}
else
{
m_shape = Shape{};
}
m_shape_changed = true;
}

void descriptor::Tensor::invalidate_values()
Expand Down Expand Up @@ -82,6 +72,15 @@ const Shape& descriptor::Tensor::get_shape() const
{
if (m_partial_shape.is_static())
{
if (m_shape_changed.load(std::memory_order_relaxed))
{
std::lock_guard<std::mutex> guard(shape_mutex);
if (m_shape_changed) // double check after mutex lock
{
m_shape = m_partial_shape.to_shape();
m_shape_changed = false;
}
}
return m_shape;
}
else
Expand Down
2 changes: 1 addition & 1 deletion ngraph/core/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ descriptor::Output& Node::get_output_descriptor(size_t position)
make_shared<descriptor::Tensor>(element::dynamic, PartialShape::dynamic(), this, i);
m_outputs.emplace_back(this, i, tensor_descriptor);
}
return m_outputs.at(position);
return m_outputs[position];
}

void Node::set_argument(size_t position, const Output<Node>& argument)
Expand Down

0 comments on commit b907c3b

Please sign in to comment.