-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9b059c8
commit 4dc5bbb
Showing
7 changed files
with
174 additions
and
101 deletions.
There are no files selected for viewing
36 changes: 36 additions & 0 deletions
36
src/common/snippets/include/snippets/pass/analyze_broadcastable_inputs.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/pass.hpp" | ||
|
||
namespace ov { | ||
namespace snippets { | ||
namespace pass { | ||
|
||
/** | ||
* @interface AnalyzeBroadcastableInputs | ||
* @brief Analyzes body parameters which affects inputs of broadcastable operations(`Broadcast` op should be inserted there). | ||
* Initializes special map `BroadcastableInputsMap = [Index of Parameter -> Index of broadcastable dimension from end]` | ||
* Notes: | ||
* - Must be called after Canonicalization pass | ||
* - Doesn't support `layouts` in PortDescriptors | ||
* @ingroup snippets | ||
*/ | ||
class AnalyzeBroadcastableInputs : public ov::pass::ModelPass { | ||
public: | ||
OPENVINO_RTTI("AnalyzeBroadcastableInputs"); | ||
using BroadcastableInputsMap = std::map<size_t, size_t>; | ||
AnalyzeBroadcastableInputs(BroadcastableInputsMap& map); | ||
|
||
bool run_on_model(const std::shared_ptr<ov::Model>& m) override; | ||
|
||
private: | ||
BroadcastableInputsMap& m_broadcastable_inputs; | ||
}; | ||
|
||
} // namespace pass | ||
} // namespace snippets | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
106 changes: 106 additions & 0 deletions
106
src/common/snippets/src/pass/analyze_broadcastable_inputs.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "snippets/pass/analyze_broadcastable_inputs.hpp" | ||
|
||
#include "snippets/lowered/pass/insert_broadcastmove.hpp" | ||
#include "snippets/utils/utils.hpp" | ||
#include "snippets/itt.hpp" | ||
|
||
namespace ov { | ||
namespace snippets { | ||
namespace pass { | ||
|
||
AnalyzeBroadcastableInputs::AnalyzeBroadcastableInputs(BroadcastableInputsMap& map) : m_broadcastable_inputs(map) {} | ||
|
||
bool pass::AnalyzeBroadcastableInputs::run_on_model(const std::shared_ptr<ov::Model>& body) { | ||
RUN_ON_MODEL_SCOPE(AnalyzeBroadcastableInputs); | ||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::AnalyzeBroadcastableInputs") | ||
// Snippets supports tokenization of the following operations: | ||
// - Unary, Binary and Ternary (Select) Elementwise ops | ||
// - Softmax, MatMul, Transpose, GroupNorm | ||
// Binary Elementwise ops (+ Select) requires explicit Broadcast op | ||
// on inputs if broadcasting of latest dimensions is needed. | ||
// These ops will be start points of DFS - need to go to Parameters and update `broadcastable_inputs_map`. | ||
// We iterates through all ops by execution order. So if we already analyzied some op in the input branch - skip this branch. | ||
// However, there some ops which can change `processing_dim_idx`: | ||
// - Transpose has order which changes `processing_dim_idx`. But Transpose can be only after Parameters and before Results. | ||
// - MatMul's first input doesn't affect output latest dimension - skip this branch. | ||
// Also MatMul has `transposed_b` which changes `processing_dim_idx` | ||
m_broadcastable_inputs.clear(); | ||
// Currently Broadcasting can be changed only if there are several Parameters in body | ||
if (body->get_parameters().size() < 2) | ||
return false; | ||
|
||
const auto& ops = body->get_ordered_ops(); | ||
std::set<std::shared_ptr<ov::Node>> visited_ops = {}; | ||
for (const auto& op : ops) { | ||
if (!ov::snippets::lowered::pass::InsertBroadcastMove::is_broadcasting_supported(op)) | ||
continue; | ||
|
||
size_t processing_dim_idx = 0; | ||
|
||
// We need to propagate `processing_dim_idx` from input of the current node to the parameter. | ||
// To do it we use DFS | ||
std::stack<std::shared_ptr<ov::Node>> nodes_to_calculate; | ||
nodes_to_calculate.push(op); | ||
while (!nodes_to_calculate.empty()) { | ||
auto current_node = nodes_to_calculate.top(); | ||
nodes_to_calculate.pop(); | ||
|
||
if (const auto& param = ov::as_type_ptr<ov::op::v0::Parameter>(current_node)) { | ||
const auto consumers = param->get_output_target_inputs(0); | ||
if (std::any_of(consumers.cbegin(), consumers.cend(), | ||
[](const ov::Input<ov::Node>& in) { return ov::is_type<ov::op::v1::Transpose>(in.get_node()); })) { | ||
OPENVINO_ASSERT(consumers.size() == 1, "Incorrect count of outputs of Parameter!"); | ||
const auto transpose = consumers.begin()->get_node(); | ||
std::vector<size_t> order; | ||
const auto& constant = ov::as_type_ptr<const opset1::Constant>(transpose->get_input_node_shared_ptr(1)); | ||
OPENVINO_ASSERT(constant, "Unsupported order node of Transpose"); | ||
order = constant->cast_vector<size_t>(); | ||
if (order.empty()) { | ||
order.resize(transpose->get_output_partial_shape(0).size()); | ||
std::iota(order.rbegin(), order.rend(), 0); | ||
} | ||
// `processing_dim_idx` starts from the end | ||
processing_dim_idx = order.size() - 1 - ov::snippets::utils::get_input_dim_idx(order, processing_dim_idx); | ||
} | ||
const auto param_idx = body->get_parameter_index(param); | ||
if (m_broadcastable_inputs.count(param_idx) == 0) { | ||
m_broadcastable_inputs[param_idx] = processing_dim_idx; | ||
} else { | ||
OPENVINO_ASSERT(m_broadcastable_inputs.at(param_idx) == processing_dim_idx, | ||
"Parameter has been already analyzed and has another processing dim index!"); | ||
} | ||
processing_dim_idx = 0; | ||
continue; | ||
} else if (ov::is_type<ov::op::v0::Constant>(current_node)) { | ||
visited_ops.insert(op); | ||
continue; | ||
} | ||
|
||
ov::OutputVector inputs = current_node->input_values(); | ||
if (const auto mm = ov::as_type_ptr<ov::op::v0::MatMul>(current_node)) { | ||
inputs = { current_node->input_value(1) }; | ||
processing_dim_idx = static_cast<size_t>(mm->get_transpose_b()); | ||
} | ||
|
||
// not a leaf - continue to search | ||
for (const auto& input_value : inputs) { | ||
const auto& input_node = input_value.get_node()->shared_from_this(); | ||
if (visited_ops.count(input_node) == 0) { | ||
nodes_to_calculate.push(input_node); | ||
} | ||
} | ||
} | ||
|
||
visited_ops.insert(op); | ||
} | ||
|
||
return true; | ||
} | ||
|
||
} // namespace pass | ||
} // namespace snippets | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters