Skip to content

Commit

Permalink
[Snippets] Fixed validation checks in BrgemmKernelExecutor::update_co…
Browse files Browse the repository at this point in the history
…nfig (#25978)

### Details:
- *`BrgemmKernelExecutor::update_config` checks Loops with the target
`Brgemm` inside that this `Loop` is really by `M, K or N`. Previously we
also checked `LoopPort.is_incremented == true`. However, the
optimization `CleanRepeatedDataPointerShifts` may set `false` to this
attribute when these ports are connected to Buffers with the same GPR
(to avoid double pointer increments in execution). This PR lefts only
checks for `LoopPort.is_incremented == false` for not incremented
ports.*

### Tickets:
 - *149082*
  • Loading branch information
a-sidorova authored Aug 9, 2024
1 parent d8cbe92 commit 4e00130
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,12 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression
// Quick validation check: Should we check that port is really Brgemm port?
// If BrgemmCopyB in the Loop by M -> first input port will be BrgemmCopyB with `incremented=false`
// to avoid extra checks, we validate only first input port
OPENVINO_ASSERT(in_ports.size() > 1 && in_ports.front().is_incremented && in_ports.front().dim_idx == 1 &&
out_ports.size() == 1 && out_ports.front().is_incremented && out_ports.front().dim_idx == 1,
"Incorrect Loop by Brgemm dimension N");
// Note: We check `is_incremented` attribute only for not incremented ports because
// this `is_incremented = true` can be changed by `CleanRepeatedDataPointerShifts` optimization
auto check_port = [&](const ov::snippets::lowered::LoopPort& p) { return p.dim_idx == 1; };
OPENVINO_ASSERT(in_ports.size() > 1 && std::all_of(in_ports.cbegin(), in_ports.cend(), check_port) &&
out_ports.size() == 1 && check_port(out_ports.back()),
"Incorrect Loop by Brgemm dimension M");
M = current_expanded_loop_info->get_increment();
input_pds[0]->set_subtensor_dim(1, M);
output_pds[0]->set_subtensor_dim(1, M);
Expand All @@ -240,8 +243,11 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression
const auto& in_ports = current_expanded_loop_info->get_input_ports();
const auto& out_ports = current_expanded_loop_info->get_output_ports();
// Quick validation check: Should we check that port is really Brgemm port?
OPENVINO_ASSERT(in_ports.size() == 2 && !in_ports.front().is_incremented && in_ports.back().is_incremented && in_ports.back().dim_idx == 0 &&
out_ports.size() == 1 && out_ports.front().is_incremented && out_ports.front().dim_idx == 0,
// Note: We check `is_incremented` attribute only for not incremented ports because
// this `is_incremented = true` can be changed by `CleanRepeatedDataPointerShifts` optimization
auto check_port = [&](const ov::snippets::lowered::LoopPort& p) { return p.dim_idx == 0; };
OPENVINO_ASSERT(in_ports.size() == 2 && !in_ports.front().is_incremented && std::all_of(in_ports.cbegin(), in_ports.cend(), check_port) &&
out_ports.size() == 1 && check_port(out_ports.back()),
"Incorrect Loop by Brgemm dimension N");
N = current_expanded_loop_info->get_increment();
input_pds[1]->set_subtensor_dim(0, N);
Expand All @@ -261,8 +267,9 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression
const auto& in_ports = current_expanded_loop_info->get_input_ports();
const auto& out_ports = current_expanded_loop_info->get_output_ports();
// Quick validation check: Should we check that port is really Brgemm port?
OPENVINO_ASSERT(in_ports.size() == 2 && in_ports.front().is_incremented && in_ports.front().dim_idx == 0 &&
in_ports.back().is_incremented && in_ports.back().dim_idx == 1 &&
// Note: We check `is_incremented` attribute only for not incremented ports because
// this `is_incremented = true` can be changed by `CleanRepeatedDataPointerShifts` optimization
OPENVINO_ASSERT(in_ports.size() == 2 && in_ports.front().dim_idx == 0 && in_ports.back().dim_idx == 1 &&
out_ports.size() == 1 && !out_ports.front().is_incremented,
"Incorrect Loop by Brgemm dimension K");
K = current_expanded_loop_info->get_increment();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ const auto& inputShapes_4D = STATIC_SHAPES(

const auto& inputShapes_3D = STATIC_SHAPES(
{{128, 12, 64}, {128, 12, 64}, {12, 128, 128}, {128, 12, 64}},
{{68, 6, 92}, {68, 6, 92}, {1, 68, 68}, {68, 6, 92}});
{{68, 6, 92}, {68, 6, 92}, {1, 68, 68}, {68, 6, 92}},
{{16, 2, 92}, {68, 2, 92}, {1, 16, 68}, {68, 2, 92}});

static inline bool is_bf16_supported() {
return ov::with_cpu_x86_bfloat16() || ov::with_cpu_x86_avx512_core_amx_bf16();
Expand Down

0 comments on commit 4e00130

Please sign in to comment.