Skip to content

Commit

Permalink
[Snippets] Applied Ivan comments
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jun 21, 2023
1 parent ac36486 commit ca05b08
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 32 deletions.
37 changes: 33 additions & 4 deletions src/common/snippets/include/snippets/lowered/linear_ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class LinearIR {
using io_container = std::list<std::shared_ptr<IOExpression>>;
using exprIt = container::iterator;
using constExprIt = container::const_iterator;
using exprReverseIt = container::reverse_iterator;
using constExprReverseIt = container::const_reverse_iterator;

LinearIR() = default;
explicit LinearIR(const std::shared_ptr<ov::Model>& m, Config config = {});
Expand Down Expand Up @@ -69,10 +71,10 @@ class LinearIR {
constExprIt end() const noexcept {return cend();}
constExprIt cbegin() const noexcept {return m_expressions.cbegin();}
constExprIt cend() const noexcept {return m_expressions.cend();}
container::reverse_iterator rbegin() noexcept {return m_expressions.rbegin();}
container::reverse_iterator rend() noexcept {return m_expressions.rend();}
container::const_reverse_iterator crbegin() const noexcept {return m_expressions.crbegin();}
container::const_reverse_iterator crend() const noexcept {return m_expressions.crend();}
exprReverseIt rbegin() noexcept {return m_expressions.rbegin();}
exprReverseIt rend() noexcept {return m_expressions.rend();}
constExprReverseIt crbegin() const noexcept {return m_expressions.crbegin();}
constExprReverseIt crend() const noexcept {return m_expressions.crend();}

exprIt insert(constExprIt pos, const ov::NodeVector& nodes);
exprIt insert(constExprIt pos, const std::shared_ptr<Node>& n);
Expand All @@ -84,6 +86,21 @@ class LinearIR {
exprIt erase(exprIt pos);
exprIt erase(constExprIt pos);

template<typename iterator>
iterator find(iterator begin, iterator end, const ExpressionPtr& target);
template<typename const_iterator>
const_iterator find(const_iterator begin, const_iterator end, const ExpressionPtr& target) const;
exprIt find(const ExpressionPtr& target);
constExprIt find(const ExpressionPtr& target) const;
template<typename iterator>
iterator find_before(iterator it, const ExpressionPtr& target);
template<typename const_iterator>
const_iterator find_before(const_iterator it, const ExpressionPtr& target) const;
template<typename iterator>
iterator find_after(iterator it, const ExpressionPtr& target);
template<typename const_iterator>
const_iterator find_after(const_iterator it, const ExpressionPtr& target) const;

void init_emitters(const std::shared_ptr<TargetMachine>& target);
void serialize(const std::string& xml, const std::string& bin);

Expand All @@ -107,6 +124,18 @@ class LinearIR {
LoopManagerPtr m_loop_manager = nullptr;
};

template<typename iterator>
iterator LinearIR::find(iterator begin, iterator end, const ExpressionPtr& target) {
auto found = std::find(begin, end, target);
OPENVINO_ASSERT(found != end, "Expression has not been found");
return found;
}
template<typename const_iterator>
const_iterator LinearIR::find(const_iterator begin, const_iterator end, const ExpressionPtr& target) const {
auto found = std::find(begin, end, target);
OPENVINO_ASSERT(found != end, "Expression has not been found");
return found;
}
} // namespace lowered
} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class InsertBuffers : public Pass {
bool run(LinearIR& linear_ir) override;

private:
void insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPtr& loop_manager,
void insertion(LinearIR& linear_ir, const LinearIR::constExprIt& expr_it, const LinearIR::LoopManagerPtr& loop_manager,
const std::vector<LinearIR::LoopManager::LoopPort>& loop_entries,
const std::vector<LinearIR::LoopManager::LoopPort>& loop_exits);

Expand Down
56 changes: 56 additions & 0 deletions src/common/snippets/src/lowered/linear_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,62 @@ void LinearIR::move(LinearIR::constExprIt from, LinearIR::constExprIt to) {
m_expressions.splice(to, m_expressions, from);
}

LinearIR::exprIt LinearIR::find(const ExpressionPtr& target) {
return find(begin(), end(), target);
}
LinearIR::constExprIt LinearIR::find(const ExpressionPtr& target) const {
return find(cbegin(), cend(), target);
}
template<>
LinearIR::exprIt LinearIR::find_before(LinearIR::exprIt it, const ExpressionPtr& target) {
return find(begin(), it, target);
}
template<>
LinearIR::constExprIt LinearIR::find_before(LinearIR::constExprIt it, const ExpressionPtr& target) {
return find(cbegin(), it, target);
}
template<>
LinearIR::constExprIt LinearIR::find_before(LinearIR::constExprIt it, const ExpressionPtr& target) const {
return find(cbegin(), it, target);
}
template<>
LinearIR::exprReverseIt LinearIR::find_before(LinearIR::exprReverseIt it, const ExpressionPtr& target) {
return find(rbegin(), it, target);
}
template<>
LinearIR::constExprReverseIt LinearIR::find_before(LinearIR::constExprReverseIt it, const ExpressionPtr& target) {
return find(crbegin(), it, target);
}
template<>
LinearIR::constExprReverseIt LinearIR::find_before(LinearIR::constExprReverseIt it, const ExpressionPtr& target) const {
return find(crbegin(), it, target);
}
template<>
LinearIR::exprIt LinearIR::find_after(LinearIR::exprIt it, const ExpressionPtr& target) {
return find(it, end(), target);
}
template<>
LinearIR::constExprIt LinearIR::find_after(LinearIR::constExprIt it, const ExpressionPtr& target) {
return find(it, cend(), target);
}
template<>
LinearIR::constExprIt LinearIR::find_after(LinearIR::constExprIt it, const ExpressionPtr& target) const {
return find(it, cend(), target);
}
template<>
LinearIR::exprReverseIt LinearIR::find_after(LinearIR::exprReverseIt it, const ExpressionPtr& target) {
return find(it, rend(), target);
}
template<>
LinearIR::constExprReverseIt LinearIR::find_after(LinearIR::constExprReverseIt it, const ExpressionPtr& target) {
return find(it, crend(), target);
}
template<>
LinearIR::constExprReverseIt LinearIR::find_after(LinearIR::constExprReverseIt it, const ExpressionPtr& target) const {
return find(it, crend(), target);
}


}// namespace lowered
}// namespace snippets
}// namespace ov
6 changes: 2 additions & 4 deletions src/common/snippets/src/lowered/loop_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ void LinearIR::LoopManager::get_loop_bounds(const LinearIR &linear_ir,
OPENVINO_ASSERT(!entries.empty(), "Loop must have entry points");
OPENVINO_ASSERT(!exits.empty(), "Loop must have entry points");
const auto& entry_expr = entries.front().expr_port->get_expr();
loop_begin_pos = std::find(linear_ir.begin(), linear_ir.end(), entry_expr);
OPENVINO_ASSERT(loop_begin_pos != linear_ir.end(), "Loop begin hasn't been found!");
loop_begin_pos = linear_ir.find(entry_expr);

// Some operations in Loop can be before first entry points: Scalars, VectorBuffer.
// We should iterate by them till the expr is in the corresponding Loop
Expand All @@ -92,8 +91,7 @@ void LinearIR::LoopManager::get_loop_bounds(const LinearIR &linear_ir,

// At the moment all Loops must have exit points
const auto& exit_expr = exits.back().expr_port->get_expr();
loop_end_pos = std::next(std::find(loop_begin_pos, linear_ir.end(), exit_expr));
OPENVINO_ASSERT(loop_end_pos != linear_ir.end(), "Loop end hasn't been found!");
loop_end_pos = std::next(linear_ir.find_after(loop_begin_pos, exit_expr));
}

void LinearIR::LoopManager::get_io_loop_ports(LinearIR::constExprIt loop_begin_pos,
Expand Down
24 changes: 10 additions & 14 deletions src/common/snippets/src/lowered/pass/insert_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,11 @@ LinearIR::constExprIt InsertBuffers::insertion_position(const LinearIR& linear_i
const auto down_loops = down_expr->get_loop_ids();
// If upper expression is out of Loop, we can insert Buffer implicitly after him
if (up_loops.empty()) {
const auto it = std::find(linear_ir.cbegin(), linear_ir.cend(), up_expr);
OPENVINO_ASSERT(it != linear_ir.cend(), "Upper expression hasn't been found to insert Buffer after him!");
return std::next(it);
return std::next(linear_ir.find(up_expr));
}
// If lower expression is out of Loop, we can insert Buffer implicitly before him
if (down_loops.empty()) {
const auto it = std::find(linear_ir.cbegin(), linear_ir.cend(), down_expr);
OPENVINO_ASSERT(it != linear_ir.cend(), "Lower expression hasn't been found to insert Buffer after him!");
return it;
return linear_ir.find(down_expr);
}

const auto up_loop_count = up_loops.size();
Expand Down Expand Up @@ -60,7 +56,7 @@ LinearIR::constExprIt InsertBuffers::insertion_position(const LinearIR& linear_i
OPENVINO_THROW("Incorrect configuration for Buffer insertion!");
}

void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPtr& loop_manager,
void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::constExprIt& expr_it, const LinearIR::LoopManagerPtr& loop_manager,
const std::vector<LinearIR::LoopManager::LoopPort>& loop_entries,
const std::vector<LinearIR::LoopManager::LoopPort>& loop_exits) {
for (const auto& entry_point : loop_entries) {
Expand Down Expand Up @@ -163,21 +159,20 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPt
const auto buffer_consumers_inputs = buffer_out->get_consumers();
linear_ir.replace_input(buffer_consumers_inputs, output_connector);
potential_consumers.insert(buffer_consumers_inputs.begin(), buffer_consumers_inputs.end());
const auto buffer_pos = std::find(linear_ir.cbegin(), linear_ir.cend(), buffer);
OPENVINO_ASSERT(buffer_pos != linear_ir.cend(), "Buffer has not been found in Linear IR");
linear_ir.erase(buffer_pos);
linear_ir.erase(linear_ir.find_after(expr_it, buffer));
}
}

// potential_consumers is unsorted by linear IR set.
// We have to find first expr in Linear IR from the set to insert Buffer before *all* consumers
// [113536]: Remove this logic with `std::find` using, when expression numeration will be supported
OPENVINO_ASSERT(!potential_consumers.empty(), "Buffer should have one consumer at least");
auto consumer_expr = potential_consumers.begin()->get_expr();
if (potential_consumers.size() > 1) {
std::set<ExpressionPtr> consumers;
for (const auto& port : potential_consumers)
consumers.insert(port.get_expr());
const auto it = std::find_if(linear_ir.cbegin(), linear_ir.cend(),
const auto it = std::find_if(expr_it, linear_ir.cend(),
[&consumers](const ExpressionPtr& expr) { return consumers.count(expr) > 0; });
OPENVINO_ASSERT(it != linear_ir.cend(), "Consumer of Buffer has not been found in Linear IR");
consumer_expr = *it;
Expand Down Expand Up @@ -219,10 +214,11 @@ bool InsertBuffers::run(LinearIR& linear_ir) {
const auto loop_info = loop_data.second;
const auto loop_entries = loop_info->entry_points;
const auto loop_exits = loop_info->exit_points;
insertion(linear_ir, loop_manager, loop_entries, loop_exits);
// using begin() as expr_it because we work with LoopInfo, not expressions in Linear IR
insertion(linear_ir, linear_ir.cbegin(), loop_manager, loop_entries, loop_exits);
}

for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) {
for (auto expr_it = linear_ir.cbegin(); expr_it != linear_ir.cend(); expr_it++) {
const auto expr = *expr_it;
const auto node = (*expr_it)->get_node();
const auto ma = ov::as_type_ptr<op::MemoryAccess>(node);
Expand All @@ -239,7 +235,7 @@ bool InsertBuffers::run(LinearIR& linear_ir) {
loop_exits[p.first] = expr->get_output_port(p.first);
}

insertion(linear_ir, loop_manager, loop_entries, loop_exits);
insertion(linear_ir, expr_it, loop_manager, loop_entries, loop_exits);
}

return true;
Expand Down
8 changes: 2 additions & 6 deletions src/common/snippets/src/lowered/pass/insert_load_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr
const auto load = std::make_shared<op::Load>(data_node->output(0), get_count(data_expr->get_output_port_descriptor(0)));
PortDescriptorUtils::set_port_descriptor_ptr(load->output(0), consumer_input.get_descriptor_ptr()->clone());
const auto load_expr = linear_ir.create_expression(load, {output_connector});
const auto insertion_pos = std::find(data_expr_it, linear_ir.cend(), consumer_expr);
OPENVINO_ASSERT(insertion_pos != linear_ir.cend(), "Consumer should be after data producer in Linear IR");
linear_ir.insert(insertion_pos, load_expr);
linear_ir.insert(linear_ir.find_after(data_expr_it, consumer_expr), load_expr);
linear_ir.replace_input(consumer_input, load_expr->get_output_port_connector(0));
// Copy Loop identifies
load_expr->set_loop_ids(loop_ids);
Expand Down Expand Up @@ -100,9 +98,7 @@ bool InsertLoadStore::insert_store(LinearIR& linear_ir, const LinearIR::constExp
const auto store = std::make_shared<op::Store>(parent->output(port), get_count(data_expr->get_input_port_descriptor(0)));
PortDescriptorUtils::set_port_descriptor_ptr(store->output(0), parent_output.get_descriptor_ptr()->clone());
const auto store_expr = linear_ir.create_expression(store, {input_connector});
const auto& reverse_insertion_pos = std::find(std::reverse_iterator<LinearIR::constExprIt>(data_expr_it), linear_ir.crend(), parent_expr);
OPENVINO_ASSERT(reverse_insertion_pos != linear_ir.crend(), "Consumer should be after data producer in Linear IR");
const auto& insertion_pos = reverse_insertion_pos.base();
const auto& insertion_pos = linear_ir.find_after(std::reverse_iterator<LinearIR::constExprIt>(data_expr_it), parent_expr).base();
linear_ir.insert(insertion_pos, store_expr);
linear_ir.replace_input(data_expr->get_input_port(0), store_expr->get_output_port_connector(0));
// Copy Loop identifies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ bool LoadMoveBroadcastToBroadcastLoad::run(LinearIR& linear_ir) {
const auto mv_expr_it = expr_it;
const auto insertion_pos = std::next(expr_it);
expr_it = linear_ir.insert(insertion_pos, broadcastload_expr);
const auto load_it = std::find(linear_ir.begin(), mv_expr_it, parent_expr);
OPENVINO_ASSERT(load_it != mv_expr_it, "Failed fuse Load and MoveBroadcast: Load should be before MoveBroadcast in Linear IR");
linear_ir.erase(load_it);
linear_ir.erase(linear_ir.find_before(mv_expr_it, parent_expr));
linear_ir.erase(mv_expr_it);
linear_ir.replace_input(move_consumers, broadcastload_expr->get_output_port_connector(0));
modified |= true;
Expand Down

0 comments on commit ca05b08

Please sign in to comment.