Skip to content

Commit

Permalink
[Snippets][CPU] Added dynamic MatMul INT8/BF16 support (openvinotoolk…
Browse files Browse the repository at this point in the history
…it#26493)

### Details:
 - *Implemented BrgemmCopyBExecutor*
 - *Added tests with dynamic input shapes for MatMul INT8/BF16*

### Tickets:
 - *151922*

### Prerequisites:
- [x] openvinotoolkit#26413
  • Loading branch information
a-sidorova authored Oct 1, 2024
1 parent d2a9873 commit 9284684
Show file tree
Hide file tree
Showing 32 changed files with 1,350 additions and 872 deletions.
18 changes: 15 additions & 3 deletions src/common/snippets/include/snippets/utils/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,21 @@ constexpr inline bool implication(bool cause, bool cond) {
}

template <typename T, typename U>
inline T div_up(const T a, const U b) {
OPENVINO_ASSERT(b != 0, "Divider must not be zero");
return static_cast<T>((a + b - 1) / b);
static inline auto div_up(const T lhs, const U rhs) -> decltype((lhs + rhs - 1) / rhs) {
OPENVINO_ASSERT(rhs != 0, "Divider must not be zero");
if (((std::is_same<T, size_t>::value || std::is_same<T, int64_t>::value) && utils::is_dynamic_value(lhs)) ||
((std::is_same<U, size_t>::value || std::is_same<U, int64_t>::value) && utils::is_dynamic_value(rhs)))
return utils::get_dynamic_value<T>();
return (lhs + rhs - 1) / rhs;
}

template<typename T, typename U>
static inline auto rnd_up(const T lhs, const U rhs) -> decltype(div_up(lhs, rhs) * rhs) {
const T div_up_res = div_up(lhs, rhs);
if (((std::is_same<T, size_t>::value || std::is_same<T, int64_t>::value) && utils::is_dynamic_value(div_up_res)) ||
((std::is_same<U, size_t>::value || std::is_same<U, int64_t>::value) && utils::is_dynamic_value(rhs)))
return utils::get_dynamic_value<T>();
return div_up_res * rhs;
}

inline bool is_dynamic_vdims(const VectorDims& shape) {
Expand Down
220 changes: 220 additions & 0 deletions src/plugins/intel_cpu/src/emitters/plugin/x64/debug_capabilities.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#ifdef CPU_DEBUG_CAPS

#include "debug_capabilities.hpp"
#include <iostream>
#include <sstream>

namespace ov {
namespace intel_cpu {

using namespace Xbyak;
using namespace dnnl::impl::cpu::x64;

template void RegPrinter::print<float, Xmm>(jit_generator &h, Xmm reg, const char *name);
template void RegPrinter::print<int, Xmm>(jit_generator &h, Xmm reg, const char *name);
template void RegPrinter::print<float, Ymm>(jit_generator &h, Ymm reg, const char *name);
template void RegPrinter::print<int, Ymm>(jit_generator &h, Ymm reg, const char *name);
template void RegPrinter::print<float, Zmm>(jit_generator &h, Zmm reg, const char *name);
template void RegPrinter::print<int, Zmm>(jit_generator &h, Zmm reg, const char *name);
template void RegPrinter::print<float, Reg64>(jit_generator &h, Reg64 reg, const char *name);
template void RegPrinter::print<int, Reg64>(jit_generator &h, Reg64 reg, const char *name);
template void RegPrinter::print<float, Reg32>(jit_generator &h, Reg32 reg, const char *name);
template void RegPrinter::print<int, Reg32>(jit_generator &h, Reg32 reg, const char *name);
template void RegPrinter::print<char, Reg16>(jit_generator &h, Reg16 reg, const char *name);
template void RegPrinter::print<unsigned char, Reg16>(jit_generator &h, Reg16 reg, const char *name);
template void RegPrinter::print<char, Reg8>(jit_generator &h, Reg8 reg, const char *name);
template void RegPrinter::print<unsigned char, Reg8>(jit_generator &h, Reg8 reg, const char *name);

template <typename T>
void RegPrinter::print_reg_prc(const char *name, const char *ori_name, T *ptr) {
std::stringstream ss;
if (name) ss << name << " | ";
ss << ori_name << ": ";
if (std::is_floating_point<T>::value) {
ss << *ptr;
} else {
if (std::is_signed<T>::value) {
ss << static_cast<int64_t>(*ptr);
} else {
ss << static_cast<uint64_t>(*ptr);
}
}
ss << std::endl;
std::cout << ss.str();
}

template <typename PRC_T, size_t vlen>
void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr) {
std::stringstream ss;
if (name) ss << name << " | ";
ss << ori_name << ": {" << ptr[0];
for (size_t i = 1; i < vlen / sizeof(float); i++) {
ss << ", " << ptr[i];
}
ss << "}" << std::endl;
std::cout << ss.str();
}
template void RegPrinter::print_vmm_prc<float, 16>(const char *name, const char *ori_name, float *ptr);
template void RegPrinter::print_vmm_prc<float, 32>(const char *name, const char *ori_name, float *ptr);
template void RegPrinter::print_vmm_prc<float, 64>(const char *name, const char *ori_name, float *ptr);
template void RegPrinter::print_vmm_prc<int, 16>(const char *name, const char *ori_name, int *ptr);
template void RegPrinter::print_vmm_prc<int, 32>(const char *name, const char *ori_name, int *ptr);
template void RegPrinter::print_vmm_prc<int, 64>(const char *name, const char *ori_name, int *ptr);

template <typename Vmm>
struct vmm_traits{};

template <>
struct vmm_traits<Xmm> {
static constexpr size_t vmm_len = 16;
static constexpr size_t vmm_cnt = 16;
};

template <>
struct vmm_traits<Ymm> {
static constexpr size_t vmm_len = 32;
static constexpr size_t vmm_cnt = 16;
};

template <>
struct vmm_traits<Zmm> {
static constexpr size_t vmm_len = 64;
static constexpr size_t vmm_cnt = 32;
};

template <typename T>
void RegPrinter::save_vmm(jit_generator &h) {
h.sub(h.rsp, vmm_traits<T>::vmm_len * vmm_traits<T>::vmm_cnt);
for (size_t i = 0; i < vmm_traits<T>::vmm_cnt; i++) {
h.uni_vmovups(h.ptr[h.rsp + i * vmm_traits<T>::vmm_len], T(i));
}
}

template <typename T>
void RegPrinter::restore_vmm(jit_generator &h) {
for (size_t i = 0; i < vmm_traits<T>::vmm_cnt; i++) {
h.uni_vmovups(T(i), h.ptr[h.rsp + i * vmm_traits<T>::vmm_len]);
}
h.add(h.rsp, vmm_traits<T>::vmm_len * vmm_traits<T>::vmm_cnt);
}

void RegPrinter::save_reg(jit_generator &h) {
h.sub(h.rsp, reg_len * reg_cnt);
for (size_t i = 0; i < reg_cnt; i++) {
h.mov(h.ptr[h.rsp + i * reg_len], Reg64(i));
}
}

void RegPrinter::restore_reg(jit_generator &h) {
for (size_t i = 0; i < reg_cnt; i++) {
h.mov(Reg64(i), h.ptr[h.rsp + i * reg_len]);
}
h.add(h.rsp, reg_len * reg_cnt);
}

void RegPrinter::preamble(jit_generator &h) {
save_reg(h);
mayiuse(cpu_isa_t::avx512_core) ? save_vmm<Zmm>(h) : (mayiuse(cpu_isa_t::avx2) ?
save_vmm<Ymm>(h) : save_vmm<Xmm>(h));
}

void RegPrinter::postamble(jit_generator &h) {
mayiuse(cpu_isa_t::avx512_core) ? restore_vmm<Zmm>(h) : (mayiuse(cpu_isa_t::avx2) ?
restore_vmm<Ymm>(h) : restore_vmm<Xmm>(h));
restore_reg(h);
}

// ABI requires 16-bype stack alignment before a call
void RegPrinter::align_rsp(jit_generator &h) {
constexpr int alignment = 16;
h.mov(h.r15, h.rsp);
h.and_(h.rsp, ~(alignment - 1));
}

void RegPrinter::restore_rsp(jit_generator &h) {
h.mov(h.rsp, h.r15);
}

template <typename PRC_T, typename REG_T>
void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) {
preamble(h);

h.push(h.rax);
h.push(abi_param1);
h.push(abi_param2);
h.push(abi_param3);
{
const int vlen = vmm.isZMM() ? 64 : (vmm.isYMM() ? 32 : 16);
h.sub(h.rsp, vlen);
h.uni_vmovups(h.ptr[h.rsp], vmm);

h.mov(abi_param3, h.rsp);
h.mov(abi_param2, reinterpret_cast<size_t>(vmm.toString()));
h.mov(abi_param1, reinterpret_cast<size_t>(name));
if (vmm.isZMM()) {
auto p = &print_vmm_prc<PRC_T, 64>;
h.mov(h.rax, reinterpret_cast<size_t>(p));
} else if (vmm.isYMM()) {
auto p = &print_vmm_prc<PRC_T, 32>;
h.mov(h.rax, reinterpret_cast<size_t>(p));
} else {
auto p = &print_vmm_prc<PRC_T, 16>;
h.mov(h.rax, reinterpret_cast<size_t>(p));
}
align_rsp(h);
h.call(h.rax);
restore_rsp(h);

h.add(h.rsp, vlen);
}

h.pop(abi_param3);
h.pop(abi_param2);
h.pop(abi_param1);
h.pop(h.rax);

postamble(h);
}

template <typename PRC_T, typename REG_T>
void RegPrinter::print_reg(jit_generator &h, REG_T reg, const char *name) {
preamble(h);

h.push(h.rax);
h.push(abi_param1);
h.push(abi_param2);
h.push(abi_param3);
{
const int rlen = reg.getBit() / 8;
h.sub(h.rsp, rlen);
h.mov(h.ptr[h.rsp], reg);

h.mov(abi_param3, h.rsp);
h.mov(abi_param2, reinterpret_cast<size_t>(reg.toString()));
h.mov(abi_param1, reinterpret_cast<size_t>(name));
auto p = &print_reg_prc<PRC_T>;
h.mov(h.rax, reinterpret_cast<size_t>(p));
align_rsp(h);
h.call(h.rax);
restore_rsp(h);

h.add(h.rsp, rlen);
}

h.pop(abi_param3);
h.pop(abi_param2);
h.pop(abi_param1);
h.pop(h.rax);

postamble(h);
}

} // namespace intel_cpu
} // namespace ov


#endif // CPU_DEBUG_CAPS
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#ifdef CPU_DEBUG_CAPS

#include "cpu/x64/jit_generator.hpp"

namespace ov {
namespace intel_cpu {

// Usage
// 1. Include this headfile where JIT kennels of CPU plugin are implemented for Register printing
// 2. Invoke RegPrinter::print method. Here are some examples. Note that user friendly register name
// will be printed, if it has been set. Current implementation doesn't buffer the name. So if you
// choose to set a name for the register, do not use local variable to pass the name, just pass a
// direct string to the interface like examples. While Original Xbyak register name will always be
// printed.
// Example 1:
// Invocation: RegPrinter::print<float>(*this, vmm_val, "vmm_val");
// Console: vmm_val | ymm0: {30, 20, 25, 29, 24, 31, 27, 23}
//
// Example 2:
// Invocation: RegPrinter::print<float>(*this, vmm_val);
// Console: ymm0: {30, 20, 25, 29, 24, 31, 27, 23}
//
// Example 3:
// Invocation: RegPrinter::print<int>(*this, vmm_idx, "vmm_idx");
// Console: vmm_idx | ymm1: {5, 6, 0, 2, 0, 6, 6, 6}
//
// Example 4:
// Invocation: RegPrinter::print<int>(*this, reg_work_amount, "reg_work_amount");
// Console: reg_work_amount | r13: 8
//
// Example 5:
// Invocation: RegPrinter::print<int>(*this, reg_work_amount);
// Console: r13: 8
//
// Example 6:
// Invocation: RegPrinter::print<float>(*this, reg_tmp_64, "reg_tmp_64");
// Console: reg_tmp_64 | r15: 1
//
// Parameter
// The following combinations of Register types and precisions are supported.
// fp32 int32 int8 u8
// Xmm Yes Yes No No
// Ymm Yes Yes No No
// Zmm Yes Yes No No
// Reg64 Yes Yes No No
// Reg32 Yes Yes No No
// Reg16 No No Yes Yes
// Reg8 No No Yes Yes

class RegPrinter {
public:
using jit_generator = dnnl::impl::cpu::x64::jit_generator;
template <typename PRC_T, typename REG_T,
typename std::enable_if<std::is_base_of<Xbyak::Xmm, REG_T>::value, int>::type = 0>
static void print(jit_generator &h, REG_T reg, const char *name = nullptr) {
print_vmm<PRC_T, REG_T>(h, reg, name);
}
template <typename PRC_T, typename REG_T,
typename std::enable_if<!std::is_base_of<Xbyak::Xmm, REG_T>::value, int>::type = 0>
static void print(jit_generator &h, REG_T reg, const char *name = nullptr) {
print_reg<PRC_T, REG_T>(h, reg, name);
}

private:
RegPrinter() {}
template <typename PRC_T, typename REG_T>
static void print_vmm(jit_generator &h, REG_T vmm, const char *name);
template <typename PRC_T, typename REG_T>
static void print_reg(jit_generator &h, REG_T reg, const char *name);
template <typename PRC_T, size_t vlen>
static void print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr);
template <typename T>
static void print_reg_prc(const char *name, const char *ori_name, T *val);
static void preamble(jit_generator &h);
static void postamble(jit_generator &h);
template <typename T>
static void save_vmm(jit_generator &h);
template <typename T>
static void restore_vmm(jit_generator &h);
static void save_reg(jit_generator &h);
static void restore_reg(jit_generator &h);
static void align_rsp(jit_generator &h);
static void restore_rsp(jit_generator &h);
static constexpr size_t reg_len = 8;
static constexpr size_t reg_cnt = 16;
};

} // namespace intel_cpu
} // namespace ov

#endif // CPU_DEBUG_CAPS
Loading

0 comments on commit 9284684

Please sign in to comment.