From f328eeced62b13533ec55d110ef784b7e147726e Mon Sep 17 00:00:00 2001 From: Eugeny Volosenkov Date: Fri, 23 Jul 2021 07:03:25 +0300 Subject: [PATCH] MultiSubgraph in nGraph (#6621) * Add multisubgraph * Fix format * Fix clang format * Fix TensorIterator RTT * Fix subgraph * Fix codestyle * Fix comments * Fix comments * Fix coments * Fix comments * delete get function * fix methods * fix ci * Fix ci * fix bugs * Fix cmake * Fix ci * delete virtual function * delete virtual function * fix ci * Fix ci --- .../include/ngraph/op/tensor_iterator.hpp | 6 +- .../ngraph/op/util/multi_subgraph_base.hpp | 366 ++++++++++++++++++ .../include/ngraph/op/util/sub_graph_base.hpp | 259 +------------ ngraph/core/src/op/loop.cpp | 72 ++-- ngraph/core/src/op/tensor_iterator.cpp | 71 ++-- .../core/src/op/util/multi_subgraph_base.cpp | 210 ++++++++++ ngraph/core/src/op/util/sub_graph_base.cpp | 146 +------ 7 files changed, 701 insertions(+), 429 deletions(-) create mode 100644 ngraph/core/include/ngraph/op/util/multi_subgraph_base.hpp create mode 100644 ngraph/core/src/op/util/multi_subgraph_base.cpp diff --git a/ngraph/core/include/ngraph/op/tensor_iterator.hpp b/ngraph/core/include/ngraph/op/tensor_iterator.hpp index de44f4b638edeb..41e02d43ca94c9 100644 --- a/ngraph/core/include/ngraph/op/tensor_iterator.hpp +++ b/ngraph/core/include/ngraph/op/tensor_iterator.hpp @@ -30,13 +30,11 @@ namespace ngraph std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; /// \return the body of the iteration - std::shared_ptr get_body() const { return m_body; } + std::shared_ptr get_body() const { return m_bodies[0]; } /// \param body set the body of the iteration - void set_body(const std::shared_ptr& body) { m_body = body; } + void set_body(const std::shared_ptr& body) { set_function(body); } void validate_and_infer_types() override; void revalidate_and_infer_types_for_body_ops(); - /// \return the body of the iteration - std::shared_ptr get_function() override; private: void try_to_set_num_iterations_if_no_slice_inputs(); diff --git a/ngraph/core/include/ngraph/op/util/multi_subgraph_base.hpp b/ngraph/core/include/ngraph/op/util/multi_subgraph_base.hpp new file mode 100644 index 00000000000000..c50b98aa4c757a --- /dev/null +++ b/ngraph/core/include/ngraph/op/util/multi_subgraph_base.hpp @@ -0,0 +1,366 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include "ngraph/op/op.hpp" + +namespace ngraph +{ + namespace op + { + namespace util + { + /// \brief Abstract base class for sub-graph based ops, i.e ops that have some + /// sub-graphs + /// + class NGRAPH_API MultiSubGraphOp : public Op + { + public: + NGRAPH_RTTI_DECLARATION; + /// \brief Abstract class describes a connection between a MultiSubGraphOp input and + /// the body. + class InputDescription + { + protected: + /// + /// \brief Constructs a new instance. + /// + /// \param input_index Position of the MultiSubGraphOp input + /// \param body_parameter_index Body parameter to receive input + /// + InputDescription(uint64_t input_index, uint64_t body_parameter_index); + InputDescription() = default; + + public: + using type_info_t = DiscreteTypeInfo; + virtual ~InputDescription() = default; + virtual std::shared_ptr copy() const = 0; + + virtual const type_info_t& get_type_info() const = 0; + + uint64_t m_input_index{0}; + uint64_t m_body_parameter_index{0}; + }; + + /// \brief Abstract class describes how a MultiSubGraphOp output is produced from + /// the body. + class OutputDescription + { + protected: + /// + /// \brief Constructs a new instance. + /// + /// \param body_value_index A body value that produces the output + /// \param output_index The MultiSubGraphOp output index + /// + OutputDescription(uint64_t body_value_index, uint64_t output_index); + OutputDescription() = default; + + public: + using type_info_t = DiscreteTypeInfo; + virtual ~OutputDescription() = default; + virtual std::shared_ptr copy() const = 0; + virtual const type_info_t& get_type_info() const = 0; + + uint64_t m_body_value_index{0}; + uint64_t m_output_index{0}; + }; + + /// + /// \brief Describes a body input formed from slices of an input to + /// MultiSubGraphOp. + /// + class NGRAPH_API SliceInputDescription : public InputDescription + { + public: + NGRAPH_RTTI_DECLARATION; + /// + /// \brief Constructs a new instance. + /// + /// \param input_index Position of the MultiSubGraphOp input + /// \param body_parameter_index Body parameter position to receive input + /// \param start First index for slices + /// \param stride Step amount for slices + /// \param part_size Width of slices + /// \param end Last index for slices + /// \param axis Axis being sliced + /// + SliceInputDescription(uint64_t input_index, + uint64_t body_parameter_index, + int64_t start, + int64_t stride, + int64_t part_size, + int64_t end, + int64_t axis); + SliceInputDescription() = default; + std::shared_ptr copy() const override; + int64_t m_start{0}; + int64_t m_stride{0}; + int64_t m_part_size{0}; + int64_t m_end{0}; + int64_t m_axis{0}; + }; + + /// + /// \brief Describes a body input initialized from a MultiSubGraphOp input + /// on the first iteration, and then a body output thereafter. + /// + class NGRAPH_API MergedInputDescription : public InputDescription + { + public: + NGRAPH_RTTI_DECLARATION; + /// + /// \brief Constructs a new instance. + /// + /// \param input_index Position of the MultiSubGraphOp input + /// supplying a value to body_parameter for + /// the initial iteration. + /// \param body_parameter_index Body parameter position to receive input. + /// \param body_value_index Body value to supply body_parameter for + /// successive + /// iterations. + /// + MergedInputDescription(uint64_t input_index, + uint64_t body_parameter_index, + uint64_t body_value_index); + MergedInputDescription() = default; + std::shared_ptr copy() const override; + uint64_t m_body_value_index{0}; + }; + + /// \brief Produces an output by concatenating an output from each iteration + class NGRAPH_API ConcatOutputDescription : public OutputDescription + { + public: + NGRAPH_RTTI_DECLARATION; + /// + /// \brief Constructs a new instance. + /// + /// \param body_value_index A body value that produces the output + /// \param output_index The MultiSubGraphOp output index + /// \param start First index for slices + /// \param stride Step amount for slices + /// \param part_size Width of slices + /// \param end Last index for slices + /// \param axis Axis being sliced + /// + ConcatOutputDescription(uint64_t body_value_index, + uint64_t output_index, + int64_t start, + int64_t stride, + int64_t part_size, + int64_t end, + int64_t axis); + ConcatOutputDescription() = default; + + std::shared_ptr copy() const override; + int64_t m_start{0}; + int64_t m_stride{0}; + int64_t m_part_size{0}; + int64_t m_end{0}; + int64_t m_axis{0}; + }; + + /// \brief Produces an input + class NGRAPH_API InvariantInputDescription : public InputDescription + { + public: + NGRAPH_RTTI_DECLARATION; + /// + /// \brief Constructs a new instance. + /// + /// \param input_index Position of the MultiSubGraphOp input + /// \param body_parameter_index Body parameter to receive input + /// + InvariantInputDescription(uint64_t input_index, uint64_t body_parameter_index); + InvariantInputDescription() = default; + std::shared_ptr copy() const override; + }; + + /// \brief Produces an output from a specific iteration + class NGRAPH_API BodyOutputDescription : public MultiSubGraphOp::OutputDescription + { + public: + NGRAPH_RTTI_DECLARATION; + /// + /// \brief Constructs a new instance. + /// + /// \param body_value_index A body value that produces the output + /// \param output_index The SubGraphOp output index + /// \param iteration which iteration (typically -1, final) will + /// supply the value + /// + BodyOutputDescription(uint64_t body_value_index, + uint64_t output_index, + int64_t iteration = -1); + BodyOutputDescription() = default; + std::shared_ptr copy() const override; + int64_t m_iteration{0}; + }; + using MultiSubgraphInputDescriptionPtr = + std::shared_ptr; + using MultiSubgraphOutputDescriptionPtr = + std::shared_ptr; + using MultiSubgraphInputDescriptionVector = + std::vector; + using MultiSubgraphOutputDescriptionVector = + std::vector; + + /// \brief Gets internal sub-graph by index in MultiSubGraphOp + /// + /// \param index sub-graph's index in op + /// \return pointer to ngraph::Function with sub-graph + virtual const std::shared_ptr& get_function(int index) const + { + return m_bodies[index]; + }; + /// \brief Adds sub-graph to MultiSubGraphOp + /// + /// \param index index of new sub-graph + /// \param func func new sub_graph as ngraph::Function + virtual void set_function(int index, const std::shared_ptr& func) + { + m_bodies[index] = func; + } + /// \brief Gets vector with connections beewtwen operation inputs + /// and internal sub-graph parameters + /// + /// \param index index of internal sub-graph + /// \return vector of input descriptions + const MultiSubgraphInputDescriptionVector& get_input_descriptions(int index) const + { + return m_input_descriptions[index]; + } + /// \brief Gets vector with connections beewtwen operation inputs + /// and internal sub-graph parameters + /// + /// \param index index of internal sub-graph + /// \return vector of input descriptions + MultiSubgraphInputDescriptionVector& get_input_descriptions(int index) + { + return m_input_descriptions[index]; + } + /// \brief Gets vector with connections beewtwen operation outputs + /// and internal sub-graph results + /// + /// \param index index of internal sub-graph + /// \return vector of output descriptions + const MultiSubgraphOutputDescriptionVector& get_output_descriptions(int index) const + { + return m_output_descriptions[index]; + } + /// \brief Gets vector with connections beewtwen operation outputs + /// and internal sub-graph results + /// + /// \param index index of internal sub-graph + /// \return vector of output descriptions + MultiSubgraphOutputDescriptionVector& get_output_descriptions(int index) + { + return m_output_descriptions[index]; + } + /// \brief Sets vector with connections beewtwen operation inputs + /// and internal sub-graph parameters + /// + /// \param index index of internal sub-graph + /// \param inputs vector of input descriptions + void set_input_descriptions(int index, + const MultiSubgraphInputDescriptionVector& inputs) + { + m_input_descriptions[index] = inputs; + } + + /// \brief Sets vector with connections beewtwen operation outputs + /// and internal sub-graph results + /// + /// \param index index of internal sub-graph + /// \param outputs vector of input descriptions + void set_output_descriptions(int index, + const MultiSubgraphOutputDescriptionVector& outputs) + { + m_output_descriptions[index] = outputs; + } + + /// + /// \brief Set input decriptions for MultiSubGraphOp input. + /// + /// \param value The value supplied as an input to the block. + /// \param bodies_parameters vector of bodies parameters. + virtual void set_invariant_inputs(const Output& value, + const ParameterVector& bodies_parameters); + /// + /// \brief Set output decriptions for MultiSubGraphOp output. + /// + /// \param bodies_results vector of bodies results for one output. + /// \return value Output node for bodies_results. + virtual Output set_body_outputs(const ResultVector& bodies_results); + + MultiSubGraphOp(const MultiSubGraphOp&) = delete; + MultiSubGraphOp(MultiSubGraphOp&&) = default; + + MultiSubGraphOp& operator=(const MultiSubGraphOp&) = delete; + MultiSubGraphOp& operator=(MultiSubGraphOp&&) = default; + + protected: + // Find an input corresponding to value, adding one if necessary. + Input input_for_value(const Output& value); + + MultiSubGraphOp(size_t number_of_bodies); + MultiSubGraphOp() = default; + MultiSubGraphOp(const OutputVector& args, size_t number_of_bodies); + explicit MultiSubGraphOp(const OutputVector& args); + + std::vector> m_bodies; + std::vector m_input_descriptions; + std::vector m_output_descriptions; + }; + using MultiSubgraphInputDescriptionPtr = + util::MultiSubGraphOp::MultiSubgraphInputDescriptionPtr; + using MultiSubgraphOutputDescriptionPtr = + util::MultiSubGraphOp::MultiSubgraphOutputDescriptionPtr; + using MultiSubgraphInputDescriptionVector = + util::MultiSubGraphOp::MultiSubgraphInputDescriptionVector; + using MultiSubgraphOutputDescriptionVector = + util::MultiSubGraphOp::MultiSubgraphOutputDescriptionVector; + + } // namespace util + } // namespace op + + template <> + class NGRAPH_API AttributeAdapter< + std::vector>> + : public DirectValueAccessor< + std::vector>> + { + public: + AttributeAdapter( + std::vector>& + value) + : DirectValueAccessor>>(value) + { + } + + NGRAPH_RTTI_DECLARATION; + }; + + template <> + class NGRAPH_API AttributeAdapter< + std::vector>> + : public DirectValueAccessor< + std::vector>> + { + public: + AttributeAdapter( + std::vector>& + value) + : DirectValueAccessor>>(value) + { + } + + NGRAPH_RTTI_DECLARATION; + }; +} // namespace ngraph diff --git a/ngraph/core/include/ngraph/op/util/sub_graph_base.hpp b/ngraph/core/include/ngraph/op/util/sub_graph_base.hpp index a44c830b78c6c9..44701e382e5ba3 100644 --- a/ngraph/core/include/ngraph/op/util/sub_graph_base.hpp +++ b/ngraph/core/include/ngraph/op/util/sub_graph_base.hpp @@ -5,7 +5,7 @@ #pragma once #include -#include "ngraph/op/op.hpp" +#include "ngraph/op/util/multi_subgraph_base.hpp" namespace ngraph { @@ -13,226 +13,46 @@ namespace ngraph { namespace util { - /// \brief Abstract base class for sub-graph based ops, i.e ops that have sub-graph + /// \brief Abstract base class for sub-graph based ops, i.e ops that have only one + /// sub-graph /// - class NGRAPH_API SubGraphOp : public Op + class NGRAPH_API SubGraphOp : public MultiSubGraphOp { public: NGRAPH_RTTI_DECLARATION; - /// \brief Describes a connection between a SubGraphOp input and the body. - class InputDescription - { - protected: - /// - /// \brief Constructs a new instance. - /// - /// \param input_index Position of the SubGraphOp input - /// \param body_parameter_index Body parameter to receive input - /// - InputDescription(uint64_t input_index, uint64_t body_parameter_index); - InputDescription() = default; - - public: - using type_info_t = DiscreteTypeInfo; - virtual ~InputDescription() = default; - virtual std::shared_ptr copy() const = 0; - - virtual const type_info_t& get_type_info() const = 0; - - uint64_t m_input_index{0}; - uint64_t m_body_parameter_index{0}; - }; - /// - /// \brief Describes a body input formed from slices of an input to - /// SubGraphOp. - /// - class NGRAPH_API SliceInputDescription : public InputDescription + virtual const std::shared_ptr& get_function() const { - public: - static constexpr type_info_t type_info{"SliceInputDescription", 0}; - const type_info_t& get_type_info() const override { return type_info; } - /// - /// \brief Constructs a new instance. - /// - /// \param input_index Position of the SubGraphOp input - /// \param body_parameter_index Body parameter position to receive input - /// \param start First index for slices - /// \param stride Step amount for slices - /// \param part_size Width of slices - /// \param end Last index for slices - /// \param axis Axis being sliced - /// - SliceInputDescription(uint64_t input_index, - uint64_t body_parameter_index, - int64_t start, - int64_t stride, - int64_t part_size, - int64_t end, - int64_t axis); - SliceInputDescription() = default; - std::shared_ptr copy() const override; - int64_t m_start{0}; - int64_t m_stride{0}; - int64_t m_part_size{0}; - int64_t m_end{0}; - int64_t m_axis{0}; + return m_bodies[0]; }; - - /// - /// \brief Describes a body input initialized from a SubGraphOp input on - /// the first iteration, and then a body output thereafter. - /// - class NGRAPH_API MergedInputDescription : public InputDescription + virtual void set_function(const std::shared_ptr& func) { - public: - static constexpr type_info_t type_info{"MergedInputDescription", 0}; - const type_info_t& get_type_info() const override { return type_info; } - /// - /// \brief Constructs a new instance. - /// - /// \param input_index Position of the SubGraphOp input - /// supplying a value to body_parameter for - /// the initial iteration. - /// \param body_parameter_index Body parameter position to receive input. - /// \param body_value_index Body value to supply body_parameter for - /// successive - /// iterations. - /// - MergedInputDescription(uint64_t input_index, - uint64_t body_parameter_index, - uint64_t body_value_index); - MergedInputDescription() = default; - std::shared_ptr copy() const override; - uint64_t m_body_value_index{0}; + m_bodies[0] = func; }; - - /// - /// \brief Describes a body input initialized from a SubGraphOp input on - /// the first iteration, and invariant thereafter. - /// - class NGRAPH_API InvariantInputDescription : public InputDescription - { - public: - static constexpr type_info_t type_info{"InvariantInputDescription", 0}; - const type_info_t& get_type_info() const override { return type_info; } - /// - /// \brief Constructs a new instance. - /// - /// \param input_index Position of the SubGraphOp input - /// \param body_parameter_index Body parameter to receive input - /// - InvariantInputDescription(uint64_t input_index, uint64_t body_parameter_index); - InvariantInputDescription() = default; - std::shared_ptr copy() const override; - }; - - /// \brief Describes how a SubGraphOp output is produced from the body. - class OutputDescription - { - protected: - /// - /// \brief Constructs a new instance. - /// - /// \param body_value_index A body value that produces the output - /// \param output_index The SubGraphOp output index - /// - OutputDescription(uint64_t body_value_index, uint64_t output_index); - OutputDescription() = default; - - public: - using type_info_t = DiscreteTypeInfo; - virtual ~OutputDescription() = default; - virtual std::shared_ptr copy() const = 0; - virtual const type_info_t& get_type_info() const = 0; - - uint64_t m_body_value_index{0}; - uint64_t m_output_index{0}; - }; - - /// \brief Produces an output by concatenating an output from each iteration - class NGRAPH_API ConcatOutputDescription : public OutputDescription - { - public: - static constexpr type_info_t type_info{"ConcatOutputDescription", 0}; - const type_info_t& get_type_info() const override { return type_info; } - /// - /// \brief Constructs a new instance. - /// - /// \param body_value_index A body value that produces the output - /// \param output_index The SubGraphOp output index - /// \param start First index for slices - /// \param stride Step amount for slices - /// \param part_size Width of slices - /// \param end Last index for slices - /// \param axis Axis being sliced - /// - ConcatOutputDescription(uint64_t body_value_index, - uint64_t output_index, - int64_t start, - int64_t stride, - int64_t part_size, - int64_t end, - int64_t axis); - ConcatOutputDescription() = default; - - std::shared_ptr copy() const override; - int64_t m_start{0}; - int64_t m_stride{0}; - int64_t m_part_size{0}; - int64_t m_end{0}; - int64_t m_axis{0}; - }; - - /// \brief Produces an output from a specific iteration - class NGRAPH_API BodyOutputDescription : public OutputDescription - { - public: - static constexpr type_info_t type_info{"BodyOutputDescription", 0}; - const type_info_t& get_type_info() const override { return type_info; } - /// - /// \brief Constructs a new instance. - /// - /// \param body_value_index A body value that produces the output - /// \param output_index The SubGraphOp output index - /// \param iteration which iteration (typically -1, final) will - /// supply the value - /// - BodyOutputDescription(uint64_t body_value_index, - uint64_t output_index, - int64_t iteration); - BodyOutputDescription() = default; - std::shared_ptr copy() const override; - int64_t m_iteration{0}; - }; - - virtual std::shared_ptr get_function() { return m_body; }; - virtual std::shared_ptr get_function() const { return m_body; }; - virtual void set_function(const std::shared_ptr& func) { m_body = func; }; /// \return a reference to the input descriptions. const std::vector>& get_input_descriptions() const { - return m_input_descriptions; + return m_input_descriptions[0]; } /// \return a reference to the input descriptions. Can add input descriptions /// before /// validation. std::vector>& get_input_descriptions() { - return m_input_descriptions; + return m_input_descriptions[0]; } /// \return a reference to the output descriptions. const std::vector>& get_output_descriptions() const { - return m_output_descriptions; + return m_output_descriptions[0]; } /// \return a reference to the output descriptions. Can add output descriptions /// before /// validation. std::vector>& get_output_descriptions() { - return m_output_descriptions; + return m_output_descriptions[0]; } /// @@ -324,15 +144,13 @@ namespace ngraph // Find an input corresponding to value, adding one if necessary. Input input_for_value(const Output& value); - SubGraphOp() = default; - + SubGraphOp(); explicit SubGraphOp(const OutputVector& args); - std::shared_ptr m_body; - std::vector> - m_input_descriptions; - std::vector> - m_output_descriptions; + private: + using MultiSubGraphOp::get_function; + + using MultiSubGraphOp::set_function; }; using InputDescriptionPtr = std::shared_ptr; using OutputDescriptionPtr = std::shared_ptr; @@ -341,47 +159,4 @@ namespace ngraph } // namespace util } // namespace op - template <> - class NGRAPH_API AttributeAdapter< - std::vector>> - : public DirectValueAccessor< - std::vector>> - { - public: - AttributeAdapter( - std::vector>& value) - : DirectValueAccessor< - std::vector>>( - value) - { - } - - static constexpr DiscreteTypeInfo type_info{ - "AttributeAdapter>>", - 0}; - const DiscreteTypeInfo& get_type_info() const override { return type_info; } - }; - - template <> - class NGRAPH_API AttributeAdapter< - std::vector>> - : public DirectValueAccessor< - std::vector>> - { - public: - AttributeAdapter( - std::vector>& value) - : DirectValueAccessor< - std::vector>>( - value) - { - } - - static constexpr DiscreteTypeInfo type_info{ - "AttributeAdapter>>", - 0}; - const DiscreteTypeInfo& get_type_info() const override { return type_info; } - }; } // namespace ngraph diff --git a/ngraph/core/src/op/loop.cpp b/ngraph/core/src/op/loop.cpp index 2941d46d2e2345..b7cc41288100d9 100644 --- a/ngraph/core/src/op/loop.cpp +++ b/ngraph/core/src/op/loop.cpp @@ -18,6 +18,7 @@ using namespace ngraph; NGRAPH_RTTI_DEFINITION(op::v5::Loop, "Loop", 5); op::v5::Loop::Loop(const Output& trip_count, const Output& execution_condition) + : SubGraphOp() { set_argument(0, trip_count); set_argument(1, execution_condition); @@ -26,9 +27,9 @@ op::v5::Loop::Loop(const Output& trip_count, const Output& execution bool op::v5::Loop::visit_attributes(AttributeVisitor& visitor) { NGRAPH_OP_SCOPE(v5_Loop_visit_attributes); - visitor.on_attribute("body", m_body); - visitor.on_attribute("input_descriptions", m_input_descriptions); - visitor.on_attribute("output_descriptions", m_output_descriptions); + visitor.on_attribute("body", m_bodies[0]); + visitor.on_attribute("input_descriptions", m_input_descriptions[0]); + visitor.on_attribute("output_descriptions", m_output_descriptions[0]); visitor.on_attribute("special_body_ports", m_special_body_ports); return true; @@ -37,9 +38,21 @@ bool op::v5::Loop::visit_attributes(AttributeVisitor& visitor) void op::v5::Loop::validate_and_infer_types() { NGRAPH_OP_SCOPE(v5_Loop_validate_and_infer_types); + + NODE_VALIDATION_CHECK( + this, m_bodies.size() == 1, "Number of bodies for loop is greater than 1"); + + NODE_VALIDATION_CHECK(this, + m_input_descriptions.size() == 1, + "Loop contains input descriptions for other bodies"); + NODE_VALIDATION_CHECK(this, + m_output_descriptions.size() == 1, + "Loop contains output descriptions for other bodies"); + if (m_special_body_ports.current_iteration_input_idx >= 0) { - const auto& cur_iter_rank = m_body->get_parameters() + const auto& cur_iter_rank = m_bodies[0] + ->get_parameters() .at(m_special_body_ports.current_iteration_input_idx) ->get_partial_shape() .rank(); @@ -78,8 +91,10 @@ void op::v5::Loop::validate_and_infer_types() // special body ports were not set yet, so we can't calculate output shape return; - const auto& body_execution_condition = - m_body->get_results().at(m_special_body_ports.body_condition_output_idx)->input_value(0); + const auto& body_execution_condition = m_bodies[0] + ->get_results() + .at(m_special_body_ports.body_condition_output_idx) + ->input_value(0); const auto& body_condition_rank = body_execution_condition.get_partial_shape().rank(); if (body_condition_rank.is_static()) { @@ -110,7 +125,7 @@ void op::v5::Loop::validate_and_infer_types() // Const(true or false) -> Loop (body: Parameter -> execution_condition output) for (const auto& desc : get_input_descriptions()) { - if (m_body->get_parameters().at(desc->m_body_parameter_index) == cond_param) + if (m_bodies[0]->get_parameters().at(desc->m_body_parameter_index) == cond_param) { if (const auto& cond_value = get_constant_from_source(input_value(desc->m_input_index))) @@ -156,7 +171,7 @@ void op::v5::Loop::validate_and_infer_types() // the inputs. // When using visit_attributes() no duplication occurs, input_offset shall be decremented. size_t input_offset = 2; - for (const auto& in_desc : m_input_descriptions) + for (const auto& in_desc : m_input_descriptions[0]) { if (in_desc->m_input_index == 0 || in_desc->m_input_index == 1) { @@ -169,18 +184,18 @@ void op::v5::Loop::validate_and_infer_types() NODE_VALIDATION_CHECK(this, input_offset >= 0, "External port id 0 or 1 is duplicated."); NODE_VALIDATION_CHECK(this, - get_input_size() == m_input_descriptions.size() + input_offset, + get_input_size() == m_input_descriptions[0].size() + input_offset, "Number of inputs must be the same as number of input descriptions"); // Input - for (const auto& input_description : m_input_descriptions) + for (const auto& input_description : m_input_descriptions[0]) { auto index = input_description->m_input_index; if (auto slice_input_description = as_type_ptr(input_description)) { auto body_parameter = - m_body->get_parameters().at(slice_input_description->m_body_parameter_index); + m_bodies[0]->get_parameters().at(slice_input_description->m_body_parameter_index); const auto& input_partial_shape = inputs().at(index).get_source_output().get_partial_shape(); if (input_partial_shape.rank().is_dynamic()) @@ -200,10 +215,10 @@ void op::v5::Loop::validate_and_infer_types() as_type_ptr(input_description)) { auto body_value = - m_body->get_results().at(merged_input_description->m_body_value_index); + m_bodies[0]->get_results().at(merged_input_description->m_body_value_index); auto body_parameter = - m_body->get_parameters().at(merged_input_description->m_body_parameter_index); + m_bodies[0]->get_parameters().at(merged_input_description->m_body_parameter_index); auto body_param_partial_shape = body_parameter->get_partial_shape(); auto input_partial_shape = input(index).get_partial_shape(); @@ -213,8 +228,8 @@ void op::v5::Loop::validate_and_infer_types() else if (auto invariant_input_description = as_type_ptr(input_description)) { - auto body_parameter = - m_body->get_parameters().at(invariant_input_description->m_body_parameter_index); + auto body_parameter = m_bodies[0]->get_parameters().at( + invariant_input_description->m_body_parameter_index); auto body_param_partial_shape = body_parameter->get_partial_shape(); auto input_partial_shape = input(index).get_partial_shape(); @@ -224,15 +239,15 @@ void op::v5::Loop::validate_and_infer_types() } // Body - m_body->validate_nodes_and_infer_types(); + m_bodies[0]->validate_nodes_and_infer_types(); // Output - for (const auto& output_description : m_output_descriptions) + for (const auto& output_description : m_output_descriptions[0]) { auto index = output_description->m_output_index; auto body_value = - m_body->get_results().at(output_description->m_body_value_index)->input_value(0); + m_bodies[0]->get_results().at(output_description->m_body_value_index)->input_value(0); if (auto concat_output_description = as_type_ptr(output_description)) @@ -286,7 +301,7 @@ void op::v5::Loop::validate_and_infer_types() } NODE_VALIDATION_CHECK(this, - get_output_size() == m_output_descriptions.size(), + get_output_size() == m_output_descriptions[0].size(), "Number of outputs must be the same as number of output descriptions"); } @@ -322,8 +337,12 @@ Output op::v5::Loop::get_concatenated_slices(const Output& value, bool op::v5::Loop::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const { NGRAPH_OP_SCOPE(v5_Loop_evaluate); - runtime::reference::loop( - m_body, m_output_descriptions, m_input_descriptions, m_special_body_ports, outputs, inputs); + runtime::reference::loop(m_bodies[0], + m_output_descriptions[0], + m_input_descriptions[0], + m_special_body_ports, + outputs, + inputs); return true; } @@ -347,20 +366,21 @@ void op::v5::Loop::clone_to(op::v5::Loop& dst, const OutputVector& new_args) con dst.m_num_iterations = m_num_iterations; dst.m_special_body_ports = m_special_body_ports; - dst.m_body = clone_function(*get_function()); + dst.m_bodies[0] = clone_function(*get_function()); - for (auto& input_description : m_input_descriptions) + for (auto& input_description : m_input_descriptions[0]) { - dst.m_input_descriptions.push_back(input_description->copy()); + dst.m_input_descriptions[0].push_back(input_description->copy()); } - for (auto& output_description : m_output_descriptions) + for (auto& output_description : m_output_descriptions[0]) { - dst.m_output_descriptions.push_back(output_description->copy()); + dst.m_output_descriptions[0].push_back(output_description->copy()); } dst.validate_and_infer_types(); } op::v5::Loop::Loop(const op::v5::Loop& other) + : SubGraphOp() { other.clone_to(*this, other.input_values()); } diff --git a/ngraph/core/src/op/tensor_iterator.cpp b/ngraph/core/src/op/tensor_iterator.cpp index 6dffdaa77fe96b..0ae86f6052a9c2 100644 --- a/ngraph/core/src/op/tensor_iterator.cpp +++ b/ngraph/core/src/op/tensor_iterator.cpp @@ -21,9 +21,9 @@ op::v0::TensorIterator::TensorIterator(const OutputVector& values) bool op::v0::TensorIterator::visit_attributes(AttributeVisitor& visitor) { NGRAPH_OP_SCOPE(v0_TensorIterator_visit_attributes); - visitor.on_attribute("body", m_body); - visitor.on_attribute("input_descriptions", m_input_descriptions); - visitor.on_attribute("output_descriptions", m_output_descriptions); + visitor.on_attribute("body", m_bodies[0]); + visitor.on_attribute("input_descriptions", m_input_descriptions[0]); + visitor.on_attribute("output_descriptions", m_output_descriptions[0]); return true; } @@ -33,7 +33,7 @@ void op::v0::TensorIterator::revalidate_and_infer_types_for_body_ops() std::stack, std::vector>> nodes_to_do; std::unordered_set> nodes_done; - for (const auto& r : m_body->get_results()) + for (const auto& r : m_bodies[0]->get_results()) { nodes_to_do.push(r); } @@ -75,8 +75,19 @@ void op::v0::TensorIterator::revalidate_and_infer_types_for_body_ops() void op::v0::TensorIterator::validate_and_infer_types() { NGRAPH_OP_SCOPE(v0_TensorIterator_validate_and_infer_types); + + NODE_VALIDATION_CHECK( + this, m_bodies.size() == 1, "Number of bodies for loop is greater than 1"); + + NODE_VALIDATION_CHECK(this, + m_input_descriptions.size() == 1, + "Loop contains input descriptions for other bodies"); + NODE_VALIDATION_CHECK(this, + m_output_descriptions.size() == 1, + "Loop contains output descriptions for other bodies"); + NODE_VALIDATION_CHECK(this, - get_input_size() == m_input_descriptions.size(), + get_input_size() == m_input_descriptions[0].size(), "Number of inputs must be the same as number of input descriptions"); std::vector> ends; @@ -89,15 +100,16 @@ void op::v0::TensorIterator::validate_and_infer_types() return value; }; + auto body = get_function(); // Input - for (const auto& input_description : m_input_descriptions) + for (const auto& input_description : m_input_descriptions[0]) { auto index = input_description->m_input_index; if (auto slice_input_description = as_type_ptr(input_description)) { auto body_parameter = - m_body->get_parameters().at(slice_input_description->m_body_parameter_index); + body->get_parameters().at(slice_input_description->m_body_parameter_index); auto input_partial_shape = inputs().at(index).get_source_output().get_partial_shape(); if (input_partial_shape.is_static()) { @@ -125,12 +137,14 @@ void op::v0::TensorIterator::validate_and_infer_types() else if (auto merged_input_description = as_type_ptr(input_description)) { - auto body_value = - m_body->get_results().at(merged_input_description->m_body_value_index)->input(0); + auto body_value = m_bodies[0] + ->get_results() + .at(merged_input_description->m_body_value_index) + ->input(0); ends.push_back(body_value.get_node()->shared_from_this()); auto body_parameter = - m_body->get_parameters().at(merged_input_description->m_body_parameter_index); + m_bodies[0]->get_parameters().at(merged_input_description->m_body_parameter_index); auto body_param_partial_shape = body_parameter->get_partial_shape(); auto input_partial_shape = inputs().at(index).get_source_output().get_partial_shape(); @@ -139,8 +153,8 @@ void op::v0::TensorIterator::validate_and_infer_types() else if (auto invariant_input_description = as_type_ptr(input_description)) { - auto body_parameter = - m_body->get_parameters().at(invariant_input_description->m_body_parameter_index); + auto body_parameter = m_bodies[0]->get_parameters().at( + invariant_input_description->m_body_parameter_index); auto body_param_partial_shape = body_parameter->get_partial_shape(); auto input_partial_shape = inputs().at(index).get_source_output().get_partial_shape(); @@ -154,12 +168,12 @@ void op::v0::TensorIterator::validate_and_infer_types() // Output try_to_set_num_iterations_if_no_slice_inputs(); - for (const auto& output_description : m_output_descriptions) + for (const auto& output_description : m_output_descriptions[0]) { auto index = output_description->m_output_index; auto body_value = - m_body->get_results().at(output_description->m_body_value_index)->input_value(0); + m_bodies[0]->get_results().at(output_description->m_body_value_index)->input_value(0); if (auto concat_output_description = as_type_ptr(output_description)) @@ -207,15 +221,10 @@ void op::v0::TensorIterator::validate_and_infer_types() } NODE_VALIDATION_CHECK(this, - get_output_size() == m_output_descriptions.size(), + get_output_size() == m_output_descriptions[0].size(), "Number of outputs must be the same as number of output descriptions"); } -std::shared_ptr op::v0::TensorIterator::get_function() -{ - return get_body(); -} - namespace { template @@ -235,7 +244,7 @@ void op::v0::TensorIterator::try_to_set_num_iterations_if_no_slice_inputs() return; } - for (const auto& output_description : m_output_descriptions) + for (const auto& output_description : m_output_descriptions[0]) { if (auto concat = as_type_ptr(output_description)) { @@ -256,14 +265,14 @@ std::shared_ptr description(), " operation with name ", get_friendly_name()); - op->set_output_size(m_output_descriptions.size()); + op->set_output_size(m_output_descriptions[0].size()); - std::vector<::ngraph::element::Type> types(m_body->get_parameters().size()); - std::vector<::ngraph::PartialShape> new_shapes(m_body->get_parameters().size()); + std::vector<::ngraph::element::Type> types(m_bodies[0]->get_parameters().size()); + std::vector<::ngraph::PartialShape> new_shapes(m_bodies[0]->get_parameters().size()); for (size_t input_index = 0; input_index < new_args.size(); ++input_index) { - for (auto& input_description : m_input_descriptions) + for (auto& input_description : m_input_descriptions[0]) { if (input_description->m_input_index == input_index) { @@ -288,19 +297,19 @@ std::shared_ptr op->m_num_iterations = m_num_iterations; auto func = std::make_shared( - m_body->get_results(), m_body->get_sinks(), m_body->get_parameters()); + m_bodies[0]->get_results(), m_bodies[0]->get_sinks(), m_bodies[0]->get_parameters()); auto spec_func = specialize_function(func, types, new_shapes, std::vector(new_args.size(), nullptr)); - op->m_body = std::make_shared( + op->m_bodies[0] = std::make_shared( spec_func->get_results(), spec_func->get_sinks(), spec_func->get_parameters()); - for (auto& input_description : m_input_descriptions) + for (auto& input_description : m_input_descriptions[0]) { - op->m_input_descriptions.push_back(input_description->copy()); + op->m_input_descriptions[0].push_back(input_description->copy()); } - for (auto& output_description : m_output_descriptions) + for (auto& output_description : m_output_descriptions[0]) { - op->m_output_descriptions.push_back(output_description->copy()); + op->m_output_descriptions[0].push_back(output_description->copy()); } op->validate_and_infer_types(); return op; diff --git a/ngraph/core/src/op/util/multi_subgraph_base.cpp b/ngraph/core/src/op/util/multi_subgraph_base.cpp new file mode 100644 index 00000000000000..0b14f2e20fb8e0 --- /dev/null +++ b/ngraph/core/src/op/util/multi_subgraph_base.cpp @@ -0,0 +1,210 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "ngraph/op/util/multi_subgraph_base.hpp" +#include "ngraph/opsets/opset5.hpp" + +#include "ngraph/graph_util.hpp" + +using namespace ngraph; + +NGRAPH_RTTI_DEFINITION(op::util::MultiSubGraphOp, "MultiSubGraphOp", 0); +NGRAPH_RTTI_DEFINITION(op::util::MultiSubGraphOp::SliceInputDescription, + "SliceInputDescription", + 0); +NGRAPH_RTTI_DEFINITION(op::util::MultiSubGraphOp::MergedInputDescription, + "MergedInputDescription", + 0); +NGRAPH_RTTI_DEFINITION(op::util::MultiSubGraphOp::InvariantInputDescription, + "InvariantInputDescription", + 0); +NGRAPH_RTTI_DEFINITION(op::util::MultiSubGraphOp::BodyOutputDescription, + "BodyOutputDescription", + 0); +NGRAPH_RTTI_DEFINITION(op::util::MultiSubGraphOp::ConcatOutputDescription, + "ConcatOutputDescription", + 0); + +op::util::MultiSubGraphOp::InputDescription::InputDescription(uint64_t input_index, + uint64_t body_parameter_index) + : m_input_index(input_index) + , m_body_parameter_index(body_parameter_index) +{ +} + +op::util::MultiSubGraphOp::OutputDescription::OutputDescription(uint64_t body_value_index, + uint64_t output_index) + : m_body_value_index(body_value_index) + , m_output_index(output_index) +{ +} + +op::util::MultiSubGraphOp::SliceInputDescription::SliceInputDescription( + uint64_t input_index, + uint64_t body_parameter_index, + int64_t start, + int64_t stride, + int64_t part_size, + int64_t end, + int64_t axis) + : InputDescription(input_index, body_parameter_index) + , m_start(start) + , m_stride(stride) + , m_part_size(part_size) + , m_end(end) + , m_axis(axis) +{ +} + +std::shared_ptr + op::util::MultiSubGraphOp::SliceInputDescription::copy() const +{ + return std::make_shared( + m_input_index, m_body_parameter_index, m_start, m_stride, m_part_size, m_end, m_axis); +} + +op::util::MultiSubGraphOp::MergedInputDescription::MergedInputDescription( + uint64_t input_index, uint64_t body_parameter_index, uint64_t body_value_index) + : InputDescription(input_index, body_parameter_index) + , m_body_value_index(body_value_index) +{ +} + +std::shared_ptr + op::util::MultiSubGraphOp::MergedInputDescription::copy() const +{ + return std::make_shared( + m_input_index, m_body_parameter_index, m_body_value_index); +} + +op::util::MultiSubGraphOp::ConcatOutputDescription::ConcatOutputDescription( + uint64_t body_value_index, + uint64_t output_index, + int64_t start, + int64_t stride, + int64_t part_size, + int64_t end, + int64_t axis) + : OutputDescription(body_value_index, output_index) + , m_start(start) + , m_stride(stride) + , m_part_size(part_size) + , m_end(end) + , m_axis(axis) +{ +} + +std::shared_ptr + op::util::MultiSubGraphOp::ConcatOutputDescription::copy() const +{ + return std::make_shared( + m_body_value_index, m_output_index, m_start, m_stride, m_part_size, m_end, m_axis); +} +op::util::MultiSubGraphOp::InvariantInputDescription::InvariantInputDescription( + uint64_t input_index, uint64_t body_parameter_index) + : InputDescription(input_index, body_parameter_index) +{ +} + +std::shared_ptr + op::util::MultiSubGraphOp::InvariantInputDescription::copy() const +{ + return std::make_shared(m_input_index, + m_body_parameter_index); +} + +op::util::MultiSubGraphOp::BodyOutputDescription::BodyOutputDescription(uint64_t body_value_index, + uint64_t output_index, + int64_t iteration) + : OutputDescription(body_value_index, output_index) + , m_iteration(iteration) +{ +} + +std::shared_ptr + op::util::MultiSubGraphOp::BodyOutputDescription::copy() const +{ + return std::make_shared(m_body_value_index, m_output_index, m_iteration); +} + +op::util::MultiSubGraphOp::MultiSubGraphOp(const OutputVector& args) + : Op(args) +{ +} + +op::util::MultiSubGraphOp::MultiSubGraphOp(size_t number_of_bodies) +{ + m_bodies.resize(number_of_bodies); + m_input_descriptions.resize(number_of_bodies); + m_output_descriptions.resize(number_of_bodies); +} + +op::util::MultiSubGraphOp::MultiSubGraphOp(const OutputVector& args, size_t number_of_bodies) + : MultiSubGraphOp(args) +{ + m_bodies.resize(number_of_bodies); + m_input_descriptions.resize(number_of_bodies); + m_output_descriptions.resize(number_of_bodies); +} + +Input op::util::MultiSubGraphOp::input_for_value(const Output& value) +{ + auto input_index = get_input_size(); + set_argument(input_index, value); + return Input(this, input_index); +} + +void op::util::MultiSubGraphOp::set_invariant_inputs(const Output& value, + const ParameterVector& bodies_parameters) +{ + auto input_index = input_for_value(value).get_index(); + for (auto& param : bodies_parameters) + { + for (size_t body_index = 0; body_index < m_bodies.size(); ++body_index) + { + auto param_index = m_bodies[body_index]->get_parameter_index(param); + if (param_index != -1) + { + m_input_descriptions[body_index].push_back( + std::make_shared(input_index, + param_index)); + } + } + } +} + +Output op::util::MultiSubGraphOp::set_body_outputs(const ResultVector& bodies_results) +{ + auto output_index = get_output_size(); + for (auto& body_result : bodies_results) + { + for (size_t body_index = 0; body_index < m_bodies.size(); body_index++) + { + auto body_result_index = m_bodies[body_index]->get_result_index(body_result); + if (body_result_index != -1) + { + m_output_descriptions[body_index].push_back( + std::make_shared(body_result_index, output_index)); + } + } + } + set_output_size(output_index + 1); + return Output(shared_from_this(), output_index); +} + +namespace ngraph +{ + NGRAPH_RTTI_DEFINITION( + AttributeAdapter>>, + "AttributeAdapter>>", + 0); + + NGRAPH_RTTI_DEFINITION( + AttributeAdapter< + std::vector>>, + "AttributeAdapter>>", + 0); +} // namespace ngraph diff --git a/ngraph/core/src/op/util/sub_graph_base.cpp b/ngraph/core/src/op/util/sub_graph_base.cpp index 916b7cc7c5bafc..7a42ed27daca3b 100644 --- a/ngraph/core/src/op/util/sub_graph_base.cpp +++ b/ngraph/core/src/op/util/sub_graph_base.cpp @@ -11,116 +11,13 @@ using namespace ngraph; NGRAPH_RTTI_DEFINITION(op::util::SubGraphOp, "SubGraphOp", 0); -constexpr DiscreteTypeInfo op::util::SubGraphOp::SliceInputDescription::type_info; -constexpr DiscreteTypeInfo op::util::SubGraphOp::MergedInputDescription::type_info; -constexpr DiscreteTypeInfo op::util::SubGraphOp::InvariantInputDescription::type_info; - -constexpr DiscreteTypeInfo op::util::SubGraphOp::BodyOutputDescription::type_info; -constexpr DiscreteTypeInfo op::util::SubGraphOp::ConcatOutputDescription::type_info; - -op::util::SubGraphOp::InputDescription::InputDescription(uint64_t input_index, - uint64_t body_parameter_index) - : m_input_index(input_index) - , m_body_parameter_index(body_parameter_index) -{ -} - -op::util::SubGraphOp::SliceInputDescription::SliceInputDescription(uint64_t input_index, - uint64_t body_parameter_index, - int64_t start, - int64_t stride, - int64_t part_size, - int64_t end, - int64_t axis) - : InputDescription(input_index, body_parameter_index) - , m_start(start) - , m_stride(stride) - , m_part_size(part_size) - , m_end(end) - , m_axis(axis) -{ -} - -std::shared_ptr - op::util::SubGraphOp::SliceInputDescription::copy() const -{ - return std::make_shared( - m_input_index, m_body_parameter_index, m_start, m_stride, m_part_size, m_end, m_axis); -} - -op::util::SubGraphOp::MergedInputDescription::MergedInputDescription(uint64_t input_index, - uint64_t body_parameter_index, - uint64_t body_value_index) - : InputDescription(input_index, body_parameter_index) - , m_body_value_index(body_value_index) -{ -} - -std::shared_ptr - op::util::SubGraphOp::MergedInputDescription::copy() const -{ - return std::make_shared( - m_input_index, m_body_parameter_index, m_body_value_index); -} - -op::util::SubGraphOp::InvariantInputDescription::InvariantInputDescription( - uint64_t input_index, uint64_t body_parameter_index) - : InputDescription(input_index, body_parameter_index) -{ -} - -std::shared_ptr - op::util::SubGraphOp::InvariantInputDescription::copy() const +op::util::SubGraphOp::SubGraphOp() + : MultiSubGraphOp(1) { - return std::make_shared(m_input_index, m_body_parameter_index); -} - -op::util::SubGraphOp::OutputDescription::OutputDescription(uint64_t body_value_index, - uint64_t output_index) - : m_body_value_index(body_value_index) - , m_output_index(output_index) -{ -} - -op::util::SubGraphOp::ConcatOutputDescription::ConcatOutputDescription(uint64_t body_value_index, - uint64_t output_index, - int64_t start, - int64_t stride, - int64_t part_size, - int64_t end, - int64_t axis) - : OutputDescription(body_value_index, output_index) - , m_start(start) - , m_stride(stride) - , m_part_size(part_size) - , m_end(end) - , m_axis(axis) -{ -} - -std::shared_ptr - op::util::SubGraphOp::ConcatOutputDescription::copy() const -{ - return std::make_shared( - m_body_value_index, m_output_index, m_start, m_stride, m_part_size, m_end, m_axis); -} - -op::util::SubGraphOp::BodyOutputDescription::BodyOutputDescription(uint64_t body_value_index, - uint64_t output_index, - int64_t iteration) - : OutputDescription(body_value_index, output_index) - , m_iteration(iteration) -{ -} - -std::shared_ptr - op::util::SubGraphOp::BodyOutputDescription::copy() const -{ - return std::make_shared(m_body_value_index, m_output_index, m_iteration); } op::util::SubGraphOp::SubGraphOp(const OutputVector& args) - : Op(args) + : MultiSubGraphOp(args, 1) { } @@ -128,26 +25,30 @@ void op::util::SubGraphOp::set_merged_input(const std::shared_ptr& bo const Output& initial_value, const Output& successive_value) { - m_input_descriptions.push_back(std::make_shared( + auto body = get_function(); + + m_input_descriptions[0].push_back(std::make_shared( input_for_value(initial_value).get_index(), - m_body->get_parameter_index(body_parameter), - m_body->get_result_index(successive_value))); + body->get_parameter_index(body_parameter), + body->get_result_index(successive_value))); validate_and_infer_types(); } void op::util::SubGraphOp::set_invariant_input(const std::shared_ptr& body_parameter, const Output& value) { - m_input_descriptions.push_back(std::make_shared( - input_for_value(value).get_index(), m_body->get_parameter_index(body_parameter))); + auto body = get_function(); + m_input_descriptions[0].push_back(std::make_shared( + input_for_value(value).get_index(), body->get_parameter_index(body_parameter))); validate_and_infer_types(); } Output op::util::SubGraphOp::get_iter_value(const Output& body_value, int64_t iteration) { auto output_index = get_output_size(); - m_output_descriptions.push_back(std::make_shared( - m_body->get_result_index(body_value), output_index, iteration)); + auto body = get_function(); + m_output_descriptions[0].push_back(std::make_shared( + body->get_result_index(body_value), output_index, iteration)); set_output_size(output_index + 1); validate_and_infer_types(); return Output(shared_from_this(), output_index); @@ -161,8 +62,9 @@ Output op::util::SubGraphOp::get_concatenated_slices(const Output& b int64_t axis) { auto output_index = get_output_size(); - m_output_descriptions.push_back(std::make_shared( - m_body->get_result_index(body_value), output_index, start, stride, part_size, end, axis)); + auto body = get_function(); + m_output_descriptions[0].push_back(std::make_shared( + body->get_result_index(body_value), output_index, start, stride, part_size, end, axis)); set_output_size(output_index + 1); validate_and_infer_types(); return Output(shared_from_this(), output_index); @@ -176,9 +78,10 @@ void op::util::SubGraphOp::set_sliced_input(const std::shared_ptr& pa int64_t end, int64_t axis) { - m_input_descriptions.push_back( + auto body = get_function(); + m_input_descriptions[0].push_back( std::make_shared(input_for_value(value).get_index(), - m_body->get_parameter_index(parameter), + body->get_parameter_index(parameter), start, stride, part_size, @@ -193,12 +96,3 @@ Input op::util::SubGraphOp::input_for_value(const Output& value) set_argument(input_index, value); return Input(this, input_index); } - -namespace ngraph -{ - constexpr DiscreteTypeInfo AttributeAdapter< - std::vector>>::type_info; - - constexpr DiscreteTypeInfo AttributeAdapter< - std::vector>>::type_info; -} // namespace ngraph