Skip to content

Commit

Permalink
[Snippets] Specific loop iterations handler
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Dec 14, 2023
1 parent 0a21205 commit a5fbd75
Show file tree
Hide file tree
Showing 35 changed files with 1,159 additions and 344 deletions.
6 changes: 3 additions & 3 deletions src/common/snippets/include/snippets/lowered/linear_ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ class LinearIR {
LinearIR::container::const_iterator end,
ExressionMap& expression_map);

const container& get_ops() const {return m_expressions; }
const io_container& get_IO_ops() const {return m_io_expressions; }
Config get_config() {return m_config; }
const container& get_ops() const { return m_expressions; }
const io_container& get_IO_ops() const { return m_io_expressions; }
Config get_config() const { return m_config; }
void set_loop_depth(size_t loop_depth) { m_config.m_loop_depth = loop_depth; }

const ExpressionPtr& get_expr_by_node(const std::shared_ptr<Node>& n) const;
Expand Down
51 changes: 27 additions & 24 deletions src/common/snippets/include/snippets/lowered/loop_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

#pragma once

#include "linear_ir.hpp"

#include <openvino/core/node.hpp>
#include <openvino/opsets/opset1.hpp>

#include "linear_ir.hpp"
#include "pass/iter_handler.hpp"
#include "port_descriptor.hpp"
#include "pass/pass_pipeline.hpp"

namespace ov {
namespace snippets {
Expand Down Expand Up @@ -45,9 +46,7 @@ class LinearIR::LoopManager {
LoopInfo(size_t work_amount, size_t increment,
const std::vector<LoopPort>& entries,
const std::vector<LoopPort>& exits,
bool outer_splited_loop = false)
: m_work_amount(work_amount), m_increment(increment),
m_entry_points(entries), m_exit_points(exits), m_outer_splited_loop(outer_splited_loop) {}
bool outer_splited_loop = false);
LoopInfo(size_t work_amount, size_t increment,
const std::vector<ExpressionPort>& entries,
const std::vector<ExpressionPort>& exits,
Expand All @@ -63,27 +62,16 @@ class LinearIR::LoopManager {
const std::vector<LoopPort>& get_exit_points() const;
bool get_outer_splited_loop() const;

/**
* \brief Inserts a separate body for first loop iteration processing if needed.
* Can also modify both main and first iter loop bodies.
* TODO: replace this temporary solution when ticket 119851 is implemented
*
* \param linear_ir LIR which should be modified
* \param loop_end_it iterator on LoopEnd expression for which the handler is called
*
* \return bool value which indicates whether the linear_ir was changed or not.
*/
using FirstIterHandler = std::function<bool(LinearIR&, LinearIR::constExprIt)>;
const FirstIterHandler& get_first_iter_handler() const;

// Sets dim_idx to all entry and exit points
void set_dim_idx(size_t dim_idx);
void set_work_amount(size_t work_amount);
void set_increment(size_t increment);
void set_entry_points(std::vector<LoopPort> entry_points);
void set_exit_points(std::vector<LoopPort> exit_points);
void set_outer_splited_loop(bool outer_splited_loop);
void set_first_iter_handler(FirstIterHandler handler);

enum {FIRST_ITER, MAIN_BODY, LAST_ITER};
std::vector<lowered::pass::SubgraphPassPipeline> handlers;

private:
size_t m_work_amount = 0;
Expand All @@ -96,7 +84,6 @@ class LinearIR::LoopManager {
std::vector<LoopPort> m_exit_points = {};
// True if this Loop is outer Loop for nested Loops that splits the same dimension
bool m_outer_splited_loop = false;
FirstIterHandler m_first_iter_handler = nullptr;
};
using LoopInfoPtr = std::shared_ptr<LoopInfo>;

Expand All @@ -118,16 +105,22 @@ class LinearIR::LoopManager {
size_t mark_loop(LinearIR::constExprIt loop_begin_pos,
LinearIR::constExprIt loop_end_pos,
size_t work_amount,
size_t work_amount_increment,
size_t increment,
size_t dim_idx,
const std::vector<T>& entries,
const std::vector<T>& exits) {
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, work_amount_increment, entries, exits);
const std::vector<T>& exits,
bool set_default_handlers = true) {
if (increment > work_amount)
increment = work_amount;
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, increment, entries, exits);
loop_info->set_dim_idx(dim_idx);
const auto loop_id = this->add_loop_info(loop_info);
for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) {
insert_loop_id(*expr_it, loop_id);
}
if (set_default_handlers) {
set_default_loop_handlers(loop_info);
}
return loop_id;
}

Expand All @@ -137,12 +130,18 @@ class LinearIR::LoopManager {
size_t work_amount,
size_t increment,
const std::vector<T>& entries,
const std::vector<T>& exits) {
const std::vector<T>& exits,
bool set_default_handlers = true) {
if (increment > work_amount)
increment = work_amount;
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, increment, entries, exits);
const auto loop_id = this->add_loop_info(loop_info);
for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) {
insert_loop_id(*expr_it, loop_id);
}
if (set_default_handlers) {
set_default_loop_handlers(loop_info);
}
return loop_id;
}

Expand Down Expand Up @@ -197,6 +196,7 @@ class LinearIR::LoopManager {
size_t loop_id, bool loop_ops_inserted = false);

LoopPort get_loop_port_by_expr_port(const ExpressionPort& expr_port, const size_t loop_id);
static void set_default_loop_handlers(const LoopInfoPtr& loop_info);

private:
static void get_io_loop_ports(LinearIR::constExprIt loop_begin_pos,
Expand All @@ -207,6 +207,9 @@ class LinearIR::LoopManager {
static void fuse_loop_ports(std::vector<LinearIR::LoopManager::LoopPort>& exit_points,
std::vector<LinearIR::LoopManager::LoopPort>& entry_points,
size_t loop_id);
static std::vector<lowered::pass::SubgraphPassPipeline> fuse_loop_handlers(
std::vector<lowered::pass::SubgraphPassPipeline>& lhs,
std::vector<lowered::pass::SubgraphPassPipeline>& rhs);

/* ===== The methods for work with Loop IDs of Expression ===== */
// Notes:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "pass.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

class InsertSpecificIterations : public Pass {
public:
OPENVINO_RTTI("InsertSpecificIterations", "Pass")
bool run(LinearIR& linear_ir) override;

static LinearIR::container copy_loop(const LinearIR& linear_ir, const size_t loop_id);
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
76 changes: 76 additions & 0 deletions src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/pass/pass.hpp"
#include "snippets/lowered/pass/pass_pipeline.hpp"
#include "snippets/op/loop.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

class SetSingleIterationWithWorkAmount : public pass::SubgraphPass {
public:
SetSingleIterationWithWorkAmount(size_t work_amount);
OPENVINO_RTTI("SetSingleIterationWithWorkAmount", "SubgraphPass")
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;

private:
size_t m_work_amount;
};

class UpdateMemoryAccessOps : public pass::SubgraphPass {
public:
UpdateMemoryAccessOps(size_t count);
OPENVINO_RTTI("UpdateMemoryAccessOps", "SubgraphPass")
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;

private:
size_t m_count;
};

class ReduceWorkAmount : public pass::SubgraphPass {
public:
ReduceWorkAmount(size_t reduce_value);
OPENVINO_RTTI("ReduceWorkAmount", "SubgraphPass")
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;

private:
size_t m_reduce_value;
};

class ZeroFinalizationOffsets : public pass::SubgraphPass {
public:
OPENVINO_RTTI("ZeroFinalizationOffsets", "SubgraphPass")
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;
};

class SetFillOffset : public pass::SubgraphPass {
public:
SetFillOffset(size_t offset);
OPENVINO_RTTI("SetFillOffset", "SubgraphPass")
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;

private:
size_t m_offset;
};

class TransformInnerSplitLoop : public pass::SubgraphPass {
public:
TransformInnerSplitLoop(size_t tail_size);
OPENVINO_RTTI("TransformInnerSplitLoop", "SubgraphPass")
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;

private:
size_t m_tail_size;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
31 changes: 17 additions & 14 deletions src/common/snippets/include/snippets/lowered/pass/pass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Pass {
return get_type_info_static();
}

const char* get_type_name() const {
std::string get_name() const {
return get_type_info().name;
}

Expand All @@ -47,25 +47,28 @@ class Pass {
virtual bool run(lowered::LinearIR& linear_ir) = 0;
};

class PassPipeline {
class SubgraphPass {
public:
PassPipeline() = default;

void register_pass(const std::shared_ptr<Pass>& pass);
SubgraphPass() = default;
virtual ~SubgraphPass() = default;
// Note that get_type_info_static and get_type_info are needed to mimic OPENVINO_RTTI interface,
// so the standard OPENVINO_RTTI(...) macros could be used in derived classes.
_OPENVINO_HIDDEN_METHOD static const ::ov::DiscreteTypeInfo& get_type_info_static() {
static ::ov::DiscreteTypeInfo type_info_static {"SubgraphPass"};
type_info_static.hash();
return type_info_static;
}

template<typename T, class... Args>
void register_pass(Args&&... args) {
static_assert(std::is_base_of<Pass, T>::value, "Pass not derived from lowered::Pass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
register_pass(pass);
virtual const DiscreteTypeInfo& get_type_info() const {
return get_type_info_static();
}

void run(lowered::LinearIR& linear_ir) const;
std::string get_name() const {
return get_type_info().name;
}

private:
std::vector<std::shared_ptr<Pass>> m_passes;
virtual bool run(const lowered::LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) = 0;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "pass.hpp"
#include "snippets/pass/pass_position.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {
class PassPipeline {
public:
PassPipeline() = default;

void run(lowered::LinearIR& linear_ir) const;
void register_pass(const std::shared_ptr<Pass>& pass);

template<typename T, class... Args>
void register_pass(Args&&... args) {
static_assert(std::is_base_of<Pass, T>::value, "Pass not derived from lowered::Pass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
register_pass(pass);
}

struct PositionedPass {
ov::snippets::pass::PassPosition position;
std::shared_ptr<Pass> pass;
PositionedPass(ov::snippets::pass::PassPosition arg_pos, std::shared_ptr<Pass> arg_pass)
: position(std::move(arg_pos)), pass(std::move(arg_pass)) {
}
};

template <typename T, class Pos, class... Args, std::enable_if<std::is_same<ov::snippets::pass::PassPosition, Pos>::value, bool>() = true>
void register_pass(const ov::snippets::pass::PassPosition& position, Args&&... args) {
static_assert(std::is_base_of<Pass, T>::value, "Attempt to insert pass that is not derived from Pass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
insert_pass_instance(position, pass);
}

void register_positioned_passes(const std::vector<PositionedPass>& pos_passes);

protected:
void insert_pass_instance(const ov::snippets::pass::PassPosition& position, const std::shared_ptr<Pass>& pass);

private:
std::vector<std::shared_ptr<Pass>> m_passes;
};

class SubgraphPassPipeline {
public:
SubgraphPassPipeline() = default;

void run(const lowered::LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) const;
const std::vector<std::shared_ptr<SubgraphPass>>& get_passes() const;
void register_pass(const std::shared_ptr<SubgraphPass>& pass);
bool empty() const { return m_passes.empty(); }

template<typename T, class... Args>
void register_pass(Args&&... args) {
static_assert(std::is_base_of<SubgraphPass, T>::value, "Pass not derived from lowered::SubgraphPass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
register_pass(pass);
}

struct PositionedPass {
ov::snippets::pass::PassPosition position;
std::shared_ptr<SubgraphPass> pass;
PositionedPass(ov::snippets::pass::PassPosition arg_pos, std::shared_ptr<SubgraphPass> arg_pass)
: position(std::move(arg_pos)), pass(std::move(arg_pass)) {
}
};

template <typename T, class Pos, class... Args, std::enable_if<std::is_same<ov::snippets::pass::PassPosition, Pos>::value, bool>() = true>
void register_pass(const ov::snippets::pass::PassPosition& position, Args&&... args) {
static_assert(std::is_base_of<SubgraphPass, T>::value, "Attempt to insert pass that is not derived from SubgraphPass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
insert_pass_instance(position, pass);
}

void register_positioned_passes(const std::vector<PositionedPass>& pos_passes);

protected:
void insert_pass_instance(const ov::snippets::pass::PassPosition& position, const std::shared_ptr<SubgraphPass>& pass);

private:
std::vector<std::shared_ptr<SubgraphPass>> m_passes;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
Loading

0 comments on commit a5fbd75

Please sign in to comment.