Skip to content

Commit

Permalink
Disable Transpose tokenization if internal config flag is not set
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanNovoselov committed Nov 10, 2022
1 parent 964dd8f commit d783853
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/plugins/intel_cpu/src/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,8 +639,12 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
// CPU Plugin support Swish in Subgraph via conversion to SwichCPU which assumes second input to be constant
const bool is_unsupported_swish = ov::is_type<const ov::op::v4::Swish>(n) && n->inputs().size() > 1 &&
!ov::is_type<const ov::op::v0::Constant>(n->get_input_node_shared_ptr(1));
const bool is_disabled_softmax_tokenization =
(ov::is_type<const ov::op::v1::Softmax>(n) || ov::is_type<const ov::op::v8::Softmax>(n)) && !_tokenizeSpecOpsSnippets;
// todo: general tokenization flow is not currently supported for these operations.
// they can be tokenized only as a part of complex patterns
const bool is_disabled_tokenization = !_tokenizeSpecOpsSnippets &&
(ov::is_type<const ov::op::v1::Softmax>(n) ||
ov::is_type<const ov::op::v8::Softmax>(n) ||
ov::is_type<const ov::op::v1::Transpose>(n));
const auto& inputs = n->inputs();
// todo: clarify whether we can evaluate snippets on const paths
const bool has_only_const_inputs = std::all_of(inputs.begin(), inputs.end(),
Expand All @@ -657,7 +661,7 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
const auto& outputs = n->outputs();
const bool bad_output_rank = std::any_of(outputs.begin(), outputs.end(),
[&](const ov::Output<const ov::Node>& out) {return rank_is_too_large(out.get_tensor());});
return has_only_const_inputs || bad_input_rank || bad_output_rank || is_unsupported_swish || is_disabled_softmax_tokenization;
return has_only_const_inputs || bad_input_rank || bad_output_rank || is_unsupported_swish || is_disabled_tokenization;
});
snippetsManager.register_pass<ngraph::snippets::pass::CommonOptimizations>();
snippetsManager.run_passes(nGraphFunc);
Expand Down
5 changes: 5 additions & 0 deletions src/tests/functional/plugin/shared/src/snippets/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "snippets/transpose.hpp"
#include "subgraph_permute.hpp"
#include "functional_test_utils/skip_tests_config.hpp"
#include "cpp_interfaces/interface/ie_internal_plugin_config.hpp"

namespace ov {
namespace test {
Expand Down Expand Up @@ -35,6 +36,10 @@ void Transpose::SetUp() {

auto f = ov::test::snippets::TransposeSinhFunction({inputShape}, order);
function = f.getOriginal();
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE)) {
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE,
InferenceEngine::PluginConfigParams::YES});
}
}

TEST_P(Transpose, CompareWithRefImpl) {
Expand Down

0 comments on commit d783853

Please sign in to comment.