-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Snippets] Move BrgemmCopyB repacking logic outside the Subgraph (#27007
) ### Details: Currently, CopyB repacking is always performed inside Subgraph. In the case when batch on B Matmul input is significantly smaller than batch on A Matmul input, and parallel work amount is big enough, this may lead to ineffective execution, since repacking for B input is performed in each parallel task whereas only one repacking iteration for each B batch is enough. Within this PR, CopyB repacking is moved outside the snippets kernel and performed via common reorder primitive just before the snippets kernel execution. ### Tickets: - *CVS-154383*
- Loading branch information
Showing
34 changed files
with
969 additions
and
465 deletions.
There are no files selected for viewing
53 changes: 53 additions & 0 deletions
53
src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "snippets/lowered/linear_ir.hpp" | ||
#include "snippets/lowered/loop_info.hpp" | ||
#include "snippets/lowered/pass/runtime_optimizer.hpp" | ||
|
||
namespace ov { | ||
namespace snippets { | ||
namespace lowered { | ||
namespace pass { | ||
/** | ||
* @class MHAParallelWAOptimizer | ||
* @brief Optimizes the dynamic MHA execution increasing parallel work amount dy dividing Brgemm's "M" dimension to "parallel_m" | ||
* and "kernel_m". Uses heuristics from snippets::pass::SplitDimensionM for dimension splitting. | ||
* The optimizer performs the following steps: | ||
* - Identifies applicable Brgemm operations within the LinearIR. | ||
* - Finds parameters whose shapes and layouts need to be adjusted after the split. | ||
* - Determines loops that should be adjusted. | ||
*/ | ||
class MHAParallelWAOptimizer : public lowered::pass::RuntimeOptimizer { | ||
public: | ||
MHAParallelWAOptimizer() = default; | ||
MHAParallelWAOptimizer(const lowered::LinearIRCPtr& linear_ir, const RuntimeConfigurator* configurator); | ||
|
||
bool run(const lowered::LinearIR& linear_ir) override; | ||
bool applicable() const override { return !m_loops_to_split.empty(); } | ||
|
||
private: | ||
static std::unordered_set<lowered::ExpressionPtr> find_applicable_brgemms(const lowered::LinearIRCPtr& linear_ir); | ||
static std::unordered_set<size_t> find_unsqueezed_params( | ||
const lowered::LinearIRCPtr& linear_ir, | ||
const std::unordered_set<lowered::ExpressionPtr>& brgemms); | ||
static std::vector<lowered::ExpandedLoopInfoPtr> find_loops_to_split( | ||
const lowered::LinearIRCPtr& linear_ir, | ||
const std::unordered_set<size_t>& unsqueezed_params); | ||
|
||
std::vector<lowered::ExpandedLoopInfoPtr> m_loops_to_split{}; | ||
std::unordered_set<size_t> m_unsqueezed_params{}; | ||
std::vector<std::vector<size_t>> m_optimized_layouts{}; | ||
std::vector<size_t> m_dim_M_idces{}; | ||
size_t m_concurrency = 0; | ||
|
||
static const size_t m_dim_M_idx; | ||
}; | ||
|
||
} // namespace pass | ||
} // namespace lowered | ||
} // namespace snippets | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
52 changes: 52 additions & 0 deletions
52
src/common/snippets/include/snippets/lowered/pass/runtime_optimizer.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "snippets/lowered/linear_ir.hpp" | ||
#include "snippets/lowered/pass/pass.hpp" | ||
#include "snippets/runtime_configurator.hpp" | ||
|
||
namespace ov { | ||
namespace snippets { | ||
namespace lowered { | ||
namespace pass { | ||
/** | ||
* @class RuntimeOptimizer | ||
* @brief Base class for runtime optimizers that operate on LinearIR and RuntimeConfigurator during | ||
* RuntimeConfigurator::update stage. | ||
*/ | ||
class RuntimeOptimizer : public ConstPass { | ||
public: | ||
RuntimeOptimizer() = default; | ||
RuntimeOptimizer(const RuntimeConfigurator* configurator) : m_configurator(configurator) { | ||
OPENVINO_ASSERT(configurator, "RuntimeConfigurator musn't be nullptr"); | ||
} | ||
/** | ||
* @brief Defines if this pass is applicable. If it is not applicable, its registration in pass pipeline can be skipped. | ||
*/ | ||
virtual bool applicable() const = 0; | ||
|
||
/** | ||
* @brief Creates an instance of the specified pass type and checks if it is applicable. | ||
* If the pass is applicable, it is registered in the provided pipeline. | ||
* @param pipeline The pipeline in which the pass should be registered. | ||
* @param args The arguments to be forwarded to the pass constructor. | ||
*/ | ||
template <typename OptimizerType, typename... Args, typename = std::enable_if<std::is_base_of<RuntimeOptimizer, OptimizerType>::value>> | ||
static void register_if_applicable(PassPipeline& pipeline, Args&&... args) { | ||
auto pass = std::make_shared<OptimizerType>(std::forward<Args>(args)...); | ||
if (pass->applicable()) { | ||
pipeline.register_pass(pass); | ||
} | ||
} | ||
|
||
protected: | ||
const RuntimeConfigurator* m_configurator = nullptr; | ||
}; | ||
|
||
} // namespace pass | ||
} // namespace lowered | ||
} // namespace snippets | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.