Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Snippets] Added PassConfig and PositionedPasses for PassPipeline #21382

18 changes: 16 additions & 2 deletions src/common/snippets/include/snippets/lowered/pass/pass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
#pragma once

#include "snippets/lowered/linear_ir.hpp"

#include "snippets/lowered/pass/pass_config.hpp"
#include "snippets/pass/positioned_pass.hpp"
#include "openvino/core/rtti.hpp"
#include "openvino/core/type.hpp"

Expand Down Expand Up @@ -49,8 +50,12 @@ class Pass {

class PassPipeline {
public:
PassPipeline() = default;
using PositionedPassLowered = snippets::pass::PositionedPass<lowered::pass::Pass>;

PassPipeline();
PassPipeline(const std::shared_ptr<PassConfig>& pass_config);

void register_pass(const snippets::pass::PassPosition& position, const std::shared_ptr<Pass>& pass);
void register_pass(const std::shared_ptr<Pass>& pass);

template<typename T, class... Args>
Expand All @@ -59,10 +64,19 @@ class PassPipeline {
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
register_pass(pass);
}
template<typename T, class Pos, class... Args, std::enable_if<std::is_same<snippets::pass::PassPosition, Pos>::value, bool>() = true>
void register_pass(const snippets::pass::PassPosition& position, Args&&... args) {
static_assert(std::is_base_of<Pass, T>::value, "Pass not derived from lowered::Pass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
register_pass(position, pass);
}

void register_positioned_passes(const std::vector<PositionedPassLowered>& pos_passes);

void run(lowered::LinearIR& linear_ir) const;

private:
std::shared_ptr<PassConfig> m_pass_config;
std::vector<std::shared_ptr<Pass>> m_passes;
};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "snippets/lowered/linear_ir.hpp"

#include "openvino/core/rtti.hpp"
#include "openvino/core/type.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

/**
* @interface PassConfig
* @brief Represents a transformations config that is used for disabling/enabling
* passes registered inside lowered::pass::PassPipeline
* @ingroup snippets
*/
class PassConfig {
public:
PassConfig() = default;

void disable(const DiscreteTypeInfo& type_info);
template <class T>
void disable() {
disable(T::get_type_info_static());
}

void enable(const DiscreteTypeInfo& type_info);
template <class T>
void enable() {
enable(T::get_type_info_static());
}

bool is_disabled(const DiscreteTypeInfo& type_info) const;
template <class T>
bool is_disabled() const {
return is_disabled(T::get_type_info_static());
}

bool is_enabled(const DiscreteTypeInfo& type_info) const;
template <class T>
bool is_enabled() const {
return is_enabled(T::get_type_info_static());
}

private:
std::unordered_set<DiscreteTypeInfo> m_disabled;
std::unordered_set<DiscreteTypeInfo> m_enabled;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
19 changes: 10 additions & 9 deletions src/common/snippets/include/snippets/op/subgraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "snippets/pass/manager.hpp"
#include "snippets/shape_inference/shape_inference.hpp"
#include "snippets/lowered/pass/pass.hpp"
#include "snippets/pass/positioned_pass.hpp"

#include "snippets/generator.hpp"

Expand Down Expand Up @@ -103,16 +104,16 @@ class Subgraph : public ov::op::util::SubGraphOp {
snippets::Schedule generate(const BlockedShapeVector& blocked_input_shapes = {},
const std::vector<ov::element::Type>& input_precisions = {},
const std::vector<ov::element::Type>& output_precisions = {},
const std::vector<pass::Manager::PositionedPass>& data_flow_passes = {},
const lowered::pass::PassPipeline& control_flow_passes_pre_common = {},
const lowered::pass::PassPipeline& control_flow_passes_post_common = {},
const std::vector<snippets::pass::Manager::PositionedPassBase>& data_flow_passes = {},
const std::shared_ptr<lowered::pass::PassConfig>& lowered_pass_config = std::make_shared<lowered::pass::PassConfig>(),
const std::vector<snippets::lowered::pass::PassPipeline::PositionedPassLowered>& lowered_backend_passes = {},
size_t min_parallel_work_amount = 8, size_t min_kernel_work_amount = 256,
const std::shared_ptr<IShapeInferSnippetsFactory>& factory = nullptr,
const void* compile_params = nullptr);

snippets::Schedule generate_from_linear_ir(const lowered::pass::PassPipeline& backend_passes_pre_common = {},
const lowered::pass::PassPipeline& backend_passes_post_common = {},
const void* compile_params = nullptr) const;
Schedule generate_from_linear_ir(const std::shared_ptr<lowered::pass::PassConfig>& lowered_pass_config = std::make_shared<lowered::pass::PassConfig>(),
const std::vector<snippets::lowered::pass::PassPipeline::PositionedPassLowered>& lowered_backend_passes = {},
const void* compile_params = nullptr) const;
IShapeInferSnippets::Result shape_infer(const std::vector<VectorDimsRef>& input_shapes);

// plugin sets generator for a snippet to some specific generator.
Expand Down Expand Up @@ -140,7 +141,7 @@ class Subgraph : public ov::op::util::SubGraphOp {
void data_flow_transformations(const BlockedShapeVector& blocked_input_shapes = {},
const std::vector<ov::element::Type>& input_precisions = {},
const std::vector<ov::element::Type>& output_precisions = {},
const std::vector<snippets::pass::Manager::PositionedPass>& = {});
const std::vector<snippets::pass::Manager::PositionedPassBase>& = {});
std::shared_ptr<lowered::LinearIR>
convert_body_to_linear_ir(size_t min_parallel_work_amount = 8, size_t min_kernel_work_amount = 256,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory = std::make_shared<IShapeInferSnippetsFactory>());
Expand All @@ -149,8 +150,8 @@ class Subgraph : public ov::op::util::SubGraphOp {
private:
void control_flow_transformations(lowered::LinearIR& linear_ir,
LoweringResult& lowering_result,
const lowered::pass::PassPipeline& backend_passes_pre_common,
const lowered::pass::PassPipeline& backend_passes_post_common) const;
const std::shared_ptr<lowered::pass::PassConfig>& lowered_pass_config = std::make_shared<lowered::pass::PassConfig>(),
const std::vector<snippets::lowered::pass::PassPipeline::PositionedPassLowered>& lowered_backend_passes = {}) const;
void init_config();
// Count of Subgraph virtual ports:
// - Potential non-scalar Constants that will be created after some transformations (At the moment it's relevant only for FakeQuantize decomposition)
Expand Down
37 changes: 4 additions & 33 deletions src/common/snippets/include/snippets/pass/manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#pragma once

#include "positioned_pass.hpp"

#include "openvino/pass/manager.hpp"
#include "openvino/pass/pass.hpp"
#include "openvino/pass/validate.hpp"
Expand All @@ -24,38 +26,7 @@ class Manager : public ov::pass::Manager {
~Manager() override = default;
using PassBase = ov::pass::PassBase;
using Validate = ov::pass::Validate;
/**
* @brief PassPosition describes a particular position in a transformation pipeline,
* where a new transformation should be inserted.
* @param pass_name name of the anchor pass, the new pass will be inserted before/after it.
* Empty pass_name could mean either beginning or the end of the pipeline depending on the `after` flag.
* No default value. Note that pass names namespaces are not supported, ov::PassName and snippets::PassName
* are considered identical.
* @param after `true` if the new pass should be inserted before the anchor pass, `false` otherwise (default).
* If `pass_name` is empty, `true` means the end, and `false` - the beginning of the pipeline.
* @param pass_instance the number of the pass with matching `pass_name` to be considered as the anchor pass.
* 0 (default) means the first pass with `pass_name` will be considered as the anchor pass.
* @ingroup snippets
*/
class PassPosition {
public:
enum class Place {Before, After, PipelineStart, PipelineEnd};
using PassListType = std::vector<std::shared_ptr<ov::pass::PassBase>>;
explicit PassPosition(Place pass_place);
explicit PassPosition(Place pass_place, std::string pass_name, size_t pass_instance = 0);
PassListType::const_iterator get_insert_position(const PassListType& pass_list) const;
private:
const std::string m_pass_name;
const size_t m_pass_instance{0};
const Place m_place{Place::Before};
};
struct PositionedPass {
PassPosition position;
std::shared_ptr<PassBase> pass;
PositionedPass(PassPosition arg_pos, std::shared_ptr<PassBase> arg_pass)
: position(std::move(arg_pos)), pass(std::move(arg_pass)) {
}
};
using PositionedPassBase = PositionedPass<PassBase>;

template <typename T, class... Args>
std::shared_ptr<T> register_pass(Args&&... args) {
Expand All @@ -74,7 +45,7 @@ class Manager : public ov::pass::Manager {
}

std::shared_ptr<PassBase> register_pass_instance(const PassPosition& pass_id, const std::shared_ptr<PassBase>& pass);
void register_positioned_passes(const std::vector<PositionedPass>& pos_passes);
void register_positioned_passes(const std::vector<PositionedPassBase>& pos_passes);

protected:
std::shared_ptr<Manager::PassBase> insert_pass_instance(const PassPosition& position, const std::shared_ptr<PassBase>& pass);
Expand Down
84 changes: 84 additions & 0 deletions src/common/snippets/include/snippets/pass/positioned_pass.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/manager.hpp"
#include "openvino/pass/pass.hpp"
#include "openvino/pass/validate.hpp"

#include <typeinfo>


namespace ov {
namespace snippets {
namespace pass {

/**
* @brief PassPosition describes a particular position in a transformation pipeline,
* where a new transformation should be inserted.
* @param pass_name name of the anchor pass, the new pass will be inserted before/after it.
* Empty pass_name could mean either beginning or the end of the pipeline depending on the `after` flag.
* No default value. Note that pass names namespaces are not supported, ov::PassName and snippets::PassName
* are considered identical.
* @param after `true` if the new pass should be inserted before the anchor pass, `false` otherwise (default).
* If `pass_name` is empty, `true` means the end, and `false` - the beginning of the pipeline.
* @param pass_instance the number of the pass with matching `pass_name` to be considered as the anchor pass.
* 0 (default) means the first pass with `pass_name` will be considered as the anchor pass.
* @ingroup snippets
*/
class PassPosition {
public:
enum class Place { Before, After, PipelineStart, PipelineEnd };

explicit PassPosition(Place pass_place);
explicit PassPosition(Place pass_place, const DiscreteTypeInfo& pass_type_info, size_t pass_instance = 0);

template<typename PassType>
typename std::vector<std::shared_ptr<PassType>>::const_iterator get_insert_position(const std::vector<std::shared_ptr<PassType>>& pass_list) const;

private:
const DiscreteTypeInfo m_pass_type_info = {};
const size_t m_pass_instance{0};
const Place m_place{Place::Before};
};

template<typename PassType>
struct PositionedPass {
PositionedPass(PassPosition arg_pos, std::shared_ptr<PassType> arg_pass)
: position(arg_pos), pass(std::move(arg_pass)) {}

PassPosition position;
std::shared_ptr<PassType> pass;
};

template<typename PassType>
typename std::vector<std::shared_ptr<PassType>>::const_iterator PassPosition::get_insert_position(
const std::vector<std::shared_ptr<PassType>>& pass_list) const {
size_t pass_count = 0;
auto match = [this, &pass_count](const std::shared_ptr<PassType>& p) {
if (p->get_type_info() == m_pass_type_info) {
if (m_pass_instance == pass_count)
return true;
pass_count++;
}
return false;
};
switch (m_place) {
case Place::PipelineStart: return pass_list.cbegin();
case Place::PipelineEnd: return pass_list.cend();
case Place::Before:
case Place::After: {
auto insert_it = std::find_if(pass_list.cbegin(), pass_list.cend(), match);
OPENVINO_ASSERT(insert_it != pass_list.cend(), "PassPosition ", m_pass_type_info, " cannot be found");
return m_place == Place::After ? std::next(insert_it) : insert_it;
}
default:
OPENVINO_THROW("Unsupported Place type in PassPosition::get_insert_position");
}
}

} // namespace pass
} // namespace snippets
} // namespace ov
21 changes: 21 additions & 0 deletions src/common/snippets/src/lowered/pass/pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,43 @@

#include "snippets/lowered/pass/pass.hpp"

#include "snippets/utils.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

PassPipeline::PassPipeline() : m_pass_config(std::make_shared<PassConfig>()) {}
PassPipeline::PassPipeline(const std::shared_ptr<PassConfig>& pass_config) : m_pass_config(pass_config) {
OPENVINO_ASSERT(m_pass_config != nullptr, "PassConfig is not initialized!");
}

void PassPipeline::register_pass(const snippets::pass::PassPosition& position, const std::shared_ptr<Pass>& pass) {
OPENVINO_ASSERT(pass != nullptr, "PassPipeline cannot register empty pass!");
m_passes.insert(position.get_insert_position(m_passes), pass);
IvanNovoselov marked this conversation as resolved.
Show resolved Hide resolved
}

void PassPipeline::register_pass(const std::shared_ptr<Pass>& pass) {
OPENVINO_ASSERT(pass != nullptr, "PassPipeline cannot register empty pass!");
m_passes.push_back(pass);
}

void PassPipeline::run(LinearIR& linear_ir) const {
for (const auto& pass : m_passes) {
OPENVINO_ASSERT(pass != nullptr, "PassPipeline has empty pass!");
if (m_pass_config->is_disabled(pass->get_type_info())) {
continue;
}
pass->run(linear_ir);
}
}

void PassPipeline::register_positioned_passes(const std::vector<PositionedPassLowered>& pos_passes) {
for (const auto& pp : pos_passes)
register_pass(pp.position, pp.pass);
}

} // namespace pass
} // namespace lowered
} // namespace snippets
Expand Down
34 changes: 34 additions & 0 deletions src/common/snippets/src/lowered/pass/pass_config.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "snippets/lowered/pass/pass_config.hpp"


namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

void PassConfig::disable(const DiscreteTypeInfo& type_info) {
m_enabled.erase(type_info);
m_disabled.insert(type_info);
}

void PassConfig::enable(const DiscreteTypeInfo& type_info) {
m_enabled.insert(type_info);
m_disabled.erase(type_info);
}

bool PassConfig::is_disabled(const DiscreteTypeInfo& type_info) const {
return m_disabled.count(type_info);
}

bool PassConfig::is_enabled(const DiscreteTypeInfo& type_info) const {
return m_enabled.count(type_info);
}

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
Loading
Loading