Skip to content

Commit

Permalink
[LIR] Added custom first iteration handler for K blocking
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Sep 11, 2023
1 parent 3112f7d commit 11a2a01
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 51 deletions.
20 changes: 20 additions & 0 deletions src/common/snippets/include/snippets/lowered/loop_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class LinearIR::LoopManager {
// Returns dimension index if dimension indices for all entry and exit points are equal, and SIZE_MAX otherwise
size_t get_dim_idx() const;

using FirstIterHandler = std::function<bool(LinearIR&, LinearIR::constExprIt)>;
void set_first_iter_handler(FirstIterHandler handler);
FirstIterHandler fst_iter_handler = nullptr;

size_t work_amount = 0;
size_t increment = 0;
// The order of entry and exit expressions is important:
Expand Down Expand Up @@ -115,6 +119,22 @@ class LinearIR::LoopManager {
return loop_id;
}

template <typename T>
size_t mark_loop_with_old_loop_replacement(LinearIR::constExprIt loop_begin_pos,
LinearIR::constExprIt loop_end_pos,
size_t work_amount,
size_t increment,
const std::vector<T>& entries,
const std::vector<T>& exits,
const size_t old_id) {
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, increment, entries, exits);
const auto loop_id = this->add_loop_info(loop_info);
for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) {
replace_loop_id(*expr_it, old_id, loop_id);
}
return loop_id;
}

void fuse_loops(const LinearIR& linear_ir, size_t loop_id_upper, size_t loop_id_lower, bool fuse_into_upper = true);
void fuse_loops(LinearIR::constExprIt loop_begin_target, LinearIR::constExprIt loop_end_target,
size_t loop_id_upper, size_t loop_id_lower, bool fuse_into_upper = true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class InsertTailLoop : public Pass {
public:
OPENVINO_RTTI("InsertTailLoop", "Pass")
bool run(LinearIR& linear_ir) override;
static bool optimize_single_evaluation(const std::shared_ptr<op::LoopEnd>& loop);
static LinearIR::container copy_loop(const LinearIR& linear_ir, const size_t loop_id);

private:
static std::shared_ptr<op::LoopEnd> create_tail_loop(LinearIR& linear_ir,
Expand All @@ -37,7 +39,6 @@ class InsertTailLoop : public Pass {
LinearIR::constExprIt tail_begin,
LinearIR::constExprIt tail_end,
size_t tail_size);
static bool optimize_single_evaluation(const std::shared_ptr<op::LoopEnd>& loop);
};

} // namespace pass
Expand Down
4 changes: 4 additions & 0 deletions src/common/snippets/src/lowered/loop_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ size_t LinearIR::LoopManager::LoopInfo::get_dim_idx() const {
}
}

void LinearIR::LoopManager::LoopInfo::set_first_iter_handler(FirstIterHandler handler) {
fst_iter_handler = std::move(handler);
}

bool operator==(const LinearIR::LoopManager::LoopPort& lhs, const LinearIR::LoopManager::LoopPort& rhs) {
if (&lhs == &rhs)
return true;
Expand Down
124 changes: 75 additions & 49 deletions src/common/snippets/src/lowered/pass/insert_tail_loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,60 @@ namespace ov {
namespace snippets {
namespace lowered {
namespace pass {
LinearIR::container InsertTailLoop::copy_loop(const LinearIR& linear_ir, const size_t loop_id) {
const auto& loop_manager = linear_ir.get_loop_manager();
const auto original_loop_info = loop_manager->get_loop_info(loop_id);
auto new_entry_points = original_loop_info->entry_points;
auto new_exit_points = original_loop_info->exit_points;

namespace {
void replace_ports_with_tail_ports(const ExpressionPtr& expr,
const ExpressionPtr& tail_expr,
std::vector<LinearIR::LoopManager::LoopPort>& ports) {
auto find_if_predicate = [&](const LinearIR::LoopManager::LoopPort& port) {
return port.expr_port->get_expr()->get_node() == expr->get_node();
auto update_loop_ports = [](const ExpressionPtr& expr,
const ExpressionPtr& tail_expr,
std::vector<LinearIR::LoopManager::LoopPort>& ports) {
auto find_if_predicate = [&](const LinearIR::LoopManager::LoopPort& port) {
return port.expr_port->get_expr()->get_node() == expr->get_node();
};
auto pos = std::find_if(ports.begin(), ports.end(), find_if_predicate);
while (pos != ports.end()) {
pos->expr_port = std::make_shared<ExpressionPort>(tail_expr, pos->expr_port->get_type(), pos->expr_port->get_index());
pos = std::find_if(pos, ports.end(), find_if_predicate);
}
};
auto pos = std::find_if(ports.begin(), ports.end(), find_if_predicate);
while (pos != ports.end()) {
pos->expr_port = std::make_shared<ExpressionPort>(tail_expr, pos->expr_port->get_type(), pos->expr_port->get_index());
pos = std::find_if(pos, ports.end(), find_if_predicate);
}

auto update_loop_info = [&](const ExpressionPtr& expr, const ExpressionPtr& new_expr) {
const auto node = expr->get_node();
// Loop begin/end ops can't be loop ports
if (ov::is_type<op::LoopBase>(node))
return;
// Clone loop ports from original loop info to tail loop info
update_loop_ports(expr, new_expr, new_entry_points);
update_loop_ports(expr, new_expr, new_exit_points);

// Update loop info of all inner loops with new loop ports
const auto loop_ids = expr->get_loop_ids();
auto cur_id_pos = std::find(loop_ids.begin(), loop_ids.end(), loop_id);
std::vector<size_t> inner_loop_ids(loop_ids.begin(), cur_id_pos);
for (size_t i = 0; i < expr->get_input_count(); ++i)
loop_manager->update_loops_port(inner_loop_ids, expr->get_input_port(i), {expr->get_input_port(i), new_expr->get_input_port(i)}, true);
for (size_t i = 0; i < expr->get_output_count(); ++i)
loop_manager->update_loops_port(inner_loop_ids, expr->get_output_port(i), {expr->get_output_port(i), new_expr->get_output_port(i)}, false);
};

LinearIR::constExprIt loop_begin_pos, loop_end_pos;
loop_manager->get_loop_bounds(linear_ir, loop_id, loop_begin_pos, loop_end_pos, true);
const auto loop_copy_range = LinearIR::deep_copy_range(loop_begin_pos, std::next(loop_end_pos), update_loop_info);
const auto new_loop_begin_pos = loop_copy_range.begin();
const auto new_loop_end_pos = loop_copy_range.end();
const auto new_id = loop_manager->mark_loop_with_old_loop_replacement(std::next(new_loop_begin_pos),
std::prev(new_loop_end_pos),
original_loop_info->work_amount,
original_loop_info->increment,
new_entry_points,
new_exit_points,
loop_id);
const auto loop_end = ov::as_type_ptr<op::LoopEnd>(std::prev(new_loop_end_pos)->get()->get_node());
loop_end->set_id(new_id);
return loop_copy_range;
}
} // namespace

std::shared_ptr<op::LoopEnd> InsertTailLoop::create_tail_loop(LinearIR& linear_ir,
LinearIR::constExprIt vector_begin,
Expand All @@ -47,41 +86,16 @@ std::shared_ptr<op::LoopEnd> InsertTailLoop::create_tail_loop(LinearIR& linear_i
const auto& original_loop_info = loop_manager->get_loop_info(original_loop_id);
auto tail_loop_info = original_loop_info;
if (need_vector_loop) {
auto tail_entry_points = original_loop_info->entry_points;
auto tail_exit_points = original_loop_info->exit_points;

auto update_loop_info = [&](const ExpressionPtr& expr, const ExpressionPtr& new_expr) {
const auto node = expr->get_node();
// Loop begin/end ops can't be loop ports
if (ov::is_type<op::LoopBase>(node))
return;
// Clone loop ports from original loop info to tail loop info
replace_ports_with_tail_ports(expr, new_expr, tail_entry_points);
replace_ports_with_tail_ports(expr, new_expr, tail_exit_points);
const auto new_loop_range = copy_loop(linear_ir, original_loop_id);
const auto loop_end = ov::as_type_ptr<op::LoopEnd>(std::prev(new_loop_range.end())->get()->get_node());
loop_end->set_work_amount(tail_size);
loop_end->set_increment(tail_size);
tail_loop_info = loop_manager->get_loop_info(loop_end->get_id());
tail_loop_info->work_amount = tail_size;
tail_loop_info->increment = tail_size;

// Update loop info of all inner loops with new loop ports
const auto loop_ids = expr->get_loop_ids();
auto cur_id_pos = std::find(loop_ids.begin(), loop_ids.end(), original_loop_id);
std::vector<size_t> inner_loop_ids(loop_ids.begin(), cur_id_pos);
for (size_t i = 0; i < expr->get_input_count(); ++i)
loop_manager->update_loops_port(inner_loop_ids, expr->get_input_port(i), {expr->get_input_port(i), new_expr->get_input_port(i)}, true);
for (size_t i = 0; i < expr->get_output_count(); ++i)
loop_manager->update_loops_port(inner_loop_ids, expr->get_output_port(i), {expr->get_output_port(i), new_expr->get_output_port(i)}, false);
};
auto vector_loop_deep_copy = LinearIR::deep_copy_range(vector_begin, vector_end, update_loop_info);
tail_begin = linear_ir.insert(vector_end, vector_loop_deep_copy.begin(), vector_loop_deep_copy.end());
tail_begin = linear_ir.insert(vector_end, new_loop_range.begin(), new_loop_range.end());
tail_end = vector_end;

const auto new_id = loop_manager->mark_loop(std::next(tail_begin),
std::prev(tail_end),
tail_size,
tail_size,
tail_entry_points,
tail_exit_points);
const auto loop_begin = ov::as_type_ptr<op::LoopBegin>(tail_begin->get()->get_node());
const auto loop_end = loop_begin->get_loop_end();
loop_end->set_id(new_id);
tail_loop_info = loop_manager->get_loop_info(new_id);
} else {
tail_begin = vector_begin;
tail_end = vector_end;
Expand Down Expand Up @@ -140,6 +154,9 @@ std::shared_ptr<op::LoopEnd> InsertTailLoop::create_tail_loop(LinearIR& linear_i
tail_loop_end->set_work_amount(tail_size);
tail_loop_end->set_finalization_offsets(tail_finalization_offsets);
tail_loop_end->has_outer_loop = vector_loop_end->has_outer_loop;
const auto new_vector_loop_wa = original_loop_info->work_amount - tail_size;
original_loop_info->work_amount = new_vector_loop_wa;
vector_loop_end->set_work_amount(new_vector_loop_wa);
return tail_loop_end;
}

Expand Down Expand Up @@ -171,21 +188,23 @@ void InsertTailLoop::tail_transformations(LinearIR& linear_ir,
// correct math calculations for ReduceMax and ReduceSum in scalar case.
// Note: We find Maximum and Add ops because HorizonMax and HorizonSum are outside Loop,
// so they are missed in <tail>
auto op = (*expr_it)->get_node();
const auto& expr = *expr_it;
const auto op = expr->get_node();
if (config.m_need_fill_tail_register &&
(ov::is_type<ov::op::v1::Maximum>(op) ||
ov::is_type<ov::op::v1::Add>(op))) {
for (size_t i = 0; i < op->inputs().size(); ++i) {
if (auto fill = insertFill(op->input(i))) {
const auto& input = expr_it->get()->get_input_port_connector(i);
const auto& input = expr->get_input_port_connector(i);
const auto consumers = input->get_consumers();
auto fill_expr = linear_ir.create_expression(fill, {input});
linear_ir.insert(expr_it, fill_expr);
linear_ir.replace_input(consumers, fill_expr->get_output_port_connector(0));
// in_reg == out_reg since we want to modify vector reg inplace
const auto reg = expr_it->get()->get_input_port_descriptor(0)->get_reg();
const auto reg = expr->get_input_port_descriptor(0)->get_reg();
fill_expr->get_input_port_descriptor(0)->set_reg(reg);
fill_expr->get_output_port_descriptor(0)->set_reg(reg);
fill_expr->set_loop_ids(expr->get_loop_ids());
}
}
} else if (const auto memory_access = std::dynamic_pointer_cast<ov::snippets::op::MemoryAccess>(op)) {
Expand Down Expand Up @@ -239,9 +258,16 @@ bool InsertTailLoop::run(LinearIR& linear_ir) {
if (!loop_end)
continue;

const auto loop_info = loop_manager->get_loop_info(loop_end->get_id());
if (loop_info->fst_iter_handler != nullptr) {
modified |= loop_info->fst_iter_handler(linear_ir, expr_it);
}

if (loop_end->get_evaluate_once() == true)
continue;

const auto work_amount = loop_end->get_work_amount();
const auto increment = loop_end->get_increment();
const auto loop_info = loop_manager->get_loop_info(loop_end->get_id());
const auto tail_size = work_amount % increment;
const auto need_tail = tail_size != 0;
const auto need_vector_loop = work_amount >= increment;
Expand Down
1 change: 1 addition & 0 deletions src/common/snippets/src/op/loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ bool LoopEnd::visit_attributes(AttributeVisitor &visitor) {
visitor.on_attribute("input_num", m_input_num);
visitor.on_attribute("output_num", m_output_num);
visitor.on_attribute("id", m_id);
visitor.on_attribute("evaluate_once", m_evaluate_once);
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "snippets/itt.hpp"
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/loop_manager.hpp"
#include "snippets/lowered/pass/insert_tail_loop.hpp"
#include "snippets/snippets_isa.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"

Expand Down Expand Up @@ -124,7 +125,39 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) {
entries.emplace_back(brgemm_expr->get_input_port(2), true, 1);
std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), false)};
auto loop_id = loop_manager->mark_loop(expr_it, std::next(expr_it), k, block_size_k, entries, exits);
loop_manager->get_loop_info(loop_id)->brgemm_k_blocking_loop = true;
const auto loop_info = loop_manager->get_loop_info(loop_id);
loop_info->brgemm_k_blocking_loop = true;

auto first_iter_handler = [](LinearIR& linear_ir, LinearIR::constExprIt expr_it) {
const auto loop_end = ov::as_type_ptr<snippets::op::LoopEnd>(expr_it->get()->get_node());
const auto loop_id = loop_end->get_id();
const auto& loop_manager = linear_ir.get_loop_manager();
const auto& loop_info = loop_manager->get_loop_info(loop_id);
const auto work_amount = loop_info->work_amount;
const auto increment = loop_info->increment;
if (!loop_info->brgemm_k_blocking_loop || work_amount <= increment)
return false;

auto new_loop_range = snippets::lowered::pass::InsertTailLoop::copy_loop(linear_ir, loop_id);
const auto new_loop_end = ov::as_type_ptr<snippets::op::LoopEnd>(std::prev(new_loop_range.end())->get()->get_node());
new_loop_end->set_work_amount(work_amount - increment);
auto new_loop_info = loop_manager->get_loop_info(new_loop_end->get_id());
new_loop_info->work_amount = work_amount - increment;
for (const auto expr : new_loop_range) {
if (const auto brgemm = ov::as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node())) {
brgemm->set_beta(1.f);
}
}

linear_ir.insert(std::next(expr_it), new_loop_range.begin(), new_loop_range.end());

loop_info->work_amount = increment;
loop_end->set_work_amount(increment);
loop_end->set_finalization_offsets(std::vector<int64_t>(loop_end->get_finalization_offsets().size(), 0));
snippets::lowered::pass::InsertTailLoop::optimize_single_evaluation(loop_end);
return true;
};
loop_info->set_first_iter_handler(first_iter_handler);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ std::vector<std::vector<ov::PartialShape>> input_shapes{
{{1, 1, 32, 23}, {1, 1, 23, 68}},
{{1, 16, 384, 64}, {1, 16, 64, 384}},
{{1, 1, 100, 700}, {1, 1, 700, 100}},
{{1, 1, 100, 2500}, {1, 1, 2500, 100}},
};

static inline std::vector<std::vector<element::Type>> quantized_precisions() {
Expand Down

0 comments on commit 11a2a01

Please sign in to comment.