Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiSubgraph in nGraph #6621

Merged
merged 34 commits into from
Jul 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
56b755e
Merge pull request #1 from openvinotoolkit/master
evolosen Dec 3, 2020
fbfe44c
Merge pull request #2 from openvinotoolkit/master
evolosen Jan 18, 2021
0cf84d4
Merge pull request #3 from openvinotoolkit/master
evolosen Jan 28, 2021
0df8350
Merge pull request #4 from openvinotoolkit/master
evolosen Feb 5, 2021
c299e1d
Merge pull request #5 from openvinotoolkit/master
evolosen Feb 15, 2021
f4f54ac
Merge pull request #6 from openvinotoolkit/master
evolosen Feb 26, 2021
8feeba4
Update forked branch
evolosen Mar 1, 2021
fb93329
Merge pull request #8 from openvinotoolkit/master
evolosen Mar 29, 2021
2a19318
Merge pull request #9 from openvinotoolkit/master
evolosen Apr 13, 2021
c09041d
Update forked branch
evolosen Apr 26, 2021
e9f256e
Merge remote-tracking branch 'upstream/master'
evolosen Jul 13, 2021
c0f3a78
Add multisubgraph
evolosen Jul 13, 2021
050f232
Fix format
evolosen Jul 13, 2021
2606b0a
Fix clang format
evolosen Jul 13, 2021
81a0c98
Fix TensorIterator RTT
evolosen Jul 13, 2021
fe358ed
Fix subgraph
evolosen Jul 13, 2021
fcec7c4
Fix codestyle
evolosen Jul 13, 2021
8af90a9
Fix comments
evolosen Jul 16, 2021
aedbcff
Fix comments
evolosen Jul 16, 2021
7fa0e14
Fix coments
evolosen Jul 19, 2021
b6989b6
Fix comments
evolosen Jul 19, 2021
de03f26
delete get function
evolosen Jul 20, 2021
dbac5ed
fix methods
evolosen Jul 20, 2021
12cca12
fix ci
evolosen Jul 21, 2021
1cfbdcf
Merge remote-tracking branch 'upstream/master' into ngraph/myltisubgraph
evolosen Jul 21, 2021
d22ed76
Fix ci
evolosen Jul 21, 2021
4538737
fix bugs
evolosen Jul 21, 2021
cc29349
Merge remote-tracking branch 'upstream/master' into ngraph/myltisubgraph
evolosen Jul 21, 2021
572d3eb
Fix cmake
evolosen Jul 21, 2021
728f8d9
Fix ci
evolosen Jul 21, 2021
847ed45
delete virtual function
evolosen Jul 22, 2021
f5d5ad4
delete virtual function
evolosen Jul 22, 2021
23a59e1
fix ci
evolosen Jul 22, 2021
3d26e89
Fix ci
evolosen Jul 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions ngraph/core/include/ngraph/op/tensor_iterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@ namespace ngraph
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return the body of the iteration
std::shared_ptr<Function> get_body() const { return m_body; }
std::shared_ptr<Function> get_body() const { return m_bodies[0]; }
evolosen marked this conversation as resolved.
Show resolved Hide resolved
evolosen marked this conversation as resolved.
Show resolved Hide resolved
/// \param body set the body of the iteration
void set_body(const std::shared_ptr<Function>& body) { m_body = body; }
void set_body(const std::shared_ptr<Function>& body) { set_function(body); }
void validate_and_infer_types() override;
void revalidate_and_infer_types_for_body_ops();
/// \return the body of the iteration
std::shared_ptr<Function> get_function() override;

private:
void try_to_set_num_iterations_if_no_slice_inputs();
Expand Down
366 changes: 366 additions & 0 deletions ngraph/core/include/ngraph/op/util/multi_subgraph_base.hpp

Large diffs are not rendered by default.

259 changes: 17 additions & 242 deletions ngraph/core/include/ngraph/op/util/sub_graph_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,234 +5,54 @@
#pragma once

#include <ngraph/op/parameter.hpp>
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/multi_subgraph_base.hpp"

namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Abstract base class for sub-graph based ops, i.e ops that have sub-graph
/// \brief Abstract base class for sub-graph based ops, i.e ops that have only one
/// sub-graph
///
class NGRAPH_API SubGraphOp : public Op
class NGRAPH_API SubGraphOp : public MultiSubGraphOp
{
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Describes a connection between a SubGraphOp input and the body.
class InputDescription
{
protected:
///
/// \brief Constructs a new instance.
///
/// \param input_index Position of the SubGraphOp input
/// \param body_parameter_index Body parameter to receive input
///
InputDescription(uint64_t input_index, uint64_t body_parameter_index);
InputDescription() = default;

public:
using type_info_t = DiscreteTypeInfo;
virtual ~InputDescription() = default;
virtual std::shared_ptr<InputDescription> copy() const = 0;

virtual const type_info_t& get_type_info() const = 0;

uint64_t m_input_index{0};
uint64_t m_body_parameter_index{0};
};

///
/// \brief Describes a body input formed from slices of an input to
/// SubGraphOp.
///
class NGRAPH_API SliceInputDescription : public InputDescription
virtual const std::shared_ptr<Function>& get_function() const
{
public:
static constexpr type_info_t type_info{"SliceInputDescription", 0};
const type_info_t& get_type_info() const override { return type_info; }
///
/// \brief Constructs a new instance.
///
/// \param input_index Position of the SubGraphOp input
/// \param body_parameter_index Body parameter position to receive input
/// \param start First index for slices
/// \param stride Step amount for slices
/// \param part_size Width of slices
/// \param end Last index for slices
/// \param axis Axis being sliced
///
SliceInputDescription(uint64_t input_index,
uint64_t body_parameter_index,
int64_t start,
int64_t stride,
int64_t part_size,
int64_t end,
int64_t axis);
SliceInputDescription() = default;
std::shared_ptr<InputDescription> copy() const override;
int64_t m_start{0};
int64_t m_stride{0};
int64_t m_part_size{0};
int64_t m_end{0};
int64_t m_axis{0};
return m_bodies[0];
};

///
/// \brief Describes a body input initialized from a SubGraphOp input on
/// the first iteration, and then a body output thereafter.
///
class NGRAPH_API MergedInputDescription : public InputDescription
virtual void set_function(const std::shared_ptr<Function>& func)
{
public:
static constexpr type_info_t type_info{"MergedInputDescription", 0};
const type_info_t& get_type_info() const override { return type_info; }
///
/// \brief Constructs a new instance.
///
/// \param input_index Position of the SubGraphOp input
/// supplying a value to body_parameter for
/// the initial iteration.
/// \param body_parameter_index Body parameter position to receive input.
/// \param body_value_index Body value to supply body_parameter for
/// successive
/// iterations.
///
MergedInputDescription(uint64_t input_index,
uint64_t body_parameter_index,
uint64_t body_value_index);
MergedInputDescription() = default;
std::shared_ptr<InputDescription> copy() const override;
uint64_t m_body_value_index{0};
m_bodies[0] = func;
};

///
/// \brief Describes a body input initialized from a SubGraphOp input on
/// the first iteration, and invariant thereafter.
///
class NGRAPH_API InvariantInputDescription : public InputDescription
{
public:
static constexpr type_info_t type_info{"InvariantInputDescription", 0};
const type_info_t& get_type_info() const override { return type_info; }
///
/// \brief Constructs a new instance.
///
/// \param input_index Position of the SubGraphOp input
/// \param body_parameter_index Body parameter to receive input
///
InvariantInputDescription(uint64_t input_index, uint64_t body_parameter_index);
InvariantInputDescription() = default;
std::shared_ptr<InputDescription> copy() const override;
};

/// \brief Describes how a SubGraphOp output is produced from the body.
class OutputDescription
{
protected:
///
/// \brief Constructs a new instance.
///
/// \param body_value_index A body value that produces the output
/// \param output_index The SubGraphOp output index
///
OutputDescription(uint64_t body_value_index, uint64_t output_index);
OutputDescription() = default;

public:
using type_info_t = DiscreteTypeInfo;
virtual ~OutputDescription() = default;
virtual std::shared_ptr<OutputDescription> copy() const = 0;
virtual const type_info_t& get_type_info() const = 0;

uint64_t m_body_value_index{0};
uint64_t m_output_index{0};
};

/// \brief Produces an output by concatenating an output from each iteration
class NGRAPH_API ConcatOutputDescription : public OutputDescription
{
public:
static constexpr type_info_t type_info{"ConcatOutputDescription", 0};
const type_info_t& get_type_info() const override { return type_info; }
///
/// \brief Constructs a new instance.
///
/// \param body_value_index A body value that produces the output
/// \param output_index The SubGraphOp output index
/// \param start First index for slices
/// \param stride Step amount for slices
/// \param part_size Width of slices
/// \param end Last index for slices
/// \param axis Axis being sliced
///
ConcatOutputDescription(uint64_t body_value_index,
uint64_t output_index,
int64_t start,
int64_t stride,
int64_t part_size,
int64_t end,
int64_t axis);
ConcatOutputDescription() = default;

std::shared_ptr<OutputDescription> copy() const override;
int64_t m_start{0};
int64_t m_stride{0};
int64_t m_part_size{0};
int64_t m_end{0};
int64_t m_axis{0};
};

/// \brief Produces an output from a specific iteration
class NGRAPH_API BodyOutputDescription : public OutputDescription
{
public:
static constexpr type_info_t type_info{"BodyOutputDescription", 0};
const type_info_t& get_type_info() const override { return type_info; }
///
/// \brief Constructs a new instance.
///
/// \param body_value_index A body value that produces the output
/// \param output_index The SubGraphOp output index
/// \param iteration which iteration (typically -1, final) will
/// supply the value
///
BodyOutputDescription(uint64_t body_value_index,
uint64_t output_index,
int64_t iteration);
BodyOutputDescription() = default;
std::shared_ptr<OutputDescription> copy() const override;
int64_t m_iteration{0};
};

virtual std::shared_ptr<Function> get_function() { return m_body; };
virtual std::shared_ptr<const Function> get_function() const { return m_body; };
virtual void set_function(const std::shared_ptr<Function>& func) { m_body = func; };
/// \return a reference to the input descriptions.
const std::vector<std::shared_ptr<InputDescription>>& get_input_descriptions() const
{
return m_input_descriptions;
return m_input_descriptions[0];
}
/// \return a reference to the input descriptions. Can add input descriptions
/// before
/// validation.
std::vector<std::shared_ptr<InputDescription>>& get_input_descriptions()
{
return m_input_descriptions;
return m_input_descriptions[0];
}
/// \return a reference to the output descriptions.
const std::vector<std::shared_ptr<OutputDescription>>&
get_output_descriptions() const
{
return m_output_descriptions;
return m_output_descriptions[0];
}
/// \return a reference to the output descriptions. Can add output descriptions
/// before
/// validation.
std::vector<std::shared_ptr<OutputDescription>>& get_output_descriptions()
{
return m_output_descriptions;
return m_output_descriptions[0];
}

///
Expand Down Expand Up @@ -324,15 +144,13 @@ namespace ngraph
// Find an input corresponding to value, adding one if necessary.
Input<Node> input_for_value(const Output<Node>& value);

SubGraphOp() = default;

SubGraphOp();
explicit SubGraphOp(const OutputVector& args);

std::shared_ptr<Function> m_body;
std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>
m_input_descriptions;
std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>
m_output_descriptions;
private:
using MultiSubGraphOp::get_function;

using MultiSubGraphOp::set_function;
};
using InputDescriptionPtr = std::shared_ptr<util::SubGraphOp::InputDescription>;
using OutputDescriptionPtr = std::shared_ptr<util::SubGraphOp::OutputDescription>;
Expand All @@ -341,47 +159,4 @@ namespace ngraph
} // namespace util
} // namespace op

template <>
class NGRAPH_API AttributeAdapter<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>
: public DirectValueAccessor<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>
{
public:
AttributeAdapter(
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>& value)
: DirectValueAccessor<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>(
value)
{
}

static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::"
"InputDescription>>>",
0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};

template <>
class NGRAPH_API AttributeAdapter<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>
: public DirectValueAccessor<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>
{
public:
AttributeAdapter(
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>& value)
: DirectValueAccessor<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>(
value)
{
}

static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::"
"OutputDescription>>>",
0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // namespace ngraph
Loading