Skip to content

Commit

Permalink
Add support for SparseTensor in Constant op (openvinotoolkit#5871)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ewa Tusień authored Jun 2, 2021
1 parent 34b02a5 commit c77a13a
Show file tree
Hide file tree
Showing 25 changed files with 1,643 additions and 0 deletions.
17 changes: 17 additions & 0 deletions ngraph/frontend/onnx_import/include/onnx_import/core/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ namespace ngraph
class Graph;
class Subgraph;
class Tensor;
class SparseTensor;
class Attribute;

class ONNX_IMPORTER_API Node
Expand Down Expand Up @@ -113,6 +114,10 @@ namespace ngraph
ONNX_IMPORTER_API Tensor Node::get_attribute_value(const std::string& name,
Tensor default_value) const;

template <>
ONNX_IMPORTER_API SparseTensor Node::get_attribute_value(const std::string& name,
SparseTensor default_value) const;

template <>
ONNX_IMPORTER_API Graph Node::get_attribute_value(const std::string& name,
Graph default_value) const;
Expand Down Expand Up @@ -147,6 +152,11 @@ namespace ngraph
Node::get_attribute_value(const std::string& name,
std::vector<Tensor> default_value) const;

template <>
ONNX_IMPORTER_API std::vector<SparseTensor>
Node::get_attribute_value(const std::string& name,
std::vector<SparseTensor> default_value) const;

template <>
ONNX_IMPORTER_API std::vector<Graph>
Node::get_attribute_value(const std::string& name,
Expand All @@ -170,6 +180,9 @@ namespace ngraph
template <>
ONNX_IMPORTER_API Tensor Node::get_attribute_value(const std::string& name) const;

template <>
ONNX_IMPORTER_API SparseTensor Node::get_attribute_value(const std::string& name) const;

template <>
ONNX_IMPORTER_API Subgraph Node::get_attribute_value(const std::string& name) const;

Expand Down Expand Up @@ -197,6 +210,10 @@ namespace ngraph
ONNX_IMPORTER_API std::vector<Tensor>
Node::get_attribute_value(const std::string& name) const;

template <>
ONNX_IMPORTER_API std::vector<SparseTensor>
Node::get_attribute_value(const std::string& name) const;

template <>
ONNX_IMPORTER_API std::vector<Graph>
Node::get_attribute_value(const std::string& name) const;
Expand Down
41 changes: 41 additions & 0 deletions ngraph/frontend/onnx_import/src/core/attribute.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <onnx/onnx_pb.h>

#include "core/sparse_tensor.hpp"
#include "core/tensor.hpp"
#include "ngraph/except.hpp"

Expand Down Expand Up @@ -231,6 +232,32 @@ namespace ngraph
}
}

template <>
inline SparseTensor get_value(const ONNX_NAMESPACE::AttributeProto& attribute)
{
if (attribute.type() !=
ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSOR)
{
throw error::attribute::InvalidData{attribute.type()};
}
return SparseTensor{attribute.sparse_tensor()};
}

template <>
inline std::vector<SparseTensor>
get_value(const ONNX_NAMESPACE::AttributeProto& attribute)
{
switch (attribute.type())
{
case ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSOR:
return {SparseTensor{attribute.sparse_tensor()}};
case ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSORS:
return {std::begin(attribute.sparse_tensors()),
std::end(attribute.sparse_tensors())};
default: throw error::attribute::InvalidData{attribute.type()};
}
}

} // namespace attribute

} // namespace detail
Expand All @@ -246,10 +273,12 @@ namespace ngraph
string = ONNX_NAMESPACE::AttributeProto_AttributeType_STRING,
tensor = ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR,
graph = ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH,
sparse_tensor = ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSOR,
float_point_array = ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS,
integer_array = ONNX_NAMESPACE::AttributeProto_AttributeType_INTS,
string_array = ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS,
tensor_array = ONNX_NAMESPACE::AttributeProto_AttributeType_TENSORS,
sparse_tensor_array = ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSORS,
graph_array = ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPHS
};

Expand All @@ -269,6 +298,8 @@ namespace ngraph
Type get_type() const { return static_cast<Type>(m_attribute_proto->type()); }
bool is_tensor() const { return get_type() == Type::tensor; }
bool is_tensor_array() const { return get_type() == Type::tensor_array; }
bool is_sparse_tensor() const { return get_type() == Type::sparse_tensor; }
bool is_sparse_tensor_array() const { return get_type() == Type::sparse_tensor_array; }
bool is_float() const { return get_type() == Type::float_point; }
bool is_float_array() const { return get_type() == Type::float_point_array; }
bool is_integer() const { return get_type() == Type::integer; }
Expand All @@ -278,6 +309,10 @@ namespace ngraph
bool is_graph() const { return get_type() == Type::graph; }
bool is_graph_array() const { return get_type() == Type::graph_array; }
Tensor get_tensor() const { return Tensor{m_attribute_proto->t()}; }
SparseTensor get_sparse_tensor() const
{
return SparseTensor{m_attribute_proto->sparse_tensor()};
}
float get_float() const { return m_attribute_proto->f(); }
int64_t get_integer() const { return m_attribute_proto->i(); }
const std::string& get_string() const { return m_attribute_proto->s(); }
Expand All @@ -291,6 +326,12 @@ namespace ngraph
std::end(m_attribute_proto->tensors())};
}

std::vector<SparseTensor> get_sparse_tensor_array() const
{
return {std::begin(m_attribute_proto->sparse_tensors()),
std::end(m_attribute_proto->sparse_tensors())};
}

std::vector<float> get_float_array() const
{
return {std::begin(m_attribute_proto->floats()),
Expand Down
29 changes: 29 additions & 0 deletions ngraph/frontend/onnx_import/src/core/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,14 @@ namespace ngraph
return m_pimpl->template get_attribute_value<Tensor>(name, std::move(default_value));
}

template <>
SparseTensor Node::get_attribute_value(const std::string& name,
SparseTensor default_value) const
{
return m_pimpl->template get_attribute_value<SparseTensor>(name,
std::move(default_value));
}

template <>
Graph Node::get_attribute_value(const std::string& name, Graph default_value) const
{
Expand Down Expand Up @@ -342,6 +350,15 @@ namespace ngraph
name, std::move(default_value));
}

template <>
std::vector<SparseTensor>
Node::get_attribute_value(const std::string& name,
std::vector<SparseTensor> default_value) const
{
return m_pimpl->template get_attribute_value<std::vector<SparseTensor>>(
name, std::move(default_value));
}

template <>
std::vector<Graph> Node::get_attribute_value(const std::string& name,
std::vector<Graph> default_value) const
Expand Down Expand Up @@ -386,6 +403,12 @@ namespace ngraph
return m_pimpl->template get_attribute_value<Tensor>(name);
}

template <>
SparseTensor Node::get_attribute_value(const std::string& name) const
{
return m_pimpl->template get_attribute_value<SparseTensor>(name);
}

template <>
Subgraph Node::get_attribute_value(const std::string& name) const
{
Expand Down Expand Up @@ -428,6 +451,12 @@ namespace ngraph
return m_pimpl->template get_attribute_value<std::vector<Tensor>>(name);
}

template <>
std::vector<SparseTensor> Node::get_attribute_value(const std::string& name) const
{
return m_pimpl->template get_attribute_value<std::vector<SparseTensor>>(name);
}

template <>
std::vector<Graph> Node::get_attribute_value(const std::string& name) const
{
Expand Down
65 changes: 65 additions & 0 deletions ngraph/frontend/onnx_import/src/core/sparse_tensor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <onnx/onnx_pb.h>
#include <vector>

#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "tensor.hpp"

namespace ngraph
{
namespace onnx_import
{
class SparseTensor
{
public:
SparseTensor() = delete;
explicit SparseTensor(const ONNX_NAMESPACE::SparseTensorProto& sparse_tensor)
: m_sparse_tensor_proto{&sparse_tensor}
, m_values{sparse_tensor.values()}
, m_indices{sparse_tensor.indices()}
, m_shape{std::begin(sparse_tensor.dims()), std::end(sparse_tensor.dims())}
{
if (m_shape == Shape{0})
{
// It's possible to construct a sparse tensor in ONNX with "dims: 0" property
// Such tensor contains a scalar. This results in a Shape{0} stored in m_shape.
// In nGraph a scalar is represented with Shape{} and thus this replacement.
m_shape = Shape{};
}
}

SparseTensor(const SparseTensor&) = default;
SparseTensor(SparseTensor&&) = default;

SparseTensor& operator=(const SparseTensor&) = delete;
SparseTensor& operator=(SparseTensor&&) = delete;

const Shape& get_shape() const { return m_shape; }

const std::string& get_name() const { return m_values.get_name(); }

const Tensor& get_values() const { return m_values; }

const Tensor& get_indices() const { return m_indices; }

const element::Type& get_ng_type() const { return m_values.get_ng_type(); }

private:
const ONNX_NAMESPACE::SparseTensorProto* m_sparse_tensor_proto;
Tensor m_values;
Tensor m_indices;
Shape m_shape;
};

inline std::ostream& operator<<(std::ostream& outs, const SparseTensor& tensor)
{
return (outs << "<Sparse Tensor>");
}
} // namespace onnx_import
} // namespace ngraph
Loading

0 comments on commit c77a13a

Please sign in to comment.