Skip to content

Commit

Permalink
Introduced a virtual method 'can_be_merged' for lowered passes
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Jan 24, 2024
1 parent 5e66fff commit 67a46b3
Show file tree
Hide file tree
Showing 15 changed files with 92 additions and 32 deletions.
10 changes: 7 additions & 3 deletions src/common/snippets/include/snippets/lowered/loop_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class LinearIR::LoopManager {
const lowered::pass::PassPipeline& get_first_iter_handelrs() const;
const lowered::pass::PassPipeline& get_main_iter_handelrs() const;
const lowered::pass::PassPipeline& get_last_iter_handelrs() const;
SpecificIterationHandlers merge(const SpecificIterationHandlers& other) const;

template <HandlerType Type, typename T, class... Args>
void register_handler(Args&&... args) {
Expand All @@ -72,8 +73,6 @@ class LinearIR::LoopManager {
}
}

static SpecificIterationHandlers merge_loop_handlers(const SpecificIterationHandlers& lhs, const SpecificIterationHandlers& rhs);

private:
lowered::pass::PassPipeline m_first_iter_handlers;
lowered::pass::PassPipeline m_main_body_handlers;
Expand All @@ -98,7 +97,7 @@ class LinearIR::LoopManager {
size_t get_increment() const;
const std::vector<LoopPort>& get_entry_points() const;
const std::vector<LoopPort>& get_exit_points() const;
SpecificIterationHandlers& get_handlers();
const SpecificIterationHandlers& get_handlers() const;

// Sets dim_idx to all entry and exit points
void set_dim_idx(size_t dim_idx);
Expand All @@ -108,6 +107,11 @@ class LinearIR::LoopManager {
void set_exit_points(std::vector<LoopPort> exit_points);
void set_handlers(SpecificIterationHandlers handlers);

template <SpecificIterationHandlers::HandlerType Type, typename T, class... Args>
void register_handler(Args&&... args) {
m_handlers.register_handler<Type, T>(args...);
}

private:
size_t m_work_amount = 0;
size_t m_increment = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class UpdateMemoryAccessCounts : public pass::RangedPass {
UpdateMemoryAccessCounts(size_t count);
OPENVINO_RTTI("UpdateMemoryAccessCounts", "RangedPass")
bool run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;
bool can_be_merged(const std::shared_ptr<pass::PassBase>& other) override;

private:
size_t m_count;
Expand All @@ -40,6 +41,7 @@ class SetFillOffset : public pass::RangedPass {
SetFillOffset(size_t offset);
OPENVINO_RTTI("SetFillOffset", "RangedPass")
bool run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;
bool can_be_merged(const std::shared_ptr<pass::PassBase>& other) override;

private:
size_t m_offset;
Expand All @@ -56,6 +58,7 @@ class TransformInnerSplitLoop : public pass::RangedPass {
TransformInnerSplitLoop(size_t tail_size);
OPENVINO_RTTI("TransformInnerSplitLoop", "RangedPass")
bool run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;
bool can_be_merged(const std::shared_ptr<pass::PassBase>& other) override;

private:
size_t m_tail_size;
Expand Down
12 changes: 12 additions & 0 deletions src/common/snippets/include/snippets/lowered/pass/pass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ class PassBase {
const char* get_type_name() const {
return get_type_info().name;
}

/**
* @brief Checks if the current pass can be merged with another one (e.g. during 2 pass pipelines fusion)
* @param other Pointer on the another pass.
* @return bool value indicating whether the passes can be merged or not
*/
virtual bool can_be_merged(const std::shared_ptr<PassBase>& other) {
return false;
}
};

/**
Expand Down Expand Up @@ -81,6 +90,7 @@ class PassPipeline {
PassPipeline(const std::shared_ptr<PassConfig>& pass_config);

const std::vector<std::shared_ptr<PassBase>>& get_passes() const { return m_passes; }
const std::shared_ptr<PassConfig>& get_pass_config() const { return m_pass_config; }
bool empty() const { return m_passes.empty(); }

void register_pass(const snippets::pass::PassPosition& position, const std::shared_ptr<PassBase>& pass);
Expand All @@ -104,6 +114,8 @@ class PassPipeline {
void run(lowered::LinearIR& linear_ir) const;
void run(lowered::LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) const;

static PassPipeline merge_pipelines(const PassPipeline& lhs, const PassPipeline& rhs);

private:
std::shared_ptr<PassConfig> m_pass_config;
std::vector<std::shared_ptr<PassBase>> m_passes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class PassConfig {
return is_enabled(T::get_type_info_static());
}

friend bool operator==(const PassConfig& lhs, const PassConfig& rhs);
friend bool operator!=(const PassConfig& lhs, const PassConfig& rhs);

private:
std::unordered_set<DiscreteTypeInfo> m_disabled;
std::unordered_set<DiscreteTypeInfo> m_enabled;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class UpdateSubtensors : public pass::RangedPass {
UpdateSubtensors(size_t tail_size);
OPENVINO_RTTI("UpdateSubtensors", "RangedPass")
bool run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;
bool can_be_merged(const std::shared_ptr<pass::PassBase>& other) override;

private:
size_t m_tail_size;
Expand Down
31 changes: 7 additions & 24 deletions src/common/snippets/src/lowered/loop_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,28 +66,11 @@ const lowered::pass::PassPipeline& LoopInfo::SpecificIterationHandlers::get_last
return m_last_iter_handlers;
}

LoopInfo::SpecificIterationHandlers LoopInfo::SpecificIterationHandlers::merge_loop_handlers(
const LoopInfo::SpecificIterationHandlers& lhs,
const LoopInfo::SpecificIterationHandlers& rhs) {
auto merge_handlers_pipelines = [](const lowered::pass::PassPipeline& lhs, const lowered::pass::PassPipeline& rhs) {
auto merged_pipeline = lhs;
const auto& merged_passes = merged_pipeline.get_passes();
for (const auto& pass : rhs.get_passes()) {
auto pred = [&pass](const std::shared_ptr<lowered::pass::PassBase>& p) {
return p->get_type_info() == pass->get_type_info();
};
if (std::find_if(merged_passes.begin(), merged_passes.end(), pred) == merged_passes.end()) {
merged_pipeline.register_pass(pass);
}
}
return merged_pipeline;
};

LoopInfo::SpecificIterationHandlers merged_handlers(
merge_handlers_pipelines(lhs.get_first_iter_handelrs(), rhs.get_first_iter_handelrs()),
merge_handlers_pipelines(lhs.get_main_iter_handelrs(), rhs.get_main_iter_handelrs()),
merge_handlers_pipelines(lhs.get_last_iter_handelrs(), rhs.get_last_iter_handelrs()));
return merged_handlers;
LoopInfo::SpecificIterationHandlers LoopInfo::SpecificIterationHandlers::merge(const LoopInfo::SpecificIterationHandlers& other) const {
return LoopInfo::SpecificIterationHandlers(
lowered::pass::PassPipeline::merge_pipelines(m_first_iter_handlers, other.get_first_iter_handelrs()),
lowered::pass::PassPipeline::merge_pipelines(m_main_body_handlers, other.get_main_iter_handelrs()),
lowered::pass::PassPipeline::merge_pipelines(m_last_iter_handlers, other.get_last_iter_handelrs()));
}

LoopInfo::LoopInfo(size_t work_amount,
Expand Down Expand Up @@ -151,7 +134,7 @@ const std::vector<LoopPort>& LoopInfo::get_exit_points() const {
return m_exit_points;
}

LoopInfo::SpecificIterationHandlers& LoopInfo::get_handlers() {
const LoopInfo::SpecificIterationHandlers& LoopInfo::get_handlers() const {
return m_handlers;
}

Expand Down Expand Up @@ -482,7 +465,7 @@ void LinearIR::LoopManager::fuse_loops(LinearIR::constExprIt loop_begin_target,
loop_info->set_entry_points(new_entries);
loop_info->set_exit_points(new_exits);

loop_info->set_handlers(LoopInfo::SpecificIterationHandlers::merge_loop_handlers(loop_info_upper->get_handlers(), loop_info_lower->get_handlers()));
loop_info->set_handlers(loop_info_upper->get_handlers().merge(loop_info_lower->get_handlers()));
// Since fusion can be called for broadcastable loops (one of the loops has work_amount = increment = 1),
// maximum value is set to the fused loop
loop_info->set_work_amount(std::max(loop_info_upper->get_work_amount(), loop_info_lower->get_work_amount()));
Expand Down
15 changes: 15 additions & 0 deletions src/common/snippets/src/lowered/pass/iter_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ bool UpdateMemoryAccessCounts::run(LinearIR& linear_ir, LinearIR::constExprIt be
return true;
}

bool UpdateMemoryAccessCounts::can_be_merged(const std::shared_ptr<pass::PassBase>& other) {
const auto casted_pass = ov::as_type_ptr<UpdateMemoryAccessCounts>(other);
return casted_pass && m_count == casted_pass->m_count;
}

SetFillOffset::SetFillOffset(size_t offset) : RangedPass(), m_offset(offset) {}

bool SetFillOffset::run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) {
Expand All @@ -57,6 +62,11 @@ bool SetFillOffset::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linear
return true;
}

bool SetFillOffset::can_be_merged(const std::shared_ptr<pass::PassBase>& other) {
const auto casted_pass = ov::as_type_ptr<SetFillOffset>(other);
return casted_pass && m_offset == casted_pass->m_offset;
}

TransformInnerSplitLoop::TransformInnerSplitLoop(size_t tail_size) : RangedPass(), m_tail_size(tail_size) {}

bool TransformInnerSplitLoop::run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) {
Expand Down Expand Up @@ -101,6 +111,11 @@ bool TransformInnerSplitLoop::run(LinearIR& linear_ir, LinearIR::constExprIt beg
return modified;
}

bool TransformInnerSplitLoop::can_be_merged(const std::shared_ptr<pass::PassBase>& other) {
const auto casted_pass = ov::as_type_ptr<TransformInnerSplitLoop>(other);
return casted_pass && m_tail_size == casted_pass->m_tail_size;
}

} // namespace pass
} // namespace lowered
} // namespace snippets
Expand Down
20 changes: 20 additions & 0 deletions src/common/snippets/src/lowered/pass/pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ void PassPipeline::register_positioned_passes(const std::vector<PositionedPassLo
register_pass(pp.position, pp.pass);
}

PassPipeline PassPipeline::merge_pipelines(const PassPipeline& lhs, const PassPipeline& rhs) {
OPENVINO_ASSERT(*lhs.get_pass_config() == *rhs.get_pass_config(), "2 passes with different PassConfigs can't be merged.");
const auto& lhs_passes = lhs.get_passes();
std::unordered_map<ov::DiscreteTypeInfo, std::shared_ptr<lowered::pass::PassBase>> passes_map;
for (const auto& pass : lhs_passes) {
passes_map[pass->get_type_info()] = pass;
}

auto merged_pipeline = lhs;
for (const auto& pass : rhs.get_passes()) {
auto lhs_pass_it = passes_map.find(pass->get_type_info());
if (lhs_pass_it == passes_map.end()) {
merged_pipeline.register_pass(pass);
} else {
OPENVINO_ASSERT(lhs_pass_it->second->can_be_merged(pass), "2 passes with type info ", pass->get_type_info(), " can't be merged.");
}
}
return merged_pipeline;
}

} // namespace pass
} // namespace lowered
} // namespace snippets
Expand Down
8 changes: 8 additions & 0 deletions src/common/snippets/src/lowered/pass/pass_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ bool PassConfig::is_enabled(const DiscreteTypeInfo& type_info) const {
return m_enabled.count(type_info);
}

bool operator==(const PassConfig& lhs, const PassConfig& rhs) {
return lhs.m_disabled == rhs.m_disabled && lhs.m_enabled == rhs.m_enabled;
}

bool operator!=(const PassConfig& lhs, const PassConfig& rhs) {
return !(lhs == rhs);
}

} // namespace pass
} // namespace lowered
} // namespace snippets
Expand Down
5 changes: 5 additions & 0 deletions src/common/snippets/src/lowered/pass/propagate_subtensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ bool UpdateSubtensors::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Lin
return true;
}

bool UpdateSubtensors::can_be_merged(const std::shared_ptr<pass::PassBase>& other) {
const auto casted_pass = ov::as_type_ptr<UpdateSubtensors>(other);
return casted_pass && m_tail_size == casted_pass->m_tail_size;
}

} // namespace pass
} // namespace lowered
} // namespace snippets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir, lowered::LinearIR::constExpr
const auto& reduce_max_loop_info = loop_manager->get_loop_info(reduce_max_loop_id);
const auto tail_size = inner_work_amount % inner_increment;
if (tail_size != 0) {
reduce_max_loop_info->get_handlers().register_handler<HandlerType::LAST_ITER, SetFillOffset>(tail_size);
reduce_max_loop_info->register_handler<HandlerType::LAST_ITER, SetFillOffset>(tail_size);
}
const auto broadcast_horizon_max = push_node(std::make_shared<op::BroadcastMove>(horizon_max.second, broadcasted_dim));
const auto vector_buffer_sum = push_node(std::make_shared<op::VectorBuffer>());
Expand All @@ -101,7 +101,7 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir, lowered::LinearIR::constExpr
(*sum.first)->get_output_port(0)});
const auto& reduce_sum_loop_info = loop_manager->get_loop_info(reduce_sum_loop_id);
if (tail_size != 0) {
reduce_sum_loop_info->get_handlers().register_handler<HandlerType::LAST_ITER, SetFillOffset>(tail_size);
reduce_sum_loop_info->register_handler<HandlerType::LAST_ITER, SetFillOffset>(tail_size);
}

// Divide is expensive operation, so we decompose it into 1 / x * y, where 1 / x is executed outside loop
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/tests/src/lowered/pass/loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ static void init_linear_ir(const std::vector<ov::PartialShape>& in_shapes, Linea
const auto& outer_loop_info = loop_manager->get_loop_info(loop_id);
const auto outer_tail_size = outer_wa % outer_inc;
if (outer_tail_size != 0) {
outer_loop_info->get_handlers().register_handler<LinearIR::LoopManager::LoopInfo::SpecificIterationHandlers::HandlerType::LAST_ITER,
pass::TransformInnerSplitLoop>(outer_tail_size);
outer_loop_info->register_handler<LinearIR::LoopManager::LoopInfo::SpecificIterationHandlers::HandlerType::LAST_ITER,
pass::TransformInnerSplitLoop>(outer_tail_size);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ bool BrgemmBlocking::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linea
std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), false)};
const auto id = loop_manager->mark_loop(loop_begin_it, loop_end_it, k, block_size_k, entries, exits);
const auto loop_info = loop_manager->get_loop_info(id);
loop_info->get_handlers().register_handler<LoopInfo::SpecificIterationHandlers::HandlerType::FIRST_ITER, SetBrgemmBeta>(0.f);
loop_info->register_handler<LoopInfo::SpecificIterationHandlers::HandlerType::FIRST_ITER, SetBrgemmBeta>(0.f);
};

apply_k_blocking();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ bool SetBrgemmBeta::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linear
}
return true;
}

bool SetBrgemmBeta::can_be_merged(const std::shared_ptr<snippets::lowered::pass::PassBase>& other) {
const auto casted_pass = ov::as_type_ptr<SetBrgemmBeta>(other);
return casted_pass && m_beta == casted_pass->m_beta;
}
} // namespace pass
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class SetBrgemmBeta : public snippets::lowered::pass::RangedPass {
bool run(snippets::lowered::LinearIR& linear_ir,
snippets::lowered::LinearIR::constExprIt begin,
snippets::lowered::LinearIR::constExprIt end) override;
bool can_be_merged(const std::shared_ptr<snippets::lowered::pass::PassBase>& other) override;

private:
float m_beta;
Expand Down

0 comments on commit 67a46b3

Please sign in to comment.