From 3a0d437e29416819067614a3c43b67813df324c7 Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Tue, 13 Jun 2023 10:22:30 +0400 Subject: [PATCH] [Snippets] Fixed Buffer insertion position --- .../snippets/src/lowered/pass/insert_buffers.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/common/snippets/src/lowered/pass/insert_buffers.cpp b/src/common/snippets/src/lowered/pass/insert_buffers.cpp index 0eb75a33749abf..1b1d65a504cb86 100644 --- a/src/common/snippets/src/lowered/pass/insert_buffers.cpp +++ b/src/common/snippets/src/lowered/pass/insert_buffers.cpp @@ -225,12 +225,26 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPt } } + // potential_consumers is unsorted by linear IR set. + // We have to find first expr in Linear IR from the set to insert Buffer before *all* consumers + OPENVINO_ASSERT(!potential_consumers.empty(), "Buffer should have one consumer at least"); + auto consumer_expr = potential_consumers.begin()->get_expr(); + if (potential_consumers.size() > 1) { + std::set consumers; + for (const auto& port : potential_consumers) + consumers.insert(port.get_expr()); + const auto it = std::find_if(linear_ir.cbegin(), linear_ir.cend(), + [&consumers](const ExpressionPtr& expr) { return consumers.count(expr) > 0; }); + OPENVINO_ASSERT(it != linear_ir.cend(), "Consumer of Buffer has not been found in Linear IR"); + consumer_expr = *it; + } + // We should insert Buffer between first different Loops. // Example: Current expr Loop identifies: 3, 2, 1 // Target consumers Loop identifies: 3, 4, 6 // Need to insert after 2nd Loops // Note: All potential consumers must have the same count of first equal Loop identifies and the same count of different last identifies - const auto pos = insertion_position(linear_ir, loop_manager, expr, (*potential_consumers.begin()).get_expr()); + const auto pos = insertion_position(linear_ir, loop_manager, expr, consumer_expr); const auto allocation_shape = compute_allocation_shape(loop_manager, buffer_loop_ids,