Skip to content

Commit

Permalink
MultiSubgraph in nGraph (#6621)
Browse files Browse the repository at this point in the history
* Add multisubgraph

* Fix format

* Fix clang format

* Fix TensorIterator RTT

* Fix subgraph

* Fix codestyle

* Fix comments

* Fix comments

* Fix coments

* Fix comments

* delete get function

* fix methods

* fix ci

* Fix ci

* fix bugs

* Fix cmake

* Fix ci

* delete virtual function

* delete virtual function

* fix ci

* Fix ci
  • Loading branch information
evolosen authored Jul 23, 2021
1 parent 22b9431 commit f328eec
Show file tree
Hide file tree
Showing 7 changed files with 701 additions and 429 deletions.
6 changes: 2 additions & 4 deletions ngraph/core/include/ngraph/op/tensor_iterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@ namespace ngraph
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \return the body of the iteration
std::shared_ptr<Function> get_body() const { return m_body; }
std::shared_ptr<Function> get_body() const { return m_bodies[0]; }
/// \param body set the body of the iteration
void set_body(const std::shared_ptr<Function>& body) { m_body = body; }
void set_body(const std::shared_ptr<Function>& body) { set_function(body); }
void validate_and_infer_types() override;
void revalidate_and_infer_types_for_body_ops();
/// \return the body of the iteration
std::shared_ptr<Function> get_function() override;

private:
void try_to_set_num_iterations_if_no_slice_inputs();
Expand Down
366 changes: 366 additions & 0 deletions ngraph/core/include/ngraph/op/util/multi_subgraph_base.hpp

Large diffs are not rendered by default.

259 changes: 17 additions & 242 deletions ngraph/core/include/ngraph/op/util/sub_graph_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,234 +5,54 @@
#pragma once

#include <ngraph/op/parameter.hpp>
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/multi_subgraph_base.hpp"

namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Abstract base class for sub-graph based ops, i.e ops that have sub-graph
/// \brief Abstract base class for sub-graph based ops, i.e ops that have only one
/// sub-graph
///
class NGRAPH_API SubGraphOp : public Op
class NGRAPH_API SubGraphOp : public MultiSubGraphOp
{
public:
NGRAPH_RTTI_DECLARATION;
/// \brief Describes a connection between a SubGraphOp input and the body.
class InputDescription
{
protected:
///
/// \brief Constructs a new instance.
///
/// \param input_index Position of the SubGraphOp input
/// \param body_parameter_index Body parameter to receive input
///
InputDescription(uint64_t input_index, uint64_t body_parameter_index);
InputDescription() = default;

public:
using type_info_t = DiscreteTypeInfo;
virtual ~InputDescription() = default;
virtual std::shared_ptr<InputDescription> copy() const = 0;

virtual const type_info_t& get_type_info() const = 0;

uint64_t m_input_index{0};
uint64_t m_body_parameter_index{0};
};

///
/// \brief Describes a body input formed from slices of an input to
/// SubGraphOp.
///
class NGRAPH_API SliceInputDescription : public InputDescription
virtual const std::shared_ptr<Function>& get_function() const
{
public:
static constexpr type_info_t type_info{"SliceInputDescription", 0};
const type_info_t& get_type_info() const override { return type_info; }
///
/// \brief Constructs a new instance.
///
/// \param input_index Position of the SubGraphOp input
/// \param body_parameter_index Body parameter position to receive input
/// \param start First index for slices
/// \param stride Step amount for slices
/// \param part_size Width of slices
/// \param end Last index for slices
/// \param axis Axis being sliced
///
SliceInputDescription(uint64_t input_index,
uint64_t body_parameter_index,
int64_t start,
int64_t stride,
int64_t part_size,
int64_t end,
int64_t axis);
SliceInputDescription() = default;
std::shared_ptr<InputDescription> copy() const override;
int64_t m_start{0};
int64_t m_stride{0};
int64_t m_part_size{0};
int64_t m_end{0};
int64_t m_axis{0};
return m_bodies[0];
};

///
/// \brief Describes a body input initialized from a SubGraphOp input on
/// the first iteration, and then a body output thereafter.
///
class NGRAPH_API MergedInputDescription : public InputDescription
virtual void set_function(const std::shared_ptr<Function>& func)
{
public:
static constexpr type_info_t type_info{"MergedInputDescription", 0};
const type_info_t& get_type_info() const override { return type_info; }
///
/// \brief Constructs a new instance.
///
/// \param input_index Position of the SubGraphOp input
/// supplying a value to body_parameter for
/// the initial iteration.
/// \param body_parameter_index Body parameter position to receive input.
/// \param body_value_index Body value to supply body_parameter for
/// successive
/// iterations.
///
MergedInputDescription(uint64_t input_index,
uint64_t body_parameter_index,
uint64_t body_value_index);
MergedInputDescription() = default;
std::shared_ptr<InputDescription> copy() const override;
uint64_t m_body_value_index{0};
m_bodies[0] = func;
};

///
/// \brief Describes a body input initialized from a SubGraphOp input on
/// the first iteration, and invariant thereafter.
///
class NGRAPH_API InvariantInputDescription : public InputDescription
{
public:
static constexpr type_info_t type_info{"InvariantInputDescription", 0};
const type_info_t& get_type_info() const override { return type_info; }
///
/// \brief Constructs a new instance.
///
/// \param input_index Position of the SubGraphOp input
/// \param body_parameter_index Body parameter to receive input
///
InvariantInputDescription(uint64_t input_index, uint64_t body_parameter_index);
InvariantInputDescription() = default;
std::shared_ptr<InputDescription> copy() const override;
};

/// \brief Describes how a SubGraphOp output is produced from the body.
class OutputDescription
{
protected:
///
/// \brief Constructs a new instance.
///
/// \param body_value_index A body value that produces the output
/// \param output_index The SubGraphOp output index
///
OutputDescription(uint64_t body_value_index, uint64_t output_index);
OutputDescription() = default;

public:
using type_info_t = DiscreteTypeInfo;
virtual ~OutputDescription() = default;
virtual std::shared_ptr<OutputDescription> copy() const = 0;
virtual const type_info_t& get_type_info() const = 0;

uint64_t m_body_value_index{0};
uint64_t m_output_index{0};
};

/// \brief Produces an output by concatenating an output from each iteration
class NGRAPH_API ConcatOutputDescription : public OutputDescription
{
public:
static constexpr type_info_t type_info{"ConcatOutputDescription", 0};
const type_info_t& get_type_info() const override { return type_info; }
///
/// \brief Constructs a new instance.
///
/// \param body_value_index A body value that produces the output
/// \param output_index The SubGraphOp output index
/// \param start First index for slices
/// \param stride Step amount for slices
/// \param part_size Width of slices
/// \param end Last index for slices
/// \param axis Axis being sliced
///
ConcatOutputDescription(uint64_t body_value_index,
uint64_t output_index,
int64_t start,
int64_t stride,
int64_t part_size,
int64_t end,
int64_t axis);
ConcatOutputDescription() = default;

std::shared_ptr<OutputDescription> copy() const override;
int64_t m_start{0};
int64_t m_stride{0};
int64_t m_part_size{0};
int64_t m_end{0};
int64_t m_axis{0};
};

/// \brief Produces an output from a specific iteration
class NGRAPH_API BodyOutputDescription : public OutputDescription
{
public:
static constexpr type_info_t type_info{"BodyOutputDescription", 0};
const type_info_t& get_type_info() const override { return type_info; }
///
/// \brief Constructs a new instance.
///
/// \param body_value_index A body value that produces the output
/// \param output_index The SubGraphOp output index
/// \param iteration which iteration (typically -1, final) will
/// supply the value
///
BodyOutputDescription(uint64_t body_value_index,
uint64_t output_index,
int64_t iteration);
BodyOutputDescription() = default;
std::shared_ptr<OutputDescription> copy() const override;
int64_t m_iteration{0};
};

virtual std::shared_ptr<Function> get_function() { return m_body; };
virtual std::shared_ptr<const Function> get_function() const { return m_body; };
virtual void set_function(const std::shared_ptr<Function>& func) { m_body = func; };
/// \return a reference to the input descriptions.
const std::vector<std::shared_ptr<InputDescription>>& get_input_descriptions() const
{
return m_input_descriptions;
return m_input_descriptions[0];
}
/// \return a reference to the input descriptions. Can add input descriptions
/// before
/// validation.
std::vector<std::shared_ptr<InputDescription>>& get_input_descriptions()
{
return m_input_descriptions;
return m_input_descriptions[0];
}
/// \return a reference to the output descriptions.
const std::vector<std::shared_ptr<OutputDescription>>&
get_output_descriptions() const
{
return m_output_descriptions;
return m_output_descriptions[0];
}
/// \return a reference to the output descriptions. Can add output descriptions
/// before
/// validation.
std::vector<std::shared_ptr<OutputDescription>>& get_output_descriptions()
{
return m_output_descriptions;
return m_output_descriptions[0];
}

///
Expand Down Expand Up @@ -324,15 +144,13 @@ namespace ngraph
// Find an input corresponding to value, adding one if necessary.
Input<Node> input_for_value(const Output<Node>& value);

SubGraphOp() = default;

SubGraphOp();
explicit SubGraphOp(const OutputVector& args);

std::shared_ptr<Function> m_body;
std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>
m_input_descriptions;
std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>
m_output_descriptions;
private:
using MultiSubGraphOp::get_function;

using MultiSubGraphOp::set_function;
};
using InputDescriptionPtr = std::shared_ptr<util::SubGraphOp::InputDescription>;
using OutputDescriptionPtr = std::shared_ptr<util::SubGraphOp::OutputDescription>;
Expand All @@ -341,47 +159,4 @@ namespace ngraph
} // namespace util
} // namespace op

template <>
class NGRAPH_API AttributeAdapter<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>
: public DirectValueAccessor<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>
{
public:
AttributeAdapter(
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>& value)
: DirectValueAccessor<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>(
value)
{
}

static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::"
"InputDescription>>>",
0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};

template <>
class NGRAPH_API AttributeAdapter<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>
: public DirectValueAccessor<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>
{
public:
AttributeAdapter(
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>& value)
: DirectValueAccessor<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>(
value)
{
}

static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::"
"OutputDescription>>>",
0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // namespace ngraph
Loading

0 comments on commit f328eec

Please sign in to comment.