Skip to content

Commit

Permalink
[CPU] Correct type configuration for i8 inner_product with f16 output (
Browse files Browse the repository at this point in the history
…openvinotoolkit#23610)

### Tickets:
 - 136298
 - 136163
  • Loading branch information
EgorDuplensky authored and bbielawx committed Apr 12, 2024
1 parent 3690d43 commit 4d60602
Showing 1 changed file with 17 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "nodes/executors/mlas/mlas_gemm.hpp"
#include "nodes/executors/precision_matcher.hpp"
#include "nodes/executors/precision_translation.hpp"
#include "nodes/executors/type_mask.hpp"
#include "openvino/core/type/element_type.hpp"
#include "ov_optional.hpp"
#include "utils/cpp/maybe_unused.hpp"
Expand All @@ -39,21 +40,23 @@ static const LayoutConfig dnnlFCLayoutConfig{LayoutType::ncsp, LayoutType::ncsp,

// clang-format off
static const TypeMapping dnnlFCTypeMapping {
// {src, wei, bia, dst} pt<src, wei, bias, dst>
{{_bf16, _bf16 | _f32, _any, _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), use<3>())},
{{_f16, _f16, _any, _f16 | _f32}, pt(bypass(), bypass(), use<3>(), use<3>())},
// {src, wei, bia, dst} pt<src, wei, bias, dst>
{{_bf16, _bf16 | _f32, _any, _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), bypass())},
{{_f16, _f16, _any, _f16 | _f32}, pt(bypass(), bypass(), use<3>(), bypass())},
// integer precision outputs are not supported for float precision inputs
{{_f32 | _bf16 | _f16, _any, _any, _i8 | _u8}, pt(bypass(), bypass(), use<0>(), use<0>())},
{{_f32 | _bf16 | _f16, _any, _any, _i8 | _u8}, pt(bypass(), bypass(), use<0>(), use<0>())},
// compresses float weights which do not match input data precision
{{_f32, _half_float, _any, _any | _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
{{_bf16, _f16, _any, _any | _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
{{_f16, _bf16, _any, _any | _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
// quantization configuration (@todo more strict requrements for output precision?)
{{_u8 | _i8, _i8, _any, _any}, pt(bypass(), bypass(), bypass(), use<3>())},
{{_f32, _half_float, _any, _any | _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
{{_bf16, _f16, _any, _any | _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
{{_f16, _bf16, _any, _any | _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
// quantization configuration
// int8 inner_product does not support f16 output and bias
{{_u8 | _i8, _i8, _any, _f16}, pt(bypass(), bypass(), just<f32>(), just<f32>())},
{{_u8 | _i8, _i8, _any, _u8 | _i8 | _i32 | _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), bypass())},
// compresses int weights (@todo more strict requrements for output precision?)
{{_f32 | _bf16, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
{{_f32 | _bf16, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
// @todo should we fallback to FPXX instead of _f32?
{{_any, _any, _any, _any}, pt(just<f32>(), just<f32>(), just<f32>(), just<f32>())},
{{_any, _any, _any, _any}, pt(just<f32>(), just<f32>(), just<f32>(), just<f32>())},
// @todo explicitly cover configuration limitations for oneDNN on ARM
};

Expand All @@ -63,16 +66,16 @@ static const MappingNotation dnnlConvolutionMappingNotation {

static const TypeMapping dnnlConvolutionTypeMapping {
// {src, wei, bia, dst} pt<src, wei, bias, dst>
{{_bf16, _bf16 | _f32, _any, _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), use<3>())},
{{_f16, _f16, _any, _f16 | _f32}, pt(bypass(), bypass(), use<3>(), use<3>())},
{{_bf16, _bf16 | _f32, _any, _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), bypass())},
{{_f16, _f16, _any, _f16 | _f32}, pt(bypass(), bypass(), use<3>(), bypass())},
// integer precision outputs are not supported for float precision inputs
{{_f32 | _bf16 | _f16, _any, _any, _i8 | _u8}, pt(bypass(), bypass(), use<0>(), use<0>())},
// compresses float weights which do not match input data precision
{{_f32, _half_float, _any, _any | _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
{{_bf16, _f16, _any, _any | _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
{{_f16, _bf16, _any, _any | _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
// quantization configuration
{{_u8 | _i8, _i8, _any, _any}, pt(bypass(), bypass(), use<3>(), use<3>())},
{{_u8 | _i8, _i8, _any, _any}, pt(bypass(), bypass(), use<3>(), bypass())},
// @todo should we fallback to _fxx instead of _f32 (currenly legacy logic is replicated)
{{_any, _any, _any, _any}, pt(just<f32>(), just<f32>(), just<f32>(), just<f32>())},
};
Expand Down

0 comments on commit 4d60602

Please sign in to comment.