Skip to content

Commit

Permalink
Fixed MatMul on Win
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jan 11, 2023
1 parent 85e3a20 commit 2166a14
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/common/snippets/src/op/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using namespace std;
using namespace ngraph;

auto normalize_rank(int32_t allocation_rank, const size_t shape_rank) -> int32_t {
return allocation_rank < 0 ? allocation_rank + shape_rank : allocation_rank;
return allocation_rank < 0 ? allocation_rank + static_cast<int32_t>(shape_rank) : allocation_rank;
}

snippets::op::Buffer::Buffer(const Output<Node>& x, const int32_t allocation_rank) : Op({x}), m_allocation_rank(allocation_rank) {
Expand Down
61 changes: 28 additions & 33 deletions src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,51 +918,27 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in
h->uni_vmovups(h->ptr[h->rsp + i * get_vec_length()], Vmm(i));

// save function address in gpr to pass in call instruction
const auto& brgemm_kernel_overload = static_cast<void (*)(const brgemm_kernel_t*,
int,
const void*,
const void*,
const brgemm_batch_element_t*,
void*,
void*)>(brgemm_kernel_execute);
const auto& brgemm_kernel_overload = static_cast<void (*)(const brgemm_kernel_t*,
const void*,
const void*,
void*)>(kernel_execute);
h->mov(h->rbp, reinterpret_cast<uintptr_t>(brgemm_kernel_overload));
// todo: several of addr_{A, B, C} could be also abi_paramX, so one of them could be corrupted
// if moving directly h->uni_vmovq(abi_paramX, adr_X). Save them to vector regs to avoid corruption.
// It's likely that a more efficient solution exists.
h->uni_vmovq(Xmm(0), addr_A);
h->uni_vmovq(Xmm(1), addr_B);
h->uni_vmovq(Xmm(2), addr_C);
// todo: Windows ABI : requires different num of arguments passed in regs and on the stack. Need to align.
h->mov(abi_param1, reinterpret_cast<uintptr_t>(brgKernel));
h->mov(abi_param2, bs);

const auto data_ptr_reg = [&](Xmm xmm, Xbyak::Reg64 reg, size_t bytes_offset) {
h->uni_vmovq(reg, xmm);
if (bytes_offset) h->add(reg, bytes_offset);
};
data_ptr_reg(Xmm(0), abi_param3, in0_kernel_offset);
data_ptr_reg(Xmm(1), abi_param4, in1_kernel_offset);

size_t num_args_passed_on_stack = 0;
#ifdef _WIN32
const auto data_ptr_stack = [&](Xmm xmm, size_t idx, size_t bytes_offset) {
h->uni_vmovq(h->qword[h->rsp + idx * gpr_size], xmm);
if (bytes_offset) h->add(h->qword[h->rsp + idx * gpr_size], bytes_offset);
};
h->mov(abi_param1, reinterpret_cast<uintptr_t>(brgKernel));
data_ptr_reg(Xmm(0), abi_param2, in0_kernel_offset);
data_ptr_reg(Xmm(1), abi_param3, in1_kernel_offset);
data_ptr_reg(Xmm(1), abi_param4, out0_kernel_offset);

num_args_passed_on_stack = 3;
h->sub(h->rsp, num_args_passed_on_stack * gpr_size);
h->mov(h->qword[h->rsp + 0 * gpr_size], reinterpret_cast<uintptr_t>(batch));
data_ptr_stack(Xmm(2), 1, out0_kernel_offset);
h->mov(h->qword[h->rsp + 2 * gpr_size], reinterpret_cast<uintptr_t>(scratch));
#else
h->mov(abi_param5, reinterpret_cast<uintptr_t>(batch));
data_ptr_reg(Xmm(2), abi_param6, out0_kernel_offset);

num_args_passed_on_stack = 1;
h->sub(h->rsp, num_args_passed_on_stack * gpr_size);
h->mov(h->qword[h->rsp], reinterpret_cast<uintptr_t>(scratch));
#endif
// align stack on 16-byte as ABI requires
// note that RBX must not be changed by the callee
h->mov(h->rbx, h->rsp);
Expand All @@ -972,7 +948,6 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in
h->call(h->rbp);

h->add(h->rsp, h->rbx);
h->add(h->rsp, gpr_size * num_args_passed_on_stack);
// restore vector registers
for (int i = static_cast<int>(get_max_vecs_count()) - 1; i >= 0; --i) {
h->uni_vmovups(Vmm(i), h->ptr[h->rsp + i * get_vec_length()]);
Expand All @@ -996,6 +971,26 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in
h->add(h->rsp, n_gprs_to_save * gpr_size);
}

void BrgemmEmitter::kernel_execute(const brgemm_kernel_t *brg_kernel, const void *A, const void *B, void *C) {
// TODO: There are 4 available abi_params on Windows so we have the copy of brgemm_kernel_execute() function
// with 4 runtime parameters (kernel and I/O) and 4 default parameter values (batch, bs and scratch)
brgemm_kernel_params_t brgemm_p;

brgemm_p.batch = nullptr; // default value
brgemm_p.ptr_A = A;
brgemm_p.ptr_B = B;
brgemm_p.ptr_C = C;
brgemm_p.ptr_D = C;
brgemm_p.ptr_buf = nullptr; // default value
brgemm_p.ptr_bias = nullptr;
brgemm_p.do_post_ops = 0;
brgemm_p.do_apply_comp = 0;
brgemm_p.skip_accm = 0;
brgemm_p.BS = 1; // default value
assert(brg_kernel);
(*brg_kernel)(&brgemm_p);
}

template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void BrgemmEmitter::emit_isa(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
using Vmm = typename dnnl::impl::utils::conditional3<isa == cpu::x64::sse41, Xmm, isa == cpu::x64::avx2, Ymm, Zmm>::type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ class BrgemmEmitter : public jit_emitter {
Reg64 addr_A, Reg64 addr_B,
const brgemm_batch_element_t *batch, Reg64 addr_C, void *scratch,
const size_t in0_kernel_offset, const size_t in1_kernel_offset, const size_t out0_kernel_offset) const;

static void kernel_execute(const brgemm_kernel_t *brg_kernel, const void *A, const void *B, void *C);
static constexpr size_t BRGEMM_KERNELS_NUM = 8;
static constexpr size_t matmulOptimalM = 32;
brgemmCtx brgCtxs0[BRGEMM_KERNELS_NUM];
Expand Down

0 comments on commit 2166a14

Please sign in to comment.