Skip to content

Commit

Permalink
Correct Transpose tokenization in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Nov 19, 2024
1 parent fb62330 commit 0fb5b64
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
5 changes: 4 additions & 1 deletion src/common/snippets/src/pass/collapse_subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,12 @@ auto is_supported_op(const std::shared_ptr<const Node> &n) -> bool {
const auto parent = transpose->get_input_node_shared_ptr(0);
const auto child = transpose->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
auto is_brgemm_case = ov::is_type<opset1::MatMul>(parent) || ov::is_type<opset1::MatMul>(child);
auto decomposition_case = true;
// Check for Transpose parent is MatMul inside Subgraph
if (const auto subgraph = ov::as_type_ptr<const op::Subgraph>(parent)) {
if (GetSnippetsSubgraphType(subgraph) != SnippetsSubgraphType::Completed) {
// Transpose decomposition is supported only for Transpose nodes right after Subgraph's parameters
decomposition_case = false;
const auto body = subgraph->body_ptr();
const auto subgraph_output = body->get_results()[transpose->input_value(0).get_index()]->get_input_node_shared_ptr(0);
is_brgemm_case = is_brgemm_case || ov::is_type<opset1::MatMul>(subgraph_output);
Expand All @@ -63,7 +66,7 @@ auto is_supported_op(const std::shared_ptr<const Node> &n) -> bool {
const auto& order = as_type_ptr<const opset1::Constant>(n->get_input_node_shared_ptr(1));
if (order) {
const auto order_value = order->cast_vector<int>();
return (TransposeDecomposition::is_supported_transpose_order(order_value)) ||
return (decomposition_case && TransposeDecomposition::is_supported_transpose_order(order_value)) ||
(is_brgemm_case && FuseTransposeBrgemm::is_supported_transpose_order(order_value));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16,
::testing::Values(ov::element::bf16),
::testing::ValuesIn({false}),
::testing::Values(MHA::default_thread_count),
::testing::Values(7),
::testing::Values(9),
::testing::Values(6),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(precision_f32(5)),
::testing::Values(ov::element::bf16),
::testing::Values(MHA::default_thread_count),
::testing::Values(8), // MHA + 1 Transpose on output + 6 Converts around
::testing::Values(10), // MHA + 1 Transpose on output + 6 Converts around + 2 Transposes on Matmul's B inputs
::testing::Values(7), // MHA + 6 Converts around
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::empty_plugin_config)),
Expand Down

0 comments on commit 0fb5b64

Please sign in to comment.