Skip to content

Commit

Permalink
set_output_type speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
Stepyreva, Evgenya committed Jul 23, 2021
1 parent 12fb83d commit f1614b9
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 23 deletions.
14 changes: 5 additions & 9 deletions ngraph/core/include/ngraph/descriptor/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

#pragma once

#include <mutex>
#include <memory>
#include <string>
#include <unordered_set>
#include <atomic>

#include "ngraph/partial_shape.hpp"
#include "ngraph/shape.hpp"
Expand Down Expand Up @@ -73,17 +75,11 @@ namespace ngraph

protected:
element::Type m_element_type;

// TODO(amprocte): For now we are maintaining both m_shape and m_partial_shape fields,
// with m_shape possibly being invalid (get_shape will throw an exception if it
// is). This is because get_shape() returns a const reference. I think ideally we
// should refactor so that get_shape returns by value.
Shape m_shape;
mutable std::atomic<bool> shape_changed;
mutable std::mutex shape_mutex;
mutable Shape m_shape; // TODO: remove along with get_shape
PartialShape m_partial_shape;
Node* m_node{nullptr};
HostTensorPtr m_lower_value, m_upper_value;
size_t m_node_output_number{0};

std::string m_name;
std::unordered_set<std::string> m_names;
};
Expand Down
22 changes: 9 additions & 13 deletions ngraph/core/src/descriptor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/node.hpp"
#include "ngraph/runtime/host_tensor.hpp"

using namespace ngraph;
using namespace std;
Expand All @@ -13,9 +12,9 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
const PartialShape& pshape,
const std::string& name)
: m_element_type(element_type)
, m_shape(pshape.is_static() ? pshape.to_shape() : Shape{})
, m_partial_shape(pshape)
, m_name(name)
, shape_changed(true)
{
}

Expand All @@ -24,10 +23,8 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
Node* node,
size_t node_output_number)
: m_element_type(element_type)
, m_shape(pshape.is_static() ? pshape.to_shape() : Shape{})
, m_partial_shape(pshape)
, m_node(node)
, m_node_output_number(node_output_number)
, shape_changed(true)
{
}

Expand All @@ -46,14 +43,7 @@ void descriptor::Tensor::set_element_type(const element::Type& element_type)
void descriptor::Tensor::set_partial_shape(const PartialShape& partial_shape)
{
m_partial_shape = partial_shape;
if (m_partial_shape.is_static())
{
m_shape = m_partial_shape.to_shape();
}
else
{
m_shape = Shape{};
}
shape_changed = true;
}

void descriptor::Tensor::invalidate_values()
Expand Down Expand Up @@ -82,6 +72,12 @@ const Shape& descriptor::Tensor::get_shape() const
{
if (m_partial_shape.is_static())
{
if (shape_changed) {
shape_mutex.lock();
m_shape = m_partial_shape.to_shape();
shape_mutex.unlock();
shape_changed = false;
}
return m_shape;
}
else
Expand Down
2 changes: 1 addition & 1 deletion ngraph/core/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ descriptor::Output& Node::get_output_descriptor(size_t position)
make_shared<descriptor::Tensor>(element::dynamic, PartialShape::dynamic(), this, i);
m_outputs.emplace_back(this, i, tensor_descriptor);
}
return m_outputs.at(position);
return m_outputs[position];
}

void Node::set_argument(size_t position, const Output<Node>& argument)
Expand Down

0 comments on commit f1614b9

Please sign in to comment.