Skip to content

Commit

Permalink
[Snippets] Specific loop iterations handler
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Dec 15, 2023
1 parent b92958b commit e240aee
Show file tree
Hide file tree
Showing 27 changed files with 908 additions and 161 deletions.
6 changes: 3 additions & 3 deletions src/common/snippets/include/snippets/lowered/linear_ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ class LinearIR {
LinearIR::container::const_iterator end,
ExressionMap& expression_map);

const container& get_ops() const {return m_expressions; }
const io_container& get_IO_ops() const {return m_io_expressions; }
Config get_config() {return m_config; }
const container& get_ops() const { return m_expressions; }
const io_container& get_IO_ops() const { return m_io_expressions; }
Config get_config() const { return m_config; }
void set_loop_depth(size_t loop_depth) { m_config.m_loop_depth = loop_depth; }

const ExpressionPtr& get_expr_by_node(const std::shared_ptr<Node>& n) const;
Expand Down
51 changes: 27 additions & 24 deletions src/common/snippets/include/snippets/lowered/loop_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

#pragma once

#include "linear_ir.hpp"

#include <openvino/core/node.hpp>
#include <openvino/opsets/opset1.hpp>

#include "linear_ir.hpp"
#include "pass/iter_handler.hpp"
#include "pass/pass.hpp"
#include "port_descriptor.hpp"

namespace ov {
Expand Down Expand Up @@ -45,9 +46,7 @@ class LinearIR::LoopManager {
LoopInfo(size_t work_amount, size_t increment,
const std::vector<LoopPort>& entries,
const std::vector<LoopPort>& exits,
bool outer_splited_loop = false)
: m_work_amount(work_amount), m_increment(increment),
m_entry_points(entries), m_exit_points(exits), m_outer_splited_loop(outer_splited_loop) {}
bool outer_splited_loop = false);
LoopInfo(size_t work_amount, size_t increment,
const std::vector<ExpressionPort>& entries,
const std::vector<ExpressionPort>& exits,
Expand All @@ -63,27 +62,16 @@ class LinearIR::LoopManager {
const std::vector<LoopPort>& get_exit_points() const;
bool get_outer_splited_loop() const;

/**
* \brief Inserts a separate body for first loop iteration processing if needed.
* Can also modify both main and first iter loop bodies.
* TODO: replace this temporary solution when ticket 119851 is implemented
*
* \param linear_ir LIR which should be modified
* \param loop_end_it iterator on LoopEnd expression for which the handler is called
*
* \return bool value which indicates whether the linear_ir was changed or not.
*/
using FirstIterHandler = std::function<bool(LinearIR&, LinearIR::constExprIt)>;
const FirstIterHandler& get_first_iter_handler() const;

// Sets dim_idx to all entry and exit points
void set_dim_idx(size_t dim_idx);
void set_work_amount(size_t work_amount);
void set_increment(size_t increment);
void set_entry_points(std::vector<LoopPort> entry_points);
void set_exit_points(std::vector<LoopPort> exit_points);
void set_outer_splited_loop(bool outer_splited_loop);
void set_first_iter_handler(FirstIterHandler handler);

enum {FIRST_ITER, MAIN_BODY, LAST_ITER};
std::vector<lowered::pass::SubgraphPassPipeline> handlers;

private:
size_t m_work_amount = 0;
Expand All @@ -96,7 +84,6 @@ class LinearIR::LoopManager {
std::vector<LoopPort> m_exit_points = {};
// True if this Loop is outer Loop for nested Loops that splits the same dimension
bool m_outer_splited_loop = false;
FirstIterHandler m_first_iter_handler = nullptr;
};
using LoopInfoPtr = std::shared_ptr<LoopInfo>;

Expand All @@ -118,16 +105,22 @@ class LinearIR::LoopManager {
size_t mark_loop(LinearIR::constExprIt loop_begin_pos,
LinearIR::constExprIt loop_end_pos,
size_t work_amount,
size_t work_amount_increment,
size_t increment,
size_t dim_idx,
const std::vector<T>& entries,
const std::vector<T>& exits) {
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, work_amount_increment, entries, exits);
const std::vector<T>& exits,
bool set_default_handlers = true) {
if (increment > work_amount)
increment = work_amount;
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, increment, entries, exits);
loop_info->set_dim_idx(dim_idx);
const auto loop_id = this->add_loop_info(loop_info);
for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) {
insert_loop_id(*expr_it, loop_id);
}
if (set_default_handlers) {
set_default_loop_handlers(loop_info);
}
return loop_id;
}

Expand All @@ -137,12 +130,18 @@ class LinearIR::LoopManager {
size_t work_amount,
size_t increment,
const std::vector<T>& entries,
const std::vector<T>& exits) {
const std::vector<T>& exits,
bool set_default_handlers = true) {
if (increment > work_amount)
increment = work_amount;
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, increment, entries, exits);
const auto loop_id = this->add_loop_info(loop_info);
for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) {
insert_loop_id(*expr_it, loop_id);
}
if (set_default_handlers) {
set_default_loop_handlers(loop_info);
}
return loop_id;
}

Expand Down Expand Up @@ -197,6 +196,7 @@ class LinearIR::LoopManager {
size_t loop_id, bool loop_ops_inserted = false);

LoopPort get_loop_port_by_expr_port(const ExpressionPort& expr_port, const size_t loop_id);
static void set_default_loop_handlers(const LoopInfoPtr& loop_info);

private:
static void get_io_loop_ports(LinearIR::constExprIt loop_begin_pos,
Expand All @@ -207,6 +207,9 @@ class LinearIR::LoopManager {
static void fuse_loop_ports(std::vector<LinearIR::LoopManager::LoopPort>& exit_points,
std::vector<LinearIR::LoopManager::LoopPort>& entry_points,
size_t loop_id);
static std::vector<lowered::pass::SubgraphPassPipeline> fuse_loop_handlers(
std::vector<lowered::pass::SubgraphPassPipeline>& lhs,
std::vector<lowered::pass::SubgraphPassPipeline>& rhs);

/* ===== The methods for work with Loop IDs of Expression ===== */
// Notes:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "pass.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

class InsertSpecificIterations : public Pass {
public:
OPENVINO_RTTI("InsertSpecificIterations", "Pass")
bool run(LinearIR& linear_ir) override;

static LinearIR::container copy_loop(const LinearIR& linear_ir, const size_t loop_id);
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
75 changes: 75 additions & 0 deletions src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/pass/pass.hpp"
#include "snippets/op/loop.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

class SetSingleIterationWithWorkAmount : public pass::SubgraphPass {
public:
SetSingleIterationWithWorkAmount(size_t work_amount);
OPENVINO_RTTI("SetSingleIterationWithWorkAmount", "SubgraphPass")
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;

private:
size_t m_work_amount;
};

class UpdateMemoryAccessOps : public pass::SubgraphPass {
public:
UpdateMemoryAccessOps(size_t count);
OPENVINO_RTTI("UpdateMemoryAccessOps", "SubgraphPass")
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;

private:
size_t m_count;
};

class ReduceWorkAmount : public pass::SubgraphPass {
public:
ReduceWorkAmount(size_t reduce_value);
OPENVINO_RTTI("ReduceWorkAmount", "SubgraphPass")
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;

private:
size_t m_reduce_value;
};

class ZeroFinalizationOffsets : public pass::SubgraphPass {
public:
OPENVINO_RTTI("ZeroFinalizationOffsets", "SubgraphPass")
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;
};

class SetFillOffset : public pass::SubgraphPass {
public:
SetFillOffset(size_t offset);
OPENVINO_RTTI("SetFillOffset", "SubgraphPass")
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;

private:
size_t m_offset;
};

class TransformInnerSplitLoop : public pass::SubgraphPass {
public:
TransformInnerSplitLoop(size_t tail_size);
OPENVINO_RTTI("TransformInnerSplitLoop", "SubgraphPass")
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;

private:
size_t m_tail_size;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
57 changes: 57 additions & 0 deletions src/common/snippets/include/snippets/lowered/pass/pass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,63 @@ class PassPipeline {
std::vector<std::shared_ptr<Pass>> m_passes;
};

class SubgraphPass {
public:
SubgraphPass() = default;
virtual ~SubgraphPass() = default;
// Note that get_type_info_static and get_type_info are needed to mimic OPENVINO_RTTI interface,
// so the standard OPENVINO_RTTI(...) macros could be used in derived classes.
_OPENVINO_HIDDEN_METHOD static const ::ov::DiscreteTypeInfo& get_type_info_static() {
static ::ov::DiscreteTypeInfo type_info_static {"SubgraphPass"};
type_info_static.hash();
return type_info_static;
}

virtual const DiscreteTypeInfo& get_type_info() const {
return get_type_info_static();
}

const char* get_type_name() const {
return get_type_info().name;
}

virtual bool run(const lowered::LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) = 0;
};

class SubgraphPassPipeline {
public:
using PositionedSubgraphPassLowered = snippets::pass::PositionedPass<lowered::pass::SubgraphPass>;

SubgraphPassPipeline();
SubgraphPassPipeline(const std::shared_ptr<PassConfig>& pass_config);

void run(const lowered::LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) const;
const std::vector<std::shared_ptr<SubgraphPass>>& get_passes() const { return m_passes; }
bool empty() const { return m_passes.empty(); }

void register_pass(const snippets::pass::PassPosition& position, const std::shared_ptr<SubgraphPass>& pass);
void register_pass(const std::shared_ptr<SubgraphPass>& pass);

template<typename T, class... Args>
void register_pass(Args&&... args) {
static_assert(std::is_base_of<SubgraphPass, T>::value, "Pass not derived from lowered::SubgraphPass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
register_pass(pass);
}
template<typename T, class Pos, class... Args, std::enable_if<std::is_same<snippets::pass::PassPosition, Pos>::value, bool>() = true>
void register_pass(const snippets::pass::PassPosition& position, Args&&... args) {
static_assert(std::is_base_of<SubgraphPass, T>::value, "Pass not derived from lowered::SubgraphPass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
register_pass(position, pass);
}

void register_positioned_passes(const std::vector<PositionedSubgraphPassLowered>& pos_passes);

private:
std::shared_ptr<PassConfig> m_pass_config;
std::vector<std::shared_ptr<SubgraphPass>> m_passes;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/pass/pass.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

class UpdateSubtensors : public pass::SubgraphPass {
public:
UpdateSubtensors(size_t tail_size);
OPENVINO_RTTI("UpdateSubtensors", "Pass")
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;

private:
size_t m_tail_size;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
6 changes: 4 additions & 2 deletions src/common/snippets/include/snippets/op/subgraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
#pragma once

#include <memory>

#include <openvino/core/model.hpp>
#include <openvino/op/util/sub_graph_base.hpp>
#include "openvino/op/op.hpp"

#include "openvino/core/rt_info.hpp"
#include "openvino/op/op.hpp"
#include "snippets/generator.hpp"
#include "snippets/lowered/pass/pass.hpp"
#include "snippets/pass/manager.hpp"
#include "snippets/shape_inference/shape_inference.hpp"
#include "snippets/lowered/pass/pass.hpp"
Expand Down
7 changes: 2 additions & 5 deletions src/common/snippets/include/snippets/pass/manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
#include "openvino/pass/pass.hpp"
#include "openvino/pass/validate.hpp"

#include <typeinfo>


namespace ov {
namespace snippets {
namespace pass {
Expand All @@ -36,7 +33,7 @@ class Manager : public ov::pass::Manager {
std::shared_ptr<T> register_pass(const PassPosition& position, Args&&... args) {
static_assert(std::is_base_of<PassBase, T>::value, "Attempt to insert pass that is not derived from PassBase");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto rc = insert_pass_instance(position, pass);
auto rc = insert_pass_instance(position, pass);
rc->set_pass_config(m_pass_config);
if (!m_pass_config->is_enabled<T>()) {
m_pass_config->disable<T>();
Expand All @@ -48,7 +45,7 @@ class Manager : public ov::pass::Manager {
void register_positioned_passes(const std::vector<PositionedPassBase>& pos_passes);

protected:
std::shared_ptr<Manager::PassBase> insert_pass_instance(const PassPosition& position, const std::shared_ptr<PassBase>& pass);
std::shared_ptr<PassBase> insert_pass_instance(const PassPosition& position, const std::shared_ptr<PassBase>& pass);
};

} // namespace pass
Expand Down
Loading

0 comments on commit e240aee

Please sign in to comment.