Skip to content

Commit

Permalink
Alexandra's comments applied: 3rd part
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Jan 25, 2024
1 parent c175d24 commit 9e963bc
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 22 deletions.
4 changes: 2 additions & 2 deletions src/common/snippets/include/snippets/lowered/expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace lowered {

class LinearIR;
using ExpressionPtr = std::shared_ptr<Expression>;
using ExressionMap = std::unordered_map<Expression*, ExpressionPtr>;
using ExpressionMap = std::unordered_map<Expression*, ExpressionPtr>;
class Expression : public std::enable_shared_from_this<Expression> {
friend class LinearIR;
friend class ExpressionPort;
Expand Down Expand Up @@ -63,7 +63,7 @@ class Expression : public std::enable_shared_from_this<Expression> {
void set_loop_ids(const std::vector<size_t>& loops);
virtual ExpressionPtr clone_with_new_inputs(const std::vector<PortConnectorPtr>& new_inputs,
const std::shared_ptr<Node>& new_node) const;
ExpressionPtr clone_with_new_inputs(const ExressionMap& expr_map, const std::shared_ptr<Node>& new_node) const;
ExpressionPtr clone_with_new_inputs(const ExpressionMap& expr_map, const std::shared_ptr<Node>& new_node) const;

protected:
Expression(const Expression& other);
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/include/snippets/lowered/linear_ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class LinearIR {
std::shared_ptr<LinearIR> clone() const;
static LinearIR::container deep_copy_range(LinearIR::container::const_iterator begin,
LinearIR::container::const_iterator end,
ExressionMap& expression_map);
ExpressionMap& expression_map);

const container& get_ops() const { return m_expressions; }
const io_container& get_IO_ops() const { return m_io_expressions; }
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/include/snippets/lowered/loop_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class LinearIR::LoopManager {
const std::vector<ExpressionPort>& exits,
const SpecificIterationHandlers& handlers = SpecificIterationHandlers());

std::shared_ptr<LoopInfo> clone_with_new_expr(const ExressionMap& expr_map) const;
std::shared_ptr<LoopInfo> clone_with_new_expr(const ExpressionMap& expr_map) const;

// Returns dimension index if dimension indices for all entry and exit points are equal, and UNDEFINED_DIM_IDX otherwise
size_t get_dim_idx() const;
Expand Down Expand Up @@ -125,7 +125,7 @@ class LinearIR::LoopManager {
};
using LoopInfoPtr = std::shared_ptr<LoopInfo>;

std::shared_ptr<LoopManager> clone_with_new_expr(const ExressionMap& expr_map) const;
std::shared_ptr<LoopManager> clone_with_new_expr(const ExpressionMap& expr_map) const;
size_t add_loop_info(const LoopInfoPtr& loop);
void remove_loop_info(size_t index);
LoopInfoPtr get_loop_info(size_t index) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,16 @@ class InsertSpecificIterations : public RangedPass {
OPENVINO_RTTI("InsertSpecificIterations", "RangedPass")
bool run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) override;

static LinearIR::container copy_loop(const LinearIR& linear_ir, const size_t loop_id);
/**
* @brief Makes a copy of a loop body with id 'loop_id' and inserts it to the LinearIR before the 'insert_pos' position
* @param linear_ir LinearIR which should be modified
* @param loop_id id of the loop which should be copied
* @param insert_pos position before which the loop body copy should be inserted
* @return iterator which points on the LoopBegin copy
*/
static LinearIR::constExprIt insert_copy_loop(LinearIR& linear_ir,
const size_t loop_id,
const LinearIR::constExprIt& insert_pos);
};

} // namespace pass
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/src/lowered/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ ExpressionPtr Expression::clone_with_new_inputs(const std::vector<PortConnectorP
return expr;
}

ExpressionPtr Expression::clone_with_new_inputs(const ExressionMap& expr_map,
ExpressionPtr Expression::clone_with_new_inputs(const ExpressionMap& expr_map,
const std::shared_ptr<Node>& new_node) const {
std::vector<PortConnectorPtr> new_inputs;
new_inputs.reserve(m_input_port_connectors.size());
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/src/lowered/linear_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ std::shared_ptr<LinearIR> LinearIR::clone() const {
auto cloned = std::make_shared<LinearIR>();
cloned->m_config = m_config;

ExressionMap expression_map;
ExpressionMap expression_map;
cloned->m_expressions = deep_copy_range(m_expressions.cbegin(), m_expressions.cend(), expression_map);
for (const auto& expr : cloned->m_expressions) {
cloned->m_node2expression_map[expr->get_node()] = expr;
Expand Down Expand Up @@ -106,7 +106,7 @@ ov::NodeVector LinearIR::get_ordered_ops(const std::shared_ptr<ov::Model>& m) {

LinearIR::container LinearIR::deep_copy_range(LinearIR::container::const_iterator begin,
LinearIR::container::const_iterator end,
ExressionMap& expression_map) {
ExpressionMap& expression_map) {
OPENVINO_ASSERT(expression_map.empty(), "deep_copy_range expects empty expression_map as an input");
LinearIR::container result;
NodeVector original_nodes;
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/src/lowered/loop_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ LoopInfo::LoopInfo(size_t work_amount,
m_exit_points.emplace_back(port);
}

std::shared_ptr<LoopInfo> LoopInfo::clone_with_new_expr(const ExressionMap& expr_map) const {
std::shared_ptr<LoopInfo> LoopInfo::clone_with_new_expr(const ExpressionMap& expr_map) const {
auto clone_loop_ports = [&expr_map](const std::vector<LoopPort>& port_points) {
std::vector<LoopPort> cloned_port_points;
cloned_port_points.reserve(port_points.size());
Expand Down Expand Up @@ -197,7 +197,7 @@ bool operator<(const LinearIR::LoopManager::LoopPort& lhs, const LinearIR::LoopM
(lhs.is_incremented == rhs.is_incremented && lhs.dim_idx < rhs.dim_idx)));
}

std::shared_ptr<LoopManager> LoopManager::clone_with_new_expr(const ExressionMap& expr_map) const {
std::shared_ptr<LoopManager> LoopManager::clone_with_new_expr(const ExpressionMap& expr_map) const {
auto new_loop_manager = std::make_shared<LoopManager>();
for (const auto& id_info : m_map)
new_loop_manager->m_map.insert({id_info.first, id_info.second->clone_with_new_expr(expr_map)});
Expand Down
25 changes: 14 additions & 11 deletions src/common/snippets/src/lowered/pass/insert_specific_iterations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ namespace snippets {
namespace lowered {
namespace pass {

LinearIR::container InsertSpecificIterations::copy_loop(const LinearIR& linear_ir, const size_t loop_id) {
LinearIR::constExprIt InsertSpecificIterations::insert_copy_loop(LinearIR& linear_ir, const size_t loop_id, const LinearIR::constExprIt& insert_pos) {
const auto& loop_manager = linear_ir.get_loop_manager();
LinearIR::constExprIt loop_begin_pos, loop_end_pos;
loop_manager->get_loop_bounds(linear_ir, loop_id, loop_begin_pos, loop_end_pos, true);
ExressionMap expression_map;
ExpressionMap expression_map;
const auto& loop_copy_range = LinearIR::deep_copy_range(loop_begin_pos, std::next(loop_end_pos), expression_map);
const auto new_loop_begin_pos = linear_ir.insert(insert_pos, loop_copy_range.begin(), loop_copy_range.end());
const auto new_loop_end_pos = insert_pos;

const auto original_loop_info = loop_manager->get_loop_info(loop_id);
std::vector<LinearIR::LoopManager::LoopPort> new_entry_points, new_exit_points;
Expand All @@ -44,8 +46,6 @@ LinearIR::container InsertSpecificIterations::copy_loop(const LinearIR& linear_i
loop_manager->update_loops_port(outer_loop_ids, expr->get_output_port(i), {expr->get_output_port(i), new_expr->get_output_port(i)}, false);
}

const auto new_loop_begin_pos = loop_copy_range.begin();
const auto new_loop_end_pos = loop_copy_range.end();
const auto new_id = loop_manager->replace_with_new_loop(linear_ir,
std::next(new_loop_begin_pos),
std::prev(new_loop_end_pos),
Expand All @@ -57,7 +57,7 @@ LinearIR::container InsertSpecificIterations::copy_loop(const LinearIR& linear_i
const auto loop_end = ov::as_type_ptr<op::LoopEnd>(std::prev(new_loop_end_pos)->get()->get_node());
OPENVINO_ASSERT(loop_end, "Cloned Loop does not contain LoopEnd op at the expected place.");
loop_end->set_id(new_id);
return loop_copy_range;
return new_loop_begin_pos;
}

using LoopInfo = LinearIR::LoopManager::LoopInfo;
Expand Down Expand Up @@ -100,13 +100,16 @@ bool InsertSpecificIterations::run(LinearIR& linear_ir, lowered::LinearIR::const
};

auto copy_and_run_specific_handlers = [&](const PassPipeline& handlers) {
const auto& cloned_body = copy_loop(linear_ir, loop_end->get_id());
lowered::LinearIR::constExprIt start = linear_ir.insert(main_loop_begin_it, cloned_body.begin(), cloned_body.end());
const auto cloned_loop_end = *std::prev(cloned_body.end());
auto end = linear_ir.find_after(start, cloned_loop_end);
const auto new_loop_begin_pos = insert_copy_loop(linear_ir, loop_end->get_id(), main_loop_begin_it);
const auto new_loop_begin = ov::as_type_ptr<op::LoopBegin>(new_loop_begin_pos->get()->get_node());
OPENVINO_ASSERT(new_loop_begin, "Cloned Loop does not contain LoopBegin op at the expected place.");
const auto new_loop_end = new_loop_begin->get_loop_end();
const auto new_loop_end_pos = linear_ir.find_after(new_loop_begin_pos, linear_ir.get_expr_by_node(new_loop_end));
OPENVINO_ASSERT(new_loop_end, "Cloned Loop does not contain LoopEnd op at the expected place.");

// Note: handlers must be run on the range started with the first operation in the loop body.
handlers.run(linear_ir, std::next(start), end);
return ov::as_type_ptr<op::LoopEnd>(cloned_loop_end->get_node());
handlers.run(linear_ir, std::next(new_loop_begin_pos), new_loop_end_pos);
return new_loop_end;
};

const bool specific_first_iteration = !handlers.get_first_iter_handelrs().empty();
Expand Down

0 comments on commit 9e963bc

Please sign in to comment.