Skip to content

Commit

Permalink
[Snippets] Specific loop iterations handler
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Dec 22, 2023
1 parent 3ef00f4 commit 2ec7fd8
Show file tree
Hide file tree
Showing 27 changed files with 870 additions and 647 deletions.
6 changes: 3 additions & 3 deletions src/common/snippets/include/snippets/lowered/linear_ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ class LinearIR {
LinearIR::container::const_iterator end,
ExressionMap& expression_map);

const container& get_ops() const {return m_expressions; }
const io_container& get_IO_ops() const {return m_io_expressions; }
Config get_config() {return m_config; }
const container& get_ops() const { return m_expressions; }
const io_container& get_IO_ops() const { return m_io_expressions; }
Config get_config() const { return m_config; }
void set_loop_depth(size_t loop_depth) { m_config.m_loop_depth = loop_depth; }

const ExpressionPtr& get_expr_by_node(const std::shared_ptr<Node>& n) const;
Expand Down
51 changes: 27 additions & 24 deletions src/common/snippets/include/snippets/lowered/loop_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

#pragma once

#include "linear_ir.hpp"

#include <openvino/core/node.hpp>
#include <openvino/opsets/opset1.hpp>

#include "linear_ir.hpp"
#include "pass/iter_handler.hpp"
#include "pass/pass.hpp"
#include "port_descriptor.hpp"

namespace ov {
Expand Down Expand Up @@ -45,9 +46,7 @@ class LinearIR::LoopManager {
LoopInfo(size_t work_amount, size_t increment,
const std::vector<LoopPort>& entries,
const std::vector<LoopPort>& exits,
bool outer_splited_loop = false)
: m_work_amount(work_amount), m_increment(increment),
m_entry_points(entries), m_exit_points(exits), m_outer_splited_loop(outer_splited_loop) {}
bool outer_splited_loop = false);
LoopInfo(size_t work_amount, size_t increment,
const std::vector<ExpressionPort>& entries,
const std::vector<ExpressionPort>& exits,
Expand All @@ -63,27 +62,16 @@ class LinearIR::LoopManager {
const std::vector<LoopPort>& get_exit_points() const;
bool get_outer_splited_loop() const;

/**
* \brief Inserts a separate body for first loop iteration processing if needed.
* Can also modify both main and first iter loop bodies.
* TODO: replace this temporary solution when ticket 119851 is implemented
*
* \param linear_ir LIR which should be modified
* \param loop_end_it iterator on LoopEnd expression for which the handler is called
*
* \return bool value which indicates whether the linear_ir was changed or not.
*/
using FirstIterHandler = std::function<bool(LinearIR&, LinearIR::constExprIt)>;
const FirstIterHandler& get_first_iter_handler() const;

// Sets dim_idx to all entry and exit points
void set_dim_idx(size_t dim_idx);
void set_work_amount(size_t work_amount);
void set_increment(size_t increment);
void set_entry_points(std::vector<LoopPort> entry_points);
void set_exit_points(std::vector<LoopPort> exit_points);
void set_outer_splited_loop(bool outer_splited_loop);
void set_first_iter_handler(FirstIterHandler handler);

enum {FIRST_ITER, MAIN_BODY, LAST_ITER};
std::vector<lowered::pass::PassPipeline> handlers;

private:
size_t m_work_amount = 0;
Expand All @@ -96,7 +84,6 @@ class LinearIR::LoopManager {
std::vector<LoopPort> m_exit_points = {};
// True if this Loop is outer Loop for nested Loops that splits the same dimension
bool m_outer_splited_loop = false;
FirstIterHandler m_first_iter_handler = nullptr;
};
using LoopInfoPtr = std::shared_ptr<LoopInfo>;

Expand All @@ -118,16 +105,22 @@ class LinearIR::LoopManager {
size_t mark_loop(LinearIR::constExprIt loop_begin_pos,
LinearIR::constExprIt loop_end_pos,
size_t work_amount,
size_t work_amount_increment,
size_t increment,
size_t dim_idx,
const std::vector<T>& entries,
const std::vector<T>& exits) {
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, work_amount_increment, entries, exits);
const std::vector<T>& exits,
bool set_default_handlers = true) {
if (increment > work_amount)
increment = work_amount;
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, increment, entries, exits);
loop_info->set_dim_idx(dim_idx);
const auto loop_id = this->add_loop_info(loop_info);
for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) {
insert_loop_id(*expr_it, loop_id);
}
if (set_default_handlers) {
set_default_loop_handlers(loop_info);
}
return loop_id;
}

Expand All @@ -137,12 +130,18 @@ class LinearIR::LoopManager {
size_t work_amount,
size_t increment,
const std::vector<T>& entries,
const std::vector<T>& exits) {
const std::vector<T>& exits,
bool set_default_handlers = true) {
if (increment > work_amount)
increment = work_amount;
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, increment, entries, exits);
const auto loop_id = this->add_loop_info(loop_info);
for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) {
insert_loop_id(*expr_it, loop_id);
}
if (set_default_handlers) {
set_default_loop_handlers(loop_info);
}
return loop_id;
}

Expand Down Expand Up @@ -197,6 +196,7 @@ class LinearIR::LoopManager {
size_t loop_id, bool loop_ops_inserted = false);

LoopPort get_loop_port_by_expr_port(const ExpressionPort& expr_port, const size_t loop_id);
static void set_default_loop_handlers(const LoopInfoPtr& loop_info);

private:
static void get_io_loop_ports(LinearIR::constExprIt loop_begin_pos,
Expand All @@ -207,6 +207,9 @@ class LinearIR::LoopManager {
static void fuse_loop_ports(std::vector<LinearIR::LoopManager::LoopPort>& exit_points,
std::vector<LinearIR::LoopManager::LoopPort>& entry_points,
size_t loop_id);
static std::vector<lowered::pass::PassPipeline> fuse_loop_handlers(
std::vector<lowered::pass::PassPipeline>& lhs,
std::vector<lowered::pass::PassPipeline>& rhs);

/* ===== The methods for work with Loop IDs of Expression ===== */
// Notes:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "pass.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

class InsertSpecificIterations : public Pass {
public:
OPENVINO_RTTI("InsertSpecificIterations", "Pass")
bool run(LinearIR& linear_ir) override;

static LinearIR::container copy_loop(const LinearIR& linear_ir, const size_t loop_id);
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov

This file was deleted.

48 changes: 48 additions & 0 deletions src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/pass/pass.hpp"
#include "snippets/op/loop.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {
class UpdateMemoryAccessOps : public pass::RangedPass {
public:
UpdateMemoryAccessOps(size_t count);
OPENVINO_RTTI("UpdateMemoryAccessOps", "RangedPass")
bool run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;

private:
size_t m_count;
};

class SetFillOffset : public pass::RangedPass {
public:
SetFillOffset(size_t offset);
OPENVINO_RTTI("SetFillOffset", "RangedPass")
bool run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;

private:
size_t m_offset;
};

class TransformInnerSplitLoop : public pass::RangedPass {
public:
TransformInnerSplitLoop(size_t tail_size);
OPENVINO_RTTI("TransformInnerSplitLoop", "RangedPass")
bool run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;

private:
size_t m_tail_size;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
51 changes: 40 additions & 11 deletions src/common/snippets/include/snippets/lowered/pass/pass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ namespace lowered {
namespace pass {

/**
* @interface Pass
* @interface PassBase
* @brief Base class for transformations on linear IR
* @ingroup snippets
*/
class Pass {
class PassBase {
public:
Pass() = default;
virtual ~Pass() = default;
PassBase() = default;
virtual ~PassBase() = default;
// Note that get_type_info_static and get_type_info are needed to mimic OPENVINO_RTTI interface,
// so the standard OPENVINO_RTTI(...) macros could be used in derived classes.
_OPENVINO_HIDDEN_METHOD static const ::ov::DiscreteTypeInfo& get_type_info_static() {
static ::ov::DiscreteTypeInfo type_info_static {"Pass"};
static ::ov::DiscreteTypeInfo type_info_static {"PassBase"};
type_info_static.hash();
return type_info_static;
}
Expand All @@ -39,7 +39,15 @@ class Pass {
const char* get_type_name() const {
return get_type_info().name;
}
};

/**
* @interface Pass
* @brief Base class for LIR passes which are performed on a full LIR body
* @ingroup snippets
*/
class Pass : public PassBase {
public:
/**
* @brief Apply the pass to the Linear IR
* @param linear_ir the target Linear IR
Expand All @@ -48,36 +56,57 @@ class Pass {
virtual bool run(lowered::LinearIR& linear_ir) = 0;
};

/**
* @interface Pass
* @brief Base class for LIR passes which are performed on a range of a LIR body
* @ingroup snippets
*/
class RangedPass : public PassBase {
public:
/**
* @brief Apply the pass to the Linear IR
* @param linear_ir the target Linear IR
* @param begin begin of the range on which the pass is performed
* @param end end of the range on which the pass is performed
* @return status of the pass
*/
virtual bool run(lowered::LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) = 0;
};

class PassPipeline {
public:
using PositionedPassLowered = snippets::pass::PositionedPass<lowered::pass::Pass>;
using PositionedPassLowered = snippets::pass::PositionedPass<lowered::pass::PassBase>;

PassPipeline();
PassPipeline(const std::shared_ptr<PassConfig>& pass_config);

void register_pass(const snippets::pass::PassPosition& position, const std::shared_ptr<Pass>& pass);
void register_pass(const std::shared_ptr<Pass>& pass);
const std::vector<std::shared_ptr<PassBase>>& get_passes() const { return m_passes; }
bool empty() const { return m_passes.empty(); }

void register_pass(const snippets::pass::PassPosition& position, const std::shared_ptr<PassBase>& pass);
void register_pass(const std::shared_ptr<PassBase>& pass);

template<typename T, class... Args>
void register_pass(Args&&... args) {
static_assert(std::is_base_of<Pass, T>::value, "Pass not derived from lowered::Pass");
static_assert(std::is_base_of<PassBase, T>::value, "Pass not derived from lowered::Pass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
register_pass(pass);
}
template<typename T, class Pos, class... Args, std::enable_if<std::is_same<snippets::pass::PassPosition, Pos>::value, bool>() = true>
void register_pass(const snippets::pass::PassPosition& position, Args&&... args) {
static_assert(std::is_base_of<Pass, T>::value, "Pass not derived from lowered::Pass");
static_assert(std::is_base_of<PassBase, T>::value, "Pass not derived from lowered::Pass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
register_pass(position, pass);
}

void register_positioned_passes(const std::vector<PositionedPassLowered>& pos_passes);

void run(lowered::LinearIR& linear_ir) const;
void run(lowered::LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) const;

private:
std::shared_ptr<PassConfig> m_pass_config;
std::vector<std::shared_ptr<Pass>> m_passes;
std::vector<std::shared_ptr<PassBase>> m_passes;
};

} // namespace pass
Expand Down
Loading

0 comments on commit 2ec7fd8

Please sign in to comment.