Skip to content

Commit

Permalink
[Snippets][CPU] Fixed FULL_DIM in subtensor and updated BrgemmBlocking
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jul 3, 2024
1 parent 62eb95e commit 033617b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 17 deletions.
3 changes: 2 additions & 1 deletion src/common/snippets/src/lowered/port_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ namespace ov {
namespace snippets {
namespace lowered {

size_t PortDescriptor::ServiceDimensions::FULL_DIM = SIZE_MAX;
// To avoid SIZE_MAX - is dynamic value
size_t PortDescriptor::ServiceDimensions::FULL_DIM = SIZE_MAX - 1;

PortDescriptor::PortDescriptor(const ov::Input<ov::Node>& in, VectorDims subtensor_shape, std::vector<size_t> layout)
: PortDescriptor(ov::Input<const Node>(in.get_node(), in.get_index()), std::move(subtensor_shape), std::move(layout)) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/loop_manager.hpp"
#include "snippets/lowered/pass/pass.hpp"
#include "snippets/lowered/pass/propagate_subtensors.hpp"
#include "snippets/snippets_isa.hpp"
#include "snippets/utils.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
Expand Down Expand Up @@ -107,12 +108,20 @@ bool BrgemmBlocking::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linea
const auto block_size_n = snippets::utils::is_dynamic_value(n) ? brgemm->get_n_block_size() : std::min(brgemm->get_n_block_size(), n);
const auto block_size_k = snippets::utils::is_dynamic_value(k) ? brgemm->get_k_block_size() : std::min(brgemm->get_k_block_size(), k);

*++in_0_subtensor.rbegin() = block_size_m;
*++out_subtensor.rbegin() = block_size_m;
*in_1_subtensor.rbegin() = block_size_n;
*out_subtensor.rbegin() = block_size_n;
*in_0_subtensor.rbegin() = block_size_k;
*++in_1_subtensor.rbegin() = block_size_k;
// If block_size is dynamic, it means that Brgemm will process full tensor:
// subtensor[i] = FULL_DIM as by default
if (!snippets::utils::is_dynamic_value(block_size_m)) {
*++in_0_subtensor.rbegin() = block_size_m;
*++out_subtensor.rbegin() = block_size_m;
}
if (!snippets::utils::is_dynamic_value(block_size_n)) {
*in_1_subtensor.rbegin() = block_size_n;
*out_subtensor.rbegin() = block_size_n;
}
if (!snippets::utils::is_dynamic_value(block_size_k)) {
*in_0_subtensor.rbegin() = block_size_k;
*++in_1_subtensor.rbegin() = block_size_k;
}

brgemm_expr->get_input_port_descriptor(0)->set_subtensor(in_0_subtensor);
brgemm_expr->get_input_port_descriptor(1)->set_subtensor(in_1_subtensor);
Expand Down Expand Up @@ -141,6 +150,15 @@ bool BrgemmBlocking::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linea
}
}

auto get_default_handlers = [](size_t work_amount, size_t block_size) {
SpecificIterationHandlers handlers;
const auto tail_size = snippets::utils::is_dynamic_value(work_amount) ? snippets::utils::get_dynamic_value<size_t>() : work_amount % block_size;
if (tail_size != 0)
handlers.register_pass<snippets::lowered::SpecificLoopIterType::LAST_ITER, snippets::lowered::pass::UpdateSubtensors>(tail_size);
handlers.register_pass<snippets::lowered::SpecificLoopIterType::LAST_ITER, SetEvaluanceOnce>(true);
return handlers;
};

auto mark_m_blocking = [&](bool include_repacking) {
const auto loop_begin_it = get_loop_begin_pos(linear_ir, expr_it, include_repacking);
const auto loop_end_it = std::next(expr_it);
Expand All @@ -154,9 +172,8 @@ bool BrgemmBlocking::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linea
entries.emplace_back(brgemm_expr->get_input_port(2), false);
const std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), true)};

const auto id = loop_manager->mark_loop(loop_begin_it, loop_end_it, m, block_size_m, 1, entries, exits);
const auto& loop_info = loop_manager->get_loop_info<ov::snippets::lowered::UnifiedLoopInfo>(id);
loop_info->register_pass_to_handler<ov::snippets::lowered::SpecificLoopIterType::LAST_ITER, SetEvaluanceOnce>(true);
const auto id = loop_manager->mark_loop(loop_begin_it, loop_end_it, m, block_size_m, 1, entries, exits, false);
loop_manager->get_loop_info<ov::snippets::lowered::UnifiedLoopInfo>(id)->set_handlers(get_default_handlers(m, block_size_m));
};

auto mark_n_blocking = [&]() {
Expand All @@ -168,9 +185,8 @@ bool BrgemmBlocking::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linea
LoopPort(brgemm_cpu && brgemm_cpu->is_with_data_repacking() ? copy_b_expr->get_input_port(0) : brgemm_expr->get_input_port(1), true)};
const std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), true)};

const auto id = loop_manager->mark_loop(loop_begin_it, loop_end_it, n, block_size_n, 0, entries, exits);
const auto& loop_info = loop_manager->get_loop_info<ov::snippets::lowered::UnifiedLoopInfo>(id);
loop_info->register_pass_to_handler<ov::snippets::lowered::SpecificLoopIterType::LAST_ITER, SetEvaluanceOnce>(true);
const auto id = loop_manager->mark_loop(loop_begin_it, loop_end_it, n, block_size_n, 0, entries, exits, false);
loop_manager->get_loop_info<ov::snippets::lowered::UnifiedLoopInfo>(id)->set_handlers(get_default_handlers(n, block_size_n));
};

auto mark_k_blocking = [&]() {
Expand All @@ -182,10 +198,11 @@ bool BrgemmBlocking::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linea
LoopPort(brgemm_cpu && brgemm_cpu->is_with_data_repacking() ? copy_b_expr->get_input_port(0) : brgemm_expr->get_input_port(1), true, 1)};
const std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), false)};

const auto id = loop_manager->mark_loop(loop_begin_it, loop_end_it, k, block_size_k, entries, exits);
const auto& loop_info = loop_manager->get_loop_info<ov::snippets::lowered::UnifiedLoopInfo>(id);
loop_info->register_pass_to_handler<ov::snippets::lowered::SpecificLoopIterType::FIRST_ITER, SetBrgemmBeta>(0.f);
loop_info->register_pass_to_handler<ov::snippets::lowered::SpecificLoopIterType::LAST_ITER, SetEvaluanceOnce>(true);
auto handlers = get_default_handlers(k, block_size_k);
handlers.register_pass<ov::snippets::lowered::SpecificLoopIterType::FIRST_ITER, SetBrgemmBeta>(0.f);

const auto id = loop_manager->mark_loop(loop_begin_it, loop_end_it, k, block_size_k, entries, exits, false);
loop_manager->get_loop_info<ov::snippets::lowered::UnifiedLoopInfo>(id)->set_handlers(handlers);
};

const bool k_blocking = block_size_k != k;
Expand Down

0 comments on commit 033617b

Please sign in to comment.