Skip to content

Commit

Permalink
[CPU] Prohibit fc avx2_vnni_2 decompression for bf16 input (openvinot…
Browse files Browse the repository at this point in the history
…oolkit#23638)

### Details:
- The FC changes made in scope of openvinotoolkit#20486 were missed when rebasing
openvinotoolkit#20718
- The context is: Even the system and the node does support bf16
precision we have to fall back to f32 in/out precision
due to lack of support for decompression with bf16 avx2_vnni_2 in oneDNN
fork.
- To cover this limitation an additional type mapping parameter in form
of std::function was introduced for disabling particular type mapping
entry using a runtime check (isa support in this case)

### Tickets:
 - 122347
 - 136163
  • Loading branch information
EgorDuplensky authored Mar 25, 2024
1 parent 36d9360 commit c3c409e
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ static const MappingNotation dnnlFCMappingNotation{ARG_SRC, ARG_WEI, ARG_BIAS, A
using LayoutConfig = std::vector<LayoutType>;
static const LayoutConfig dnnlFCLayoutConfig{LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp};

template<dnnl::impl::cpu::x64::cpu_isa_t ISA>
struct Require {
bool operator()() {
return dnnl::impl::cpu::x64::mayiuse(ISA);
}
};

// clang-format off
static const TypeMapping dnnlFCTypeMapping {
// {src, wei, bia, dst} pt<src, wei, bias, dst>
Expand All @@ -54,7 +61,10 @@ static const TypeMapping dnnlFCTypeMapping {
{{_u8 | _i8, _i8, _any, _f16}, pt(bypass(), bypass(), just<f32>(), just<f32>())},
{{_u8 | _i8, _i8, _any, _u8 | _i8 | _i32 | _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), bypass())},
// compresses int weights (@todo more strict requrements for output precision?)
{{_f32 | _bf16, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
{{_bf16, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>()),
Require<dnnl::impl::cpu::x64::avx512_core_bf16>()}, // Ticket 122347
{{_bf16, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(just<f32>(), bypass(), just<f32>(), just<f32>())},
{{_f32, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
// @todo should we fallback to FPXX instead of _f32?
{{_any, _any, _any, _any}, pt(just<f32>(), just<f32>(), just<f32>(), just<f32>())},
// @todo explicitly cover configuration limitations for oneDNN on ARM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ InOutTypes getTypeConfiguration(const MemoryDescArgs& descriptors, const TypeMap
});

for (const auto& entry : mapping) {
const auto& pattern = entry.first;
if (!entry.enabled())
continue;

const auto& pattern = entry.mask();
if (!match(pattern, types))
continue;

const auto& translator = entry.second;
return translator(types);
return entry.translate(types);
}

OPENVINO_THROW("Failed to create a type configuration for the provided memory descriptors");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include <cassert>
#include <functional>
#include <utility>
#include <vector>

#include "nodes/executors/memory_arguments.hpp"
Expand Down Expand Up @@ -82,9 +81,43 @@ struct PortsTranslation {
// pros: should be more efficient and safe
// cons: more template instances (binary size) of the translation utility functions
using InOutTypes = std::vector<ov::element::Type>;
using PortsConfigurationImpl = std::function<InOutTypes(const InOutTypes&)>;
using TypeTranslationFunction = std::function<InOutTypes(const InOutTypes&)>;
using InOutTypeMask = std::vector<TypeMask>;
using TypeMapping = std::vector<std::pair<InOutTypeMask, PortsConfigurationImpl>>;

class TypeMappingEntry {
public:
using EnabledPredicate = std::function<bool(void)>;

TypeMappingEntry(InOutTypeMask mask,
TypeTranslationFunction translation,
EnabledPredicate enabled = {})
: m_mask(std::move(mask)),
m_translation(std::move(translation)),
m_enabled(std::move(enabled)) {}

const InOutTypeMask& mask() const {
return m_mask;
}

InOutTypes translate(const InOutTypes& types) const {
if (m_translation)
return m_translation(types);
return {};
}

bool enabled() const {
if (m_enabled)
return m_enabled();
return true;
}

private:
InOutTypeMask m_mask;
TypeTranslationFunction m_translation;
EnabledPredicate m_enabled;
};

using TypeMapping = std::vector<TypeMappingEntry>;
using MappingNotation = std::vector<int>;
using pt = PortsTranslation;

Expand Down

0 comments on commit c3c409e

Please sign in to comment.