From 3f554bedc87a258440f88e24c98018a66ce535f4 Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Tue, 12 Sep 2023 19:09:12 +0200 Subject: [PATCH] [LIR] CleanupLoopOffsets fix --- .../src/lowered/pass/cleanup_loop_offsets.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/common/snippets/src/lowered/pass/cleanup_loop_offsets.cpp b/src/common/snippets/src/lowered/pass/cleanup_loop_offsets.cpp index 8d4a529c1667ca..79c9a115718c1f 100644 --- a/src/common/snippets/src/lowered/pass/cleanup_loop_offsets.cpp +++ b/src/common/snippets/src/lowered/pass/cleanup_loop_offsets.cpp @@ -48,16 +48,23 @@ bool CleanupLoopOffsets::run(LinearIR& linear_ir) { const auto& found = per_port_connector_offset.find(managed_connector); if (found != per_port_connector_offset.end()) { // Since data ptr is incremented on [ptr_increment x increment], - // we should guarantee proportionality of ptr shifts + // we should guarantee proportionality of ptr shifts. + // If the data ptr can't be proportionally shifted, the optimization is not applied // For example, // Inner Loop: WA = 32, Inc = 1, ptr_increment[0] = 20, final_offset[0] = -640 // Outer Loop: WA = 70, Inc = 32, ptr_increment[0] = 20, final_offset[0] = -1400 // To save data ptr shift proportionality, we have to calculate so: // outer_ptr_increment[0] = (inner_final_offset[0] + outer_ptr_increment[0] * outer_Inc) / outer_Inc // outer_ptr_increment[0] = (-640 + 20 x 32) / 32 = 0 - outer_ptr_increments[i] = (fin_offsets[found->second] + outer_ptr_increments[i] * outer_increment) / outer_increment; - fin_offsets[found->second] = 0; - is_modified = true; + + const auto full_outer_increment = outer_ptr_increments[i] * outer_increment; + const auto new_final_outer_increment = full_outer_increment + fin_offsets[found->second]; + + if (new_final_outer_increment % outer_increment == 0) { + outer_ptr_increments[i] = new_final_outer_increment / outer_increment; + fin_offsets[found->second] = 0; + is_modified = true; + } } } outer_loop_end->set_ptr_increments(outer_ptr_increments);