Skip to content

Commit

Permalink
Review comments 4
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanNovoselov committed Jan 3, 2025
1 parent 3a97f39 commit ca8c239
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 17 deletions.
17 changes: 10 additions & 7 deletions src/common/snippets/src/lowered/pass/insert_reg_spills.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,19 @@ bool InsertRegSpills::run(LinearIR& linear_ir) {
const auto end = std::make_shared<op::RegSpillEnd>(begin);
const auto loop_ids = start_it->get()->get_loop_ids();
OPENVINO_ASSERT(loop_ids == std::prev(stop_it)->get()->get_loop_ids(), "Inconsistent loop ids for RegSpill expressions");
const auto spill_begin_expr = *linear_ir.insert_node(begin, std::vector<PortConnectorPtr>{}, loop_ids,
false, start_it, std::vector<std::set<ExpressionPort>>{});
const auto spill_begin_it = linear_ir.insert_node(begin, std::vector<PortConnectorPtr>{}, loop_ids,
false, start_it, std::vector<std::set<ExpressionPort>>{});
std::vector<Reg> vregs{regs_to_spill.begin(), regs_to_spill.end()};
spill_begin_expr->set_reg_info({{}, vregs});
spill_begin_expr->set_live_regs(std::prev(start_it, 2)->get()->get_live_regs());
spill_begin_it->get()->set_reg_info({{}, vregs});
// Note: spill_begin and spill_end do not use any registers, so:
// - the regs that are live on entry of spill_begin are the same as for its predecessor (since no regs consumed)
// - similarly, live regs for spill_end are the same as for its successor (since no regs produced)
spill_begin_it->get()->set_live_regs(std::prev(spill_begin_it)->get()->get_live_regs());

const auto spill_end_expr = *linear_ir.insert_node(end, spill_begin_expr->get_output_port_connectors(), loop_ids,
const auto spill_end_it = linear_ir.insert_node(end, spill_begin_it->get()->get_output_port_connectors(), loop_ids,
false, stop_it, std::vector<std::set<ExpressionPort>>{});
spill_end_expr->set_reg_info({vregs, {}});
spill_begin_expr->set_live_regs(stop_it->get()->get_live_regs());
spill_end_it->get()->set_reg_info({vregs, {}});
spill_end_it->get()->set_live_regs(std::next(spill_end_it)->get()->get_live_regs());
modified = true;
}
return modified;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ void jit_kernel_emitter::emit_impl(const std::vector<size_t>& in, const std::vec
auto expected_in_type = snippets::RegType::undefined;
auto expected_out_type = snippets::RegType::undefined;
const auto& node = expression->get_node();
// Note: currently only a few operations are allowed to have mixed in/out register types => skip validation here
// Note: A few operations are allowed to have mixed register types on their inputs (or outputs) => skip
// validation here
if (!ov::is_type<snippets::op::LoopEnd>(node) && !ov::is_type<snippets::op::RegSpillBase>(node) &&
!std::dynamic_pointer_cast<jit_nop_emitter>(emitter))
std::tie(expected_in_type, expected_out_type) = get_expected_reg_types(emitter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@ namespace intel_cpu {
jit_reg_spill_begin_emitter::jit_reg_spill_begin_emitter(dnnl::impl::cpu::x64::jit_generator* h,
dnnl::impl::cpu::x64::cpu_isa_t isa,
const ov::snippets::lowered::ExpressionPtr& expr)
: jit_emitter(h, isa),
m_reg_spill_begin_expr(expr) {
const auto& reg_spill_node = ov::as_type_ptr<snippets::op::RegSpillBegin>(m_reg_spill_begin_expr->get_node());
: jit_emitter(h, isa) {
const auto& reg_spill_node = ov::as_type_ptr<snippets::op::RegSpillBegin>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(reg_spill_node, "expects RegSpillBegin expression");
m_num_spilled = reg_spill_node->get_regs_to_spill().size();
m_regs_to_spill = std::set<snippets::Reg>(expr->get_reg_info().second.begin(), expr->get_reg_info().second.end());
m_abi_reg_spiller = std::make_shared<EmitABIRegSpills>(h);
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
}

void jit_reg_spill_begin_emitter::validate_arguments(const std::vector<size_t>& in,
const std::vector<size_t>& out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.empty(), "In regs should be empty for reg_spill_begin emitter");
OV_CPU_JIT_EMITTER_ASSERT(out.size() == m_num_spilled, "Invalid number of out regs for reg_spill_begin emitter");
OV_CPU_JIT_EMITTER_ASSERT(out.size() == m_regs_to_spill.size(),
"Invalid number of out regs for reg_spill_begin emitter");
}

void jit_reg_spill_begin_emitter::emit_code(const std::vector<size_t>& in,
Expand All @@ -42,8 +42,7 @@ void jit_reg_spill_begin_emitter::emit_code(const std::vector<size_t>& in,
}

void jit_reg_spill_begin_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
const auto& reg_info = m_reg_spill_begin_expr->get_reg_info();
m_abi_reg_spiller->preamble({reg_info.second.begin(), reg_info.second.end()});
m_abi_reg_spiller->preamble(m_regs_to_spill);
}

/* ============================================================== */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ class jit_reg_spill_begin_emitter : public jit_emitter {
protected:
void validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;
const ov::snippets::lowered::ExpressionPtr m_reg_spill_begin_expr;
std::set<snippets::Reg> m_regs_to_spill;
std::shared_ptr<EmitABIRegSpills> m_abi_reg_spiller;
size_t m_num_spilled = SIZE_MAX;
};

/* ============================================================== */
Expand Down

0 comments on commit ca8c239

Please sign in to comment.