From 6b72cb2bae6fd0c19486a4f84bdcb92cd2c36f94 Mon Sep 17 00:00:00 2001 From: chenhu-wang Date: Mon, 16 Dec 2024 00:55:24 +0800 Subject: [PATCH] code style apply --- .../x64/jit_brgemm_copy_b_emitter.hpp | 3 +- src/plugins/intel_cpu/src/nodes/subgraph.cpp | 20 ++++--- .../snippets/x64/op/brgemm_utils.cpp | 24 +++++---- .../snippets/x64/op/brgemm_utils.hpp | 9 ++-- .../snippets/x64/pass/enforce_precision.cpp | 2 +- .../transformation_pipeline.cpp | 23 ++++---- .../shared_tests_instances/snippets/mha.cpp | 53 ++++++++++--------- 7 files changed, 78 insertions(+), 56 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp index 0b0bce550f0f27..8c8bae1f9a2187 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp @@ -18,7 +18,8 @@ class jit_brgemm_copy_b_emitter : public jit_emitter { const snippets::KernelExecutorTablePtr& kernel_table, const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache); size_t get_inputs_num() const override {return 1;} - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr) { + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr) { return {{element::i8}, {element::bf16}, {element::f16}, {element::f32}}; } diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index 4afcc2f32a407a..41b65c7602b76c 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -458,10 +458,12 @@ void Subgraph::initSupportedPrimitiveDescriptors() { config.inConfs.resize(inputShapes.size()); for (size_t i = 0; i < inputShapes.size(); i++) { const auto originalInputPrecision = getOriginalInputPrecisionAtPort(i); - const auto precision = ((originalInputPrecision == ov::element::f32) && - one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && - subgraph_attrs->snippet->has_domain_sensitive_ops()) ? - context->getConfig().inferencePrecision : originalInputPrecision; + const auto precision = + ((originalInputPrecision == ov::element::f32) && + one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && + subgraph_attrs->snippet->has_domain_sensitive_ops()) + ? context->getConfig().inferencePrecision + : originalInputPrecision; if (supportedPrecisions.count(precision) == 0) OPENVINO_THROW("Subgraph node with name `", getName(), "` doesn't support ", precision, " precision."); @@ -650,13 +652,17 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() { SNIPPETS_REGISTER_PASS_ABSOLUTE_COMMON(Place::PipelineStart, ConvertToSwishCPU); SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(Place::After, ov::snippets::pass::Canonicalization, ov::snippets::pass::AnalyzeBroadcastableInputs, broadcastable_inputs); - if (one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && subgraph_attrs->snippet->has_domain_sensitive_ops()) { + if (one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && + subgraph_attrs->snippet->has_domain_sensitive_ops()) { // enforce BF16 precisions to supported operations // MatMul has to be decomposed to Brgemm operations before enforcement // Note, MatMul decomposition will be run later again for case if BF16 enforcement is not happened SNIPPETS_REGISTER_PASS_ABSOLUTE_X86_64(Place::PipelineStart, ov::snippets::pass::MatMulToBrgemm); - SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, ov::snippets::pass::MatMulToBrgemm, - pass::EnforcePrecision, element::f32, context->getConfig().inferencePrecision); + SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, + ov::snippets::pass::MatMulToBrgemm, + pass::EnforcePrecision, + element::f32, + context->getConfig().inferencePrecision); } SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before, ov::snippets::pass::PropagatePrecision, diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp index c2f3874bbade74..386941fd94bb98 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp @@ -36,9 +36,11 @@ cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx) { // Note: AMX might be not used even if it's supported by the hardware, check the BrgemmToBrgemmCPU pass for details if (is_with_amx) { if (dt_in0 == ov::element::f16) - SUPPORT_ONE(avx512_core_amx_fp16, "Unsupported hardware configuration: amx is supported only on avx512 platforms") + SUPPORT_ONE(avx512_core_amx_fp16, + "Unsupported hardware configuration: amx is supported only on avx512 platforms") else - SUPPORT_ONE(avx512_core_amx, "Unsupported hardware configuration: amx is supported only on avx512 platforms") + SUPPORT_ONE(avx512_core_amx, + "Unsupported hardware configuration: amx is supported only on avx512 platforms") } else if (dt_in0 == ov::element::bf16) { SUPPORT_ONE(avx512_core_bf16, "Unsupported hardware configuration: bf16 is supported only on avx512 platforms") } else if (one_of(dt_in0, ov::element::u8, ov::element::i8)) { @@ -69,8 +71,7 @@ BRGEMM_TYPE get_brgemm_type(const ov::element::Type& element_type_a, bool transp if (one_of(element_type_a, element::u8, element::i8, element::bf16) && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx)) return BRGEMM_TYPE::WITH_AMX; - if (element_type_a == ov::element::f16 && - dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16)) + if (element_type_a == ov::element::f16 && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16)) return BRGEMM_TYPE::WITH_AMX; // Note: this condition reproduces logic from the OneDNN Brgemm implementation. This is needed to align with the // backend requirements. More details in onednn/src/cpu/x64/brgemm/brgemm_utils.cpp @@ -99,11 +100,16 @@ size_t get_elems_in_vec(const ov::element::Type& precision) { namespace repacking { size_t compute_inner_n_block(const ov::element::Type& precision) { switch (precision) { - case element::i8: return 64; - case element::bf16: return 32; - case element::f16: return 32; - case element::f32: return 16; - default: OPENVINO_THROW("BrgemmCopyB doesn't support precision ", precision); + case element::i8: + return 64; + case element::bf16: + return 32; + case element::f16: + return 32; + case element::f32: + return 16; + default: + OPENVINO_THROW("BrgemmCopyB doesn't support precision ", precision); } } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp index 21ad288d173a39..6a7794d1cda454 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp @@ -15,10 +15,11 @@ namespace intel_cpu { namespace brgemm_utils { enum class BRGEMM_TYPE { - STAND_ALONE, // No extra requirements, used for f32|f32 - WITH_AMX, // i8|i8 or bf16|bf16 on AMX system or fp16|fp16 on AMX_FP16 system - needs BrgemmCopyB and scratchpad - WITH_COMPENSATIONS, // i8|i8 (non-AMX system) - needs BrgemmCopyB for data repacking and compensations - REPACKING_ONLY, // u8|i8, or bf16|bf16 (non-AMX system), or brgemm with transpose_b=true - needs BrgemmCopyB on second input for data repacking + STAND_ALONE, // No extra requirements, used for f32|f32 + WITH_AMX, // i8|i8 or bf16|bf16 on AMX system or fp16|fp16 on AMX_FP16 system - needs BrgemmCopyB and scratchpad + WITH_COMPENSATIONS, // i8|i8 (non-AMX system) - needs BrgemmCopyB for data repacking and compensations + REPACKING_ONLY, // u8|i8, or bf16|bf16 (non-AMX system), or brgemm with transpose_b=true - needs BrgemmCopyB on + // second input for data repacking }; dnnl::impl::cpu::x64::cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp index ebd7bbe3339fb9..6b7d5d31a5b12f 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp @@ -120,7 +120,7 @@ bool EnforcePrecision::run_on_model(const std::shared_ptr& f) { } std::set> EnforcePrecision::get_supported_precisions_default( - const std::shared_ptr&op) noexcept { + const std::shared_ptr& op) noexcept { std::set> types; if (ov::is_type(op)) { if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16)) diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 79729183e162db..4013c1c3cd84f9 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -1039,12 +1039,12 @@ void Transformations::MainSnippets(void) { #if defined(OPENVINO_ARCH_ARM64) false; #else - (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) && - one_of(config.inferencePrecision, ov::element::f32, element::undefined)) || - (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) && - one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined)) || - (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) && - one_of(config.inferencePrecision, ov::element::f16)); + (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) && + one_of(config.inferencePrecision, ov::element::f32, element::undefined)) || + (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) && + one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined)) || + (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) && + one_of(config.inferencePrecision, ov::element::f16)); #endif if (!isMHASupported) { CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::TokenizeMHASnippets); @@ -1060,8 +1060,9 @@ void Transformations::MainSnippets(void) { const auto in_type1 = matmul->get_input_element_type(1); const auto is_fp32 = (in_type0 == ov::element::f32 && in_type1 == ov::element::f32 && one_of(config.inferencePrecision, element::f32, element::undefined)); - const auto is_fp16 = (in_type0 == ov::element::f16 || in_type1 == ov::element::f16) || - (in_type0 == element::f32 && in_type1 == ov::element::f32 && config.inferencePrecision == ov::element::f16); + const auto is_fp16 = + (in_type0 == ov::element::f16 || in_type1 == ov::element::f16) || + (in_type0 == element::f32 && in_type1 == ov::element::f32 && config.inferencePrecision == ov::element::f16); const auto is_bf16 = (in_type0 == ov::element::bf16 && in_type1 == ov::element::bf16) || ((in_type0 == element::f32 && in_type1 == ov::element::f32 && config.inferencePrecision == ov::element::bf16)); @@ -1082,8 +1083,10 @@ void Transformations::MainSnippets(void) { const auto& b_shape = matmul->get_input_partial_shape(1); const auto K = matmul->get_transpose_b() ? *b_shape.rbegin() : *++b_shape.rbegin(); const size_t brgemm_vnni_factor_for_real16 = 2; // 4/2(size in term of byte for bf16/fp16) - if (is_bf16 || is_fp16) return K.is_static() && (K.get_length() % brgemm_vnni_factor_for_real16 == 0); - if (is_int8) return K.is_static(); + if (is_bf16 || is_fp16) + return K.is_static() && (K.get_length() % brgemm_vnni_factor_for_real16 == 0); + if (is_int8) + return K.is_static(); } if (is_int8) return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_vnni) || diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index ced702d31ca3cc..df0b69f99ef06d 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -15,33 +15,38 @@ namespace { std::vector> transposedShape_4D(bool with_static = true, bool with_dynamic = true) { std::vector> shapes; if (with_static) { - auto static_shapes = SNIPPETS_TESTS_STATIC_SHAPES( - {{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}, - {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 1, 1}, {1, 128, 16, 64}}, - {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}}, - {{2, 68, 6, 92}, {2, 68, 6, 92}, {1, 1, 68, 68}, {2, 68, 6, 92}}, - {{1, 58, 16, 34}, {1, 58, 16, 34}, {1, 1, 1, 58}, {1, 58, 16, 34}}); + auto static_shapes = + SNIPPETS_TESTS_STATIC_SHAPES({{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}, + {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 1, 1}, {1, 128, 16, 64}}, + {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}}, + {{2, 68, 6, 92}, {2, 68, 6, 92}, {1, 1, 68, 68}, {2, 68, 6, 92}}, + {{1, 58, 16, 34}, {1, 58, 16, 34}, {1, 1, 1, 58}, {1, 58, 16, 34}}); shapes.insert(shapes.end(), static_shapes.begin(), static_shapes.end()); } if (with_dynamic) { - std::vector> dynamic_shapes = {{ - {PartialShape{-1, -1, -1, 100}, {{1, 64, 4, 100}, {2, 16, 2, 100}, {1, 72, 4, 100}}}, - {PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}}, - {PartialShape{-1, -1, -1, 128}, {{1, 4, 64, 128}, {2, 2, 16, 128}, {1, 4, 72, 128}}}, - {PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}}, - }, - { - {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 16, 2, 100}, {1, 128, 3, 64}}}, - {PartialShape{-1, -1, -1, -1}, {{1, 128, 1, 64}, {2, 128, 2, 100}, {1, 128, 1, 64}}}, - {PartialShape{-1, -1, -1, -1}, {{2, 1, 128, 128}, {2, 2, 16, 128}, {2, 1, 128, 128}}}, - {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 128, 2, 100}, {1, 128, 3, 64}}}, - }, - { - {PartialShape{-1, -1, 12, 64}, {{1, 70, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 70, 12, 64}}}, - {PartialShape{-1, -1, 12, 64}, {{1, 35, 12, 64}, {2, 10, 12, 64}, {2, 1, 12, 64}, {2, 10, 12, 64}, {1, 35, 12, 64}}}, - {PartialShape{-1, 12, -1, -1}, {{2, 12, 70, 35}, {1, 12, 20, 10}, {1, 12, 20, 10}, {1, 12, 20, 1}, {2, 12, 70, 35}}}, - {PartialShape{-1, -1, 12, 64}, {{1, 35, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 35, 12, 64}}}, - }}; + std::vector> dynamic_shapes = { + { + {PartialShape{-1, -1, -1, 100}, {{1, 64, 4, 100}, {2, 16, 2, 100}, {1, 72, 4, 100}}}, + {PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}}, + {PartialShape{-1, -1, -1, 128}, {{1, 4, 64, 128}, {2, 2, 16, 128}, {1, 4, 72, 128}}}, + {PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}}, + }, + { + {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 16, 2, 100}, {1, 128, 3, 64}}}, + {PartialShape{-1, -1, -1, -1}, {{1, 128, 1, 64}, {2, 128, 2, 100}, {1, 128, 1, 64}}}, + {PartialShape{-1, -1, -1, -1}, {{2, 1, 128, 128}, {2, 2, 16, 128}, {2, 1, 128, 128}}}, + {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 128, 2, 100}, {1, 128, 3, 64}}}, + }, + { + {PartialShape{-1, -1, 12, 64}, + {{1, 70, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 70, 12, 64}}}, + {PartialShape{-1, -1, 12, 64}, + {{1, 35, 12, 64}, {2, 10, 12, 64}, {2, 1, 12, 64}, {2, 10, 12, 64}, {1, 35, 12, 64}}}, + {PartialShape{-1, 12, -1, -1}, + {{2, 12, 70, 35}, {1, 12, 20, 10}, {1, 12, 20, 10}, {1, 12, 20, 1}, {2, 12, 70, 35}}}, + {PartialShape{-1, -1, 12, 64}, + {{1, 35, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 35, 12, 64}}}, + }}; shapes.insert(shapes.end(), dynamic_shapes.begin(), dynamic_shapes.end()); } return shapes;