Skip to content

Commit

Permalink
[Snippets] Specific loop iterations handler
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Feb 26, 2024
1 parent a5f6308 commit 4f44c34
Show file tree
Hide file tree
Showing 76 changed files with 1,262 additions and 836 deletions.
4 changes: 2 additions & 2 deletions src/common/snippets/include/snippets/lowered/expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace lowered {

class LinearIR;
using ExpressionPtr = std::shared_ptr<Expression>;
using ExressionMap = std::unordered_map<Expression*, ExpressionPtr>;
using ExpressionMap = std::unordered_map<Expression*, ExpressionPtr>;
class Expression : public std::enable_shared_from_this<Expression> {
friend class LinearIR;
friend class ExpressionPort;
Expand Down Expand Up @@ -63,7 +63,7 @@ class Expression : public std::enable_shared_from_this<Expression> {
void set_loop_ids(const std::vector<size_t>& loops);
virtual ExpressionPtr clone_with_new_inputs(const std::vector<PortConnectorPtr>& new_inputs,
const std::shared_ptr<Node>& new_node) const;
ExpressionPtr clone_with_new_inputs(const ExressionMap& expr_map, const std::shared_ptr<Node>& new_node) const;
ExpressionPtr clone_with_new_inputs(const ExpressionMap& expr_map, const std::shared_ptr<Node>& new_node) const;

protected:
Expression(const Expression& other);
Expand Down
8 changes: 4 additions & 4 deletions src/common/snippets/include/snippets/lowered/linear_ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ class LinearIR {
std::shared_ptr<LinearIR> clone() const;
static LinearIR::container deep_copy_range(LinearIR::container::const_iterator begin,
LinearIR::container::const_iterator end,
ExressionMap& expression_map);
ExpressionMap& expression_map);

const container& get_ops() const {return m_expressions; }
const io_container& get_IO_ops() const {return m_io_expressions; }
Config get_config() {return m_config; }
const container& get_ops() const { return m_expressions; }
const io_container& get_IO_ops() const { return m_io_expressions; }
const Config& get_config() const { return m_config; }
void set_loop_depth(size_t loop_depth) { m_config.m_loop_depth = loop_depth; }

const ExpressionPtr& get_expr_by_node(const std::shared_ptr<Node>& n) const;
Expand Down
106 changes: 75 additions & 31 deletions src/common/snippets/include/snippets/lowered/loop_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

#pragma once

#include "linear_ir.hpp"

#include <openvino/core/node.hpp>
#include <openvino/opsets/opset1.hpp>

#include "linear_ir.hpp"
#include "pass/iter_handler.hpp"
#include "pass/pass.hpp"
#include "port_descriptor.hpp"

namespace ov {
Expand Down Expand Up @@ -41,49 +42,82 @@ class LinearIR::LoopManager {
class LoopInfo {
public:
enum {UNDEFINED_DIM_IDX = std::numeric_limits<size_t>::max()};
class SpecificIterationHandlers {
public:
enum class HandlerType { FIRST_ITER, MAIN_BODY, LAST_ITER };
SpecificIterationHandlers() = default;
SpecificIterationHandlers(size_t loop_work_amount, size_t loop_increment);
SpecificIterationHandlers(lowered::pass::PassPipeline first_iter_handlers,
lowered::pass::PassPipeline main_body_handlers,
lowered::pass::PassPipeline last_iter_handlers);

const lowered::pass::PassPipeline& get_first_iter_handelrs() const;
const lowered::pass::PassPipeline& get_main_iter_handelrs() const;
const lowered::pass::PassPipeline& get_last_iter_handelrs() const;
static SpecificIterationHandlers merge_handlers(const SpecificIterationHandlers& lhs, const SpecificIterationHandlers& rhs);

template <HandlerType Type,
typename T,
class... Args,
typename std::enable_if<Type == HandlerType::FIRST_ITER, bool>::type = true>
void register_handler(Args&&... args) {
m_first_iter_handlers.register_pass<T>(args...);
}

template <HandlerType Type,
typename T,
class... Args,
typename std::enable_if<Type == HandlerType::MAIN_BODY, bool>::type = true>
void register_handler(Args&&... args) {
m_main_body_handlers.register_pass<T>(args...);
}

template <HandlerType Type,
typename T,
class... Args,
typename std::enable_if<Type == HandlerType::LAST_ITER, bool>::type = true>
void register_handler(Args&&... args) {
m_last_iter_handlers.register_pass<T>(args...);
}

private:
lowered::pass::PassPipeline m_first_iter_handlers;
lowered::pass::PassPipeline m_main_body_handlers;
lowered::pass::PassPipeline m_last_iter_handlers;
};

LoopInfo() = default;
LoopInfo(size_t work_amount, size_t increment,
const std::vector<LoopPort>& entries,
const std::vector<LoopPort>& exits,
bool outer_splited_loop = false)
: m_work_amount(work_amount), m_increment(increment),
m_entry_points(entries), m_exit_points(exits), m_outer_splited_loop(outer_splited_loop) {}
const SpecificIterationHandlers& handlers = SpecificIterationHandlers());
LoopInfo(size_t work_amount, size_t increment,
const std::vector<ExpressionPort>& entries,
const std::vector<ExpressionPort>& exits,
bool outer_splited_loop = false);
const SpecificIterationHandlers& handlers = SpecificIterationHandlers());

std::shared_ptr<LoopInfo> clone_with_new_expr(const ExressionMap& expr_map) const;
std::shared_ptr<LoopInfo> clone_with_new_expr(const ExpressionMap& expr_map) const;

// Returns dimension index if dimension indices for all entry and exit points are equal, and UNDEFINED_DIM_IDX otherwise
size_t get_dim_idx() const;
size_t get_work_amount() const;
size_t get_increment() const;
const std::vector<LoopPort>& get_entry_points() const;
const std::vector<LoopPort>& get_exit_points() const;
bool get_outer_splited_loop() const;

/**
* \brief Inserts a separate body for first loop iteration processing if needed.
* Can also modify both main and first iter loop bodies.
* TODO: replace this temporary solution when ticket 119851 is implemented
*
* \param linear_ir LIR which should be modified
* \param loop_end_it iterator on LoopEnd expression for which the handler is called
*
* \return bool value which indicates whether the linear_ir was changed or not.
*/
using FirstIterHandler = std::function<bool(LinearIR&, LinearIR::constExprIt)>;
const FirstIterHandler& get_first_iter_handler() const;
const SpecificIterationHandlers& get_handlers() const;

// Sets dim_idx to all entry and exit points
void set_dim_idx(size_t dim_idx);
void set_work_amount(size_t work_amount);
void set_increment(size_t increment);
void set_entry_points(std::vector<LoopPort> entry_points);
void set_exit_points(std::vector<LoopPort> exit_points);
void set_outer_splited_loop(bool outer_splited_loop);
void set_first_iter_handler(FirstIterHandler handler);
void set_handlers(SpecificIterationHandlers handlers);

template <SpecificIterationHandlers::HandlerType Type, typename T, class... Args>
void register_handler(Args&&... args) {
m_handlers.register_handler<Type, T>(args...);
}

// Update the parameters of existing LoopPorts
void update_entry_points(const std::function<void(LoopPort&)>& updater);
Expand All @@ -98,9 +132,7 @@ class LinearIR::LoopManager {
// Note: Scalars aren't entry expressions but can be before first entry expr in Linear IR
std::vector<LoopPort> m_entry_points = {};
std::vector<LoopPort> m_exit_points = {};
// True if this Loop is outer Loop for nested Loops that splits the same dimension
bool m_outer_splited_loop = false;
FirstIterHandler m_first_iter_handler = nullptr;
SpecificIterationHandlers m_handlers = {};
};
using LoopInfoPtr = std::shared_ptr<LoopInfo>;

Expand All @@ -109,7 +141,7 @@ class LinearIR::LoopManager {
* @param expr_map map of new and old expressions
* @return the copy
*/
std::shared_ptr<LoopManager> clone_with_new_expr(const ExressionMap& expr_map) const;
std::shared_ptr<LoopManager> clone_with_new_expr(const ExpressionMap& expr_map) const;

/**
* @brief Get target Loop Info
Expand Down Expand Up @@ -167,6 +199,7 @@ class LinearIR::LoopManager {
* @param dim_idx loop iterates by this index of dimension
* @param entries input loop ports
* @param exits output loop ports
* @param set_default_handlers flag defines whether it is needed to set default set of SpecificIterationHandlers or not
* @return new loop ID
*/
template <typename T>
Expand All @@ -176,8 +209,13 @@ class LinearIR::LoopManager {
size_t increment,
size_t dim_idx,
const std::vector<T>& entries,
const std::vector<T>& exits) {
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, increment, entries, exits);
const std::vector<T>& exits,
bool set_default_handlers = true) {
const auto normalized_increment = std::min(increment, work_amount);
const auto handlers = set_default_handlers
? LoopInfo::SpecificIterationHandlers(work_amount, normalized_increment)
: LoopInfo::SpecificIterationHandlers();
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, normalized_increment, entries, exits, handlers);
loop_info->set_dim_idx(dim_idx);
const auto loop_id = this->add_loop_info(loop_info);
for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) {
Expand All @@ -193,6 +231,7 @@ class LinearIR::LoopManager {
* @param increment the step of loop counter increment
* @param entries input loop ports
* @param exits output loop ports
* @param set_default_handlers flag defines whether it is needed to set default set of SpecificIterationHandlers or not
* @return new loop ID
*/
template <typename T>
Expand All @@ -201,8 +240,13 @@ class LinearIR::LoopManager {
size_t work_amount,
size_t increment,
const std::vector<T>& entries,
const std::vector<T>& exits) {
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, increment, entries, exits);
const std::vector<T>& exits,
bool set_default_handlers = true) {
const auto normalized_increment = std::min(increment, work_amount);
const auto handlers = set_default_handlers
? LoopInfo::SpecificIterationHandlers(work_amount, normalized_increment)
: LoopInfo::SpecificIterationHandlers();
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, normalized_increment, entries, exits, handlers);
const auto loop_id = this->add_loop_info(loop_info);
for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) {
insert_loop_id(*expr_it, loop_id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ namespace pass {
* The buffer scratchpad has one general data pointer. Each buffer has offset relative to the data pointer of buffer scratchpad.
* @ingroup snippets
*/
class AllocateBuffers: public Pass {
class AllocateBuffers: public RangedPass {
public:
OPENVINO_RTTI("AllocateBuffers", "Pass")
OPENVINO_RTTI("AllocateBuffers", "RangedPass")
AllocateBuffers(size_t& buffer_scratchpad_size, bool is_optimized = true);

/**
* @brief Apply the pass to the Linear IR
* @param linear_ir the target Linear IR
* @return status of the pass
*/
bool run(LinearIR& linear_ir) override;
bool run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) override;

/**
* @brief Set offset to Buffer op and propagates its to the connected memory access ops
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ namespace pass {
* This condition should be removed when Buffers stop being inplace by default.
* @ingroup snippets
*/
class CleanRepeatedDataPointerShifts: public Pass {
class CleanRepeatedDataPointerShifts: public RangedPass {
public:
OPENVINO_RTTI("CleanRepeatedDataPointerShifts", "Pass")
OPENVINO_RTTI("CleanRepeatedDataPointerShifts", "RangedPass")
CleanRepeatedDataPointerShifts() = default;

bool run(LinearIR& linear_ir) override;
bool run(lowered::LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) override;

private:
bool reuse_increments(const LinearIR& linear_ir, const ExpressionPtr& loop_end_expr);
bool reuse_increments(const ExpressionPtr& loop_end_expr);
};

} // namespace pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ namespace pass {
* This transformation "fuses" the offsets with an outer loop's ptr_increments, and zeroes the offsets before Results.
* @ingroup snippets
*/
class CleanupLoopOffsets : public Pass {
class CleanupLoopOffsets : public RangedPass {
public:
OPENVINO_RTTI("CleanupLoopOffsets", "Pass")
bool run(LinearIR& linear_ir) override;
OPENVINO_RTTI("CleanupLoopOffsets", "RangedPass")
bool run(lowered::LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) override;
};

} // namespace pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ namespace pass {
* These passes should be executed separately before this pass!
* @ingroup snippets
*/
class DefineBufferClusters : public Pass {
class DefineBufferClusters : public RangedPass {
public:
OPENVINO_RTTI("DefineBufferClusters", "Pass")
OPENVINO_RTTI("DefineBufferClusters", "RangedPass")

DefineBufferClusters(AllocateBuffers::BufferClusters& clusters) : m_clusters(clusters) {}

Expand All @@ -42,7 +42,7 @@ class DefineBufferClusters : public Pass {
* @param linear_ir the target Linear IR
* @return status of the pass
*/
bool run(lowered::LinearIR& linear_ir) override;
bool run(lowered::LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) override;

private:
using BufferPorts = std::unordered_map<ExpressionPtr, std::set<size_t>>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ namespace pass {
* The main conditions of possible fusion is the equal increments and the equal/broadcastable work amounts.
* @ingroup snippets
*/
class FuseLoops : public Pass {
class FuseLoops : public RangedPass {
public:
OPENVINO_RTTI("FuseLoops", "Pass")
OPENVINO_RTTI("FuseLoops", "RangedPass")
FuseLoops();
bool run(LinearIR& linear_ir) override;
bool run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) override;

// This method checks that all ports which connect lower and upper loops are incremented.
// This helps to avoid fusing for the ports with incompleted data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,17 @@ namespace pass {
* Note: should be called before ResetBuffer() pass to have correct offsets
* @ingroup snippets
*/
class IdentifyBuffers: public Pass {
class IdentifyBuffers: public RangedPass {
public:
OPENVINO_RTTI("IdentifyBuffers", "Pass")
OPENVINO_RTTI("IdentifyBuffers", "RangedPass")
IdentifyBuffers() = default;

/**
* @brief Apply the pass to the Linear IR
* @param linear_ir the target Linear IR
* @return status of the pass
*/
bool run(LinearIR& linear_ir) override;
bool run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) override;

struct ShiftPtrParams {
ShiftPtrParams() = default;
Expand Down Expand Up @@ -75,7 +75,7 @@ class IdentifyBuffers: public Pass {
* @param pool set of Buffers from the Linear IR
* @return adjacency matrix where True value means that Buffers are adjacent and cannot have the same ID
*/
static std::vector<bool> create_adjacency_matrix(const LinearIR& linear_ir, const BufferPool& pool);
static std::vector<bool> create_adjacency_matrix(lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end, const BufferPool& pool);
/**
* @brief Algorithm of Graph coloring where vertices are Buffers
* @param buffers set of Buffers from the Linear IR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ namespace pass {
* @ingroup snippets
*/

class InitBuffersDefault : public Pass {
class InitBuffersDefault : public RangedPass {
public:
OPENVINO_RTTI("InitBuffersDefault", "Pass")
OPENVINO_RTTI("InitBuffersDefault", "RangedPass")

InitBuffersDefault(size_t& buffer_scratchpad_size) : m_buffer_scratchpad_size(buffer_scratchpad_size) {
m_buffer_scratchpad_size = 0;
Expand All @@ -29,7 +29,7 @@ class InitBuffersDefault : public Pass {
* @param linear_ir the target Linear IR
* @return status of the pass
*/
bool run(lowered::LinearIR& linear_ir) override;
bool run(lowered::LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) override;

private:
size_t& m_buffer_scratchpad_size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ namespace pass {
* @brief Injects explicit Movebroadcast operations when the most varying dim is broadcasted
* @ingroup snippets
*/
class InsertBroadcastMove : public Pass {
class InsertBroadcastMove : public RangedPass {
public:
OPENVINO_RTTI("InsertBroadcastMove", "Pass")
bool run(LinearIR& linear_ir) override;
OPENVINO_RTTI("InsertBroadcastMove", "RangedPass")
bool run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) override;
};

} // namespace pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ namespace pass {
* @param m_buffer_allocation_rank - rank of shape for memory allocation: shape[shape_rank - normalize(m_allocation_rank) : shape_rank]
* @ingroup snippets
*/
class InsertBuffers : public Pass {
class InsertBuffers : public RangedPass {
public:
OPENVINO_RTTI("InsertBuffers", "Pass")
OPENVINO_RTTI("InsertBuffers", "RangedPass")
InsertBuffers(int32_t buffer_allocation_rank);
bool run(LinearIR& linear_ir) override;
bool run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) override;

private:
void insertion(LinearIR& linear_ir, const LinearIR::constExprIt& expr_it, const LinearIR::LoopManagerPtr& loop_manager,
void insertion(LinearIR& linear_ir,
const LinearIR::constExprIt& begin_it,
const LinearIR::constExprIt& end_it,
const LinearIR::LoopManagerPtr& loop_manager,
const std::vector<LinearIR::LoopManager::LoopPort>& loop_entries,
const std::vector<LinearIR::LoopManager::LoopPort>& loop_exits);

Expand Down
Loading

0 comments on commit 4f44c34

Please sign in to comment.