Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Snippets] Fixed insertion position iterators #18023

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/common/snippets/include/snippets/lowered/linear_ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class LinearIR {
using io_container = std::list<std::shared_ptr<IOExpression>>;
using exprIt = container::iterator;
using constExprIt = container::const_iterator;
using exprReverseIt = container::reverse_iterator;
using constExprReverseIt = container::const_reverse_iterator;

LinearIR() = default;
explicit LinearIR(const std::shared_ptr<ov::Model>& m, Config config = {});
Expand Down Expand Up @@ -69,10 +71,10 @@ class LinearIR {
constExprIt end() const noexcept {return cend();}
constExprIt cbegin() const noexcept {return m_expressions.cbegin();}
constExprIt cend() const noexcept {return m_expressions.cend();}
container::reverse_iterator rbegin() noexcept {return m_expressions.rbegin();}
container::reverse_iterator rend() noexcept {return m_expressions.rend();}
container::const_reverse_iterator crbegin() const noexcept {return m_expressions.crbegin();}
container::const_reverse_iterator crend() const noexcept {return m_expressions.crend();}
exprReverseIt rbegin() noexcept {return m_expressions.rbegin();}
exprReverseIt rend() noexcept {return m_expressions.rend();}
constExprReverseIt crbegin() const noexcept {return m_expressions.crbegin();}
constExprReverseIt crend() const noexcept {return m_expressions.crend();}

exprIt insert(constExprIt pos, const ov::NodeVector& nodes);
exprIt insert(constExprIt pos, const std::shared_ptr<Node>& n);
Expand All @@ -84,6 +86,14 @@ class LinearIR {
exprIt erase(exprIt pos);
exprIt erase(constExprIt pos);

constExprIt find(const ExpressionPtr& target) const;
template<typename iterator>
iterator find(iterator begin, iterator end, const ExpressionPtr& target) const;
template<typename iterator>
iterator find_before(iterator it, const ExpressionPtr& target) const;
template<typename iterator>
iterator find_after(iterator it, const ExpressionPtr& target) const;

void init_emitters(const std::shared_ptr<TargetMachine>& target);
void serialize(const std::string& xml, const std::string& bin);

Expand All @@ -107,6 +117,12 @@ class LinearIR {
LoopManagerPtr m_loop_manager = nullptr;
};

template<typename iterator>
iterator LinearIR::find(iterator begin, iterator end, const ExpressionPtr& target) const {
auto found = std::find(begin, end, target);
OPENVINO_ASSERT(found != end, "Expression has not been found");
return found;
}
} // namespace lowered
} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class InsertBuffers : public Pass {
bool run(LinearIR& linear_ir) override;

private:
void insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPtr& loop_manager,
void insertion(LinearIR& linear_ir, const LinearIR::constExprIt& expr_it, const LinearIR::LoopManagerPtr& loop_manager,
const std::vector<LinearIR::LoopManager::LoopPort>& loop_entries,
const std::vector<LinearIR::LoopManager::LoopPort>& loop_exits);

Expand Down
21 changes: 21 additions & 0 deletions src/common/snippets/src/lowered/linear_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,27 @@ void LinearIR::move(LinearIR::constExprIt from, LinearIR::constExprIt to) {
m_expressions.splice(to, m_expressions, from);
}

LinearIR::constExprIt LinearIR::find(const ExpressionPtr& target) const {
return find(cbegin(), cend(), target);
}
template<>
LinearIR::constExprIt LinearIR::find_before(LinearIR::constExprIt it, const ExpressionPtr& target) const {
return find(cbegin(), it, target);
}
template<>
LinearIR::constExprReverseIt LinearIR::find_before(LinearIR::constExprReverseIt it, const ExpressionPtr& target) const {
return find(crbegin(), it, target);
}
template<>
LinearIR::constExprIt LinearIR::find_after(LinearIR::constExprIt it, const ExpressionPtr& target) const {
return find(it, cend(), target);
}
template<>
LinearIR::constExprReverseIt LinearIR::find_after(LinearIR::constExprReverseIt it, const ExpressionPtr& target) const {
return find(it, crend(), target);
}


}// namespace lowered
}// namespace snippets
}// namespace ov
8 changes: 3 additions & 5 deletions src/common/snippets/src/lowered/loop_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ void LinearIR::LoopManager::get_loop_bounds(const LinearIR &linear_ir,
OPENVINO_ASSERT(!entries.empty(), "Loop must have entry points");
OPENVINO_ASSERT(!exits.empty(), "Loop must have entry points");
const auto& entry_expr = entries.front().expr_port->get_expr();
loop_begin_pos = std::find(linear_ir.begin(), linear_ir.end(), entry_expr);
OPENVINO_ASSERT(loop_begin_pos != linear_ir.end(), "Loop begin hasn't been found!");
loop_begin_pos = linear_ir.find(entry_expr);

// Some operations in Loop can be before first entry points: Scalars, VectorBuffer.
// We should iterate by them till the expr is in the corresponding Loop
Expand All @@ -104,12 +103,11 @@ void LinearIR::LoopManager::get_loop_bounds(const LinearIR &linear_ir,
const auto loop_end = loop_begin->get_loop_end();
OPENVINO_ASSERT(loop_end->get_id() == loop_id, "Failed explicit loop bounds getting: Loop bounds with correct ID have not been found");
loop_begin_pos = std::prev(loop_begin_pos);
loop_end_pos = std::find(loop_begin_pos, linear_ir.end(), linear_ir.get_expr_by_node(loop_end));
loop_end_pos = linear_ir.find_after(loop_begin_pos, linear_ir.get_expr_by_node(loop_end));
} else {
// At the moment all Loops must have exit points
const auto& exit_expr = exits.back().expr_port->get_expr();
loop_end_pos = std::next(std::find(loop_begin_pos, linear_ir.end(), exit_expr));
OPENVINO_ASSERT(loop_end_pos != linear_ir.end(), "Loop end hasn't been found!");
loop_end_pos = std::next(linear_ir.find_after(loop_begin_pos, exit_expr));
}
}

Expand Down
36 changes: 24 additions & 12 deletions src/common/snippets/src/lowered/pass/insert_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,11 @@ LinearIR::constExprIt InsertBuffers::insertion_position(const LinearIR& linear_i
const auto down_loops = down_expr->get_loop_ids();
// If upper expression is out of Loop, we can insert Buffer implicitly after him
if (up_loops.empty()) {
const auto it = std::find(linear_ir.cbegin(), linear_ir.cend(), up_expr);
OPENVINO_ASSERT(it != linear_ir.cend(), "Upper expression hasn't been found to insert Buffer after him!");
return std::next(it);
return std::next(linear_ir.find(up_expr));
}
// If lower expression is out of Loop, we can insert Buffer implicitly before him
if (down_loops.empty()) {
const auto it = std::find(linear_ir.cbegin(), linear_ir.cend(), down_expr);
OPENVINO_ASSERT(it != linear_ir.cend(), "Lower expression hasn't been found to insert Buffer after him!");
return it;
return linear_ir.find(down_expr);
}

const auto up_loop_count = up_loops.size();
Expand Down Expand Up @@ -120,7 +116,7 @@ LinearIR::constExprIt InsertBuffers::insertion_position(const LinearIR& linear_i
OPENVINO_THROW("Incorrect configuration for Buffer insertion!");
}

void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPtr& loop_manager,
void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::constExprIt& expr_it, const LinearIR::LoopManagerPtr& loop_manager,
const std::vector<LinearIR::LoopManager::LoopPort>& loop_entries,
const std::vector<LinearIR::LoopManager::LoopPort>& loop_exits) {
for (const auto& entry_point : loop_entries) {
Expand Down Expand Up @@ -221,16 +217,31 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPt
const auto buffer_consumers_inputs = buffer_out->get_consumers();
linear_ir.replace_input(buffer_consumers_inputs, output_connector);
potential_consumers.insert(buffer_consumers_inputs.begin(), buffer_consumers_inputs.end());
linear_ir.erase(std::find(linear_ir.begin(), linear_ir.end(), buffer));
linear_ir.erase(linear_ir.find_after(expr_it, buffer));
}
}

// potential_consumers is unsorted by linear IR set.
// We have to find first expr in Linear IR from the set to insert Buffer before *all* consumers
// [113536]: Remove this logic with `std::find` using, when expression numeration will be supported
OPENVINO_ASSERT(!potential_consumers.empty(), "Buffer should have one consumer at least");
auto consumer_expr = potential_consumers.begin()->get_expr();
if (potential_consumers.size() > 1) {
std::set<ExpressionPtr> consumers;
for (const auto& port : potential_consumers)
consumers.insert(port.get_expr());
const auto it = std::find_if(expr_it, linear_ir.cend(),
[&consumers](const ExpressionPtr& expr) { return consumers.count(expr) > 0; });
OPENVINO_ASSERT(it != linear_ir.cend(), "Consumer of Buffer has not been found in Linear IR");
consumer_expr = *it;
}

// We should insert Buffer between first different Loops.
// Example: Current expr Loop identifies: 3, 2, 1
// Target consumers Loop identifies: 3, 4, 6
// Need to insert after 2nd Loops
// Note: All potential consumers must have the same count of first equal Loop identifies and the same count of different last identifies
const auto pos = insertion_position(linear_ir, loop_manager, expr, (*potential_consumers.begin()).get_expr());
const auto pos = insertion_position(linear_ir, loop_manager, expr, consumer_expr);

const auto allocation_shape = compute_allocation_shape(loop_manager,
buffer_loop_ids,
Expand Down Expand Up @@ -266,10 +277,11 @@ bool InsertBuffers::run(LinearIR& linear_ir) {
const auto loop_info = loop_data.second;
const auto loop_entries = loop_info->entry_points;
const auto loop_exits = loop_info->exit_points;
insertion(linear_ir, loop_manager, loop_entries, loop_exits);
// using begin() as expr_it because we work with LoopInfo, not expressions in Linear IR
insertion(linear_ir, linear_ir.cbegin(), loop_manager, loop_entries, loop_exits);
}

for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) {
for (auto expr_it = linear_ir.cbegin(); expr_it != linear_ir.cend(); expr_it++) {
const auto expr = *expr_it;
const auto node = (*expr_it)->get_node();
const auto ma = ov::as_type_ptr<op::MemoryAccess>(node);
Expand All @@ -286,7 +298,7 @@ bool InsertBuffers::run(LinearIR& linear_ir) {
loop_exits[p.first] = expr->get_output_port(p.first);
}

insertion(linear_ir, loop_manager, loop_entries, loop_exits);
insertion(linear_ir, expr_it, loop_manager, loop_entries, loop_exits);
}

return true;
Expand Down
5 changes: 2 additions & 3 deletions src/common/snippets/src/lowered/pass/insert_load_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr
const auto load = std::make_shared<op::Load>(data_node->output(0), get_count(data_expr->get_output_port_descriptor(0)));
PortDescriptorUtils::set_port_descriptor_ptr(load->output(0), consumer_input.get_descriptor_ptr()->clone());
const auto load_expr = linear_ir.create_expression(load, {output_connector});
linear_ir.insert(std::find(data_expr_it, linear_ir.cend(), consumer_expr), load_expr);
linear_ir.insert(linear_ir.find_after(data_expr_it, consumer_expr), load_expr);
linear_ir.replace_input(consumer_input, load_expr->get_output_port_connector(0));
// Copy Loop identifies
load_expr->set_loop_ids(loop_ids);
Expand Down Expand Up @@ -80,8 +80,7 @@ bool InsertLoadStore::insert_store(LinearIR& linear_ir, const LinearIR::constExp
const auto store = std::make_shared<op::Store>(parent->output(port), get_count(data_expr->get_input_port_descriptor(0)));
PortDescriptorUtils::set_port_descriptor_ptr(store->output(0), parent_output.get_descriptor_ptr()->clone());
const auto store_expr = linear_ir.create_expression(store, {input_connector});
const auto& reverse_insertion_pos = std::find(std::reverse_iterator<LinearIR::constExprIt>(data_expr_it), linear_ir.crend(), parent_expr);
const auto& insertion_pos = reverse_insertion_pos.base();
const auto& insertion_pos = linear_ir.find_after(std::reverse_iterator<LinearIR::constExprIt>(data_expr_it), parent_expr).base();
linear_ir.insert(insertion_pos, store_expr);
linear_ir.replace_input(data_expr->get_input_port(0), store_expr->get_output_port_connector(0));
// Copy Loop identifies
Expand Down
5 changes: 2 additions & 3 deletions src/common/snippets/src/lowered/pass/insert_tail_loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ void InsertTailLoop::tail_transformations(LinearIR& linear_ir,
// Skip inner Loops
const auto loop_begin = ov::as_type_ptr<op::LoopBegin>(expr_it->get()->get_node());
if (loop_begin) {
expr_it = std::find(expr_it, tail_end, linear_ir.get_expr_by_node(loop_begin->get_loop_end()));
OPENVINO_ASSERT(expr_it != tail_end, "LoopEnd has not been found");
expr_it = linear_ir.find(expr_it, tail_end, linear_ir.get_expr_by_node(loop_begin->get_loop_end()));
continue;
}
// We should fill vector regs by float_min and zero to have
Expand Down Expand Up @@ -198,7 +197,7 @@ bool InsertTailLoop::run(LinearIR& linear_ir) {
// finalization offsets which are supported by LoopEnd.
if (need_tail) {
const auto loop_begin = loop_end->get_loop_begin();
const auto begin_it = std::find(linear_ir.begin(), linear_ir.end(), linear_ir.get_expr_by_node(loop_begin));
const auto begin_it = linear_ir.find(linear_ir.get_expr_by_node(loop_begin));
LinearIR::constExprIt tail_begin, tail_end;
const auto tail_loop_end = create_tail_loop(linear_ir, begin_it, std::next(expr_it), tail_begin, tail_end,
loop_end, need_vector_loop, tail_size, tail_finalization_offsets);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ bool LoadMoveBroadcastToBroadcastLoad::run(LinearIR& linear_ir) {
const auto& loop_manager = linear_ir.get_loop_manager();
bool modified = false;

for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) {
for (auto expr_it = linear_ir.cbegin(); expr_it != linear_ir.cend(); expr_it++) {
const auto& expr = *expr_it;
const auto& op = expr->get_node();
// Match on MoveBroadcast because MoveBroadcast is rare node in bodies
Expand Down Expand Up @@ -56,7 +56,7 @@ bool LoadMoveBroadcastToBroadcastLoad::run(LinearIR& linear_ir) {
const auto mv_expr_it = expr_it;
const auto insertion_pos = std::next(expr_it);
expr_it = linear_ir.insert(insertion_pos, broadcastload_expr);
linear_ir.erase(std::find(linear_ir.begin(), mv_expr_it, parent_expr));
linear_ir.erase(linear_ir.find_before(mv_expr_it, parent_expr));
linear_ir.erase(mv_expr_it);
linear_ir.replace_input(move_consumers, broadcastload_expr->get_output_port_connector(0));
modified |= true;
Expand Down