Skip to content

Commit

Permalink
Added config parameter to disable MHA ops tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova authored and IvanNovoselov committed Nov 10, 2022
1 parent c7947a8 commit 964dd8f
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ DECLARE_CONFIG_KEY(FORCE_DISABLE_CACHE);
*/
DECLARE_CONFIG_KEY(CONFIG_DEVICE_ID);

/**
* @brief Defines if MHA ops can be tokenized in Snippets
* Softmax, Transpose should be tokenized in Snippets only in tests and in MHA pattern
* @ingroup ie_dev_api_plugin_api
*/
DECLARE_CONFIG_KEY(SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE);

} // namespace PluginConfigInternalParams

} // namespace InferenceEngine
8 changes: 8 additions & 0 deletions src/plugins/intel_cpu/src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,14 @@ void Config::readProperties(const std::map<std::string, std::string> &prop) {
IE_THROW() << "Wrong value for property key " << CPUConfigParams::KEY_CPU_DENORMALS_OPTIMIZATION
<< ". Expected only YES/NO";
}
} else if (key == PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE) {
if (val == PluginConfigParams::YES)
tokenizeMHAOpsSnippets = true;
else if (val == PluginConfigParams::NO)
tokenizeMHAOpsSnippets = false;
else
IE_THROW() << "Wrong value for property key " << PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE
<< ". Expected only YES/NO";
} else {
IE_THROW(NotFound) << "Unsupported property " << key << " by CPU plugin";
}
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct Config {
bool collectPerfCounters = false;
bool exclusiveAsyncRequests = false;
bool enableDynamicBatch = false;
bool tokenizeMHAOpsSnippets = false;
std::string dumpToDot = "";
int batchLimit = 0;
size_t rtCacheCapacity = 5000ul;
Expand Down
22 changes: 13 additions & 9 deletions src/plugins/intel_cpu/src/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ Engine::~Engine() {
}

static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function> nGraphFunc, const bool _enableLPT, const bool _enableBF16,
const bool _enableSnippets, const bool isLegacyApi) {
const bool _enableSnippets, const bool _tokenizeSpecOpsSnippets, const bool isLegacyApi) {
ngraph::pass::Manager manager;
manager.set_per_pass_validation(false);
manager.register_pass<ngraph::pass::InitNodeInfo>();
Expand Down Expand Up @@ -635,12 +635,12 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
snippetsManager.register_pass<ngraph::snippets::pass::EnumerateNodes>();
snippetsManager.register_pass<ngraph::snippets::pass::TokenizeSnippets>();
snippetsManager.get_pass_config()->set_callback<ngraph::snippets::pass::TokenizeSnippets>(
[](const std::shared_ptr<const ov::Node>& n) -> bool {
[_tokenizeSpecOpsSnippets](const std::shared_ptr<const ov::Node>& n) -> bool {
// CPU Plugin support Swish in Subgraph via conversion to SwichCPU which assumes second input to be constant
if (ov::is_type<const ov::op::v4::Swish>(n)) {
if (n->inputs().size() > 1 && !ov::is_type<const ov::op::v0::Constant>(n->get_input_node_shared_ptr(1)))
return true;
}
const bool is_unsupported_swish = ov::is_type<const ov::op::v4::Swish>(n) && n->inputs().size() > 1 &&
!ov::is_type<const ov::op::v0::Constant>(n->get_input_node_shared_ptr(1));
const bool is_disabled_softmax_tokenization =
(ov::is_type<const ov::op::v1::Softmax>(n) || ov::is_type<const ov::op::v8::Softmax>(n)) && !_tokenizeSpecOpsSnippets;
const auto& inputs = n->inputs();
// todo: clarify whether we can evaluate snippets on const paths
const bool has_only_const_inputs = std::all_of(inputs.begin(), inputs.end(),
Expand All @@ -657,7 +657,7 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
const auto& outputs = n->outputs();
const bool bad_output_rank = std::any_of(outputs.begin(), outputs.end(),
[&](const ov::Output<const ov::Node>& out) {return rank_is_too_large(out.get_tensor());});
return has_only_const_inputs || bad_input_rank || bad_output_rank;
return has_only_const_inputs || bad_input_rank || bad_output_rank || is_unsupported_swish || is_disabled_softmax_tokenization;
});
snippetsManager.register_pass<ngraph::snippets::pass::CommonOptimizations>();
snippetsManager.run_passes(nGraphFunc);
Expand Down Expand Up @@ -829,8 +829,10 @@ Engine::LoadExeNetworkImpl(const InferenceEngine::CNNNetwork &network, const std
const bool enableDynamicBatch = (dynamicBatchProp != config.end() && dynamicBatchProp->second == PluginConfigParams::YES)
|| engConfig.enableDynamicBatch;
const bool enableSnippets = !(enableModelCache || enableDynamicBatch || enableBF16);
const auto& mhaOpsSnippetsProp = config.find(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE);
const bool tokenizeMHAOpSnippets = enableSnippets && (mhaOpsSnippetsProp != config.end() && mhaOpsSnippetsProp->second == PluginConfigParams::YES);
auto nGraphFunc = clonedNetwork.getFunction();
TransformationUpToCPUSpecificOpSet(nGraphFunc, enableLPT, enableBF16, enableSnippets, isLegacyAPI());
TransformationUpToCPUSpecificOpSet(nGraphFunc, enableLPT, enableBF16, enableSnippets, tokenizeMHAOpSnippets, isLegacyAPI());

// need to check that all outputs have static shapes
// checking that all inputs have static shapes is performed in the common part
Expand Down Expand Up @@ -1070,6 +1072,8 @@ QueryNetworkResult Engine::QueryNetwork(const CNNNetwork& network, const std::ma
|| Config::LPTransformsMode::On == engConfig.lpTransformsMode /* or already enabled */;
const bool enableSnippets = !(conf.cache_dir.empty() || conf.enableDynamicBatch || (conf.enforceBF16
&& dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)));
const auto& mhaOpsSnippetsProp = config.find(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE);
const bool tokenizeMHAOpSnippets = enableSnippets && (mhaOpsSnippetsProp != config.end() && mhaOpsSnippetsProp->second == PluginConfigParams::YES);

auto model = network.getFunction();
if (model == nullptr) {
Expand All @@ -1078,7 +1082,7 @@ QueryNetworkResult Engine::QueryNetwork(const CNNNetwork& network, const std::ma

auto supported = GetSupportedNodes(model,
[&](std::shared_ptr<ov::Model>& model) {
TransformationUpToCPUSpecificOpSet(model, enableLPT, conf.enforceBF16, enableSnippets, isLegacyAPI());
TransformationUpToCPUSpecificOpSet(model, enableLPT, conf.enforceBF16, enableSnippets, tokenizeMHAOpSnippets, isLegacyAPI());
ConvertToCPUSpecificOpset(model);
},
[&](const std::shared_ptr<ngraph::Node>& op) {
Expand Down

0 comments on commit 964dd8f

Please sign in to comment.