Skip to content

Commit

Permalink
code style apply
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Dec 15, 2024
1 parent 1503bc4 commit 6b72cb2
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class jit_brgemm_copy_b_emitter : public jit_emitter {
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache);
size_t get_inputs_num() const override {return 1;}
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr) {
static std::set<std::vector<element::Type>> get_supported_precisions(
const std::shared_ptr<ov::Node>& node = nullptr) {
return {{element::i8}, {element::bf16}, {element::f16}, {element::f32}};
}

Expand Down
20 changes: 13 additions & 7 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,12 @@ void Subgraph::initSupportedPrimitiveDescriptors() {
config.inConfs.resize(inputShapes.size());
for (size_t i = 0; i < inputShapes.size(); i++) {
const auto originalInputPrecision = getOriginalInputPrecisionAtPort(i);
const auto precision = ((originalInputPrecision == ov::element::f32) &&
one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) &&
subgraph_attrs->snippet->has_domain_sensitive_ops()) ?
context->getConfig().inferencePrecision : originalInputPrecision;
const auto precision =
((originalInputPrecision == ov::element::f32) &&
one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) &&
subgraph_attrs->snippet->has_domain_sensitive_ops())
? context->getConfig().inferencePrecision
: originalInputPrecision;
if (supportedPrecisions.count(precision) == 0)
OPENVINO_THROW("Subgraph node with name `", getName(), "` doesn't support ", precision, " precision.");

Expand Down Expand Up @@ -650,13 +652,17 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {
SNIPPETS_REGISTER_PASS_ABSOLUTE_COMMON(Place::PipelineStart, ConvertToSwishCPU);
SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(Place::After, ov::snippets::pass::Canonicalization,
ov::snippets::pass::AnalyzeBroadcastableInputs, broadcastable_inputs);
if (one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && subgraph_attrs->snippet->has_domain_sensitive_ops()) {
if (one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) &&
subgraph_attrs->snippet->has_domain_sensitive_ops()) {
// enforce BF16 precisions to supported operations
// MatMul has to be decomposed to Brgemm operations before enforcement
// Note, MatMul decomposition will be run later again for case if BF16 enforcement is not happened
SNIPPETS_REGISTER_PASS_ABSOLUTE_X86_64(Place::PipelineStart, ov::snippets::pass::MatMulToBrgemm);
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, ov::snippets::pass::MatMulToBrgemm,
pass::EnforcePrecision, element::f32, context->getConfig().inferencePrecision);
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After,
ov::snippets::pass::MatMulToBrgemm,
pass::EnforcePrecision,
element::f32,
context->getConfig().inferencePrecision);
}
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before,
ov::snippets::pass::PropagatePrecision,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx) {
// Note: AMX might be not used even if it's supported by the hardware, check the BrgemmToBrgemmCPU pass for details
if (is_with_amx) {
if (dt_in0 == ov::element::f16)
SUPPORT_ONE(avx512_core_amx_fp16, "Unsupported hardware configuration: amx is supported only on avx512 platforms")
SUPPORT_ONE(avx512_core_amx_fp16,
"Unsupported hardware configuration: amx is supported only on avx512 platforms")
else
SUPPORT_ONE(avx512_core_amx, "Unsupported hardware configuration: amx is supported only on avx512 platforms")
SUPPORT_ONE(avx512_core_amx,
"Unsupported hardware configuration: amx is supported only on avx512 platforms")
} else if (dt_in0 == ov::element::bf16) {
SUPPORT_ONE(avx512_core_bf16, "Unsupported hardware configuration: bf16 is supported only on avx512 platforms")
} else if (one_of(dt_in0, ov::element::u8, ov::element::i8)) {
Expand Down Expand Up @@ -69,8 +71,7 @@ BRGEMM_TYPE get_brgemm_type(const ov::element::Type& element_type_a, bool transp
if (one_of(element_type_a, element::u8, element::i8, element::bf16) &&
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx))
return BRGEMM_TYPE::WITH_AMX;
if (element_type_a == ov::element::f16 &&
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16))
if (element_type_a == ov::element::f16 && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16))
return BRGEMM_TYPE::WITH_AMX;
// Note: this condition reproduces logic from the OneDNN Brgemm implementation. This is needed to align with the
// backend requirements. More details in onednn/src/cpu/x64/brgemm/brgemm_utils.cpp
Expand Down Expand Up @@ -99,11 +100,16 @@ size_t get_elems_in_vec(const ov::element::Type& precision) {
namespace repacking {
size_t compute_inner_n_block(const ov::element::Type& precision) {
switch (precision) {
case element::i8: return 64;
case element::bf16: return 32;
case element::f16: return 32;
case element::f32: return 16;
default: OPENVINO_THROW("BrgemmCopyB doesn't support precision ", precision);
case element::i8:
return 64;
case element::bf16:
return 32;
case element::f16:
return 32;
case element::f32:
return 16;
default:
OPENVINO_THROW("BrgemmCopyB doesn't support precision ", precision);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ namespace intel_cpu {
namespace brgemm_utils {

enum class BRGEMM_TYPE {
STAND_ALONE, // No extra requirements, used for f32|f32
WITH_AMX, // i8|i8 or bf16|bf16 on AMX system or fp16|fp16 on AMX_FP16 system - needs BrgemmCopyB and scratchpad
WITH_COMPENSATIONS, // i8|i8 (non-AMX system) - needs BrgemmCopyB for data repacking and compensations
REPACKING_ONLY, // u8|i8, or bf16|bf16 (non-AMX system), or brgemm with transpose_b=true - needs BrgemmCopyB on second input for data repacking
STAND_ALONE, // No extra requirements, used for f32|f32
WITH_AMX, // i8|i8 or bf16|bf16 on AMX system or fp16|fp16 on AMX_FP16 system - needs BrgemmCopyB and scratchpad
WITH_COMPENSATIONS, // i8|i8 (non-AMX system) - needs BrgemmCopyB for data repacking and compensations
REPACKING_ONLY, // u8|i8, or bf16|bf16 (non-AMX system), or brgemm with transpose_b=true - needs BrgemmCopyB on
// second input for data repacking
};

dnnl::impl::cpu::x64::cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ bool EnforcePrecision::run_on_model(const std::shared_ptr<ov::Model>& f) {
}

std::set<std::vector<ov::element::Type>> EnforcePrecision::get_supported_precisions_default(
const std::shared_ptr<ov::Node>&op) noexcept {
const std::shared_ptr<ov::Node>& op) noexcept {
std::set<std::vector<ov::element::Type>> types;
if (ov::is_type<snippets::op::Brgemm>(op)) {
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1039,12 +1039,12 @@ void Transformations::MainSnippets(void) {
#if defined(OPENVINO_ARCH_ARM64)
false;
#else
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) &&
one_of(config.inferencePrecision, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) &&
one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) &&
one_of(config.inferencePrecision, ov::element::f16));
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) &&
one_of(config.inferencePrecision, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) &&
one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) &&
one_of(config.inferencePrecision, ov::element::f16));
#endif
if (!isMHASupported) {
CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::TokenizeMHASnippets);
Expand All @@ -1060,8 +1060,9 @@ void Transformations::MainSnippets(void) {
const auto in_type1 = matmul->get_input_element_type(1);
const auto is_fp32 = (in_type0 == ov::element::f32 && in_type1 == ov::element::f32 &&
one_of(config.inferencePrecision, element::f32, element::undefined));
const auto is_fp16 = (in_type0 == ov::element::f16 || in_type1 == ov::element::f16) ||
(in_type0 == element::f32 && in_type1 == ov::element::f32 && config.inferencePrecision == ov::element::f16);
const auto is_fp16 =
(in_type0 == ov::element::f16 || in_type1 == ov::element::f16) ||
(in_type0 == element::f32 && in_type1 == ov::element::f32 && config.inferencePrecision == ov::element::f16);
const auto is_bf16 = (in_type0 == ov::element::bf16 && in_type1 == ov::element::bf16) ||
((in_type0 == element::f32 && in_type1 == ov::element::f32 &&
config.inferencePrecision == ov::element::bf16));
Expand All @@ -1082,8 +1083,10 @@ void Transformations::MainSnippets(void) {
const auto& b_shape = matmul->get_input_partial_shape(1);
const auto K = matmul->get_transpose_b() ? *b_shape.rbegin() : *++b_shape.rbegin();
const size_t brgemm_vnni_factor_for_real16 = 2; // 4/2(size in term of byte for bf16/fp16)
if (is_bf16 || is_fp16) return K.is_static() && (K.get_length() % brgemm_vnni_factor_for_real16 == 0);
if (is_int8) return K.is_static();
if (is_bf16 || is_fp16)
return K.is_static() && (K.get_length() % brgemm_vnni_factor_for_real16 == 0);
if (is_int8)
return K.is_static();
}
if (is_int8)
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_vnni) ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,38 @@ namespace {
std::vector<std::vector<InputShape>> transposedShape_4D(bool with_static = true, bool with_dynamic = true) {
std::vector<std::vector<ov::test::InputShape>> shapes;
if (with_static) {
auto static_shapes = SNIPPETS_TESTS_STATIC_SHAPES(
{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}},
{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 1, 1}, {1, 128, 16, 64}},
{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}},
{{2, 68, 6, 92}, {2, 68, 6, 92}, {1, 1, 68, 68}, {2, 68, 6, 92}},
{{1, 58, 16, 34}, {1, 58, 16, 34}, {1, 1, 1, 58}, {1, 58, 16, 34}});
auto static_shapes =
SNIPPETS_TESTS_STATIC_SHAPES({{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}},
{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 1, 1}, {1, 128, 16, 64}},
{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}},
{{2, 68, 6, 92}, {2, 68, 6, 92}, {1, 1, 68, 68}, {2, 68, 6, 92}},
{{1, 58, 16, 34}, {1, 58, 16, 34}, {1, 1, 1, 58}, {1, 58, 16, 34}});
shapes.insert(shapes.end(), static_shapes.begin(), static_shapes.end());
}
if (with_dynamic) {
std::vector<std::vector<ov::test::InputShape>> dynamic_shapes = {{
{PartialShape{-1, -1, -1, 100}, {{1, 64, 4, 100}, {2, 16, 2, 100}, {1, 72, 4, 100}}},
{PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}},
{PartialShape{-1, -1, -1, 128}, {{1, 4, 64, 128}, {2, 2, 16, 128}, {1, 4, 72, 128}}},
{PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}},
},
{
{PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 16, 2, 100}, {1, 128, 3, 64}}},
{PartialShape{-1, -1, -1, -1}, {{1, 128, 1, 64}, {2, 128, 2, 100}, {1, 128, 1, 64}}},
{PartialShape{-1, -1, -1, -1}, {{2, 1, 128, 128}, {2, 2, 16, 128}, {2, 1, 128, 128}}},
{PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 128, 2, 100}, {1, 128, 3, 64}}},
},
{
{PartialShape{-1, -1, 12, 64}, {{1, 70, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 70, 12, 64}}},
{PartialShape{-1, -1, 12, 64}, {{1, 35, 12, 64}, {2, 10, 12, 64}, {2, 1, 12, 64}, {2, 10, 12, 64}, {1, 35, 12, 64}}},
{PartialShape{-1, 12, -1, -1}, {{2, 12, 70, 35}, {1, 12, 20, 10}, {1, 12, 20, 10}, {1, 12, 20, 1}, {2, 12, 70, 35}}},
{PartialShape{-1, -1, 12, 64}, {{1, 35, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 35, 12, 64}}},
}};
std::vector<std::vector<ov::test::InputShape>> dynamic_shapes = {
{
{PartialShape{-1, -1, -1, 100}, {{1, 64, 4, 100}, {2, 16, 2, 100}, {1, 72, 4, 100}}},
{PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}},
{PartialShape{-1, -1, -1, 128}, {{1, 4, 64, 128}, {2, 2, 16, 128}, {1, 4, 72, 128}}},
{PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}},
},
{
{PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 16, 2, 100}, {1, 128, 3, 64}}},
{PartialShape{-1, -1, -1, -1}, {{1, 128, 1, 64}, {2, 128, 2, 100}, {1, 128, 1, 64}}},
{PartialShape{-1, -1, -1, -1}, {{2, 1, 128, 128}, {2, 2, 16, 128}, {2, 1, 128, 128}}},
{PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 128, 2, 100}, {1, 128, 3, 64}}},
},
{
{PartialShape{-1, -1, 12, 64},
{{1, 70, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 70, 12, 64}}},
{PartialShape{-1, -1, 12, 64},
{{1, 35, 12, 64}, {2, 10, 12, 64}, {2, 1, 12, 64}, {2, 10, 12, 64}, {1, 35, 12, 64}}},
{PartialShape{-1, 12, -1, -1},
{{2, 12, 70, 35}, {1, 12, 20, 10}, {1, 12, 20, 10}, {1, 12, 20, 1}, {2, 12, 70, 35}}},
{PartialShape{-1, -1, 12, 64},
{{1, 35, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 35, 12, 64}}},
}};
shapes.insert(shapes.end(), dynamic_shapes.begin(), dynamic_shapes.end());
}
return shapes;
Expand Down

0 comments on commit 6b72cb2

Please sign in to comment.