-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Snippets] Assign registers and ABI call optimizations
- Loading branch information
1 parent
357eb54
commit 890beb9
Showing
50 changed files
with
1,858 additions
and
672 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
33 changes: 33 additions & 0 deletions
33
src/common/snippets/include/snippets/lowered/pass/init_live_ranges.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// Copyright (C) 2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "pass.hpp" | ||
#include "snippets/generator.hpp" | ||
#include "snippets/lowered/reg_manager.hpp" | ||
|
||
namespace ov { | ||
namespace snippets { | ||
namespace lowered { | ||
namespace pass { | ||
|
||
/** | ||
* @interface InitLiveRanges | ||
* @brief Calculates live ranges of registers. This information will be used to assign registers and optimize ABI reg spills. | ||
* @ingroup snippets | ||
*/ | ||
class InitLiveRanges : public Pass { | ||
public: | ||
OPENVINO_RTTI("InitLiveRanges", "Pass") | ||
explicit InitLiveRanges(RegManager& reg_manager) : m_reg_manager(reg_manager) {} | ||
bool run(LinearIR& linear_ir) override; | ||
private: | ||
RegManager& m_reg_manager; | ||
}; | ||
|
||
} // namespace pass | ||
} // namespace lowered | ||
} // namespace snippets | ||
} // namespace ov |
32 changes: 32 additions & 0 deletions
32
src/common/snippets/include/snippets/lowered/pass/insert_reg_spills.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// Copyright (C) 2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "pass.hpp" | ||
#include "snippets/lowered/reg_manager.hpp" | ||
|
||
namespace ov { | ||
namespace snippets { | ||
namespace lowered { | ||
namespace pass { | ||
|
||
/** | ||
* @interface InsertRegSpills | ||
* @brief Insert RegSpill and RegRestore operations for binary call emitters to comply with ABI conventions. | ||
* @ingroup snippets | ||
*/ | ||
class InsertRegSpills : public Pass { | ||
public: | ||
OPENVINO_RTTI("InsertRegSpills", "Pass") | ||
explicit InsertRegSpills(RegManager& reg_manager) : m_reg_manager(reg_manager) {} | ||
bool run(LinearIR& linear_ir) override; | ||
|
||
RegManager& m_reg_manager; | ||
}; | ||
|
||
} // namespace pass | ||
} // namespace lowered | ||
} // namespace snippets | ||
} // namespace ov |
67 changes: 67 additions & 0 deletions
67
src/common/snippets/include/snippets/lowered/reg_manager.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
#include "openvino/core/node.hpp" | ||
#include "snippets/emitter.hpp" | ||
#include "snippets/lowered/expression.hpp" | ||
#include "snippets/generator.hpp" | ||
#include "snippets/op/kernel.hpp" | ||
|
||
/** | ||
* @interface RegManager | ||
* @brief The class holds supplementary info about assigned registers and live ranges | ||
* @ingroup snippets | ||
*/ | ||
namespace ov { | ||
namespace snippets { | ||
namespace lowered { | ||
|
||
using RegTypeMapper = std::function<RegType(const ov::Output<Node>& out)>; | ||
using LiveInterval = std::pair<double, double>; | ||
class RegManager { | ||
public: | ||
RegManager() = delete; | ||
RegManager(const std::shared_ptr<const Generator>& generator) : m_generator(generator) {} | ||
inline RegType get_reg_type(const ov::Output<Node>& out) const { return m_generator->get_op_out_reg_type(out); } | ||
inline std::vector<Reg> get_vec_reg_pool() const { return m_generator->get_target_machine()->get_vec_reg_pool(); } | ||
|
||
inline void set_live_range(const Reg& reg, const LiveInterval& interval, bool force = false) { | ||
OPENVINO_ASSERT(force || m_reg_live_range.count(reg) == 0, "Live range for this reg is already set"); | ||
m_reg_live_range[reg] = interval; | ||
} | ||
|
||
inline std::vector<Reg> get_kernel_call_regs(const std::shared_ptr<snippets::op::Kernel>& kernel) const { | ||
const auto& abi_regs = m_generator->get_target_machine()->get_abi_arg_regs(); | ||
const auto num_kernel_args = kernel->get_num_call_args(); | ||
OPENVINO_ASSERT(abi_regs.size() > num_kernel_args, "Too many kernel args requested"); | ||
return {abi_regs.begin(), abi_regs.begin() + static_cast<int64_t>(num_kernel_args)}; | ||
} | ||
|
||
inline std::vector<Reg> get_gp_regs_except_kernel_call(const std::shared_ptr<snippets::op::Kernel>& kernel) const { | ||
auto res = m_generator->get_target_machine()->get_gp_reg_pool(); | ||
std::set<Reg> kernel_call; | ||
for (auto r : get_kernel_call_regs(kernel)) | ||
kernel_call.insert(r); | ||
res.erase(std::remove_if(res.begin(), res.end(), [&kernel_call](const Reg& r) {return kernel_call.count(r) != 0; }), res.end()); | ||
return res; | ||
} | ||
|
||
inline const LiveInterval& get_live_range(const Reg& reg) { | ||
OPENVINO_ASSERT(m_reg_live_range.count(reg), "Live range for this reg was not set"); | ||
return m_reg_live_range[reg]; | ||
} | ||
inline std::map<Reg, LiveInterval> get_live_range_map() const { | ||
return m_reg_live_range; | ||
} | ||
|
||
private: | ||
// Maps Register to {Start, Stop} pairs | ||
std::map<Reg, LiveInterval> m_reg_live_range; | ||
const std::shared_ptr<const Generator> m_generator; | ||
}; | ||
|
||
} // namespace lowered | ||
} // namespace snippets | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "snippets/emitter.hpp" | ||
|
||
#include "openvino/op/op.hpp" | ||
#include "snippets/shape_inference/shape_inference.hpp" | ||
|
||
namespace ov { | ||
namespace snippets { | ||
namespace op { | ||
|
||
/** | ||
* @interface RegSpillBase | ||
* @brief Base class for RegSpillBegin and RegSpillEnd ops | ||
* @ingroup snippets | ||
*/ | ||
class RegSpillBase : public ov::op::Op { | ||
public: | ||
OPENVINO_OP("RegSpillBaseBase", "SnippetsOpset"); | ||
RegSpillBase(const std::vector<Output<Node>>& args); | ||
RegSpillBase() = default; | ||
virtual std::set<Reg> get_regs_to_spill() const = 0; | ||
bool visit_attributes(AttributeVisitor& visitor) override; | ||
protected: | ||
}; | ||
class RegSpillEnd; | ||
/** | ||
* @interface RegSpillBegin | ||
* @brief Marks the start of the register spill region. | ||
* @ingroup snippets | ||
*/ | ||
class RegSpillBegin : public RegSpillBase { | ||
public: | ||
OPENVINO_OP("RegSpillBegin", "SnippetsOpset", RegSpillBase); | ||
RegSpillBegin(std::set<Reg> regs_to_spill); | ||
|
||
void validate_and_infer_types() override; | ||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override; | ||
std::shared_ptr<RegSpillEnd> get_reg_spill_end() const; | ||
std::set<Reg> get_regs_to_spill() const override { return m_regs_to_spill; } | ||
|
||
class ShapeInfer : public IShapeInferSnippets { | ||
size_t num_out_shapes = 0; | ||
public: | ||
explicit ShapeInfer(const std::shared_ptr<ov::Node>& n); | ||
Result infer(const std::vector<VectorDimsRef>& input_shapes) override; | ||
}; | ||
protected: | ||
void validate_and_infer_types_except_RegSpillEnd(); | ||
std::set<Reg> m_regs_to_spill = {}; | ||
}; | ||
/** | ||
* @interface RegSpillEnd | ||
* @brief Marks the end of the register spill region. | ||
* @ingroup snippets | ||
*/ | ||
class RegSpillEnd : public RegSpillBase { | ||
public: | ||
OPENVINO_OP("RegSpillEnd", "SnippetsOpset", RegSpillBase); | ||
RegSpillEnd() = default; | ||
RegSpillEnd(const Output<Node>& reg_spill_begin); | ||
|
||
void validate_and_infer_types() override; | ||
|
||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override; | ||
std::shared_ptr<RegSpillBegin> get_reg_spill_begin() const { | ||
auto reg_spill_begin = ov::as_type_ptr<RegSpillBegin>(get_input_node_shared_ptr(0)); | ||
OPENVINO_ASSERT(reg_spill_begin, "Can't get reg_spill_begin from reg_spill_end"); | ||
return reg_spill_begin; | ||
} | ||
std::set<Reg> get_regs_to_spill() const override { | ||
return get_reg_spill_begin()->get_regs_to_spill(); | ||
} | ||
}; | ||
|
||
} // namespace op | ||
} // namespace snippets | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.