Skip to content

Commit

Permalink
[CPU] sns f16_mha_on_avx512_core_amx_f16_target (#27514)
Browse files Browse the repository at this point in the history
### Details:
 - *support f16 precision mha on GNR*

### Tickets:
 - *CVS-122494, CVS-122495*
  • Loading branch information
chenhu-wang authored Dec 17, 2024
1 parent b543d0b commit 8f0094d
Show file tree
Hide file tree
Showing 15 changed files with 201 additions and 61 deletions.
3 changes: 2 additions & 1 deletion src/common/snippets/src/op/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ ov::element::Type Brgemm::get_output_type(const ov::element::Type& in_type0, con
const bool is_f32 = utils::everyone_is(element::f32, in_type0, in_type1);
const bool is_int8 = utils::one_of(in_type0, element::i8, element::u8) && in_type1 == element::i8;
const bool is_bf16 = utils::everyone_is(element::bf16, in_type0, in_type1);
if (is_f32 || is_bf16) {
const bool is_f16 = utils::everyone_is(element::f16, in_type0, in_type1);
if (is_f32 || is_bf16 || is_f16) {
return element::f32;
} else if (is_int8) {
return element::i32;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@ class jit_brgemm_copy_b_emitter : public jit_emitter {
const ov::snippets::lowered::ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache);

size_t get_inputs_num() const override {
return 1;
}
static std::set<std::vector<element::Type>> get_supported_precisions(
const std::shared_ptr<ov::Node>& node = nullptr) {
return {{element::i8}, {element::bf16}, {element::f32}};
return {{element::i8}, {element::bf16}, {element::f16}, {element::f32}};
}

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precision
} else if (brgemm->get_type() == BRGEMM_TYPE::WITH_AMX) {
return {{element::i8, element::i8, element::u8},
{element::u8, element::i8, element::u8},
{element::bf16, element::bf16, element::u8}};
{element::bf16, element::bf16, element::u8},
{element::f16, element::f16, element::u8}};
}
OV_CPU_JIT_EMITTER_THROW("got BrgemmCPU node with unsupported type");
}
Expand Down
15 changes: 8 additions & 7 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,11 +458,12 @@ void Subgraph::initSupportedPrimitiveDescriptors() {
config.inConfs.resize(inputShapes.size());
for (size_t i = 0; i < inputShapes.size(); i++) {
const auto originalInputPrecision = getOriginalInputPrecisionAtPort(i);
const auto precision = ((originalInputPrecision == ov::element::f32) &&
context->getConfig().inferencePrecision == ov::element::bf16 &&
subgraph_attrs->snippet->has_domain_sensitive_ops())
? static_cast<ov::element::Type>(ov::element::bf16)
: originalInputPrecision;
const auto precision =
((originalInputPrecision == ov::element::f32) &&
one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) &&
subgraph_attrs->snippet->has_domain_sensitive_ops())
? context->getConfig().inferencePrecision
: originalInputPrecision;
if (supportedPrecisions.count(precision) == 0)
OPENVINO_THROW("Subgraph node with name `", getName(), "` doesn't support ", precision, " precision.");

Expand Down Expand Up @@ -653,7 +654,7 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {
ov::snippets::pass::Canonicalization,
ov::snippets::pass::AnalyzeBroadcastableInputs,
broadcastable_inputs);
if (context->getConfig().inferencePrecision == ov::element::bf16 &&
if (one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) &&
subgraph_attrs->snippet->has_domain_sensitive_ops()) {
// enforce BF16 precisions to supported operations
// MatMul has to be decomposed to Brgemm operations before enforcement
Expand All @@ -663,7 +664,7 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {
ov::snippets::pass::MatMulToBrgemm,
pass::EnforcePrecision,
element::f32,
element::bf16);
context->getConfig().inferencePrecision);
}
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before,
ov::snippets::pass::PropagatePrecision,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void BrgemmCopyB::validate_and_infer_types() {
}

void BrgemmCopyB::validate_element_type(const ov::element::Type& element_type) {
OPENVINO_ASSERT(one_of(element_type, element::f32, element::bf16, element::i8),
OPENVINO_ASSERT(one_of(element_type, element::f32, element::bf16, element::f16, element::i8),
"BrgemmCopyB doesn't support element type" + element_type.get_type_name());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx) {

// Note: AMX might be not used even if it's supported by the hardware, check the BrgemmToBrgemmCPU pass for details
if (is_with_amx) {
SUPPORT_ONE(avx512_core_amx, "Unsupported hardware configuration: amx is supported only on avx512 platforms")
if (dt_in0 == ov::element::f16)
SUPPORT_ONE(avx512_core_amx_fp16,
"Unsupported hardware configuration: amx is supported only on avx512 platforms")
else
SUPPORT_ONE(avx512_core_amx,
"Unsupported hardware configuration: amx is supported only on avx512 platforms")
} else if (dt_in0 == ov::element::bf16) {
SUPPORT_ONE(avx512_core_bf16, "Unsupported hardware configuration: bf16 is supported only on avx512 platforms")
} else if (one_of(dt_in0, ov::element::u8, ov::element::i8)) {
Expand All @@ -59,12 +64,15 @@ BRGEMM_TYPE get_brgemm_type(const ov::element::Type& element_type_a, bool transp
return transpose_b ? BRGEMM_TYPE::REPACKING_ONLY : BRGEMM_TYPE::STAND_ALONE;

OPENVINO_ASSERT(element_type_a != element::bf16 || mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16),
"BF16 precision is not supported on this hardware");
"BrgemmCPU BF16 precision is not supported on non avx512_core_bf16 system");
OPENVINO_ASSERT(element_type_a != element::f16 || mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16),
"BrgemmCPU FP16 precision is not supported on non avx512_core_amx_fp16 system");

if (one_of(element_type_a, element::u8, element::i8, element::bf16) &&
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx))
return BRGEMM_TYPE::WITH_AMX;

if (element_type_a == ov::element::f16 && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16))
return BRGEMM_TYPE::WITH_AMX;
// Note: this condition reproduces logic from the OneDNN Brgemm implementation. This is needed to align with the
// backend requirements. More details in onednn/src/cpu/x64/brgemm/brgemm_utils.cpp
if (element_type_a == ov::element::i8)
Expand Down Expand Up @@ -96,6 +104,8 @@ size_t compute_inner_n_block(const ov::element::Type& precision) {
return 64;
case element::bf16:
return 32;
case element::f16:
return 32;
case element::f32:
return 16;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ namespace intel_cpu {
namespace brgemm_utils {

enum class BRGEMM_TYPE {
STAND_ALONE, // No extra requirements, used for f32|f32
WITH_AMX, // i8|i8 or bf16|bf16 on AMX system - needs BrgemmCopyB and scratchpad
STAND_ALONE, // No extra requirements, used for f32|f32
WITH_AMX, // i8|i8 or bf16|bf16 on AMX system or fp16|fp16 on AMX_FP16 system - needs BrgemmCopyB and scratchpad
WITH_COMPENSATIONS, // i8|i8 (non-AMX system) - needs BrgemmCopyB for data repacking and compensations
REPACKING_ONLY, // u8|i8, or bf16|bf16 (non-AMX system), or brgemm with transpose_b=true - needs BrgemmCopyB on
// second input for data repacking
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace pass {
* \ Buffer (with repacked data) Buffer (with compensations)
* \ | /
* BrgemmCPU
* - f32|f32 with transpose_b, u8|i8, i8|i8 or bf16|bf16 on AMX system:
* - f32|f32 with transpose_b, u8|i8, i8|i8 or bf16|bf16 on AMX system or fp16|fp16 on AMX_FP16 system:
* \ BrgemmCopyB
* \ Buffer (with repacked data) Buffer (with new memory)
* \ | /
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,12 @@ bool EnforcePrecision::run_on_model(const std::shared_ptr<ov::Model>& f) {

std::set<std::vector<ov::element::Type>> EnforcePrecision::get_supported_precisions_default(
const std::shared_ptr<ov::Node>& op) noexcept {
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16) &&
ov::is_type<snippets::op::Brgemm>(op)) {
return {{element::bf16, element::bf16}};
std::set<std::vector<ov::element::Type>> types;
if (ov::is_type<snippets::op::Brgemm>(op)) {
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16))
types.insert({element::f16, element::f16});
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16))
types.insert({element::bf16, element::bf16});
}
return {};
return types;
}
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ void Transformations::PostLpt() {
postLPTPassManager,
[](const std::shared_ptr<const ov::Node>& node) -> bool {
if (!ov::is_type<const ov::op::v0::FakeQuantize>(node) &&
node->get_output_element_type(0) != node->get_input_element_type(0))
node->get_output_element_type(0).size() > node->get_input_element_type(0).size())
return true;
if (node->get_input_size() >= 2) {
return node->get_input_element_type(1) == ov::element::i8 ||
Expand Down Expand Up @@ -986,7 +986,7 @@ void Transformations::MainSnippets(void) {
// MatMul and Result. However there may be Convert [f32->bf16] before Result since:
// - bf16 Brgemm has f32 output;
// - CPU Node Subgraph requires bf16 on output when inference precision is bf16.
// To avoid sitations when Transpose is not alone node between MatMul and Result,
// To avoid situations when Transpose is not alone node between MatMul and Result,
// Plugin disables Transpose tokenization on output
bool mha_token_enable_transpose_on_output = one_of(config.inferencePrecision, element::f32, element::undefined);
size_t concurrency = config.streamExecutorConfig.get_threads_per_stream();
Expand Down Expand Up @@ -1023,6 +1023,7 @@ void Transformations::MainSnippets(void) {

ov::pass::Manager snippetsManager("CPU:Snippets");
snippetsManager.set_per_pass_validation(false);
// if callback needed for better perf, enable SnippetsMarkSkipped, and disable TokenizeFCSnippets.
if (!ignoreCallback) {
#if defined(OPENVINO_ARCH_ARM64)
CPU_REGISTER_PASS_ARM(snippetsManager, SnippetsMarkSkipped);
Expand All @@ -1033,17 +1034,17 @@ void Transformations::MainSnippets(void) {
}
CPU_REGISTER_PASS_COMMON(snippetsManager, snippets::pass::SnippetsTokenization, tokenization_config);

// - MHA has BRGEMM that is supported only on AVX512 platforms
// - CPU Plugin Subgraph supports only f32, bf16 (and quantized) BRGEMM
// [122494] Need to add support of f16
// - CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM
const bool isMHASupported =
#if defined(OPENVINO_ARCH_ARM64)
false;
#else
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) &&
one_of(config.inferencePrecision, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) &&
one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined));
one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) &&
one_of(config.inferencePrecision, ov::element::f16));
#endif
if (!isMHASupported) {
CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::TokenizeMHASnippets);
Expand All @@ -1059,13 +1060,13 @@ void Transformations::MainSnippets(void) {
const auto in_type1 = matmul->get_input_element_type(1);
const auto is_fp32 = (in_type0 == ov::element::f32 && in_type1 == ov::element::f32 &&
one_of(config.inferencePrecision, element::f32, element::undefined));
const auto is_fp16 = (in_type0 == ov::element::f16 || in_type1 == ov::element::f16);
const auto is_fp16 =
(in_type0 == ov::element::f16 || in_type1 == ov::element::f16) ||
(in_type0 == element::f32 && in_type1 == ov::element::f32 && config.inferencePrecision == ov::element::f16);
const auto is_bf16 = (in_type0 == ov::element::bf16 && in_type1 == ov::element::bf16) ||
((in_type0 == element::f32 && in_type1 == ov::element::f32 &&
config.inferencePrecision == ov::element::bf16));
const auto is_int8 = in_type0 == ov::element::i8;
if (is_fp16)
return false;
if (is_fp32)
return true;
// Only FP32 dynamic MHA is supported
Expand All @@ -1076,13 +1077,14 @@ void Transformations::MainSnippets(void) {
// brgemm_copy_b kernel
if (matmul->get_transpose_a() || matmul->get_transpose_b())
return false;
// [150842] The execution of Brgemm INT8/BF16 on AMX platforms depends on the value of "K % VNNIFactor".
// [150842] The execution of Brgemm INT8/BF16/FP16 on AMX platforms depends on the value of "K % VNNIFactor".
// For more details, please teake a look at the ticket 150842
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx)) {
const auto& b_shape = matmul->get_input_partial_shape(1);
const auto K = matmul->get_transpose_b() ? *b_shape.rbegin() : *++b_shape.rbegin();
if (is_bf16)
return K.is_static() && (K.get_length() % 2 == 0);
const size_t brgemm_vnni_factor_for_real16 = 2; // 4/2(size in term of byte for bf16/fp16)
if (is_bf16 || is_fp16)
return K.is_static() && (K.get_length() % brgemm_vnni_factor_for_real16 == 0);
if (is_int8)
return K.is_static();
}
Expand All @@ -1091,6 +1093,8 @@ void Transformations::MainSnippets(void) {
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni);
if (is_bf16)
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16);
if (is_fp16)
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16);
return true;
};
auto is_unsupported_parallel_work_amount = [&](const std::shared_ptr<const ov::Node>& n,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class MHATest : public testing::WithParamInterface<MHATuple>, virtual public Sub
for (size_t i = 0; i < funcInputs.size(); ++i) {
const auto& funcInput = funcInputs[i];
ov::Tensor tensor;
if (funcInput.get_element_type() == ov::element::bf16) {
if (funcInput.get_element_type() == ov::element::bf16 || funcInput.get_element_type() == ov::element::f16) {
ov::test::utils::InputGenerateData in_data;
in_data.start_from = -1;
in_data.range = 2;
Expand Down Expand Up @@ -232,6 +232,9 @@ class MHATest : public testing::WithParamInterface<MHATuple>, virtual public Sub
configuration.insert({ov::hint::inference_precision(ov::element::bf16)});
}

if (inputPrecisions[0] == ElementType::f16)
configuration.insert({ov::hint::inference_precision(ov::element::f16)});

// Snippets MHA tokenization has limitations to avoid performance degradations. These limitations depend on
// target machine. Just for testing, we disable these limitations to allow Snippets to tokenize pattern on all
// machines for validation.
Expand All @@ -253,6 +256,9 @@ TEST_P(MHATest, CompareWithRefs) {
if (inputPrecisions[0] == ElementType::bf16 && !ov::with_cpu_x86_bfloat16())
GTEST_SKIP();

if (inputPrecisions[0] == ElementType::f16 && !ov::with_cpu_x86_avx512_core_amx_fp16())
GTEST_SKIP();

if (!ov::with_cpu_x86_avx512_core())
GTEST_SKIP();

Expand Down Expand Up @@ -308,6 +314,20 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(ov::test::utils::DEVICE_CPU)),
MHATest::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(
smoke_MHA_FP16,
MHATest,
::testing::Combine(
::testing::ValuesIn(static_shapes_to_test_representation(inputShapes)),
::testing::Values(
std::vector<ElementType>{ElementType::f16, ElementType::f16, ElementType::f16, ElementType::f16}),
::testing::ValuesIn(matMulIn0Precisions),
::testing::ValuesIn(patternTypes),
::testing::Values(ExpectedNodes{{"Subgraph", 1},
{"Transpose", 1}}), // Plugin disables tokenization of Transpose on output
::testing::Values(ov::test::utils::DEVICE_CPU)),
MHATest::getTestCaseName);

} // namespace

static std::shared_ptr<ov::Model> initMHAQuantSubgraph0(std::vector<ov::PartialShape>& inputDynamicShapes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,11 @@ std::vector<std::string> disabledTestPatterns() {
retVector.emplace_back(R"(.*smoke_Snippets_MHA.*EnforceBF16.*)");
retVector.emplace_back(R"(.*ConcatSDPTest.*bf16.*)");
}
// MHA FP16 precision is only supported on amx_fp16 platform
if (!ov::with_cpu_x86_avx512_core_amx_fp16()) {
retVector.emplace_back(R"(.*smoke_Snippets_MHA.*FP16.*)");
}

#ifdef SNIPPETS_LIBXSMM_TPP
// GN in TPP requires exposing tmp Buffer results outside the loop (ticket: 151234)
retVector.emplace_back(R"(.*smoke_Snippets_GroupNormalization.*)");
Expand Down
Loading

0 comments on commit 8f0094d

Please sign in to comment.