forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Snippets][CPU] Added dynamic MatMul INT8/BF16 support (openvinotoolk…
…it#26493) ### Details: - *Implemented BrgemmCopyBExecutor* - *Added tests with dynamic input shapes for MatMul INT8/BF16* ### Tickets: - *151922* ### Prerequisites: - [x] openvinotoolkit#26413
- Loading branch information
1 parent
d2a9873
commit 9284684
Showing
32 changed files
with
1,350 additions
and
872 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
220 changes: 220 additions & 0 deletions
220
src/plugins/intel_cpu/src/emitters/plugin/x64/debug_capabilities.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#ifdef CPU_DEBUG_CAPS | ||
|
||
#include "debug_capabilities.hpp" | ||
#include <iostream> | ||
#include <sstream> | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
|
||
using namespace Xbyak; | ||
using namespace dnnl::impl::cpu::x64; | ||
|
||
template void RegPrinter::print<float, Xmm>(jit_generator &h, Xmm reg, const char *name); | ||
template void RegPrinter::print<int, Xmm>(jit_generator &h, Xmm reg, const char *name); | ||
template void RegPrinter::print<float, Ymm>(jit_generator &h, Ymm reg, const char *name); | ||
template void RegPrinter::print<int, Ymm>(jit_generator &h, Ymm reg, const char *name); | ||
template void RegPrinter::print<float, Zmm>(jit_generator &h, Zmm reg, const char *name); | ||
template void RegPrinter::print<int, Zmm>(jit_generator &h, Zmm reg, const char *name); | ||
template void RegPrinter::print<float, Reg64>(jit_generator &h, Reg64 reg, const char *name); | ||
template void RegPrinter::print<int, Reg64>(jit_generator &h, Reg64 reg, const char *name); | ||
template void RegPrinter::print<float, Reg32>(jit_generator &h, Reg32 reg, const char *name); | ||
template void RegPrinter::print<int, Reg32>(jit_generator &h, Reg32 reg, const char *name); | ||
template void RegPrinter::print<char, Reg16>(jit_generator &h, Reg16 reg, const char *name); | ||
template void RegPrinter::print<unsigned char, Reg16>(jit_generator &h, Reg16 reg, const char *name); | ||
template void RegPrinter::print<char, Reg8>(jit_generator &h, Reg8 reg, const char *name); | ||
template void RegPrinter::print<unsigned char, Reg8>(jit_generator &h, Reg8 reg, const char *name); | ||
|
||
template <typename T> | ||
void RegPrinter::print_reg_prc(const char *name, const char *ori_name, T *ptr) { | ||
std::stringstream ss; | ||
if (name) ss << name << " | "; | ||
ss << ori_name << ": "; | ||
if (std::is_floating_point<T>::value) { | ||
ss << *ptr; | ||
} else { | ||
if (std::is_signed<T>::value) { | ||
ss << static_cast<int64_t>(*ptr); | ||
} else { | ||
ss << static_cast<uint64_t>(*ptr); | ||
} | ||
} | ||
ss << std::endl; | ||
std::cout << ss.str(); | ||
} | ||
|
||
template <typename PRC_T, size_t vlen> | ||
void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr) { | ||
std::stringstream ss; | ||
if (name) ss << name << " | "; | ||
ss << ori_name << ": {" << ptr[0]; | ||
for (size_t i = 1; i < vlen / sizeof(float); i++) { | ||
ss << ", " << ptr[i]; | ||
} | ||
ss << "}" << std::endl; | ||
std::cout << ss.str(); | ||
} | ||
template void RegPrinter::print_vmm_prc<float, 16>(const char *name, const char *ori_name, float *ptr); | ||
template void RegPrinter::print_vmm_prc<float, 32>(const char *name, const char *ori_name, float *ptr); | ||
template void RegPrinter::print_vmm_prc<float, 64>(const char *name, const char *ori_name, float *ptr); | ||
template void RegPrinter::print_vmm_prc<int, 16>(const char *name, const char *ori_name, int *ptr); | ||
template void RegPrinter::print_vmm_prc<int, 32>(const char *name, const char *ori_name, int *ptr); | ||
template void RegPrinter::print_vmm_prc<int, 64>(const char *name, const char *ori_name, int *ptr); | ||
|
||
template <typename Vmm> | ||
struct vmm_traits{}; | ||
|
||
template <> | ||
struct vmm_traits<Xmm> { | ||
static constexpr size_t vmm_len = 16; | ||
static constexpr size_t vmm_cnt = 16; | ||
}; | ||
|
||
template <> | ||
struct vmm_traits<Ymm> { | ||
static constexpr size_t vmm_len = 32; | ||
static constexpr size_t vmm_cnt = 16; | ||
}; | ||
|
||
template <> | ||
struct vmm_traits<Zmm> { | ||
static constexpr size_t vmm_len = 64; | ||
static constexpr size_t vmm_cnt = 32; | ||
}; | ||
|
||
template <typename T> | ||
void RegPrinter::save_vmm(jit_generator &h) { | ||
h.sub(h.rsp, vmm_traits<T>::vmm_len * vmm_traits<T>::vmm_cnt); | ||
for (size_t i = 0; i < vmm_traits<T>::vmm_cnt; i++) { | ||
h.uni_vmovups(h.ptr[h.rsp + i * vmm_traits<T>::vmm_len], T(i)); | ||
} | ||
} | ||
|
||
template <typename T> | ||
void RegPrinter::restore_vmm(jit_generator &h) { | ||
for (size_t i = 0; i < vmm_traits<T>::vmm_cnt; i++) { | ||
h.uni_vmovups(T(i), h.ptr[h.rsp + i * vmm_traits<T>::vmm_len]); | ||
} | ||
h.add(h.rsp, vmm_traits<T>::vmm_len * vmm_traits<T>::vmm_cnt); | ||
} | ||
|
||
void RegPrinter::save_reg(jit_generator &h) { | ||
h.sub(h.rsp, reg_len * reg_cnt); | ||
for (size_t i = 0; i < reg_cnt; i++) { | ||
h.mov(h.ptr[h.rsp + i * reg_len], Reg64(i)); | ||
} | ||
} | ||
|
||
void RegPrinter::restore_reg(jit_generator &h) { | ||
for (size_t i = 0; i < reg_cnt; i++) { | ||
h.mov(Reg64(i), h.ptr[h.rsp + i * reg_len]); | ||
} | ||
h.add(h.rsp, reg_len * reg_cnt); | ||
} | ||
|
||
void RegPrinter::preamble(jit_generator &h) { | ||
save_reg(h); | ||
mayiuse(cpu_isa_t::avx512_core) ? save_vmm<Zmm>(h) : (mayiuse(cpu_isa_t::avx2) ? | ||
save_vmm<Ymm>(h) : save_vmm<Xmm>(h)); | ||
} | ||
|
||
void RegPrinter::postamble(jit_generator &h) { | ||
mayiuse(cpu_isa_t::avx512_core) ? restore_vmm<Zmm>(h) : (mayiuse(cpu_isa_t::avx2) ? | ||
restore_vmm<Ymm>(h) : restore_vmm<Xmm>(h)); | ||
restore_reg(h); | ||
} | ||
|
||
// ABI requires 16-bype stack alignment before a call | ||
void RegPrinter::align_rsp(jit_generator &h) { | ||
constexpr int alignment = 16; | ||
h.mov(h.r15, h.rsp); | ||
h.and_(h.rsp, ~(alignment - 1)); | ||
} | ||
|
||
void RegPrinter::restore_rsp(jit_generator &h) { | ||
h.mov(h.rsp, h.r15); | ||
} | ||
|
||
template <typename PRC_T, typename REG_T> | ||
void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) { | ||
preamble(h); | ||
|
||
h.push(h.rax); | ||
h.push(abi_param1); | ||
h.push(abi_param2); | ||
h.push(abi_param3); | ||
{ | ||
const int vlen = vmm.isZMM() ? 64 : (vmm.isYMM() ? 32 : 16); | ||
h.sub(h.rsp, vlen); | ||
h.uni_vmovups(h.ptr[h.rsp], vmm); | ||
|
||
h.mov(abi_param3, h.rsp); | ||
h.mov(abi_param2, reinterpret_cast<size_t>(vmm.toString())); | ||
h.mov(abi_param1, reinterpret_cast<size_t>(name)); | ||
if (vmm.isZMM()) { | ||
auto p = &print_vmm_prc<PRC_T, 64>; | ||
h.mov(h.rax, reinterpret_cast<size_t>(p)); | ||
} else if (vmm.isYMM()) { | ||
auto p = &print_vmm_prc<PRC_T, 32>; | ||
h.mov(h.rax, reinterpret_cast<size_t>(p)); | ||
} else { | ||
auto p = &print_vmm_prc<PRC_T, 16>; | ||
h.mov(h.rax, reinterpret_cast<size_t>(p)); | ||
} | ||
align_rsp(h); | ||
h.call(h.rax); | ||
restore_rsp(h); | ||
|
||
h.add(h.rsp, vlen); | ||
} | ||
|
||
h.pop(abi_param3); | ||
h.pop(abi_param2); | ||
h.pop(abi_param1); | ||
h.pop(h.rax); | ||
|
||
postamble(h); | ||
} | ||
|
||
template <typename PRC_T, typename REG_T> | ||
void RegPrinter::print_reg(jit_generator &h, REG_T reg, const char *name) { | ||
preamble(h); | ||
|
||
h.push(h.rax); | ||
h.push(abi_param1); | ||
h.push(abi_param2); | ||
h.push(abi_param3); | ||
{ | ||
const int rlen = reg.getBit() / 8; | ||
h.sub(h.rsp, rlen); | ||
h.mov(h.ptr[h.rsp], reg); | ||
|
||
h.mov(abi_param3, h.rsp); | ||
h.mov(abi_param2, reinterpret_cast<size_t>(reg.toString())); | ||
h.mov(abi_param1, reinterpret_cast<size_t>(name)); | ||
auto p = &print_reg_prc<PRC_T>; | ||
h.mov(h.rax, reinterpret_cast<size_t>(p)); | ||
align_rsp(h); | ||
h.call(h.rax); | ||
restore_rsp(h); | ||
|
||
h.add(h.rsp, rlen); | ||
} | ||
|
||
h.pop(abi_param3); | ||
h.pop(abi_param2); | ||
h.pop(abi_param1); | ||
h.pop(h.rax); | ||
|
||
postamble(h); | ||
} | ||
|
||
} // namespace intel_cpu | ||
} // namespace ov | ||
|
||
|
||
#endif // CPU_DEBUG_CAPS |
97 changes: 97 additions & 0 deletions
97
src/plugins/intel_cpu/src/emitters/plugin/x64/debug_capabilities.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#ifdef CPU_DEBUG_CAPS | ||
|
||
#include "cpu/x64/jit_generator.hpp" | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
|
||
// Usage | ||
// 1. Include this headfile where JIT kennels of CPU plugin are implemented for Register printing | ||
// 2. Invoke RegPrinter::print method. Here are some examples. Note that user friendly register name | ||
// will be printed, if it has been set. Current implementation doesn't buffer the name. So if you | ||
// choose to set a name for the register, do not use local variable to pass the name, just pass a | ||
// direct string to the interface like examples. While Original Xbyak register name will always be | ||
// printed. | ||
// Example 1: | ||
// Invocation: RegPrinter::print<float>(*this, vmm_val, "vmm_val"); | ||
// Console: vmm_val | ymm0: {30, 20, 25, 29, 24, 31, 27, 23} | ||
// | ||
// Example 2: | ||
// Invocation: RegPrinter::print<float>(*this, vmm_val); | ||
// Console: ymm0: {30, 20, 25, 29, 24, 31, 27, 23} | ||
// | ||
// Example 3: | ||
// Invocation: RegPrinter::print<int>(*this, vmm_idx, "vmm_idx"); | ||
// Console: vmm_idx | ymm1: {5, 6, 0, 2, 0, 6, 6, 6} | ||
// | ||
// Example 4: | ||
// Invocation: RegPrinter::print<int>(*this, reg_work_amount, "reg_work_amount"); | ||
// Console: reg_work_amount | r13: 8 | ||
// | ||
// Example 5: | ||
// Invocation: RegPrinter::print<int>(*this, reg_work_amount); | ||
// Console: r13: 8 | ||
// | ||
// Example 6: | ||
// Invocation: RegPrinter::print<float>(*this, reg_tmp_64, "reg_tmp_64"); | ||
// Console: reg_tmp_64 | r15: 1 | ||
// | ||
// Parameter | ||
// The following combinations of Register types and precisions are supported. | ||
// fp32 int32 int8 u8 | ||
// Xmm Yes Yes No No | ||
// Ymm Yes Yes No No | ||
// Zmm Yes Yes No No | ||
// Reg64 Yes Yes No No | ||
// Reg32 Yes Yes No No | ||
// Reg16 No No Yes Yes | ||
// Reg8 No No Yes Yes | ||
|
||
class RegPrinter { | ||
public: | ||
using jit_generator = dnnl::impl::cpu::x64::jit_generator; | ||
template <typename PRC_T, typename REG_T, | ||
typename std::enable_if<std::is_base_of<Xbyak::Xmm, REG_T>::value, int>::type = 0> | ||
static void print(jit_generator &h, REG_T reg, const char *name = nullptr) { | ||
print_vmm<PRC_T, REG_T>(h, reg, name); | ||
} | ||
template <typename PRC_T, typename REG_T, | ||
typename std::enable_if<!std::is_base_of<Xbyak::Xmm, REG_T>::value, int>::type = 0> | ||
static void print(jit_generator &h, REG_T reg, const char *name = nullptr) { | ||
print_reg<PRC_T, REG_T>(h, reg, name); | ||
} | ||
|
||
private: | ||
RegPrinter() {} | ||
template <typename PRC_T, typename REG_T> | ||
static void print_vmm(jit_generator &h, REG_T vmm, const char *name); | ||
template <typename PRC_T, typename REG_T> | ||
static void print_reg(jit_generator &h, REG_T reg, const char *name); | ||
template <typename PRC_T, size_t vlen> | ||
static void print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr); | ||
template <typename T> | ||
static void print_reg_prc(const char *name, const char *ori_name, T *val); | ||
static void preamble(jit_generator &h); | ||
static void postamble(jit_generator &h); | ||
template <typename T> | ||
static void save_vmm(jit_generator &h); | ||
template <typename T> | ||
static void restore_vmm(jit_generator &h); | ||
static void save_reg(jit_generator &h); | ||
static void restore_reg(jit_generator &h); | ||
static void align_rsp(jit_generator &h); | ||
static void restore_rsp(jit_generator &h); | ||
static constexpr size_t reg_len = 8; | ||
static constexpr size_t reg_cnt = 16; | ||
}; | ||
|
||
} // namespace intel_cpu | ||
} // namespace ov | ||
|
||
#endif // CPU_DEBUG_CAPS |
Oops, something went wrong.