Skip to content

Commit

Permalink
Some fixes for transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jan 10, 2023
1 parent 2a4ab82 commit ee4aa7a
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace pass {
*/
class TokenizeMHASnippets: public ngraph::pass::MatcherPass {
public:
OPENVINO_RTTI("TokenizeMHASnippets", "0");
TokenizeMHASnippets();
};

Expand Down
4 changes: 3 additions & 1 deletion src/common/snippets/src/pass/convert_constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ ngraph::snippets::pass::ConvertConstantsToScalars::ConvertConstantsToScalars() {
return false;
// Note that all Constants {1,1,1,1} are converted to Scalar {1} here
// This is needed to simplify shape inference, otherwise {1,1,1,1} Constants can increase output rank
auto scalar = std::make_shared<snippets::op::Scalar>(ov::op::v0::Constant(*constant, ov::Shape{1}));
// Also some operations support only scalar shapes, so we need separate scalars and shape [1]
const auto shape = constant->get_output_shape(0).size() == 0 ? ov::Shape{} : ov::Shape{1};
auto scalar = std::make_shared<snippets::op::Scalar>(ov::op::v0::Constant(*constant, shape));
scalar->set_friendly_name(constant->get_friendly_name());
ngraph::copy_runtime_info(constant, scalar);
ngraph::replace_node(constant, scalar);
Expand Down
24 changes: 15 additions & 9 deletions src/common/snippets/src/pass/mha_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,18 +241,24 @@ ngraph::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
}

auto transpose1 = ngraph::as_type_ptr<ngraph::opset1::Transpose>(parent);
if (!matmul0->get_transpose_b() && is_valid_transpose(transpose1, {0, 2, 3, 1})) {
ordered_ops.insert(ordered_ops.begin(), transpose1);
} else if (matmul0->get_transpose_b() && is_valid_transpose(transpose1, {0, 2, 1, 3})) {
// We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order
// only if these ops have scalar shapes on other inputs.
// There is transformation ExplicitTransposeMatMulInputs that set supported order and transposed_b(false).
// We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching
if (are_weights_scalar) {
ordered_ops.insert(ordered_ops.begin(), transpose1);
if (matmul0->get_transpose_b()) {
if (is_valid_transpose(transpose1, {0, 2, 1, 3})) {
// We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order
// only if these ops have scalar shapes on other inputs.
// There is transformation ExplicitTransposeMatMulInputs that set supported order and transposed_b(false).
// We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching
if (are_weights_scalar) {
ordered_ops.insert(ordered_ops.begin(), transpose1);
} else {
return false;
}
} else {
return false;
}
} else {
if (is_valid_transpose(transpose1, {0, 2, 3, 1})) {
ordered_ops.insert(ordered_ops.begin(), transpose1);
}
}

// TODO: Add Reshape Support for all Transposes
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/src/pass/tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ bool EnumerateNodes::run_on_model(const std::shared_ptr<ov::Model> &m) {

bool SnippetsTokenization::run_on_model(const std::shared_ptr<ov::Model>& m) {
RUN_ON_FUNCTION_SCOPE(SnippetsTokenization);
ngraph::pass::Manager manager;
ngraph::pass::Manager manager(get_pass_config());
manager.set_per_pass_validation(false);

manager.register_pass<EnumerateNodes>();
Expand Down
13 changes: 5 additions & 8 deletions src/plugins/intel_cpu/src/transformation_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -544,10 +544,9 @@ void Transformations::PostLpt() {
});

// Float MHA is supported by snippets now
auto postLPTPassConfig = postLPTPassManager.get_pass_config();
if (!enableBF16) {
postLPTPassConfig->disable<MHAFloatFusion>();
postLPTPassConfig->disable<MHAFloatFusion2>();
postLPTPassManager.get_pass_config()->disable<MHAFloatFusion>();
postLPTPassManager.get_pass_config()->disable<MHAFloatFusion2>();
}

// Execute before snippets. Otherwise FQ will be converted to Subgraph
Expand All @@ -567,8 +566,7 @@ void Transformations::MainSnippets(void) {

if (enableBF16) {
// TODO: Need to add BF16 support for MHA in Snippets
const auto snippetsConfig = snippetsManager.get_pass_config();
snippetsConfig->disable<ngraph::snippets::pass::TokenizeMHASnippets>();
snippetsManager.get_pass_config()->disable<ngraph::snippets::pass::TokenizeMHASnippets>();
}
if (snippetsMode != Config::SnippetsMode::IgnoreCallback) {
snippetsManager.get_pass_config()->set_callback<ngraph::snippets::pass::TokenizeMHASnippets>(
Expand All @@ -590,9 +588,8 @@ void Transformations::MainSnippets(void) {
// - parallelism support on JIT level
const auto needed_num_of_threads = 12lu;
const auto l2_cache_size = dnnl::utils::get_cache_size(2, true);
const auto is_unsupported_parallel_work_amount = IMPLICATION(
parallel_get_num_threads() / 2 > parallel_work_amount,
parallel_work_amount < needed_num_of_threads);
const auto is_unsupported_parallel_work_amount = parallel_get_num_threads() / 2 > parallel_work_amount &&
parallel_work_amount < needed_num_of_threads;
const auto is_unsupported_kernel_work_amount = kernel_buffer_size > l2_cache_size;
return is_unsupported_parallel_work_amount || is_unsupported_kernel_work_amount;
});
Expand Down

0 comments on commit ee4aa7a

Please sign in to comment.