Skip to content

Commit

Permalink
[Snippets] Fixed aux GPR using in Loop emitters (#26152)
Browse files Browse the repository at this point in the history
### Details:
- *If Loop emitters need aux GPR and have empty `aux_gpr_pool`, now we
take any (unused by this emitter) GPR, push on stack before using and
pop it from stack explicitly after using*

### Tickets:
 - *N/A*

### Prerequisites:
- [x] #25500
  • Loading branch information
a-sidorova authored Aug 22, 2024
1 parent b850b8e commit ce92c13
Show file tree
Hide file tree
Showing 8 changed files with 303 additions and 105 deletions.
4 changes: 2 additions & 2 deletions src/common/snippets/src/pass/mha_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,8 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
parent = parent->get_input_node_shared_ptr(0);
has_matmul0_has_ops_on_input = true;
}
// If there are ops on second input of MatMul0 -> there always will be unique Buffer
if (has_matmul0_has_ops_on_input) {
// If there are ops on second input of MatMul0 and only one unique Buffer between MatMuls - there must be one more unique Buffer
if (has_matmul0_has_ops_on_input && uniqie_buffer_reg_group_count < 2) {
uniqie_buffer_reg_group_count++;
}

Expand Down
102 changes: 71 additions & 31 deletions src/plugins/intel_cpu/src/emitters/snippets/x64/jit_loop_emitters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,52 @@ using namespace dnnl::impl::cpu::x64;
namespace ov {
namespace intel_cpu {

namespace {
class jit_aux_gpr_holder {
public:
jit_aux_gpr_holder(dnnl::impl::cpu::x64::jit_generator* host, std::vector<size_t>& pool_gpr_idxs, const std::vector<size_t>& used_gpr_idxs)
: m_h(host), m_pool_gpr_idxs(pool_gpr_idxs) {
// If the pool is empty, let's manualy allocate the gpr and push original vlaue on stack
if (m_pool_gpr_idxs.empty()) {
m_aux_gpr_idx = Reg64(static_cast<int>(allocate_aux_gpr(used_gpr_idxs)));
m_is_preserved = true;
m_h->push(m_aux_gpr_idx);
} else {
m_aux_gpr_idx = Reg64(static_cast<int>(m_pool_gpr_idxs.back()));
m_pool_gpr_idxs.pop_back();
}
}

~jit_aux_gpr_holder() {
if (m_is_preserved) {
m_h->pop(m_aux_gpr_idx);
} else {
m_pool_gpr_idxs.push_back(m_aux_gpr_idx.getIdx());
}
}

const Reg64& get_reg() const { return m_aux_gpr_idx; }

private:
size_t allocate_aux_gpr(const std::vector<size_t>& used_gpr_idxs) const {
// RSP, RBP - stack-related registers, abi_param1 - runtime parameter register in the kernel
static std::set<size_t> blakclist_gpr_idxs = { Operand::RSP, Operand::RBP, static_cast<size_t>(abi_param1.getIdx()) };
for (size_t gpr_idx = 0; gpr_idx <= Operand::R15; ++gpr_idx) {
size_t _idx = Operand::R15 - gpr_idx; // we allocate from the end
if (std::find(used_gpr_idxs.cbegin(), used_gpr_idxs.cend(), _idx) != used_gpr_idxs.cend()) continue;
if (std::find(blakclist_gpr_idxs.cbegin(), blakclist_gpr_idxs.cend(), _idx) != blakclist_gpr_idxs.cend()) continue;
return _idx;
}
OV_CPU_JIT_EMITTER_THROW("Failed to allocate aux GPR");
}

dnnl::impl::cpu::x64::jit_generator* m_h;
std::vector<size_t>& m_pool_gpr_idxs;
Reg64 m_aux_gpr_idx {};
bool m_is_preserved = false;
};
} // namespace

/* ================== jit_loop_begin_emitter ====================== */

jit_loop_begin_emitter::jit_loop_begin_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa,
Expand All @@ -30,12 +76,6 @@ jit_loop_begin_emitter::jit_loop_begin_emitter(dnnl::impl::cpu::x64::jit_generat
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
}

size_t jit_loop_begin_emitter::aux_gprs_count() const {
// We should have aux GPR to store Loop arguments from `runtime_args`
// where we will take all needed information about the current loop: work amount
return is_work_amount_dynamic ? 1 : 0;
}

void jit_loop_begin_emitter::validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.empty(), "Invalid inputs size: expected 0 got " + std::to_string(in.size()));
// Note: the only expected output is work amount register (communicated to jit_loop_end_emitter)
Expand All @@ -59,10 +99,10 @@ void jit_loop_begin_emitter::emit_impl(const std::vector<size_t>& in, const std:

Reg64 reg_work_amount = Reg64(static_cast<int>(out.back()));
if (is_work_amount_dynamic) {
Reg64 reg_runtime_params = abi_param1; // defined by jit_kernel_emitter
Reg64 reg_loop_args_ptr = Reg64(static_cast<int>(aux_gpr_idxs[0]));
jit_aux_gpr_holder gpr_holder(h, aux_gpr_idxs, out); // loop_begin has only output registers
Reg64 reg_loop_args_ptr = gpr_holder.get_reg();
const auto id_offset = loop_id * sizeof(jit_snippets_call_args::loop_args_t);
h->mov(reg_loop_args_ptr, h->ptr[reg_runtime_params + GET_OFF(loop_args)]);
h->mov(reg_loop_args_ptr, h->ptr[abi_param1 + GET_OFF(loop_args)]);
h->mov(reg_work_amount, h->ptr[reg_loop_args_ptr + id_offset + GET_OFF_LOOP_ARGS(m_work_amount)]);
} else {
h->mov(reg_work_amount, work_amount);
Expand Down Expand Up @@ -141,37 +181,37 @@ void jit_loop_end_emitter::emit_code(const std::vector<size_t> &in, const std::v
jit_emitter::emit_code(in, out, pool_vec_idxs, pool_gpr_idxs);
}

size_t jit_loop_end_emitter::aux_gprs_count() const {
// We should have aux GPR to store Loop arguments from `runtime_args`
// where we will take all needed information about the current loop: data pointer shifts
return are_ptr_shifts_dynamic ? 1 : 0;
}

void jit_loop_end_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
std::vector<size_t> data_ptr_reg_idxs;
// the last input is actually a work_amount reg
data_ptr_reg_idxs.reserve(num_inputs + num_outputs);
std::copy(in.begin(), in.end() - 1, std::back_inserter(data_ptr_reg_idxs));

const auto id_offset = loop_id * sizeof(jit_snippets_call_args::loop_args_t);
Reg64 reg_increments = are_ptr_shifts_dynamic ? Reg64(static_cast<int>(aux_gpr_idxs[0])) : Reg64();

auto apply_increments = [&](bool use_runtime_args, size_t field_offset, const std::vector<int64_t>& increments, size_t scale) {
if (use_runtime_args) {
Reg64 reg_runtime_params = abi_param1; /* defined by jit_kernel_emitter */
h->mov(reg_increments, h->ptr[reg_runtime_params + GET_OFF(loop_args)]);
h->mov(reg_increments, h->ptr[reg_increments + id_offset + field_offset]);
}
for (size_t idx = 0; idx < data_ptr_reg_idxs.size(); idx++) {
const auto& increment = increments[idx];
if (is_incremented[idx] && increment != 0) {
if (ov::snippets::utils::is_dynamic_value(increment)) {
OV_CPU_JIT_EMITTER_ASSERT(use_runtime_args, "Loop argument structure cannot be pushed to aux GPR");
h->add(Reg64(static_cast<int>(data_ptr_reg_idxs[idx])), h->ptr[reg_increments + idx * sizeof(int64_t)]);
} else {
h->add(Reg64(static_cast<int>(data_ptr_reg_idxs[idx])), increment * scale * data_sizes[idx]);
Reg64 reg_increments;
auto add_increments = [&]() {
for (size_t idx = 0; idx < data_ptr_reg_idxs.size(); idx++) {
const auto& increment = increments[idx];
if (is_incremented[idx] && increment != 0) {
if (ov::snippets::utils::is_dynamic_value(increment)) {
OV_CPU_JIT_EMITTER_ASSERT(use_runtime_args, "Loop argument structure cannot be pushed to aux GPR");
h->add(Reg64(static_cast<int>(data_ptr_reg_idxs[idx])), h->ptr[reg_increments + idx * sizeof(int64_t)]);
} else {
h->add(Reg64(static_cast<int>(data_ptr_reg_idxs[idx])), increment * scale * data_sizes[idx]);
}
}
}
};

const auto id_offset = loop_id * sizeof(jit_snippets_call_args::loop_args_t);
if (use_runtime_args) {
jit_aux_gpr_holder gpr_holder(h, aux_gpr_idxs, in); // loop_end has only input registers
reg_increments = gpr_holder.get_reg();
h->mov(reg_increments, h->ptr[abi_param1 + GET_OFF(loop_args)]);
h->mov(reg_increments, h->ptr[reg_increments + id_offset + field_offset]);
add_increments();
} else {
add_increments();
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class jit_loop_begin_emitter: public jit_emitter {
void validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const override;
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;

size_t aux_gprs_count() const override;
// `jit_loop_begin_emitter` handles manually aux_gpr allocation using `jit_aux_gpr_holder`
size_t aux_gprs_count() const override { return 0; }

std::shared_ptr<Xbyak::Label> loop_begin_label = nullptr;
std::shared_ptr<const Xbyak::Label> loop_end_label = nullptr;
Expand Down Expand Up @@ -61,7 +62,8 @@ class jit_loop_end_emitter: public jit_emitter {
void validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const override;
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;

size_t aux_gprs_count() const override;
// `jit_loop_end_emitter` handles manually aux_gpr allocation using `jit_aux_gpr_holder`
size_t aux_gprs_count() const override { return 0; }

static ov::snippets::lowered::ExpressionPtr get_loop_begin_expr(const ov::snippets::lowered::ExpressionPtr& expr);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,35 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(CPUTestUtils::empty_plugin_config)),
MHA::getTestCaseName);

std::vector<std::vector<ov::test::InputShape>> inputShapes_4D_WithMul_dynamic{
{
{PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {1, 70, 3, 19}, {1, 128, 3, 64}, {1, 68, 6, 87}}},
{PartialShape{-1, -1, -1, -1}, {{1, 128, 1, 64}, {2, 49, 1, 19}, {1, 128, 1, 64}, {2, 13, 6, 87}}},
{PartialShape{1}, {{1}, {1}, {1}, {1} }},
{PartialShape{-1, -1, -1, -1}, {{2, 1, 128, 128}, {1, 1, 70, 49}, {2, 1, 128, 128}, {1, 1, 68, 13}}},
{PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {1, 49, 3, 19}, {1, 128, 3, 64}, {2, 13, 6, 87}}},
},
{
{PartialShape{-1, -1, 12, 64}, {{1, 70, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 70, 12, 64}}},
{PartialShape{-1, -1, 12, 64}, {{1, 35, 12, 64}, {2, 10, 12, 64}, {2, 1, 12, 64}, {2, 10, 12, 64}, {1, 35, 12, 64}}},
{PartialShape{-1, 12, 64, -1}, {{1, 12, 64, 35}, {1, 12, 64, 10}, {1, 12, 64, 10}, {1, 12, 64, 1}, {1, 12, 64, 35}}},
{PartialShape{-1, 12, -1, -1}, {{2, 12, 70, 35}, {1, 12, 20, 10}, {1, 12, 20, 10}, {1, 12, 20, 1}, {2, 12, 70, 35}}},
{PartialShape{-1, -1, 12, 64}, {{1, 35, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 35, 12, 64}}},
}
};

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_DynMHA_4D_WithMul,
MHAWithDynamicMul,
::testing::Combine(::testing::ValuesIn(inputShapes_4D_WithMul_dynamic),
::testing::ValuesIn(precision_f32(5)),
::testing::Values(ov::element::f32),
::testing::Values(MHA::default_thread_count),
::testing::Values(1),
::testing::Values(1),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::empty_plugin_config)),
MHAWithDynamicMul::getTestCaseName);

} // namespace
} // namespace snippets
} // namespace test
Expand Down
65 changes: 47 additions & 18 deletions src/tests/functional/plugin/shared/include/snippets/mha.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace ov {
namespace test {
namespace snippets {

typedef std::tuple<std::vector<InputShape>, // Input shapes
typedef std::tuple<std::vector<InputShape>, // Input shapes
std::vector<ov::element::Type>, // Input Element types
ov::element::Type, // Inference precision
bool, // With Multiply
Expand All @@ -23,72 +23,101 @@ typedef std::tuple<std::vector<InputShape>, // Input shapes
>
MHAParams;

class MHA : public testing::WithParamInterface<ov::test::snippets::MHAParams>,
virtual public ov::test::SnippetsTestsCommon {
public:
static std::string getTestCaseName(testing::TestParamInfo<ov::test::snippets::MHAParams> obj);
typedef std::tuple<std::vector<InputShape>, // Input shapes
std::vector<ov::element::Type>, // Input Element types
ov::element::Type, // Inference precision
size_t, // Thread count
size_t, // Expected num nodes
size_t, // Expected num subgraphs
std::string, // Target Device
ov::AnyMap // Config
>
MHAWithDynamicMulParams;

class MHABase : virtual public ov::test::SnippetsTestsCommon {
public:
constexpr static size_t default_thread_count = 0;

protected:
void SetUp() override;

void compile_model() override;
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override;
virtual std::shared_ptr<SnippetsFunctionBase> get_subgraph();
virtual std::shared_ptr<SnippetsFunctionBase> get_subgraph() const = 0;
virtual void init_params(std::vector<InputShape>& input_shapes, ov::element::Type& prc, ov::AnyMap& additional_config) = 0;

bool m_with_mul = false;
size_t m_thread_count;
std::vector<ov::element::Type> m_input_types;
};

class MHA : public testing::WithParamInterface<ov::test::snippets::MHAParams>,
virtual public MHABase {
public:
static std::string getTestCaseName(testing::TestParamInfo<ov::test::snippets::MHAParams> obj);

protected:
std::shared_ptr<SnippetsFunctionBase> get_subgraph() const override;
void init_params(std::vector<InputShape>& input_shapes, ov::element::Type& prc, ov::AnyMap& additional_config) override;

bool m_with_mul = false;
};

class MHASelect : public MHA {
protected:
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override;
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
std::shared_ptr<SnippetsFunctionBase> get_subgraph() const override;
};

class MHAWOTransposeOnInputs : public MHA {
protected:
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
std::shared_ptr<SnippetsFunctionBase> get_subgraph() const override;
};

class MHAWOTranspose : public MHA {
protected:
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
std::shared_ptr<SnippetsFunctionBase> get_subgraph() const override;
};

class MHAMulAdd : public MHA {
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
std::shared_ptr<SnippetsFunctionBase> get_subgraph() const override;
};

class MHATransposedB : public MHA {
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
std::shared_ptr<SnippetsFunctionBase> get_subgraph() const override;
};

class MHAINT8MatMul : public MHA {
protected:
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
std::shared_ptr<SnippetsFunctionBase> get_subgraph() const override;
};

class MHAQuantMatMul0 : public MHA {
protected:
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
std::shared_ptr<SnippetsFunctionBase> get_subgraph() const override;
};

class MHAFQAfterMatMul : public MHA {
protected:
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
std::shared_ptr<SnippetsFunctionBase> get_subgraph() const override;
};

class MHAFQ : public MHA {
protected:
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
std::shared_ptr<SnippetsFunctionBase> get_subgraph() const override;
};

class MHAWithExtractedReshape : public MHA {
protected:
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
std::shared_ptr<SnippetsFunctionBase> get_subgraph() const override;
};

class MHAWithDynamicMul : public testing::WithParamInterface<ov::test::snippets::MHAWithDynamicMulParams>,
virtual public MHABase {
public:
static std::string getTestCaseName(testing::TestParamInfo<ov::test::snippets::MHAWithDynamicMulParams> obj);

protected:
std::shared_ptr<SnippetsFunctionBase> get_subgraph() const override;
void init_params(std::vector<InputShape>& input_shapes, ov::element::Type& prc, ov::AnyMap& additional_config) override;
};

} // namespace snippets
Expand Down
Loading

0 comments on commit ce92c13

Please sign in to comment.