From 4d60602c91da506efdb510c5162e9f6536d9d7e5 Mon Sep 17 00:00:00 2001 From: Egor Duplenskii Date: Mon, 25 Mar 2024 10:35:11 +0100 Subject: [PATCH] [CPU] Correct type configuration for i8 inner_product with f16 output (#23610) ### Tickets: - 136298 - 136163 --- .../fullyconnected_implementations.cpp | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp b/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp index 0c66c37394ab52..eea989656e49b6 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp @@ -21,6 +21,7 @@ #include "nodes/executors/mlas/mlas_gemm.hpp" #include "nodes/executors/precision_matcher.hpp" #include "nodes/executors/precision_translation.hpp" +#include "nodes/executors/type_mask.hpp" #include "openvino/core/type/element_type.hpp" #include "ov_optional.hpp" #include "utils/cpp/maybe_unused.hpp" @@ -39,21 +40,23 @@ static const LayoutConfig dnnlFCLayoutConfig{LayoutType::ncsp, LayoutType::ncsp, // clang-format off static const TypeMapping dnnlFCTypeMapping { - // {src, wei, bia, dst} pt - {{_bf16, _bf16 | _f32, _any, _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), use<3>())}, - {{_f16, _f16, _any, _f16 | _f32}, pt(bypass(), bypass(), use<3>(), use<3>())}, + // {src, wei, bia, dst} pt + {{_bf16, _bf16 | _f32, _any, _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), bypass())}, + {{_f16, _f16, _any, _f16 | _f32}, pt(bypass(), bypass(), use<3>(), bypass())}, // integer precision outputs are not supported for float precision inputs - {{_f32 | _bf16 | _f16, _any, _any, _i8 | _u8}, pt(bypass(), bypass(), use<0>(), use<0>())}, + {{_f32 | _bf16 | _f16, _any, _any, _i8 | _u8}, pt(bypass(), bypass(), use<0>(), use<0>())}, // compresses float weights which do not match input data precision - {{_f32, _half_float, _any, _any | _any}, pt(bypass(), bypass(), use<0>(), use<0>())}, - {{_bf16, _f16, _any, _any | _any}, pt(bypass(), bypass(), use<0>(), use<0>())}, - {{_f16, _bf16, _any, _any | _any}, pt(bypass(), bypass(), use<0>(), use<0>())}, - // quantization configuration (@todo more strict requrements for output precision?) - {{_u8 | _i8, _i8, _any, _any}, pt(bypass(), bypass(), bypass(), use<3>())}, + {{_f32, _half_float, _any, _any | _any}, pt(bypass(), bypass(), use<0>(), use<0>())}, + {{_bf16, _f16, _any, _any | _any}, pt(bypass(), bypass(), use<0>(), use<0>())}, + {{_f16, _bf16, _any, _any | _any}, pt(bypass(), bypass(), use<0>(), use<0>())}, + // quantization configuration + // int8 inner_product does not support f16 output and bias + {{_u8 | _i8, _i8, _any, _f16}, pt(bypass(), bypass(), just(), just())}, + {{_u8 | _i8, _i8, _any, _u8 | _i8 | _i32 | _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), bypass())}, // compresses int weights (@todo more strict requrements for output precision?) - {{_f32 | _bf16, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>())}, + {{_f32 | _bf16, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>())}, // @todo should we fallback to FPXX instead of _f32? - {{_any, _any, _any, _any}, pt(just(), just(), just(), just())}, + {{_any, _any, _any, _any}, pt(just(), just(), just(), just())}, // @todo explicitly cover configuration limitations for oneDNN on ARM }; @@ -63,8 +66,8 @@ static const MappingNotation dnnlConvolutionMappingNotation { static const TypeMapping dnnlConvolutionTypeMapping { // {src, wei, bia, dst} pt - {{_bf16, _bf16 | _f32, _any, _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), use<3>())}, - {{_f16, _f16, _any, _f16 | _f32}, pt(bypass(), bypass(), use<3>(), use<3>())}, + {{_bf16, _bf16 | _f32, _any, _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), bypass())}, + {{_f16, _f16, _any, _f16 | _f32}, pt(bypass(), bypass(), use<3>(), bypass())}, // integer precision outputs are not supported for float precision inputs {{_f32 | _bf16 | _f16, _any, _any, _i8 | _u8}, pt(bypass(), bypass(), use<0>(), use<0>())}, // compresses float weights which do not match input data precision @@ -72,7 +75,7 @@ static const TypeMapping dnnlConvolutionTypeMapping { {{_bf16, _f16, _any, _any | _any}, pt(bypass(), bypass(), use<0>(), use<0>())}, {{_f16, _bf16, _any, _any | _any}, pt(bypass(), bypass(), use<0>(), use<0>())}, // quantization configuration - {{_u8 | _i8, _i8, _any, _any}, pt(bypass(), bypass(), use<3>(), use<3>())}, + {{_u8 | _i8, _i8, _any, _any}, pt(bypass(), bypass(), use<3>(), bypass())}, // @todo should we fallback to _fxx instead of _f32 (currenly legacy logic is replicated) {{_any, _any, _any, _any}, pt(just(), just(), just(), just())}, };