Skip to content

Commit

Permalink
Fixed Load+Broadcast and added FuseLoadStoreConvert support
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Mar 27, 2023
1 parent 4ad4e9f commit 3604db5
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 201 deletions.
3 changes: 3 additions & 0 deletions src/common/snippets/include/snippets/generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "emitter.hpp"
#include "target_machine.hpp"
#include "lowered_expr.hpp"
#include "pass/lowered/linear_IR_transformation.hpp"

namespace ngraph {
namespace snippets {
Expand Down Expand Up @@ -86,6 +87,8 @@ class Generator {
std::shared_ptr<const TargetMachine> get_target_machine() const;

protected:
virtual pass::lowered::LinearIRTransformationPipeline target_specific_transformations() const;

std::shared_ptr<TargetMachine> target;
// todo: we need to save lowered code to access compiled brgemm kernels on execution time (normally lowered is destructed by then).
// This is temporary solution, remove this when kernel caching is implemented. Don't forget to make generate const method.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,25 @@ class LinearIRTransformation {
virtual bool run(LoweredExprIR& linear_ir) = 0;
};

class LinearIRTransformationPipeline {
public:
LinearIRTransformationPipeline() = default;

void register_transformation(const std::shared_ptr<pass::lowered::LinearIRTransformation>& transformation);

template<typename T, class... Args>
void register_transformation(Args&&... args) {
static_assert(std::is_base_of<LinearIRTransformation, T>::value, "Transformation not derived from LinearIRTransformation");
auto transformation = std::make_shared<T>(std::forward<Args>(args)...);
register_transformation(transformation);
}

void run(LoweredExprIR& linear_ir);

private:
std::vector<std::shared_ptr<pass::lowered::LinearIRTransformation>> m_transformations;
};

} // namespace lowered
} // namespace pass
} // namespace snippets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace lowered {
*/
class LoadMoveBroadcastToBroadcastLoad: public LinearIRTransformation {
public:
LoadMoveBroadcastToBroadcastLoad();
LoadMoveBroadcastToBroadcastLoad() = default;
OPENVINO_RTTI("LoadMoveBroadcastToBroadcastLoad", "LinearIRTransformation")
bool run(LoweredExprIR& linear_ir) override;
};
Expand Down
53 changes: 30 additions & 23 deletions src/common/snippets/src/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,30 +41,33 @@ Generator::LoweringResult Generator::generate(std::shared_ptr<ov::Model>& m, con
// Note: The pass LoopInit uses LoopInfo that contains entry and exit points of the corresponding Loop.
// To avoid the Loop information corruption, we should call the passes with Load/Store work
// (for example, LoadMoveBroadcastToBroadcastLoad()) after explicit Loop insertion (LoopInit())
auto propagate_buffer_offsets = std::make_shared<pass::lowered::PropagateOffsetAndResetBuffer>();
std::vector<std::shared_ptr<pass::lowered::LinearIRTransformation>> transformation_pipeline {
std::make_shared<pass::lowered::LoopMarkup>(vector_size),
std::make_shared<pass::lowered::SoftmaxDecomposition>(vector_size),
std::make_shared<pass::lowered::LoopFusion>(),
std::make_shared<pass::lowered::MoveResultOutOfLoop>(),
std::make_shared<pass::lowered::BufferInsertion>(buffer_allocation_rank),
std::make_shared<pass::lowered::LoadStoreInsertion>(vector_size),
std::make_shared<pass::lowered::SetScalarCountForLoadStore>(),
std::make_shared<pass::lowered::LoopInit>(),
std::make_shared<pass::lowered::MoveScalarToConsumer>(),
std::make_shared<pass::lowered::LoadMoveBroadcastToBroadcastLoad>(),
std::make_shared<pass::lowered::PropagateLayout>(),
propagate_buffer_offsets,
std::make_shared<pass::lowered::CleanupLoopOffsets>(),
std::make_shared<pass::lowered::AssignRegisters>(),
std::make_shared<pass::lowered::InsertTailLoop>()
};
for (const auto& transform : transformation_pipeline) {
transform->run(linear_ir);
}
const auto propagate_buffer_offsets = std::make_shared<pass::lowered::PropagateOffsetAndResetBuffer>();
pass::lowered::LinearIRTransformationPipeline common_pipeline;
common_pipeline.register_transformation<pass::lowered::LoopMarkup>(vector_size);
common_pipeline.register_transformation<pass::lowered::SoftmaxDecomposition>(vector_size);
common_pipeline.register_transformation<pass::lowered::LoopFusion>();
common_pipeline.register_transformation<pass::lowered::MoveResultOutOfLoop>();
common_pipeline.register_transformation<pass::lowered::BufferInsertion>(buffer_allocation_rank);
common_pipeline.register_transformation<pass::lowered::LoadStoreInsertion>(vector_size);
common_pipeline.register_transformation<pass::lowered::SetScalarCountForLoadStore>();
common_pipeline.register_transformation<pass::lowered::LoopInit>();
common_pipeline.register_transformation<pass::lowered::MoveScalarToConsumer>();
common_pipeline.register_transformation<pass::lowered::LoadMoveBroadcastToBroadcastLoad>();
common_pipeline.register_transformation<pass::lowered::PropagateLayout>();
common_pipeline.register_transformation(propagate_buffer_offsets);
common_pipeline.register_transformation<pass::lowered::CleanupLoopOffsets>();
common_pipeline.run(linear_ir);

pass::lowered::LinearIRTransformationPipeline target_pipeline = target_specific_transformations();
target_pipeline.run(linear_ir);

pass::lowered::LinearIRTransformationPipeline final_pipeline;
final_pipeline.register_transformation<pass::lowered::AssignRegisters>();
final_pipeline.register_transformation<pass::lowered::InsertTailLoop>();
final_pipeline.run(linear_ir);

const auto buffer_scratchpad_size = propagate_buffer_offsets->get_scratchpad_size();
linear_ir.init_emitters(target);

OV_ITT_TASK_NEXT(GENERATE, "::EmitCode")
auto loops2DKernel = std::make_shared<op::Kernel>(linear_ir);
loops2DKernel->compile_params = compile_params;
Expand All @@ -83,12 +86,16 @@ Generator::LoweringResult Generator::generate(std::shared_ptr<ov::Model>& m, con
if (config.m_save_lowered_code)
lowered_saved = linear_ir;

return {target->get_snippet(), buffer_scratchpad_size};
return {target->get_snippet(), propagate_buffer_offsets->get_scratchpad_size()};
}

std::shared_ptr<const TargetMachine> Generator::get_target_machine() const {
return target;
}

pass::lowered::LinearIRTransformationPipeline Generator::target_specific_transformations() const {
return pass::lowered::LinearIRTransformationPipeline();
}

}// namespace snippets
}// namespace ngraph
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ namespace pass {
namespace lowered {


LoadMoveBroadcastToBroadcastLoad::LoadMoveBroadcastToBroadcastLoad() {}

bool LoadMoveBroadcastToBroadcastLoad::run(LoweredExprIR& linear_ir) {
OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::LoadMoveBroadcastToBroadcastLoad")
bool modified = false;
Expand Down Expand Up @@ -47,9 +45,10 @@ bool LoadMoveBroadcastToBroadcastLoad::run(LoweredExprIR& linear_ir) {
const auto in_td = std::vector<TensorDescriptorPtr>{ parent_expr->get_inputs().front() };
const auto out_td = std::vector<TensorDescriptorPtr>{ (*expr_it)->get_outputs().front() };
const auto mv_expr_it = expr_it;
expr_it = linear_ir.insert(std::next(expr_it), std::make_shared<LoweredExpr>(broadcastload, in_td, out_td));
const auto insertion_pos = std::next(expr_it);
linear_ir.erase(std::find(linear_ir.begin(), mv_expr_it, parent_expr));
linear_ir.erase(mv_expr_it);
expr_it = linear_ir.insert(insertion_pos, std::make_shared<LoweredExpr>(broadcastload, in_td, out_td));
modified |= true;
}
}
Expand Down
7 changes: 7 additions & 0 deletions src/plugins/intel_cpu/src/emitters/cpu_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "jit_dnnl_ext_emitters.hpp"
#include "jit_conversion_emitters.hpp"

#include "snippets_transformations/lowered/fuse_load_store_and_convert.hpp"
#include "snippets_transformations/op/load_convert.hpp"
#include "snippets_transformations/op/store_convert.hpp"
#include "snippets_transformations/op/fused_mul_add.hpp"
Expand Down Expand Up @@ -169,3 +170,9 @@ code ov::intel_cpu::CPUTargetMachine::get_snippet() const {

ov::intel_cpu::CPUGenerator::CPUGenerator(dnnl::impl::cpu::x64::cpu_isa_t isa_) : Generator(std::make_shared<CPUTargetMachine>(isa_)) {
}

ngraph::snippets::pass::lowered::LinearIRTransformationPipeline ov::intel_cpu::CPUGenerator::target_specific_transformations() const {
ngraph::snippets::pass::lowered::LinearIRTransformationPipeline target_specific_transformation;
target_specific_transformation.register_transformation<ov::intel_cpu::pass::FuseLoadStoreConvert>();
return target_specific_transformation;
}
3 changes: 3 additions & 0 deletions src/plugins/intel_cpu/src/emitters/cpu_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class CPUTargetMachine : public ngraph::snippets::TargetMachine {
class CPUGenerator : public ngraph::snippets::Generator {
public:
CPUGenerator(dnnl::impl::cpu::x64::cpu_isa_t isa);

protected:
ngraph::snippets::pass::lowered::LinearIRTransformationPipeline target_specific_transformations() const override;
};

} // namespace intel_cpu
Expand Down
17 changes: 0 additions & 17 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <snippets/op/subgraph.hpp>
#include "emitters/cpu_generator.hpp"
#include "utils/cpu_utils.hpp"
#include "snippets_transformations/fuse_load_store_and_convert.hpp"
#include "snippets_transformations/mul_add_to_fma.hpp"
#include "snippets_transformations/remove_converts.hpp"
#include "ngraph_transformations/convert_to_swish_cpu.hpp"
Expand Down Expand Up @@ -552,22 +551,6 @@ void Snippet::generate(const jit_snippets_compile_args* jcp) {

ov::pass::Manager post_precision;
post_precision.register_pass<ov::intel_cpu::pass::RemoveConverts>();
post_precision.register_pass<ov::intel_cpu::pass::FuseLoadConvert>();
post_precision.register_pass<ov::intel_cpu::pass::FuseStoreConvert>();
// LoadConvert uses Load emitter that support conversion from any type to only f32
post_precision.get_pass_config()->set_callback<ov::intel_cpu::pass::FuseLoadConvert>(
[](const std::shared_ptr<const ov::Node>& n) -> bool {
if (const auto& convert = std::dynamic_pointer_cast<const ov::op::v0::Convert>(n))
return convert->get_destination_type() != ov::element::f32;
return true;
});
// StoreConvert uses Store emitter that support conversion from only f32 to any types
post_precision.get_pass_config()->set_callback<ov::intel_cpu::pass::FuseStoreConvert>(
[](const std::shared_ptr<const ov::Node>& n) -> bool {
if (const auto& convert = std::dynamic_pointer_cast<const ov::op::v0::Convert>(n))
return convert->get_input_element_type(0) != ov::element::f32;
return true;
});
post_precision.register_pass<ov::intel_cpu::pass::MulAddToFMA>();

schedule = snippet->generate(
Expand Down

This file was deleted.

This file was deleted.

0 comments on commit 3604db5

Please sign in to comment.