Skip to content

Commit

Permalink
AMX scratchpad buffer is moved inside blocking loops
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Nov 21, 2023
1 parent 7ba840d commit e1fc633
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ using ExpressionPtr = ov::snippets::lowered::ExpressionPtr;

BrgemmBlocking::BrgemmBlocking() : Pass() {}

void BrgemmBlocking::move_amx_scratchpad_buffer(snippets::lowered::LinearIR& linear_ir, const snippets::lowered::LinearIR::constExprIt& brgemm_it) {
const auto& brgemm_expr = brgemm_it->get();
const auto wsp_expr = brgemm_expr->get_input_port_connector(2)->get_source().get_expr();
const auto wsp_buffer = ov::as_type_ptr<ov::snippets::op::Buffer>(wsp_expr->get_node());
OPENVINO_ASSERT(wsp_buffer && wsp_buffer->is_new_memory(), "Incorrect Scratchpad buffer for Brgemm AMX");
// If scratchpad with temp memory is not explicitly before Brgemm, need to move to there.
if (wsp_expr != *std::prev(brgemm_it)) {
const auto wsp_it = linear_ir.find(wsp_expr);
linear_ir.move(wsp_it, brgemm_it);
}
}

bool BrgemmBlocking::run(LinearIR& linear_ir) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::BrgemmBlocking")
if (linear_ir.empty())
Expand Down Expand Up @@ -75,12 +87,17 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) {
*(in_0_subtensor.rbegin() + 1) = block_size_m;
*(out_subtensor.rbegin() + 1) = block_size_m;

auto loop_begin_it = expr_it, loop_end_it = std::next(expr_it);
std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), true),
LoopPort(brgemm_expr->get_input_port(1), false)};
if (brgemm->is_with_scratchpad())
if (brgemm->is_with_compensations()) {
entries.emplace_back(brgemm_expr->get_input_port(2), false);
} else if (brgemm->is_amx()) {
move_amx_scratchpad_buffer(linear_ir, expr_it);
loop_begin_it = std::prev(expr_it);
}
std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), true)};
loop_manager->mark_loop(expr_it, std::next(expr_it), m, block_size_m, 1, entries, exits);
loop_manager->mark_loop(loop_begin_it, loop_end_it, m, block_size_m, 1, entries, exits);
}
};

Expand All @@ -94,15 +111,17 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) {
*in_1_subtensor.rbegin() = block_size_n;
*out_subtensor.rbegin() = block_size_n;

auto loop_begin_it = expr_it, loop_end_it = std::next(expr_it);
std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), false),
LoopPort(brgemm_expr->get_input_port(1), true)};
if (brgemm->is_with_scratchpad()) {
// The second input of Brgemm for AMX case is scratch buffer so it mustn't be incremented
const bool is_incremented = brgemm->is_with_compensations() ? true : false;
entries.emplace_back(brgemm_expr->get_input_port(2), is_incremented);
if (brgemm->is_with_compensations()) {
entries.emplace_back(brgemm_expr->get_input_port(2), true);
} else if (brgemm->is_amx()) {
move_amx_scratchpad_buffer(linear_ir, expr_it);
loop_begin_it = std::prev(expr_it);
}
std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), true)};
loop_manager->mark_loop(expr_it, std::next(expr_it), n, block_size_n, 0, entries, exits);
loop_manager->mark_loop(loop_begin_it, loop_end_it, n, block_size_n, 0, entries, exits);
}
};

Expand All @@ -117,12 +136,17 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) {
*in_0_subtensor.rbegin() = block_size_k;
*(in_1_subtensor.rbegin() + 1) = block_size_k;

auto loop_begin_it = expr_it, loop_end_it = std::next(expr_it);
std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), true, 0),
LoopPort(brgemm_expr->get_input_port(1), true, 1)};
if (brgemm->is_with_scratchpad())
if (brgemm->is_with_compensations()) {
entries.emplace_back(brgemm_expr->get_input_port(2), false, 1);
} else if (brgemm->is_amx()) {
move_amx_scratchpad_buffer(linear_ir, expr_it);
loop_begin_it = std::prev(expr_it);
}
std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), false)};
auto loop_id = loop_manager->mark_loop(expr_it, std::next(expr_it), k, block_size_k, entries, exits);
auto loop_id = loop_manager->mark_loop(loop_begin_it, loop_end_it, k, block_size_k, entries, exits);
const auto loop_info = loop_manager->get_loop_info(loop_id);

auto first_iter_handler = [](LinearIR& linear_ir, LinearIR::constExprIt loop_end_it) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class BrgemmBlocking : public snippets::lowered::pass::Pass {
OPENVINO_RTTI("BrgemmBlocking", "Pass")
BrgemmBlocking();
bool run(snippets::lowered::LinearIR& linear_ir) override;

private:
static void move_amx_scratchpad_buffer(snippets::lowered::LinearIR& linear_ir, const snippets::lowered::LinearIR::constExprIt& brgemm_it);
};

} // namespace pass
Expand Down

0 comments on commit e1fc633

Please sign in to comment.