Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Dec 21, 2023
1 parent 6d88e63 commit ba609fa
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
1 change: 1 addition & 0 deletions src/common/snippets/src/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "snippets/lowered/pass/cleanup_loop_offsets.hpp"
#include "snippets/lowered/pass/insert_specific_iterations.hpp"
#include "snippets/lowered/pass/optimize_loop_single_evaluation.hpp"
#include "snippets/lowered/pass/serialize_control_flow.hpp"
#include "snippets/lowered/pass/pass.hpp"
#include "snippets/op/kernel.hpp"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ bool InsertSpecificIterations::run(LinearIR& linear_ir) {
}
handlers[LoopInfo::LAST_ITER].run(linear_ir, main_body_begin_it, main_body_end_it);
update_loop_params(loop_end, tail_size, tail_size, false);
} else if (specific_first_iteration) {
handlers[LoopInfo::MAIN_BODY].run(linear_ir, main_body_begin_it, main_body_end_it);
update_loop_params(loop_end, work_amount - increment, increment, false);
}
}
modified = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,12 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) {

auto apply_m_blocking = [&]() {
const auto& m = *(out_preordered_dims.rbegin() + 1);
auto block_size_m = brgemm->get_m_block_size();
if (block_size_m > m)
block_size_m = m;
const auto block_size_m = brgemm->get_m_block_size();
if (block_size_m >= m) {
*(in_0_subtensor.rbegin() + 1) = m;
*(out_subtensor.rbegin() + 1) = m;
return;
}

*(in_0_subtensor.rbegin() + 1) = block_size_m;
*(out_subtensor.rbegin() + 1) = block_size_m;
Expand All @@ -102,9 +105,12 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) {

auto apply_n_blocking = [&]() {
const auto& n = *out_preordered_dims.rbegin();
auto block_size_n = brgemm->get_n_block_size();
if (block_size_n > n)
block_size_n = n;
const auto block_size_n = brgemm->get_n_block_size();
if (block_size_n >= n) {
*in_1_subtensor.rbegin() = n;
*out_subtensor.rbegin() = n;
return;
}

*in_1_subtensor.rbegin() = block_size_n;
*out_subtensor.rbegin() = block_size_n;
Expand All @@ -124,15 +130,19 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) {
auto apply_k_blocking = [&]() {
const auto& k = *in_0_planar_dims.rbegin();
OPENVINO_ASSERT(k == *(in_1_planar_dims.rbegin() + 1), "Brgemm input descriptors have different K dimension value.");
auto block_size_k = brgemm->get_k_block_size();
if (block_size_k > k)
block_size_k = k;
const auto block_size_k = brgemm->get_k_block_size();
if (block_size_k >= k) {
*in_0_subtensor.rbegin() = k;
*(in_1_subtensor.rbegin() + 1) = k;
brgemm->set_beta(0.f);
return;
}

*in_0_subtensor.rbegin() = block_size_k;
*(in_1_subtensor.rbegin() + 1) = block_size_k;
auto loop_begin_it = expr_it, loop_end_it = std::next(expr_it);
std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), true, 0),
LoopPort(brgemm_expr->get_input_port(1), true, 1)};
LoopPort(brgemm_expr->get_input_port(1), true, 1)};
if (brgemm->is_with_compensations()) {
entries.emplace_back(brgemm_expr->get_input_port(2), false, 1);
} else if (brgemm->is_amx()) {
Expand Down

0 comments on commit ba609fa

Please sign in to comment.