From bda8afc9d4489f0290dedf37f616222927cf6864 Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Thu, 30 Nov 2023 16:31:01 +0100 Subject: [PATCH] SerializationNode: added SerializationMode --- .../snippets/op/serialization_node.hpp | 6 +++++- .../src/lowered/pass/serialize_data_flow.cpp | 5 +++-- .../snippets/src/op/serialization_node.cpp | 19 +++++++++++++++---- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/common/snippets/include/snippets/op/serialization_node.hpp b/src/common/snippets/include/snippets/op/serialization_node.hpp index d8b26aa41a508a..8910f98bd0a570 100644 --- a/src/common/snippets/include/snippets/op/serialization_node.hpp +++ b/src/common/snippets/include/snippets/op/serialization_node.hpp @@ -19,8 +19,11 @@ namespace op { */ class SerializationNode : public ov::op::Op { public: + enum SerializationMode { DATA_FLOW, CONTROL_FLOW }; SerializationNode() = default; - SerializationNode(const ov::OutputVector& args, const std::shared_ptr& expr); + SerializationNode(const ov::OutputVector& args, + const std::shared_ptr& expr, + SerializationMode mode = SerializationMode::CONTROL_FLOW); void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const OutputVector &new_args) const override; @@ -37,6 +40,7 @@ class SerializationNode : public ov::op::Op { private: std::shared_ptr m_expr; + SerializationMode m_mode; }; } // namespace op diff --git a/src/common/snippets/src/lowered/pass/serialize_data_flow.cpp b/src/common/snippets/src/lowered/pass/serialize_data_flow.cpp index 1d5919f0b3064d..7ae3e7ce15e8af 100644 --- a/src/common/snippets/src/lowered/pass/serialize_data_flow.cpp +++ b/src/common/snippets/src/lowered/pass/serialize_data_flow.cpp @@ -23,13 +23,14 @@ bool SerializeDataFlow::run(LinearIR& linear_ir) { ov::ResultVector results; ov::ParameterVector parameters; std::map> ops_map; + const auto serialization_mode = op::SerializationNode::SerializationMode::DATA_FLOW; for (const auto& expr : linear_ir) { const auto node = expr->get_node(); ov::OutputVector inputs(expr->get_input_count()); for (size_t i = 0; i < expr->get_input_count(); ++i) { const auto& input_expr = expr->get_input_port_connector(i)->get_source().get_expr(); OPENVINO_ASSERT(ops_map.count(input_expr), "input node wasn't found during serialization"); - inputs[i] = ops_map[input_expr]->output(0); + inputs[i] = ops_map[input_expr]->output(expr->get_input_port_connector(i)->get_source().get_index()); } if (auto ioexpr = std::dynamic_pointer_cast(expr)) { if (ioexpr->get_type() == IOExpression::io_type::INPUT) { @@ -42,7 +43,7 @@ bool SerializeDataFlow::run(LinearIR& linear_ir) { results.push_back(result); } } else { - const auto serialization_node = std::make_shared(inputs, expr); + const auto serialization_node = std::make_shared(inputs, expr, serialization_mode); ops_map[expr] = serialization_node; } } diff --git a/src/common/snippets/src/op/serialization_node.cpp b/src/common/snippets/src/op/serialization_node.cpp index eea2118d9f0993..a91c63beb9402b 100644 --- a/src/common/snippets/src/op/serialization_node.cpp +++ b/src/common/snippets/src/op/serialization_node.cpp @@ -9,8 +9,12 @@ namespace ov { namespace snippets { namespace op { -SerializationNode::SerializationNode(const ov::OutputVector& args, const std::shared_ptr& expr) - : Op(args), m_expr(expr) { +SerializationNode::SerializationNode(const ov::OutputVector& args, + const std::shared_ptr& expr, + SerializationMode mode) + : Op(args), + m_expr(expr), + m_mode(mode) { OPENVINO_ASSERT(m_expr && m_expr->get_node(), "SerializationNode requires a valid expression with non-null node pointer"); const auto& node = expr->get_node(); set_friendly_name(node->get_friendly_name()); @@ -20,12 +24,19 @@ SerializationNode::SerializationNode(const ov::OutputVector& args, const std::sh } void SerializationNode::validate_and_infer_types() { - set_output_type(0, element::f32, {}); + // If SerializationNode is used for control flow serialization, it always has one output + // (since it represents a linear execution order) + if (m_mode == SerializationMode::CONTROL_FLOW) { + set_output_type(0, element::f32, {}); + } else if (m_mode == SerializationMode::DATA_FLOW) { + for (size_t i = 0; i < m_expr->get_output_count(); ++i) + set_output_type(i, element::f32, {}); + } } std::shared_ptr SerializationNode::clone_with_new_inputs(const OutputVector &new_args) const { check_new_args_count(this, new_args); - return std::make_shared(new_args, m_expr); + return std::make_shared(new_args, m_expr, m_mode); } bool SerializationNode::visit_attributes(AttributeVisitor &visitor) {