Skip to content

Commit

Permalink
[Snippets] Applied Ivan comments 2
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Aug 19, 2024
1 parent b0d06e4 commit 2336757
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 137 deletions.
35 changes: 32 additions & 3 deletions src/common/snippets/include/snippets/lowered/loop_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@ namespace snippets {
namespace lowered {

class LoopInfo;
using LoopInfoMap = std::map<LoopInfo*, std::shared_ptr<LoopInfo>>;
using LoopInfoMap = std::unordered_map<const LoopInfo*, std::shared_ptr<LoopInfo>>;
using LoopInfoSet = std::unordered_set<const LoopInfo*>;
using LoopInfoPtr = std::shared_ptr<LoopInfo>;

/**
* @interface LoopInfo
* @brief The base class that contains the common information about a Loop in Linear Intermediate Representation (Linear IR):
* work amount of the Loop, step of loop counter increment, input and output ports of the Loop.
* @ingroup snippets
*/
class LoopInfo {
class LoopInfo : public std::enable_shared_from_this<LoopInfo> {
public:
enum {UNDEFINED_DIM_IDX = std::numeric_limits<size_t>::max()};

Expand All @@ -38,6 +40,13 @@ class LoopInfo {
*/
virtual std::shared_ptr<LoopInfo> clone_with_new_expr(const ExpressionMap& expr_map, LoopInfoMap& loop_map) const = 0;

/**
* @brief Apply the passed function to the current LoopInfo
* @param func function for applying
* @param applied_loops set of already updated loops
*/
virtual void apply(const std::function<void(const LoopInfoPtr&)>& func, LoopInfoSet& applied_loops) = 0;

/**
* @brief Check if some parameters of Loop are dynamic (undefined)
* @return True if some parameters of Loop are unknown, False if all parameters are static
Expand Down Expand Up @@ -182,7 +191,6 @@ class LoopInfo {
std::vector<LoopPort> m_input_ports = {};
std::vector<LoopPort> m_output_ports = {};
};
using LoopInfoPtr = std::shared_ptr<LoopInfo>;

/**
* @interface UnifiedLoopInfo
Expand Down Expand Up @@ -232,6 +240,13 @@ class UnifiedLoopInfo : public LoopInfo {
*/
std::shared_ptr<LoopInfo> clone_with_new_expr(const ExpressionMap& expr_map, LoopInfoMap& loop_map) const override;

/**
* @brief Apply the passed function on the current LoopInfo.
* @param func function for applying
* @param applied_loops set of already updated loops
*/
void apply(const std::function<void(const LoopInfoPtr&)>& func, LoopInfoSet& applied_loops) override;

/**
* @brief Check if some parameters of Loop are dynamic (undefined)
* @return True if some parameters of Loop are unknown, False if all parameters are static
Expand Down Expand Up @@ -392,6 +407,13 @@ class InnerSplittedUnifiedLoopInfo : public UnifiedLoopInfo {
*/
std::shared_ptr<LoopInfo> clone_with_new_expr(const ExpressionMap& expr_map, LoopInfoMap& loop_map) const override;

/**
* @brief Apply the passed function on OuterSplittedLoopInfo and then on the current LoopInfo.
* @param func function for applying
* @param applied_loops set of already updated loops
*/
void apply(const std::function<void(const LoopInfoPtr&)>& func, LoopInfoSet& applied_loops) override;

/**
* @brief Returns work amount of the Loop.
* @return m_work_amount
Expand Down Expand Up @@ -443,6 +465,13 @@ class ExpandedLoopInfo : public LoopInfo {
*/
std::shared_ptr<LoopInfo> clone_with_new_expr(const ExpressionMap& expr_map, LoopInfoMap& loop_map) const override;

/**
* @brief Apply the passed function on UnifiedLoopInfo and then on the current LoopInfo.
* @param func function for applying
* @param applied_loops set of already updated loops
*/
void apply(const std::function<void(const LoopInfoPtr&)>& func, LoopInfoSet& applied_loops) override;

/**
* @brief Check if some parameters of Loop are dynamic (undefined)
* @return True if some parameters of Loop are unknown, False if all parameters are static
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ namespace pass {
class SplitLoops : public RangedPass {
public:
OPENVINO_RTTI("SplitLoops", "RangedPass")
SplitLoops();
SplitLoops() = default;
bool run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) override;

static void split(LinearIR& linear_ir, size_t loop_to_split_id, size_t outer_increment);

private:
static bool can_be_split(const UnifiedLoopInfoPtr& current, const UnifiedLoopInfoPtr& target);

static void split(LinearIR& linear_ir, size_t loop_to_split_id, size_t outer_increment);

/**
* @interface TransformInnerSplitLoop
* @brief The pass replace existing inner splitted LoopInfo with new InnerSplittedUnifiedLoopInfo and
Expand Down
70 changes: 48 additions & 22 deletions src/common/snippets/src/lowered/loop_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,21 @@ UnifiedLoopInfo::UnifiedLoopInfo(size_t work_amount, size_t increment,
}

std::shared_ptr<LoopInfo> UnifiedLoopInfo::clone_with_new_expr(const ExpressionMap& expr_map, LoopInfoMap& loop_map) const {
const auto& new_input_ports = clone_loop_ports(expr_map, m_input_ports);
const auto& new_output_ports = clone_loop_ports(expr_map, m_output_ports);
if (loop_map.count(this) == 0) {
const auto& new_input_ports = clone_loop_ports(expr_map, m_input_ports);
const auto& new_output_ports = clone_loop_ports(expr_map, m_output_ports);

return std::make_shared<UnifiedLoopInfo>(m_work_amount, m_increment, new_input_ports, new_output_ports,
m_input_port_descs, m_output_port_descs, m_handlers);
loop_map[this] = std::make_shared<UnifiedLoopInfo>(m_work_amount, m_increment, new_input_ports, new_output_ports,
m_input_port_descs, m_output_port_descs, m_handlers);
}
return loop_map.at(this);
}

void UnifiedLoopInfo::apply(const std::function<void(const LoopInfoPtr&)>& func, LoopInfoSet& applied_loops) {
if (applied_loops.count(this) == 0) {
func(this->shared_from_this());
applied_loops.insert(this);
}
}

bool UnifiedLoopInfo::is_dynamic() const {
Expand Down Expand Up @@ -364,16 +374,24 @@ InnerSplittedUnifiedLoopInfo::InnerSplittedUnifiedLoopInfo(size_t increment, con
}

std::shared_ptr<LoopInfo> InnerSplittedUnifiedLoopInfo::clone_with_new_expr(const ExpressionMap& expr_map, LoopInfoMap& loop_map) const {
if (loop_map.count(m_outer_splitted_loop_info.get()) == 0)
loop_map[m_outer_splitted_loop_info.get()] = m_outer_splitted_loop_info->clone_with_new_expr(expr_map, loop_map);

const auto cloned_outer_splitted_loop_info = loop_map.at(m_outer_splitted_loop_info.get());
const auto& new_input_ports = clone_loop_ports(expr_map, m_input_ports);
const auto& new_output_ports = clone_loop_ports(expr_map, m_output_ports);
if (loop_map.count(this) == 0) {
auto cloned_outer_splitted_loop_info = m_outer_splitted_loop_info->clone_with_new_expr(expr_map, loop_map);
const auto& new_input_ports = clone_loop_ports(expr_map, m_input_ports);
const auto& new_output_ports = clone_loop_ports(expr_map, m_output_ports);

loop_map[this] = std::make_shared<InnerSplittedUnifiedLoopInfo>(m_increment, new_input_ports, new_output_ports,
m_input_port_descs, m_output_port_descs, m_handlers,
std::move(cloned_outer_splitted_loop_info));
}
return loop_map.at(this);
}

return std::make_shared<InnerSplittedUnifiedLoopInfo>(m_increment, new_input_ports, new_output_ports,
m_input_port_descs, m_output_port_descs, m_handlers,
std::move(cloned_outer_splitted_loop_info));
void InnerSplittedUnifiedLoopInfo::apply(const std::function<void(const LoopInfoPtr&)>& func, LoopInfoSet& applied_loops) {
if (applied_loops.count(this) == 0) {
m_outer_splitted_loop_info->apply(func, applied_loops);
func(this->shared_from_this());
applied_loops.insert(this);
}
}

size_t InnerSplittedUnifiedLoopInfo::get_work_amount() const {
Expand Down Expand Up @@ -406,16 +424,24 @@ ExpandedLoopInfo::ExpandedLoopInfo(size_t work_amount, size_t increment,
}

std::shared_ptr<LoopInfo> ExpandedLoopInfo::clone_with_new_expr(const ExpressionMap& expr_map, LoopInfoMap& loop_map) const {
if (loop_map.count(m_unified_loop_info.get()) == 0)
loop_map[m_unified_loop_info.get()] = m_unified_loop_info->clone_with_new_expr(expr_map, loop_map);

const auto cloned_unified_loop_info = ov::as_type_ptr<UnifiedLoopInfo>(loop_map.at(m_unified_loop_info.get()));
const auto& new_input_ports = clone_loop_ports(expr_map, m_input_ports);
const auto& new_output_ports = clone_loop_ports(expr_map, m_output_ports);
if (loop_map.count(this) == 0) {
auto cloned_unified_loop_info = ov::as_type_ptr<UnifiedLoopInfo>(m_unified_loop_info->clone_with_new_expr(expr_map, loop_map));
const auto& new_input_ports = clone_loop_ports(expr_map, m_input_ports);
const auto& new_output_ports = clone_loop_ports(expr_map, m_output_ports);

loop_map[this] = std::make_shared<ExpandedLoopInfo>(m_work_amount, m_increment, new_input_ports, new_output_ports,
m_ptr_increments, m_finalization_offsets, m_data_sizes, m_type,
std::move(cloned_unified_loop_info), m_evaluate_once);
}
return loop_map.at(this);
}

return std::make_shared<ExpandedLoopInfo>(m_work_amount, m_increment, new_input_ports, new_output_ports,
m_ptr_increments, m_finalization_offsets, m_data_sizes, m_type,
std::move(cloned_unified_loop_info), m_evaluate_once);
void ExpandedLoopInfo::apply(const std::function<void(const LoopInfoPtr&)>& func, LoopInfoSet& applied_loops) {
if (applied_loops.count(this) == 0) {
m_unified_loop_info->apply(func, applied_loops);
func(this->shared_from_this());
applied_loops.insert(this);
}
}

bool ExpandedLoopInfo::is_dynamic() const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ size_t ComputeBufferAllocationSize::get_allocation_size(const LoopManagerPtr& lo
auto it = std::find_if(output_ports.begin(), output_ports.end(), hard_equal);
// [149219] : Try to find original loop port if this LoopInfo is cloned after InsertSpecificIterations
// and ports are not mapped on the original ExpressionPorts
// Note: this check is needed only in Splitted Loops
if (it == output_ports.end()) {
it = std::find_if(output_ports.begin(), output_ports.end(), soft_equal);
OPENVINO_ASSERT(it != output_ports.end(), "compute_allocation_shape: output port of parent loop can not be found");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ std::vector<size_t> get_reordered_loop_ids(const LoopManagerPtr& loop_manager) {
auto sorter = [&](size_t lhs, size_t rhs) {
const auto lhs_last_expr = loop_manager->get_loop_info(lhs)->get_output_ports().back().expr_port->get_expr();
const auto rhs_last_expr = loop_manager->get_loop_info(rhs)->get_output_ports().back().expr_port->get_expr();
// If LoopEnd is the same expression - first executive Loop has inner ID in expression loop IDs.
// If last output loop ports are the same expressions - first executive Loop has inner ID in expression loop IDs.
if (lhs_last_expr == rhs_last_expr) {
for (const auto& id : lhs_last_expr->get_loop_ids()) {
if (id == lhs) return false;
Expand Down
2 changes: 0 additions & 2 deletions src/common/snippets/src/lowered/pass/split_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ namespace snippets {
namespace lowered {
namespace pass {

SplitLoops::SplitLoops() : RangedPass() {}

bool SplitLoops::can_be_split(const UnifiedLoopInfoPtr& loop_to_split, const UnifiedLoopInfoPtr& loop_to_fuse) {
OPENVINO_ASSERT(loop_to_split != nullptr && loop_to_fuse != nullptr, "LoopInfo is nullptr!");
const auto current_dim_idx = loop_to_split->get_dim_idx();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ SpecificIterationHandlers::SpecificIterationHandlers(size_t loop_work_amount, si
if (!utils::is_dynamic_value(loop_work_amount)) {
last_iter_increment = loop_work_amount % loop_increment;
} else if (utils::is_dynamic_value(loop_work_amount) && processing_dim_idx == 0) {
// Last Iterations of Loop processed last dimensions with Eltwise nodes inside should have increment = 1
// [149935] : Last Iterations of Loop processed last dimensions with Eltwise nodes inside should have increment = 1
last_iter_increment = 1;
}
if (last_iter_increment != 0) {
Expand Down
39 changes: 29 additions & 10 deletions src/common/snippets/src/runtime_configurator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,19 @@ void RuntimeConfigurator::init_buffer_info(const lowered::LinearIRCPtr& linear_i

void RuntimeConfigurator::update_loop_info(const lowered::LinearIRCPtr& linear_ir,
LoopInfoRuntimeParamsMap& initializated_info_map) const {
const auto& loop_manager = linear_ir->get_loop_manager();
for (const auto& loop_id : m_ordered_loop_ids) {
const auto& expanded_loop_info = loop_manager->get_loop_info<lowered::ExpandedLoopInfo>(loop_id);
OPENVINO_ASSERT(expanded_loop_info, "UpdateLoopInfo expects ExpandedLoopInfo in LoopManager");
lowered::LoopInfoSet updated_loops;
std::function<void(const lowered::LoopInfoPtr& loop_info)> update_loop_info;

// First visiting of unified (whole) loop
const auto& current_unified_loop_info = expanded_loop_info->get_unified_loop_info();
if (initializated_info_map.count(current_unified_loop_info) == 0) {
lowered::pass::InitLoops::update_runtime_parameters(current_unified_loop_info);
initializated_info_map[current_unified_loop_info] = compute_runtime_params(current_unified_loop_info);
auto update_unified_loop_info = [&](const lowered::UnifiedLoopInfoPtr& unified_loop_info) {
if (initializated_info_map.count(unified_loop_info) == 0) {
lowered::pass::InitLoops::update_runtime_parameters(unified_loop_info);
initializated_info_map[unified_loop_info] = compute_runtime_params(unified_loop_info);
}
};

auto update_expanded_loop_info = [&](const lowered::ExpandedLoopInfoPtr& expanded_loop_info) {
const auto& current_unified_loop_info = expanded_loop_info->get_unified_loop_info();
current_unified_loop_info->apply(update_loop_info, updated_loops);

auto& initializated_info = initializated_info_map.at(current_unified_loop_info);
auto& current_work_amount = initializated_info.work_amount;
Expand All @@ -209,7 +211,7 @@ void RuntimeConfigurator::update_loop_info(const lowered::LinearIRCPtr& linear_i
expanded_loop_info->set_work_amount(0);
if (expanded_loop_info->is_evaluate_once())
expanded_loop_info->set_increment(0);
continue;
return;
}

const auto work_amount =
Expand All @@ -229,6 +231,23 @@ void RuntimeConfigurator::update_loop_info(const lowered::LinearIRCPtr& linear_i
expanded_loop_info->update_ptr_increments(ptr_increments);
}
expanded_loop_info->update_finalization_offsets(updated_finalization_offsets);
};

update_loop_info = [&](const lowered::LoopInfoPtr& loop_info) {
if (const auto unified_loop_info = ov::as_type_ptr<lowered::UnifiedLoopInfo>(loop_info)) {
update_unified_loop_info(unified_loop_info);
} else if (const auto expanded_loop_info = ov::as_type_ptr<lowered::ExpandedLoopInfo>(loop_info)) {
update_expanded_loop_info(expanded_loop_info);
} else {
OPENVINO_THROW("Failed to update loop info: unknown type!");
}
};

const auto& loop_map = linear_ir->get_loop_manager()->get_map();
for (const auto& p : loop_map) {
const auto& expanded_loop_info = ov::as_type_ptr<lowered::ExpandedLoopInfo>(p.second);
OPENVINO_ASSERT(expanded_loop_info, "UpdateLoopInfo expects ExpandedLoopInfo in LoopManager");
expanded_loop_info->apply(update_loop_info, updated_loops);
}
}

Expand Down
Loading

0 comments on commit 2336757

Please sign in to comment.