From db6305327216448e68b4c2897b0286ede93e6208 Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Thu, 11 May 2023 17:05:28 +0400 Subject: [PATCH] Updated Buffer Identification logic --- .../lowered/pass/identify_buffers.hpp | 4 +- .../src/lowered/pass/identify_buffers.cpp | 134 ++++++++---------- 2 files changed, 65 insertions(+), 73 deletions(-) diff --git a/src/common/snippets/include/snippets/lowered/pass/identify_buffers.hpp b/src/common/snippets/include/snippets/lowered/pass/identify_buffers.hpp index 9c97ded91cf471..05bedba6f72453 100644 --- a/src/common/snippets/include/snippets/lowered/pass/identify_buffers.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/identify_buffers.hpp @@ -6,6 +6,8 @@ #include "transformation.hpp" +#include "snippets/op/buffer.hpp" + namespace ngraph { namespace snippets { namespace lowered { @@ -34,7 +36,7 @@ class IdentifyBuffers: public Transformation { bool run(LinearIR& linear_ir) override; private: - using BufferSet = std::vector; + using BufferSet = std::vector>; std::vector create_adjacency_matrix(const LinearIR& linear_ir, const BufferSet& buffers) const; std::map coloring(BufferSet& buffers, std::vector& adj); diff --git a/src/common/snippets/src/lowered/pass/identify_buffers.cpp b/src/common/snippets/src/lowered/pass/identify_buffers.cpp index a3fd9157b92056..699f201bba36d6 100644 --- a/src/common/snippets/src/lowered/pass/identify_buffers.cpp +++ b/src/common/snippets/src/lowered/pass/identify_buffers.cpp @@ -14,11 +14,6 @@ namespace lowered { namespace pass { namespace { -auto is_intermediate_buffer(const std::shared_ptr& op) -> std::shared_ptr { - const auto buffer = ov::as_type_ptr(op); - return buffer && buffer->is_intermediate_memory() ? buffer : nullptr; -} - inline size_t index(size_t col_num, size_t row, size_t col) { return row * col_num + col; } @@ -34,73 +29,70 @@ std::vector IdentifyBuffers::create_adjacency_matrix(const LinearIR& linea for (size_t i = 0; i < size; ++i) adj[index(size, i, i)] = true; - auto update_adj_matrix = [&](const std::shared_ptr& buffer, size_t buffer_index, - const std::shared_ptr& neighbour_buffer, - size_t buffer_loop_port, size_t neighbour_buffer_loop_port, - const std::vector& ptr_increments, - const std::vector& io_data_sizes) { - if (neighbour_buffer) { - // TODO: What's about finalization offsets? It's needed? - if (ptr_increments[buffer_loop_port] != ptr_increments[neighbour_buffer_loop_port] || - io_data_sizes[buffer_loop_port] != io_data_sizes[neighbour_buffer_loop_port]) { - const auto iter = std::find(buffers.cbegin(), buffers.cend(), linear_ir.get_expr_by_node(neighbour_buffer)); - NGRAPH_CHECK(iter != buffers.cend(), "Buffer wasn't find in Buffer system of Subgraph"); - - const size_t adj_idx = std::distance(buffers.cbegin(), iter); - adj[index(size, adj_idx, buffer_index)] = adj[index(size, buffer_index, adj_idx)] = true; - } + // < ptr_increment, finalization_offset > + using ShiftPtrParams = std::pair; + + auto get_buffer_idx = [&](const std::shared_ptr& buffer) { + const auto iter = std::find(buffers.cbegin(), buffers.cend(), buffer); + NGRAPH_CHECK(iter != buffers.cend(), "Buffer wasn't find in Buffer system of Subgraph"); + return std::distance(buffers.cbegin(), iter); + }; + + auto update_adj_matrix = [&](const std::pair, ShiftPtrParams>& buffer, + const std::pair, ShiftPtrParams>& neighbour_buffer) { + const bool equal_ptr_params_shifting = buffer.second == neighbour_buffer.second; + const bool equal_element_type_sizes = buffer.first->get_element_type().size() == neighbour_buffer.first->get_element_type().size(); + if (!equal_ptr_params_shifting || ((buffer.second.first != 0 || buffer.second.second != 0) && !equal_element_type_sizes)) { + const auto buffer_idx = get_buffer_idx(buffer.first); + const auto neighbour_idx = get_buffer_idx(neighbour_buffer.first); + adj[index(size, neighbour_idx, buffer_idx)] = adj[index(size, buffer_idx, neighbour_idx)] = true; } }; - for (size_t buffer_idx = 0; buffer_idx < buffers.size(); ++buffer_idx) { - // Here intermediate Buffer - const auto& buffer_expr = buffers[buffer_idx]; - const auto buffer = ov::as_type_ptr(buffer_expr->get_node()); - const auto& buffer_tensor = buffer_expr->get_input_tensor(0); - const auto buffer_siblings = buffer_tensor->get_consumers(); - for (const auto& buffer_sibling : buffer_siblings) { - const auto& sibling_expr = buffer_sibling.get_expr(); - // Skip myself - if (sibling_expr == buffer_expr) { - continue; - } else if (const auto loop_end = ov::as_type_ptr(sibling_expr->get_node())) { - const auto loop_tds = sibling_expr->get_input_tensors(); - const auto input_count = loop_end->get_input_num(); - const auto output_count = loop_end->get_output_num(); - const auto& ptr_increments = loop_end->get_ptr_increments(); - const auto& io_data_sizes = loop_end->get_element_type_sizes(); - const auto buffer_loop_port = std::distance(loop_tds.begin(), std::find(loop_tds.begin(), loop_tds.end(), buffer_tensor)); - - // Verify Buffers on Loop inputs: - for (size_t input_idx = 0; input_idx < input_count; ++input_idx) { - const auto& loop_in = loop_tds[input_idx]->get_source().get_expr(); - if (const auto& neighbour_buffer = is_intermediate_buffer(loop_in->get_node())) { - const auto neighbour_buffer_loop_port = input_idx; - update_adj_matrix(buffer, buffer_idx, neighbour_buffer, - buffer_loop_port, neighbour_buffer_loop_port, - ptr_increments, io_data_sizes); - } - } + for (auto expr_it = linear_ir.cbegin(); expr_it != linear_ir.cend(); expr_it++) { + const auto &expr = *expr_it; + const auto& loop_end = ov::as_type_ptr(expr->get_node()); + if (!loop_end) + continue; + + const auto input_count = loop_end->get_input_num(); + const auto output_count = loop_end->get_output_num(); - // Verify Buffers on Loop outputs - for (size_t output_idx = 0; output_idx < output_count; ++output_idx) { - // Skip the current Buffer - if (buffer_tensor == loop_tds[input_count + output_idx]) - continue; - - const auto consumer_inputs = loop_tds[input_count + output_idx]->get_consumers(); - for (const auto& consumer_input : consumer_inputs) { - const auto& child_node = consumer_input.get_expr()->get_node(); - if (const auto& neighbour_buffer = is_intermediate_buffer(child_node)) { - const auto neighbour_buffer_loop_port = input_count + output_idx; - update_adj_matrix(buffer, buffer_idx, neighbour_buffer, - buffer_loop_port, neighbour_buffer_loop_port, - ptr_increments, io_data_sizes); - } - } + const auto ptr_increments = loop_end->get_ptr_increments(); + const auto finalization_offsets = loop_end->get_finalization_offsets(); + + // Buffer -> + std::map, ShiftPtrParams> buffer_neighbours; + + for (size_t i = 0; i < input_count; ++i) { + const auto& parent_output = expr->get_input_tensor(i)->get_source().get_expr(); + if (const auto buffer = ov::as_type_ptr(parent_output->get_node())) { + buffer_neighbours[buffer] = { ptr_increments[i], finalization_offsets[i] }; + } + } + for (size_t i = 0; i < output_count; ++i) { + // The consumers of the corresponding Store ops + const auto index = input_count + i; + const auto consumer_inputs = expr->get_input_tensor(index)->get_consumers(); + size_t buffer_count = 0; + size_t loop_count = 0; + for (const auto& consumer_input : consumer_inputs) { + const auto& child_node = consumer_input.get_expr()->get_node(); + if (const auto buffer = ov::as_type_ptr(child_node)) { + buffer_neighbours[buffer] = { ptr_increments[index], finalization_offsets[index] }; + } else if (ov::is_type(child_node)) { + loop_count++; } - } else { - OPENVINO_THROW("Buffer has incorrect siblings! There can be only LoopEnds"); + } + if (buffer_count > 0) { + OPENVINO_ASSERT((buffer_count == 1) && (buffer_count + loop_count == consumer_inputs.size()), + "Loop output must have not more than 1 Buffer"); + } + } + + for (auto buffer_it = buffer_neighbours.begin(); buffer_it != buffer_neighbours.end(); ++buffer_it) { + for (auto neighbour_it = std::next(buffer_it); neighbour_it != buffer_neighbours.end(); ++neighbour_it) { + update_adj_matrix(*buffer_it, *neighbour_it); } } } @@ -161,9 +153,8 @@ bool IdentifyBuffers::run(LinearIR& linear_ir) { BufferSet buffer_exprs; for (const auto& expr : linear_ir) { - const auto& op = expr->get_node(); - if (const auto buffer = is_intermediate_buffer(op)) { - buffer_exprs.push_back(expr); + if (const auto buffer = ov::as_type_ptr(expr->get_node())) { + buffer_exprs.push_back(buffer); } } @@ -176,8 +167,7 @@ bool IdentifyBuffers::run(LinearIR& linear_ir) { for (const auto& pair : color_groups) { const auto color = pair.first; const auto& united_buffers = pair.second; - for (const auto& buffer_expr : united_buffers) { - const auto buffer = ov::as_type_ptr(buffer_expr->get_node()); + for (const auto& buffer : united_buffers) { buffer->set_id(color); } }