Skip to content

Commit

Permalink
Fixed conditions for tensor in tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Feb 21, 2023
1 parent bbb0a8c commit 67ed836
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions src/common/snippets/src/pass/mha_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
namespace {
auto is_supported_tensor(const ngraph::descriptor::Tensor& t) -> bool {
// TODO: Add support of non-4D tensors
return ngraph::snippets::pass::TokenizeSnippets::supported_element_types.count(t.get_element_type()) != 0 &&
t.get_partial_shape().is_static() && t.get_shape().size() == 4;
return t.get_partial_shape().is_static() && t.get_shape().size() == 4;
}

// TODO: Add support of Reshape?
Expand All @@ -41,9 +40,12 @@ auto is_valid_transpose(const std::shared_ptr<ngraph::opset1::Transpose>& node,
return false;
return transpose_pattern->cast_vector<int64_t>() == expected_order;
};
auto is_supported_transpose_tensor = [](const ngraph::descriptor::Tensor& t) {
return is_supported_tensor(t) && ngraph::snippets::pass::TokenizeSnippets::supported_element_types.count(t.get_element_type()) != 0;
};

return node && node->get_output_target_inputs(0).size() == 1 && node->get_shape().size() == 4 &&
valid_transpose_order(node->get_input_node_shared_ptr(1)) && is_supported_tensor(node->get_input_tensor(0));
valid_transpose_order(node->get_input_node_shared_ptr(1)) && is_supported_transpose_tensor(node->get_input_tensor(0));
}

auto tokenize_broadcast(const std::shared_ptr<ov::Node>& interm_op, ov::NodeVector& ordered_ops) -> void {
Expand Down Expand Up @@ -99,8 +101,9 @@ auto tokenize_reshape_around_softmax(std::shared_ptr<ov::Node>& interm_op,
ngraph::NodeVector& ordered_ops) -> bool {
reshape = ngraph::as_type_ptr<ngraph::opset1::Reshape>(interm_op);
if (reshape) {
const auto shape = reshape->get_input_shape(0);
if (shape.back() != reshape->get_output_shape(0).back() || reshape->get_output_target_inputs(0).size() != 1)
const auto in_shape = reshape->get_input_shape(0);
const auto out_shape = reshape->get_output_shape(0);
if (in_shape.back() != out_shape.back() || reshape->get_output_target_inputs(0).size() != 1)
return false;
ordered_ops.push_back(reshape);
interm_op = reshape->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
Expand Down Expand Up @@ -232,7 +235,8 @@ ngraph::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
* MatMul1
*/
const auto matmul0 = ngraph::as_type_ptr<ngraph::opset1::MatMul>(pattern_to_output.at(m_matmul0).get_node_shared_ptr());
if (!matmul0 || matmul0->get_output_target_inputs(0).size() != 1 || matmul0->get_transpose_a())
if (!matmul0 || matmul0->get_output_target_inputs(0).size() != 1 || matmul0->get_transpose_a() ||
!is_supported_tensor(matmul0->get_input_tensor(0)) || !is_supported_tensor(matmul0->get_input_tensor(1)))
return false;

const auto matmul0_prc = op::Brgemm::get_output_type(matmul0->get_input_element_type(0), matmul0->get_input_element_type(1));
Expand Down Expand Up @@ -288,7 +292,8 @@ ngraph::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {

const auto matmul1 = ngraph::as_type_ptr<ngraph::opset1::MatMul>(interm_op);
if (!matmul1 || matmul1->get_output_target_inputs(0).size() != 1 || matmul1->get_transpose_a() || matmul1->get_transpose_b() ||
op::Brgemm::get_output_type(matmul1->get_input_element_type(0), matmul1->get_input_element_type(1)) == element::undefined)
op::Brgemm::get_output_type(matmul1->get_input_element_type(0), matmul1->get_input_element_type(1)) == element::undefined ||
!is_supported_tensor(matmul1->get_input_tensor(0)) || !is_supported_tensor(matmul1->get_input_tensor(1)))
return false;

if (transformation_callback(matmul1)) {
Expand Down

0 comments on commit 67ed836

Please sign in to comment.