diff --git a/src/common/snippets/CMakeLists.txt b/src/common/snippets/CMakeLists.txt index 962d939c563ebd..6321a375838f9e 100644 --- a/src/common/snippets/CMakeLists.txt +++ b/src/common/snippets/CMakeLists.txt @@ -26,9 +26,10 @@ ie_faster_build(${TARGET_NAME} ) target_link_libraries(${TARGET_NAME} PUBLIC openvino::runtime - PRIVATE ngraph_reference openvino::runtime::dev) + PRIVATE ngraph_reference ov_shape_inference openvino::runtime::dev) -target_include_directories(${TARGET_NAME} PUBLIC $) +target_include_directories(${TARGET_NAME} PUBLIC $ + PRIVATE $) add_cpplint_target(${TARGET_NAME}_cpplint FOR_TARGETS ${TARGET_NAME}) diff --git a/src/common/snippets/include/snippets/generator.hpp b/src/common/snippets/include/snippets/generator.hpp index f21a6951fedd62..7540c950e32253 100644 --- a/src/common/snippets/include/snippets/generator.hpp +++ b/src/common/snippets/include/snippets/generator.hpp @@ -112,12 +112,21 @@ class Generator { * @brief Default destructor */ virtual ~Generator() = default; + /** + * @interface GeneratorConfig + * @brief Allows to tweak the lowering process. + */ + class GeneratorConfig { + public: + // True if the lowered Emitters need to be accessed during runtime. Normally they're destroyed after code emission. + bool m_save_lowered_code = false; + }; /** * @brief virtual method any specific implementation should implement * @param m model in canonical for for table-based code generation * @return pointer to generated code */ - code generate(std::shared_ptr& m, const void* compile_params = nullptr) const; + code generate(std::shared_ptr& m, const GeneratorConfig& config, const void* compile_params = nullptr); /** * @brief gets target machine @@ -127,6 +136,9 @@ class Generator { protected: std::shared_ptr target; + // todo: we need to save lowered code to access compiled brgemm kernels on execution time (normally lowered is destructed by then). + // This is temporary solution, remove this when kernel caching is implemented. Don't forget to make generate const method. + std::vector lowered_saved; }; } // namespace snippets diff --git a/src/common/snippets/include/snippets/op/brgemm.hpp b/src/common/snippets/include/snippets/op/brgemm.hpp new file mode 100644 index 00000000000000..83471c04d0553a --- /dev/null +++ b/src/common/snippets/include/snippets/op/brgemm.hpp @@ -0,0 +1,33 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "ngraph/op/op.hpp" +#include "ngraph/op/matmul.hpp" + +namespace ngraph { +namespace snippets { +namespace op { + +/** + * @interface Brgemm + * @brief Brgemm is a batch-reduced matrix multiplication with the support of arbitrary strides between matrices rows + * @ingroup snippets + */ +class Brgemm : public ngraph::op::v0::MatMul { +public: + OPENVINO_OP("Brgemm", "SnippetsOpset", ngraph::op::v0::MatMul); + Brgemm(const Output& A, const Output& B); + Brgemm() = default; + + void validate_and_infer_types() override; + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + + bool has_evaluate() const override { return false; } +}; + +} // namespace op +} // namespace snippets +} // namespace ngraph \ No newline at end of file diff --git a/src/common/snippets/include/snippets/op/subgraph.hpp b/src/common/snippets/include/snippets/op/subgraph.hpp index 7d2aed25bde76d..31975978695c5f 100644 --- a/src/common/snippets/include/snippets/op/subgraph.hpp +++ b/src/common/snippets/include/snippets/op/subgraph.hpp @@ -132,6 +132,7 @@ class Subgraph : public ngraph::op::Op { private: void align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes); void convert_to_snippet_dialect(); + void init_config(); // Count of potentional non-scalar Consants that will be created after some tranformations // At the moment it's relevant only for FakeQuantize decomposition // NOTE: To avoid overheads in each calcution of this count (for example, in validate_and_type_infer()), @@ -144,9 +145,16 @@ class Subgraph : public ngraph::op::Op { // TODO: Change logic of insert Converts. This exec element type can be different for plugins const ov::element::Type execution_element_type = ov::element::f32; - // Config to know which transformations should be called. - // It helps to avoid overheads of extra transformation calls - struct { + ov::PartialShape master_shape; + size_t tileRank = 0; // set by plugin to specify the number of dimensions processed in a single kernel call + + /** + * @interface SubgraphConfig + * @brief Config to optimize IR transformation pipeline. It indicates which transformations are necessary + * so the irrelevant ones could be skipped. + */ + class SubgraphConfig { + public: // True if Subgraph contains FakeQuantize -> FQ decomposition should be called bool m_is_quantized = false; // True if we should align element types indise body @@ -154,13 +162,12 @@ class Subgraph : public ngraph::op::Op { // True if Subgraph contains TypeRelaxed nodes -> for several streams in tp mode we should copy body using mutexes // because TypeRelaxed::copy_with_new_inputs() isn't save-thread method bool m_has_type_relaxed_ops = false; + // True if we should check runtime info for nodes to call specific needed transformations + bool m_need_fill_tail_register = false; // True if body has operations that don't support plugin-side domain optimizations - // (e.g. Transpose in general doesn't support dimensions collapsing) + // (e.g. Transpose, Softmax, MatMul in general doesn't support dimensions collapsing) bool m_has_domain_sensitive_ops = false; } config; - - ov::PartialShape master_shape; - size_t tileRank = 0; // set by plugin to specify the number of dimensions processed in a single kernel call }; static inline std::ostream& operator<<(std::ostream& os, const op::Subgraph::BlockedShape& blocked_shape) { diff --git a/src/common/snippets/include/snippets/pass/fuse_transpose_brgemm.hpp b/src/common/snippets/include/snippets/pass/fuse_transpose_brgemm.hpp new file mode 100644 index 00000000000000..1c2eaa11ea039f --- /dev/null +++ b/src/common/snippets/include/snippets/pass/fuse_transpose_brgemm.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "ngraph/pass/graph_rewrite.hpp" +#include "ngraph/pattern/matcher.hpp" + +namespace ngraph { +namespace snippets { +namespace pass { + +/** + * @interface FuseTransposeBrgemm + * @brief Fuses Transpose with Brgemm node, fusing on both Brgemm inputs and output is supported. Applicable to + * Transposes that don't change the position of the last dimension (since Brgemm supports strided rows i/o), + * but only 0213 Transpose is currently supported. + * @ingroup snippets + */ +class FuseTransposeBrgemm: public ngraph::pass::MatcherPass { +public: + OPENVINO_RTTI("FuseTransposeBrgemm", "0"); + FuseTransposeBrgemm(); + static const std::set> supported_cases; +}; + +} // namespace pass +} // namespace snippets +} // namespace ngraph \ No newline at end of file diff --git a/src/common/snippets/include/snippets/pass/matmul_to_brgemm.hpp b/src/common/snippets/include/snippets/pass/matmul_to_brgemm.hpp new file mode 100644 index 00000000000000..1f00b944b56808 --- /dev/null +++ b/src/common/snippets/include/snippets/pass/matmul_to_brgemm.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "ngraph/pass/graph_rewrite.hpp" +#include "ngraph/pattern/matcher.hpp" + +namespace ngraph { +namespace snippets { +namespace pass { + +/** + * @interface MatMulToBrgemm + * @brief Replaces ngraph::MatMul with snippets::op::Brgemm operation (only non-trasposing MatMuls are currently supported) + * @ingroup snippets + */ +class MatMulToBrgemm: public ngraph::pass::MatcherPass { +public: + OPENVINO_RTTI("MatMulToBrgemm", "0"); + MatMulToBrgemm(); +}; + + +} // namespace pass +} // namespace snippets +} // namespace ngraph diff --git a/src/common/snippets/include/snippets/snippets_isa.hpp b/src/common/snippets/include/snippets/snippets_isa.hpp index 1137de1db0c76c..20ce6444682b82 100644 --- a/src/common/snippets/include/snippets/snippets_isa.hpp +++ b/src/common/snippets/include/snippets/snippets_isa.hpp @@ -18,6 +18,7 @@ #include "op/powerstatic.hpp" #include "op/store.hpp" #include "op/loop.hpp" +#include "op/brgemm.hpp" namespace ngraph { namespace snippets { diff --git a/src/common/snippets/include/snippets/snippets_isa_tbl.hpp b/src/common/snippets/include/snippets/snippets_isa_tbl.hpp index 255a4f3a5e23d1..b0a68fd57d8afc 100644 --- a/src/common/snippets/include/snippets/snippets_isa_tbl.hpp +++ b/src/common/snippets/include/snippets/snippets_isa_tbl.hpp @@ -11,6 +11,10 @@ // SnippetS dialect NGRAPH_OP(Load, ngraph::snippets::op) +NGRAPH_OP(LoadReshape, ngraph::snippets::op) +NGRAPH_OP(LoopBegin, ngraph::snippets::op) +NGRAPH_OP(LoopEnd, ngraph::snippets::op) +NGRAPH_OP(Brgemm, ngraph::snippets::op) NGRAPH_OP(BroadcastLoad, ngraph::snippets::op) NGRAPH_OP(Store, ngraph::snippets::op) diff --git a/src/common/snippets/include/snippets/utils.hpp b/src/common/snippets/include/snippets/utils.hpp index 975479432d852b..1d08a786922bfb 100644 --- a/src/common/snippets/include/snippets/utils.hpp +++ b/src/common/snippets/include/snippets/utils.hpp @@ -23,6 +23,12 @@ inline auto is_scalar_constant(const std::shared_ptr& source_outpu return ngraph::is_type(source_output_node) && ngraph::shape_size(source_output_node->get_shape()) == 1; } + +ov::PartialShape get_port_planar_shape(const Output& out); +ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const std::vector& layout); +std::vector get_node_output_layout(const std::shared_ptr& node); +std::vector get_node_output_layout(const Node* node); + } // namespace utils } // namespace snippets } // namespace ngraph \ No newline at end of file diff --git a/src/common/snippets/src/generator.cpp b/src/common/snippets/src/generator.cpp index 2b1457a958e672..3d0060b3805925 100644 --- a/src/common/snippets/src/generator.cpp +++ b/src/common/snippets/src/generator.cpp @@ -36,11 +36,13 @@ auto getRegisters(const std::shared_ptr &n) -> RegInfo { if (it_rt != rt.end()) rin.push_back(it_rt->second.as()); } + return std::make_pair(rin, rout); } ngraph::snippets::code ngraph::snippets::Generator::generate(std::shared_ptr& m, - const void* compile_params) const { + const GeneratorConfig& config, + const void* compile_params) { OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::Generator::generate") if (!target->is_supported()) throw ngraph_error("unsupported architecture for code generation"); @@ -157,6 +159,12 @@ ngraph::snippets::code ngraph::snippets::Generator::generate(std::shared_ptremit_data(); } OV_ITT_TASK_NEXT(GENERATE, "::GetSnippet") + + // todo: we save lowered to access compiled brgemm kernels on execution time (normally lowered is destructed by then) + // remove this when kernel caching is implemented. Don't forget to make generate const method. + if (config.m_save_lowered_code) + lowered_saved = lowered; + return target->get_snippet(); } diff --git a/src/common/snippets/src/op/brgemm.cpp b/src/common/snippets/src/op/brgemm.cpp new file mode 100644 index 00000000000000..e48b599b96a22b --- /dev/null +++ b/src/common/snippets/src/op/brgemm.cpp @@ -0,0 +1,55 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/itt.hpp" +#include "snippets/op/brgemm.hpp" +#include "ngraph/runtime/host_tensor.hpp" +#include "openvino/core/rt_info.hpp" +#include "snippets/utils.hpp" +#include "matmul_shape_inference.hpp" + +namespace ngraph { +namespace snippets { +namespace op { + +Brgemm::Brgemm(const Output& A, const Output& B) : MatMul() { + set_arguments({A, B}); + set_output_size(1); + constructor_validate_and_infer_types(); +} + +void Brgemm::validate_and_infer_types() { + INTERNAL_OP_SCOPE(Brgemm_validate_and_infer_types); + element::Type result_et; + NODE_VALIDATION_CHECK(this, + element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)), + "Arguments do not have the same element type (arg0 element type: ", + get_input_element_type(0), + ", arg1 element type: ", + get_input_element_type(1), + ")."); + // If no leading dimensions are provided, assume dense row-major inputs-outputs + NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static(), + "Brgemm currently supports only static shapes."); + + std::vector planar_input_shapes; + for (const auto& in : input_values()) + planar_input_shapes.emplace_back(utils::get_port_planar_shape(in)); + + std::vector output_shapes = {ov::PartialShape{}}; + ov::op::v0::shape_infer(this, planar_input_shapes, output_shapes); + const auto& output_layout = utils::get_node_output_layout(this); + output_shapes[0] = utils::get_reordered_planar_shape(output_shapes[0], output_layout); + set_output_type(0, result_et, output_shapes[0]); +} + +std::shared_ptr Brgemm::clone_with_new_inputs(const OutputVector& new_args) const { + INTERNAL_OP_SCOPE(Brgemm_clone_with_new_inputs); + check_new_args_count(this, new_args); + return std::make_shared(new_args.at(0), new_args.at(1));; +} + +} // namespace op +} // namespace snippets +} // namespace ngraph diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index 639f3c07faa58b..933e05b89fca7c 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -18,6 +18,8 @@ #include "snippets/pass/transpose_decomposition.hpp" #include "snippets/pass/transform_convert.hpp" #include "snippets/pass/align_element_type.hpp" +#include "snippets/pass/matmul_to_brgemm.hpp" +#include "snippets/pass/fuse_transpose_brgemm.hpp" #include "snippets/utils.hpp" #include "transformations/common_optimizations/nop_elimination.hpp" @@ -43,17 +45,31 @@ void snippets::op::Subgraph::set_non_scalar_constants_count(const size_t count) m_non_scalar_constants_count = count; } -snippets::op::Subgraph::Subgraph(const OutputVector& args, std::shared_ptr body) - : Op(args), m_body(std::move(body)), m_generator(nullptr) { +void snippets::op::Subgraph::init_config() { const auto ops = m_body->get_ops(); for (const auto& op : ops) { - config.m_is_quantized = config.m_is_quantized || ov::is_type(op); - config.m_has_type_relaxed_ops = config.m_has_type_relaxed_ops || std::dynamic_pointer_cast(op); - config.m_is_needed_to_align_precision = config.m_is_needed_to_align_precision || is_quantized() || has_type_relaxed_ops() || + config.m_is_quantized = config.m_is_quantized || + ov::is_type(op); + config.m_need_fill_tail_register = config.m_need_fill_tail_register || + ov::is_type(op) || + ov::is_type(op); + config.m_has_type_relaxed_ops = config.m_has_type_relaxed_ops || + std::dynamic_pointer_cast(op); + config.m_is_needed_to_align_precision = config.m_is_needed_to_align_precision || + is_quantized() || + has_type_relaxed_ops() || snippets::pass::AlignElementType::opNeedsAlignElementType(op, execution_element_type); - config.m_has_domain_sensitive_ops = config.m_has_domain_sensitive_ops || ov::is_type(op); + config.m_has_domain_sensitive_ops = config.m_has_domain_sensitive_ops || + ov::is_type(op) || + ov::is_type(op) || + ov::is_type(op) || + ov::is_type(op); } +} +snippets::op::Subgraph::Subgraph(const OutputVector& args, std::shared_ptr body) + : Op(args), m_body(body), m_generator(nullptr) { + init_config(); constructor_validate_and_infer_types(); } @@ -251,9 +267,11 @@ ov::PartialShape snippets::op::Subgraph::canonicalize(const BlockedShapeVector& "Snippets canonicalization got input shapes of equal ranks but different layouts, which is not supported"); } ov::PartialShape tmpPShape(baseShape); - NODE_VALIDATION_CHECK(this, - PartialShape::broadcast_merge_into(tmpPShape, inShape, ::ngraph::op::AutoBroadcastType::NUMPY), - "Failed to create broadcastable shapes in snippets canonicalization"); + // todo: we need to generalize canonicalization for domain-sensitive ops. E.g. MatMul inputs can't be broadcasted one to another + if (!config.m_has_domain_sensitive_ops) + NODE_VALIDATION_CHECK(this, + PartialShape::broadcast_merge_into(tmpPShape, inShape, ::ngraph::op::AutoBroadcastType::NUMPY), + "Failed to create broadcastable shapes in snippets canonicalization"); const auto paramShape = m_body->get_parameters()[i]->get_partial_shape(); const auto paramType = m_body->get_parameters()[i]->get_element_type(); if (paramShape.size() != inShape.size() || !equal(paramShape.begin(), paramShape.end(), inShape.begin())) @@ -276,20 +294,31 @@ ov::PartialShape snippets::op::Subgraph::canonicalize(const BlockedShapeVector& // Check that output shapes are broadcastable => can be scheduled const auto& body_results = m_body->get_results(); PartialShape outPShape = body_results[0]->get_input_partial_shape(0); - for (size_t i = 0; i < body_results.size(); i++) { - auto shape_i = body_results[i]->get_input_partial_shape(0); - auto outputShape_i = std::get<0>(outputShapes[i]); - // Check that the produced output shape corresponds to the passed shape - // Some produced shapes may have been changed to be broadcastable (e.g. blocked + planar outputs), - // so we need to remove leading and trailing "1" before the comparison - PartialShape pShape_i(skipStartEndOnes(shape_i)); - bool compatibleWithPassedShape = PartialShape::broadcast_merge_into(pShape_i, skipStartEndOnes(outputShape_i), - ::ngraph::op::AutoBroadcastType::NUMPY); - NODE_VALIDATION_CHECK(this, compatibleWithPassedShape, "Inferred and passed results shapes are incompatible for snippet "); - // Check that output shapes are broadcastable to each other => can be scheduled - bool compatibleWithOtherOutputs = PartialShape::broadcast_merge_into(outPShape, shape_i, - ::ngraph::op::AutoBroadcastType::NUMPY); - NODE_VALIDATION_CHECK(this, compatibleWithOtherOutputs, "Snippets output shapes must be numpy broadcastable"); + // todo: we need a slightly more general approach for backward ROI propagation + const auto& result_parent = body_results[0]->get_input_node_shared_ptr(0); + if (body_results.size() == 1 && + ov::is_type(result_parent) && + ov::is_type(result_parent->get_input_node_shared_ptr(0))) { + outPShape = result_parent->get_input_partial_shape(0); + } else { + for (size_t i = 0; i < body_results.size(); i++) { + auto shape_i = body_results[i]->get_input_partial_shape(0); + auto outputShape_i = std::get<0>(outputShapes[i]); + // Check that the produced output shape corresponds to the passed shape + // Some produced shapes may have been changed to be broadcastable (e.g. blocked + planar outputs), + // so we need to remove leading and trailing "1" before the comparison + PartialShape pShape_i(skipStartEndOnes(shape_i)); + bool compatibleWithPassedShape = PartialShape::broadcast_merge_into(pShape_i, + skipStartEndOnes(outputShape_i), + ::ngraph::op::AutoBroadcastType::NUMPY); + NODE_VALIDATION_CHECK(this, compatibleWithPassedShape, + "Inferred and passed results shapes are incompatible for snippet "); + // Check that output shapes are broadcastable to each other => can be scheduled + bool compatibleWithOtherOutputs = PartialShape::broadcast_merge_into(outPShape, shape_i, + ::ngraph::op::AutoBroadcastType::NUMPY); + NODE_VALIDATION_CHECK(this, compatibleWithOtherOutputs, + "Snippets output shapes must be numpy broadcastable"); + } } // We should insert Converts after Parameters and Constant and before Results @@ -357,6 +386,8 @@ void snippets::op::Subgraph::convert_to_snippet_dialect() { ngraph::pass::Manager manager; manager.register_pass(); manager.register_pass(); + manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(count); manager.register_pass(count); @@ -429,12 +460,12 @@ snippets::Schedule snippets::op::Subgraph::generate(ngraph::pass::Manager& opt, convert_to_snippet_dialect(); opt.run_passes(m_body); - snippets::pass::AssignRegisters().run_on_model(m_body); - // schedule generation should go here and be target agnostic + ngraph::snippets::Generator::GeneratorConfig generatorConfig; + generatorConfig.m_save_lowered_code = config.m_has_domain_sensitive_ops; // actual code emission - ngraph::snippets::code ptr = m_generator->generate(m_body, compile_params); + ngraph::snippets::code ptr = m_generator->generate(m_body, generatorConfig, compile_params); // check that body doesn't have constants for scheduling std::vector> constants; diff --git a/src/common/snippets/src/pass/assign_registers.cpp b/src/common/snippets/src/pass/assign_registers.cpp index 7478ed39263ff1..dd40f6640a3a10 100644 --- a/src/common/snippets/src/pass/assign_registers.cpp +++ b/src/common/snippets/src/pass/assign_registers.cpp @@ -5,6 +5,7 @@ #include #include "snippets/pass/assign_registers.hpp" #include "snippets/snippets_isa.hpp" +#include bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr& f) { RUN_ON_MODEL_SCOPE(AssignRegisters); @@ -22,7 +23,8 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr if (std::dynamic_pointer_cast(op) || std::dynamic_pointer_cast(op) || std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op)) + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op)) return gpr2gpr; else if (std::dynamic_pointer_cast(op) || std::dynamic_pointer_cast(op)) @@ -87,7 +89,7 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr std::set result; for (const auto& t : tensors) { if (reg_map.count(t) == 0) - ngraph::ngraph_error("Assign registers: attempt to access not enumerated tensor"); + throw ngraph::ngraph_error("Assign registers: attempt to access not enumerated tensor"); Reg reg_id = reg_map.at(t); if (reg_id != IS_MANUALLY_ALLOCATED_REG) result.insert(reg_id); @@ -252,7 +254,7 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr if (reg.second == IS_MANUALLY_ALLOCATED_REG) continue; if (unique2reused.count(reg.second) == 0) - ngraph::ngraph_error("Assign registers failed to allocate register for a tensor"); + throw ngraph::ngraph_error("Assign registers failed to allocate register for a tensor"); assigned_regs[reg.first] = unique2reused.at(reg.second); } }; diff --git a/src/common/snippets/src/pass/collapse_subgraph.cpp b/src/common/snippets/src/pass/collapse_subgraph.cpp index 02928e75f7a4c3..4501eb0797467d 100644 --- a/src/common/snippets/src/pass/collapse_subgraph.cpp +++ b/src/common/snippets/src/pass/collapse_subgraph.cpp @@ -7,6 +7,7 @@ #include "snippets/pass/collapse_subgraph.hpp" #include "snippets/pass/transpose_decomposition.hpp" +#include "snippets/pass/fuse_transpose_brgemm.hpp" #include "snippets/op/subgraph.hpp" #include "snippets/utils.hpp" @@ -47,6 +48,11 @@ auto outputs_are_not_broadcastable(const std::shared_ptr& node) -> b auto is_supported_op(const std::shared_ptr &n) -> bool { OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::is_supported_op") + auto is_supported_matmul = [](const std::shared_ptr& n) -> bool { + const auto& matmul = is_type(n); + const auto& out_shape = n->get_output_partial_shape(0); + return matmul && out_shape.is_static() && out_shape.size() == 4; + }; auto is_supported_transpose = [](const std::shared_ptr& n) -> bool { const auto& transpose = as_type_ptr(n); const auto& out_shape = n->get_output_partial_shape(0); @@ -54,7 +60,8 @@ auto is_supported_op(const std::shared_ptr &n) -> bool { const auto& order = as_type_ptr(n->get_input_node_shared_ptr(1)); if (order) { const auto order_value = order->cast_vector(); - return TransposeDecomposition::supported_cases.count(order_value) != 0; + return TransposeDecomposition::supported_cases.count(order_value) != 0 || + FuseTransposeBrgemm::supported_cases.count(order_value) != 0; } } return false; @@ -116,7 +123,7 @@ auto is_supported_op(const std::shared_ptr &n) -> bool { || ov::is_type(n); }; return is_supported_unary_eltwise_op(n) || is_supported_binary_eltwise_op(n) || - is_supported_transpose(n) || is_supported_fq_op(n); + is_supported_transpose(n) || is_supported_fq_op(n) || is_supported_matmul(n); } auto has_supported_in_out(const std::shared_ptr &n) -> bool { @@ -230,7 +237,12 @@ TokenizeSnippets::TokenizeSnippets() { continuation_strategy strategy = continuation_strategy::reset; auto label = std::make_shared(pattern::any_input(), [](const std::shared_ptr &n) { - return GetSnippetsNodeType(n) != SnippetsNodeType::SkippedByPlugin && AppropriateForSubgraph(n); + // todo: MatMul and Transpose ops are always skipped by the SnippetsMarkSkipped pass. + // This is a temporary solution. Either modify SnippetsMarkSkipped + // or align this with the custom MHA tokenization pass. + return (GetSnippetsNodeType(n) != SnippetsNodeType::SkippedByPlugin || + ov::is_type(n) || ov::is_type(n)) + && AppropriateForSubgraph(n); }); ngraph::graph_rewrite_callback callback = [&, strategy](ngraph::pattern::Matcher &m) -> bool { OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::CreateSubgraph_callback") diff --git a/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp b/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp new file mode 100644 index 00000000000000..73347c6475bba0 --- /dev/null +++ b/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp @@ -0,0 +1,86 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/itt.hpp" + +#include "snippets/pass/fuse_transpose_brgemm.hpp" +#include "snippets/snippets_isa.hpp" + +#include "snippets/utils.hpp" + +#include "ngraph/opsets/opset1.hpp" +#include "ngraph/rt_info.hpp" +#include "ngraph/pattern/op/wrap_type.hpp" +#include "openvino/pass/pattern/op/or.hpp" + +namespace ngraph { +namespace snippets { +namespace pass { +const std::set> FuseTransposeBrgemm::supported_cases = {{0, 2, 1, 3}}; +FuseTransposeBrgemm::FuseTransposeBrgemm() { + MATCHER_SCOPE(FuseTransposeBrgemm); + auto transpose_is_supported = [](const Output& transpose_port) { + const auto transpose_node = transpose_port.get_node_shared_ptr(); + // it's safe to do so because of the patterns we used. alternatively we can do it through pattern_values_map + const auto& constant = as_type_ptr(transpose_node->get_input_node_shared_ptr(1)); + // if Transpose in and out layout is not empty => something was already fused on this port + if (!utils::get_node_output_layout(transpose_node).empty() || + !utils::get_node_output_layout(transpose_node->get_input_node_shared_ptr(0)).empty()) + return false; + const auto& transpose_order = constant->cast_vector(); + // todo: this limitation is due to the fact that offsets are calculated in Kernel, and the only way + // to calc them non-default way is to set Parameter rt_info field. This limitation can be removed if + // the rt_info is properly propagated to the corresponding parameter + if (!is_type(transpose_node->get_input_node_shared_ptr(0)) || + supported_cases.count(transpose_order) == 0) + return false; + return true; + }; + auto constant = pattern::wrap_type(); + auto transpose = pattern::wrap_type({pattern::any_input(), constant}, transpose_is_supported); + auto transpose_matcher = std::make_shared(transpose); + auto brgemm_any = pattern::wrap_type({pattern::any_input(), pattern::any_input()}); + + auto brgemm_in0 = pattern::wrap_type({transpose, pattern::any_input()}); + auto brgemm_in1 = pattern::wrap_type({pattern::any_input(), transpose}); + auto brgemm_out0 = pattern::wrap_type({brgemm_any, constant}); + auto brgemm_or_transpose = std::make_shared(OutputVector{brgemm_in0, brgemm_in1, brgemm_out0}); + + auto callback = [=](pattern::Matcher& m) { + OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "FuseTransposeBrgemm") + auto set_layout_from_order = [](const std::shared_ptr& node, const ov::Output& port) { + const auto& const_order = as_type_ptr(node->get_input_node_shared_ptr(1)); + std::vector layout = const_order->cast_vector(); + auto& rt_info = port.get_node_shared_ptr()->get_rt_info(); + rt_info["Layout"] = layout; + }; + auto brgemm = as_type_ptr(m.get_match_root()); + // Transpose on the Brgemm's output + if (!brgemm) { + brgemm = as_type_ptr(m.get_match_root()->get_input_node_shared_ptr(0)); + const auto& brgemm_out = brgemm->output(0); + const auto& transpose_out = m.get_match_value(); + for (const auto& in : transpose_out.get_target_inputs()) + in.replace_source_output(brgemm->output(0)); + set_layout_from_order(as_type_ptr(transpose_out.get_node_shared_ptr()), brgemm_out); + } + for (int i = 0; i < brgemm->get_input_size(); i++) { + const auto& in_value = brgemm->input_value(i); + if (transpose_matcher->match(in_value)) { + const auto& transpose = as_type_ptr(in_value.get_node_shared_ptr()); + set_layout_from_order(transpose, transpose->input_value(0)); + brgemm->set_argument(i, transpose->input_value(0)); + } + } + // need to run validate_and_infer_types manually: either input shapes were updated or + // output Layout was updated (out shape will be updated in validate_and_infer_types()) + brgemm->validate_and_infer_types(); + return true; + }; + register_matcher(std::make_shared(brgemm_or_transpose, matcher_name), callback); +} + +} // namespace pass +} // namespace snippets +} // namespace ngraph \ No newline at end of file diff --git a/src/common/snippets/src/pass/insert_load_store.cpp b/src/common/snippets/src/pass/insert_load_store.cpp index 81353444185920..d22d094fdd207c 100644 --- a/src/common/snippets/src/pass/insert_load_store.cpp +++ b/src/common/snippets/src/pass/insert_load_store.cpp @@ -21,14 +21,17 @@ ngraph::snippets::pass::InsertLoad::InsertLoad(const size_t count) { auto root = m.get_match_root(); // check if already has Load as an output - for (auto output : root->outputs()) { - for (auto consumer : output.get_target_inputs()) { + for (const auto& output : root->outputs()) { + for (const auto& consumer : output.get_target_inputs()) { // if a parameter is connected to a Load => we don't need another one // if a parameter is connected to LoopBegin => there must be Load inside the Loop + // if a parameter is connected to MatMul => we don't need Load (read/write is encapsulated into the brgemm emitter) // (it's the responsibility of transformation that inserted the Loops) const auto& consumer_node = consumer.get_node(); if (ov::is_type(consumer_node) || - ov::is_type(consumer_node)) { + ov::is_type(consumer_node) || + ov::is_type(consumer_node) || + ov::is_type(consumer_node)) { return false; } } @@ -38,8 +41,8 @@ ngraph::snippets::pass::InsertLoad::InsertLoad(const size_t count) { ngraph::copy_runtime_info(root, load); bool rewritten = false; - for (auto output : root->outputs()) { - for (auto consumer : output.get_target_inputs()) { + for (const auto& output : root->outputs()) { + for (const auto& consumer : output.get_target_inputs()) { if (consumer.get_node()->shared_from_this() != load) { consumer.replace_source_output(load); rewritten |= true; @@ -60,10 +63,12 @@ ngraph::snippets::pass::InsertStore::InsertStore(const size_t count) { auto root = m.get_match_root(); // check if already has Store as an input - for (auto input : root->inputs()) { + for (const auto& input : root->inputs()) { const auto& parent_node = input.get_source_output().get_node(); if (ov::is_type(parent_node) || - ov::is_type(parent_node)) { + ov::is_type(parent_node) || + ov::is_type(parent_node) || + ov::is_type(parent_node)) { return false; } } diff --git a/src/common/snippets/src/pass/matmul_to_brgemm.cpp b/src/common/snippets/src/pass/matmul_to_brgemm.cpp new file mode 100644 index 00000000000000..b74fb3e68cc47e --- /dev/null +++ b/src/common/snippets/src/pass/matmul_to_brgemm.cpp @@ -0,0 +1,45 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/itt.hpp" + +#include "snippets/pass/matmul_to_brgemm.hpp" + +#include "snippets/op/brgemm.hpp" + +#include "ngraph/opsets/opset1.hpp" +#include "ngraph/rt_info.hpp" +#include "ngraph/pattern/op/wrap_type.hpp" + +namespace ngraph { +namespace snippets { +namespace pass { + +MatMulToBrgemm::MatMulToBrgemm() { + MATCHER_SCOPE(MatMulToBrgemm); + auto matmul_pattern = ngraph::pattern::wrap_type({ngraph::pattern::any_input(), + ngraph::pattern::any_input()}); + + auto callback = [=](ngraph::pattern::Matcher& m) { + OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::MatMulToBrgemm") + auto& pm = m.get_pattern_value_map(); + const auto matmul = as_type_ptr(pm.at(matmul_pattern).get_node_shared_ptr()); + // Brgemm doesn't support transposed inputs currently, so we don't convert such matmuls + if (matmul->get_transpose_a() || matmul->get_transpose_b()) + return false; + + auto brgemm = std::make_shared(matmul->get_input_source_output(0), matmul->get_input_source_output(1)); + brgemm->set_friendly_name(matmul->get_friendly_name()); + ngraph::copy_runtime_info(matmul, brgemm); + ngraph::replace_node(matmul, brgemm); + return true; + }; + + auto m = std::make_shared(matmul_pattern, matcher_name); + register_matcher(m, callback); +} + +} // namespace pass +} // namespace snippets +} // namespace ngraph diff --git a/src/common/snippets/src/pass/transpose_decomposition.cpp b/src/common/snippets/src/pass/transpose_decomposition.cpp index 21f8e256693651..db9b00bf5b8f2a 100644 --- a/src/common/snippets/src/pass/transpose_decomposition.cpp +++ b/src/common/snippets/src/pass/transpose_decomposition.cpp @@ -37,11 +37,11 @@ ngraph::snippets::pass::TransposeDecomposition::TransposeDecomposition() { auto order_value = order->cast_vector(); if (supported_cases.count(order_value) == 0) - throw ngraph::ngraph_error("TransposeDecomposition: unsupported order"); + return false; auto data_input = pattern_to_output.at(match_data); const auto& data_node = pattern_to_output.at(match_data).get_node_shared_ptr(); - auto ¶m_rt = data_input.get_tensor_ptr()->get_rt_info(); + auto ¶m_rt = data_node->get_rt_info(); // Note: store and usage inside emitters as size_t is more convenient, so static_cast here const auto& access_pattern = order->cast_vector(); param_rt["Layout"] = access_pattern; diff --git a/src/common/snippets/src/utils.cpp b/src/common/snippets/src/utils.cpp index e6f3bcbedda11b..d904317d6029f7 100644 --- a/src/common/snippets/src/utils.cpp +++ b/src/common/snippets/src/utils.cpp @@ -6,8 +6,11 @@ #include "snippets/pass/fq_decomposition.hpp" +namespace ngraph { +namespace snippets { +namespace utils { -auto ngraph::snippets::utils::get_non_scalar_constant_count_for_fq(const std::shared_ptr& fq) -> size_t { +auto get_non_scalar_constant_count_for_fq(const std::shared_ptr& fq) -> size_t { std::vector out_scales; std::vector cl, ch, isc, ish, osc, osh; const bool status = ngraph::snippets::pass::FakeQuantizeDecomposition::getScalesAndShifts(fq, cl, ch, isc, ish, osc, osh); @@ -55,3 +58,54 @@ auto ngraph::snippets::utils::get_non_scalar_constant_count_for_fq(const std::sh return 1; return 0; } +std::vector get_node_output_layout(const std::shared_ptr& node) { + return get_node_output_layout(node.get()); +} +std::vector get_node_output_layout(const Node* node) { + if (!node) + return {}; + if (node->is_dynamic()) + throw ngraph_error("It's illegal to call get_node_output_layout for dynamic nodes"); + auto &rt = node->get_rt_info(); + const auto rinfo = rt.find("Layout"); + if (rinfo != rt.end()) { + std::vector layout(rinfo->second.as>()); + // This might be a little costy, but still useful sanity check. Remove if proved to be unacceptably heavy. + std::set unique_elements(layout.begin(), layout.end()); + if (unique_elements.size() < layout.size()) + throw ngraph_error("Layout must contain only unique dimension indexes"); + return layout; + } else { + return {}; + } +} + +ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const std::vector& layout) { + if (layout.empty()) + return shape; + std::vector reordered_shape(layout.size()); + if (shape.rank().is_dynamic()) + throw ngraph_error("get_reordered_planar_shape can't be called for outputs with dynamic rank"); + const size_t rank = shape.rank().get_length(); + if (layout.size() > rank) + throw ngraph_error("Layout rank can't be larger than tensor rank"); + // Note that it can be smaller though, for example tensor shape can be prepended with 1 for scheduling purposes + if (std::any_of(layout.begin(), layout.end(), [=](size_t x) {return x >= rank;})) + throw ngraph_error("Invalid layout detected: all layout indexes must be smaller than the tensor rank"); + for (int i = 0; i < layout.size(); i++) + reordered_shape[i] = shape[layout[i]]; + return reordered_shape; +} + +ov::PartialShape get_port_planar_shape(const Output& out) { + std::vector layout = get_node_output_layout(out.get_node_shared_ptr()); + const auto& tensor = out.get_tensor_ptr(); + if (!tensor) + throw ngraph_error("get_port_planar_shape can't be called for an uninitialized output tensor"); + auto tensor_shape = tensor->get_partial_shape(); + return get_reordered_planar_shape(tensor_shape, layout); +} + +} // namespace utils +} // namespace snippets +} // namespace ngraph diff --git a/src/common/snippets/tests/include/pass/fuse_transpose_brgemm.hpp b/src/common/snippets/tests/include/pass/fuse_transpose_brgemm.hpp new file mode 100644 index 00000000000000..20c2fa1b272958 --- /dev/null +++ b/src/common/snippets/tests/include/pass/fuse_transpose_brgemm.hpp @@ -0,0 +1,32 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "lowering_utils.hpp" +#include "snippets_helpers.hpp" + +/* The main purpose is to test that FuseTransposeBrgemm properly fuses 0213 Transposes on both inputs, as well as on output + */ + +namespace ov { +namespace test { +namespace snippets { + +typedef std::tuple< + std::vector, // Input shapes + size_t // Transpose position +> fuseTransposeBrgemmParams; + +class FuseTransposeBrgemmTests : public LoweringTests, public testing::WithParamInterface { +public: + static std::string getTestCaseName(testing::TestParamInfo obj); +protected: + void SetUp() override; + std::shared_ptr snippets_function; +}; + +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/common/snippets/tests/src/lowering_utils.cpp b/src/common/snippets/tests/src/lowering_utils.cpp index 7c9f15a6bc48e9..ef5b74a08b910d 100644 --- a/src/common/snippets/tests/src/lowering_utils.cpp +++ b/src/common/snippets/tests/src/lowering_utils.cpp @@ -32,6 +32,7 @@ DummyTargetMachine::DummyTargetMachine() { jitters[ngraph::snippets::op::Kernel::get_type_info_static()] = dummy_functor; jitters[ngraph::snippets::op::LoopBegin::get_type_info_static()] = dummy_functor; jitters[ngraph::snippets::op::LoopEnd::get_type_info_static()] = dummy_functor; + jitters[ngraph::snippets::op::Brgemm::get_type_info_static()] = dummy_functor; } void LoweringTests::SetUp() { diff --git a/src/common/snippets/tests/src/pass/collapse_subgraph.cpp b/src/common/snippets/tests/src/pass/collapse_subgraph.cpp index aa26ecfe4cdb74..dc5d4831fe44dd 100644 --- a/src/common/snippets/tests/src/pass/collapse_subgraph.cpp +++ b/src/common/snippets/tests/src/pass/collapse_subgraph.cpp @@ -17,6 +17,11 @@ void CollapseSubgraphTests::run() { std::string name; manager.register_pass(); manager.register_pass(); + // todo: This is a temporary work-around. remove when MatMul tokenization is supported through general pipeline + manager.get_pass_config()->set_callback( + [](const std::shared_ptr& n) -> bool { + return ov::is_type(n); + }); } TEST_F(CollapseSubgraphTests, smoke_Snippets_Eltwise) { diff --git a/src/common/snippets/tests/src/pass/fuse_transpose_brgemm.cpp b/src/common/snippets/tests/src/pass/fuse_transpose_brgemm.cpp new file mode 100644 index 00000000000000..a3f60e4656abc1 --- /dev/null +++ b/src/common/snippets/tests/src/pass/fuse_transpose_brgemm.cpp @@ -0,0 +1,56 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include "pass/fuse_transpose_brgemm.hpp" +#include "common_test_utils/common_utils.hpp" +#include "subgraph_matmul.hpp" +#include "subgraph_lowered.hpp" + +namespace ov { +namespace test { +namespace snippets { + +std::string FuseTransposeBrgemmTests::getTestCaseName(testing::TestParamInfo obj) { + std::vector input_shapes(2); + size_t transpose_position; + std::tie(input_shapes, transpose_position) = obj.param; + std::ostringstream result; + result << "IS[0]=" << CommonTestUtils::partialShape2str({input_shapes[0]}) << "_"; + result << "IS[1]=" << CommonTestUtils::partialShape2str({input_shapes[1]}) << "_"; + result << "Pos=" << transpose_position << "_"; + return result.str(); +} + +void FuseTransposeBrgemmTests::SetUp() { + LoweringTests::SetUp(); + std::vector input_shapes(2); + size_t transpose_position; + std::tie(input_shapes, transpose_position) = this->GetParam(); + + snippets_function = std::make_shared(input_shapes, transpose_position); +} + +TEST_P(FuseTransposeBrgemmTests, FuseTransposeMatmul) { + auto subgraph = getLoweredSubgraph(snippets_function->getOriginal(), master_shape); + function = subgraph->get_body(); + function_ref = snippets_function->getLowered(); +} + +namespace FuseTransposeBrgemmTestsInstantiation { +using ov::Shape; +std::vector test_params{ + {{{1, 49, 2, 23}, {2, 2, 23, 39}}, 0}, + {{{1, 2, 49, 23}, {2, 23, 1, 39}}, 1}, + {{{1, 2, 49, 23}, {2, 2, 23, 39}}, 2}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_FuseTransposeMatMul, FuseTransposeBrgemmTests, + ::testing::ValuesIn(test_params), + FuseTransposeBrgemmTests::getTestCaseName); + +} // namespace FuseTransposeBrgemmTestsInstantiation +} // namespace snippets +} // namespace test +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp index 5233dc97ebd25f..1438fc286ce4e4 100644 --- a/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp @@ -17,6 +17,7 @@ #include "snippets_transformations/op/load_convert.hpp" #include "snippets_transformations/op/store_convert.hpp" +#include "snippets/op/brgemm.hpp" #include "ngraph_transformations/op/swish_cpu.hpp" #include @@ -126,6 +127,7 @@ ov::intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_ jitters[ngraph::snippets::op::Kernel::get_type_info_static()] = CREATE_EMITTER(KernelEmitter); jitters[ngraph::snippets::op::LoopBegin::get_type_info_static()] = CREATE_EMITTER(LoopBeginEmitter); jitters[ngraph::snippets::op::LoopEnd::get_type_info_static()] = CREATE_EMITTER(LoopEndEmitter); + jitters[ngraph::snippets::op::Brgemm::get_type_info_static()] = CREATE_EMITTER(BrgemmEmitter); } size_t ov::intel_cpu::CPUTargetMachine::get_lanes() const { diff --git a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp index c41f625d5c18e1..327e6acd258438 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp @@ -7,7 +7,9 @@ #include #include "jit_snippets_emitters.hpp" +#include "snippets/op/brgemm.hpp" #include "snippets/op/subgraph.hpp" +#include "snippets/utils.hpp" using namespace Xbyak; using ngraph::snippets::op::Subgraph; @@ -62,7 +64,8 @@ void jit_container_emitter::map_abstract_registers(mapping_info& gpr_map_pool, // todo: Note that LoopBeginEmitter and LoopEndEmitter demonstrate new paradigm, // where all utility emitters align with conventional Op emitters if (std::dynamic_pointer_cast(emitter) || - std::dynamic_pointer_cast(emitter)) + std::dynamic_pointer_cast(emitter) || + std::dynamic_pointer_cast(emitter)) in_physical_regs = std::move(map_regs(in_abstract_regs, gpr_map_pool)); else in_physical_regs = std::move(in_abstract_regs); @@ -111,24 +114,19 @@ KernelEmitter::KernelEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: IE_THROW() << "KernelEmitter can't calc offsets for dynamic shapes"; return pshape.get_shape(); }; - const auto get_access_pattern = [](const Output& out, std::vector& shape) { - std::vector access_pattern{}; - auto &rt = out.get_tensor_ptr()->get_rt_info(); - const auto rinfo = rt.find("Layout"); + const auto get_data_layout = [](const Output& out, std::vector& shape) { + const auto& layout = ngraph::snippets::utils::get_node_output_layout(out.get_node_shared_ptr()); // default access pattern - if (rinfo != rt.end()) { - access_pattern = rinfo->second.as>(); - const int64_t pattern_shape_diff = static_cast(shape.size()) - static_cast(access_pattern.size()); + if (!layout.empty()) { + const auto layout_shape_diff = static_cast(shape.size()) - static_cast(layout.size()); // Plugin can (and usually does) prepend shapes with 1's to facilitate scheduling, here we can safely remove leading 1's - if (pattern_shape_diff > 0) { - if (std::any_of(shape.begin(), shape.begin() + pattern_shape_diff, [](size_t x){return x != 1;})) + if (layout_shape_diff > 0) { + if (std::any_of(shape.begin(), shape.begin() + layout_shape_diff, [](size_t x){return x != 1;})) IE_THROW() << "KernelEmitter detected shape vs access pattern conflict: only leading 1's can be removed from the shape"; - shape.erase(shape.begin(), shape.begin() + pattern_shape_diff); - } else if (pattern_shape_diff < 0) { - IE_THROW() << "KernelEmitter detected invalid access pattern: pattern size can't be larger than shape size"; + shape.erase(shape.begin(), shape.begin() + layout_shape_diff); } } - return access_pattern; + return layout; }; auto params = model->get_parameters(); auto results = model->get_results(); @@ -149,8 +147,8 @@ KernelEmitter::KernelEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: io_shapes = new_shapes; } for (int i = 0; i < io_nodes.size(); i++) { - const auto& out = io_nodes[i]->output(0); - data_access_pattern.push_back(get_access_pattern(out, io_shapes[i])); + const auto& out = i < num_inputs ? io_nodes[i]->output(0) : io_nodes[i]->input_value(0); + data_layout.push_back(get_data_layout(out, io_shapes[i])); io_data_size.push_back(out.get_element_type().size()); } // Initialize pools of gp and vec registers @@ -178,7 +176,11 @@ KernelEmitter::KernelEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: [](const AllocatedEmitter& code){ const auto& emitter = code.first; const auto emitter_type = std::dynamic_pointer_cast(emitter)->get_in_out_type(); - return emitter_type == gpr_to_vec || emitter_type == vec_to_gpr; + // todo: how this will be handled if Brgemm in & out are op::Buffer + // Brgemm is a special case since it incorporates input and output (we use onednn kernel) + // Just like Load & Store it requires offsets calculation + const auto is_brgemm = std::dynamic_pointer_cast(emitter) != nullptr; + return emitter_type == gpr_to_vec || emitter_type == vec_to_gpr || is_brgemm; }); // Note that we can't use reg_indexes_idx or reg_const_params_idx to store data pointers because these two // regs are used to calculate offsets for the data pointers @@ -222,7 +224,7 @@ void KernelEmitter::init_data_pointers(size_t num_inputs, size_t num_params, //const size_t tile_rank = jcp.tile_rank; std::vector> data_offsets(num_params, std::vector{}); auto offset_calculation = [=](const std::vector& shape, - const std::vector& access_pattern, const size_t data_size) { + const std::vector& layout, const size_t data_size) { // Strides represent distance between consecutive elements of corresponding dimension. // If a dim size == 1, then the next dim starts immediately and the stride is 0 // case 1: @@ -239,10 +241,10 @@ void KernelEmitter::init_data_pointers(size_t num_inputs, size_t num_params, strides[k] = shape[k] != 1 ? dim_step * data_size : 0; } // Note: this is an extra copy, but let's keep it for clarity - if (!access_pattern.empty()) { + if (!layout.empty()) { std::vector reordered_strides(strides.size()); - for (auto i = 0; i < access_pattern.size(); i++) - reordered_strides[i] = strides[access_pattern[i]]; + for (auto i = 0; i < layout.size(); i++) + reordered_strides[i] = strides[layout[i]]; strides = std::move(reordered_strides); } // the last stride is ignored, since the entire last dim is processed by kernel @@ -257,7 +259,7 @@ void KernelEmitter::init_data_pointers(size_t num_inputs, size_t num_params, return strides; }; for (size_t i = 0; i < num_params; i++) { - data_offsets[i] = offset_calculation(io_shapes[i], data_access_pattern[i], io_data_size[i]); + data_offsets[i] = offset_calculation(io_shapes[i], data_layout[i], io_data_size[i]); } // master_shape size must be valid in both static and dynamic cases std::function&, Reg64)> init_ptr_with_offset; @@ -719,6 +721,286 @@ void StoreConvertEmitter::emit_isa(const std::vector &in, const std::vec void StoreConvertEmitter::emit_data() const { store_emitter->emit_data(); } +size_t BrgemmEmitter::getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) const { + return mIdx * 4 + kIdx * 2 + nIdx; +} +BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, + const std::shared_ptr& node) : jit_emitter(h, isa, node) { + in_out_type_ = emitter_in_out_map::gpr_to_gpr; + const auto& brgemm_node = as_type_ptr(node); + if (brgemm_node->is_dynamic()) + IE_THROW() << "Snippets don't support code generation for dynamic Brgemm"; + const OutputVector io_values {brgemm_node->input_value(0), brgemm_node->input_value(1), brgemm_node->output(0)}; + std::vector leading_dimensions; + std::vector> io_layouts; + for (const auto& val : io_values) { + const auto& layout = ngraph::snippets::utils::get_node_output_layout(val.get_node_shared_ptr()); + const auto& io_shape = val.get_shape(); + if (layout.empty()) { + // empty value indicates a planar layout + leading_dimensions.push_back(io_shape.back()); + std::vector default_layout(io_shape.size()); + std::iota(default_layout.begin(), default_layout.end(), 0); + io_layouts.push_back(default_layout); + } else { + // The idea here is to find "2" (for 4D shapes) in the layout and multiply dimensions that are to the right + // This implies that "3" is the last layout value, otherwise this layout is not supported. + // counting from the end since shape could be prepended with ones + const int64_t num_last_dims = layout.end() - std::find(layout.begin(), layout.end(), layout.size() - 2) - 1; + if (layout.back() != layout.size() - 1 || num_last_dims < 1) + IE_THROW() << "BrgemmEmitter detected invalid layout values: " << + "check that this shape + layout combination is schedulable"; + leading_dimensions.emplace_back( + std::accumulate(io_shape.end() - num_last_dims, io_shape.end(), 1, std::multiplies())); + io_layouts.push_back(layout); + } + } + // todo: leave AMX and VNNI related code for now, it'll help to enable int8 and bf16 support + bool isAMXSupported = mayiuse(avx512_core_bf16_amx_int8) || mayiuse(avx512_core_bf16_amx_bf16); + + const auto& A_shape = io_values[0].get_shape(); + const auto& A_layout = io_layouts[0]; + const auto& C_shape = io_values[2].get_shape(); + const auto& C_layout = io_layouts[2]; + + M = C_shape[C_layout[2]]; + K = A_shape[A_layout[3]]; + M_blk = matmulOptimalM; + M_tail = M % M_blk; + // B_shape[B_layout[3]] + N = C_shape[C_layout[3]]; + + auto brg0Prc = InferenceEngine::details::convertPrecision(brgemm_node->get_input_element_type(0)); + auto brg1Prc = InferenceEngine::details::convertPrecision(brgemm_node->get_input_element_type(1)); + io_data_size = {brg0Prc.size(), brg1Prc.size(), brgemm_node->get_output_element_type(0).size()}; + brg0VnniFactor = 4 / brg0Prc.size(); + bool brg0WithAMX = isAMXSupported && brg0Prc != Precision::FP32 && (K % brg0VnniFactor == 0) && (N % brg0VnniFactor == 0); + + N_blk = brg0Prc == Precision::FP32 ? N : + brg0Prc == Precision::BF16 ? 32 : 64; + N_tail = N % N_blk; + K_blk = brg0WithAMX ? brg0Prc == Precision::BF16 ? 32 : 64 + : K; + K_tail = K % K_blk; + + size_t brg0BaseIdx = -1; + for (size_t m = 0; m < 2; m++) { + for (size_t k = 0; k < 2; k++) { + for (size_t n = 0; n < 2; n++) { + auto& brgemmCtx = brgCtxs0[getBrgIdx(m, k, n)]; + + auto M_ = m ? M_tail + : M < M_blk ? 0 : M_blk; + auto N_ = n ? N_tail : N - N_tail; + auto K_ = k ? K_tail : K - K_tail; + auto beta = k && brgCtxs0[getBrgIdx(m, 0, n)].K != 0 ? 1.0f : 0.0f; + + brgemmCtx.M = M_; + brgemmCtx.N = N_; + brgemmCtx.K = K_; + brgemmCtx.LDA = leading_dimensions[0]; + brgemmCtx.LDB = leading_dimensions[1]; + brgemmCtx.LDC = leading_dimensions[2]; + brgemmCtx.dt_in0 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(brg0Prc)); + brgemmCtx.dt_in1 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(brg1Prc)); + brgemmCtx.beta = beta; + + // don't create brgemm kernels for empty tiles + if (M_ != 0 && K_ != 0 && N_ != 0) { + if (brg0BaseIdx == -1) + brg0BaseIdx = getBrgIdx(m, k, n); + initBrgemm(brgemmCtx, brgKernels0[getBrgIdx(m, k, n)], brg0WithAMX); + } + } + } + } +} + +void BrgemmEmitter::initBrgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, bool use_amx) const { + brgemm_t brgDesc; + brgemm_strides_t strides {static_cast(ctx.M * ctx.K), static_cast(ctx.K * ctx.N)}; + // When implementing int8 support, note that isa logics is more complicated in the MHA node + auto status = brgemm_desc_init(&brgDesc, host_isa_, brgemm_strd, ctx.dt_in0, ctx.dt_in1, + false, false, brgemm_row_major, 1.f, ctx.beta, ctx.LDA, ctx.LDB, ctx.LDC, ctx.M, ctx.N, ctx.K, &strides); + if (status != dnnl_success) + IE_THROW() << "BrgemmEmitter cannot initialize brgemm descriptor due to invalid params"; + ctx.is_with_amx = use_amx; + status = brgemm_init_tiles(brgDesc, ctx.palette); + if (use_amx) + amx_tile_configure(ctx.palette); + + ctx.is_with_comp = ctx.dt_in0 == dnnl_data_type_t::dnnl_s8 && !ctx.is_with_amx; + + brgemm_kernel_t* brgKernel_ = nullptr; + status = brgemm_kernel_create(&brgKernel_, brgDesc); + if (status != dnnl_success) + IE_THROW() << "BrgemmEmitter cannot create brgemm kernel due to invalid params"; + brgKernel.reset(brgKernel_); +} + +void BrgemmEmitter::emit_impl(const std::vector& in, + const std::vector& out, + const std::vector& pool, + const std::vector& gpr, + const ov::intel_cpu::emitter_context *emit_context) const { + if (host_isa_ == cpu::x64::sse41 || host_isa_ == cpu::x64::avx2) { + IE_THROW() << "BrgemmEmitter requires at least avx512_core instruction set"; + } else if (host_isa_ == cpu::x64::avx512_core) { + emit_isa(in, out); + } else { + assert(!"unsupported isa"); + } +} +template +void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, int bs, + Reg64 addr_A, Reg64 addr_B, + const brgemm_batch_element_t *batch, Reg64 addr_C, void *scratch) const { + using Vmm = typename dnnl::impl::utils::conditional3::type; + size_t gpr_size = 8; + Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->rax, + h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx}; + size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); + + h->sub(h->rsp, n_gprs_to_save * gpr_size); + for (size_t i = 0; i < n_gprs_to_save; ++i) + h->mov(h->ptr[h->rsp + i * gpr_size], gprs_to_save[i]); + + // caller obligation to save k-regs as callee may use them + size_t n_k_regs_to_save = 8; + if (isa == cpu::x64::avx512_core) { + h->sub(h->rsp, n_k_regs_to_save * k_mask_size); + for (size_t i = 0; i < n_k_regs_to_save; ++i) { + if (mayiuse(avx512_core)) + h->kmovq(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast(i))); + else + h->kmovw(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast(i))); + } + } + + // 1. Caller obligation to save vector registers as callee may use them. + // 2. There is an implicit assumption that the host code uses the same + // `isa` as the injector. Once the assumption is wrong, `vecs_count` and + // `vlen` should be replaced with `host_isa::vlen` and + // `host_isa::vecs_count`. + h->sub(h->rsp, get_max_vecs_count() * get_vec_length()); + for (size_t i = 0; i < get_max_vecs_count(); ++i) + h->uni_vmovups(h->ptr[h->rsp + i * get_vec_length()], Vmm(i)); + + // save function address in gpr to pass in call instruction + const auto& brgemm_kernel_overload = static_cast(brgemm_kernel_execute); + h->mov(h->rbp, reinterpret_cast(brgemm_kernel_overload)); + // todo: several of addr_{A, B, C} could be also abi_paramX, so one of them could be corrupted + // if moving directly h->uni_vmovq(abi_paramX, adr_X). Save them to vector regs to avoid corruption. + // It's likely that a more efficient solution exists. + h->uni_vmovq(Xmm(0), addr_A); + h->uni_vmovq(Xmm(1), addr_B); + h->uni_vmovq(Xmm(2), addr_C); + // todo: Windows ABI : requires different num of arguments passed in regs and on the stack. Need to align. + h->mov(abi_param1, reinterpret_cast(brgKernel)); + h->mov(abi_param2, bs); + h->uni_vmovq(abi_param3, Xmm(0)); + h->uni_vmovq(abi_param4, Xmm(1)); + size_t num_args_passed_on_stack = 1; +#ifdef _WIN32 + num_args_passed_on_stack = 3; + h->sub(h->rsp, gpr_size * num_args_passed_on_stack); + h->sub(h->rsp, gpr_size); + h->mov(h->qword[h->rsp], reinterpret_cast(scratch)); + h->mov(h->qword[h->rsp + gpr_size], reinterpret_cast(batch)); + h->mov(h->qword[h->rsp + 2 * gpr_size], Xmm(2)); +#else + h->mov(abi_param5, reinterpret_cast(batch)); + h->uni_vmovq(abi_param6, Xmm(2)); + h->sub(h->rsp, gpr_size); + h->mov(h->qword[h->rsp], reinterpret_cast(scratch)); +#endif + // align stack on 16-byte as ABI requires + // note that RBX must not be changed by the callee + h->mov(h->rbx, h->rsp); + h->and_(h->rbx, 0xf); + h->sub(h->rsp, h->rbx); + + h->call(h->rbp); + + h->add(h->rsp, h->rbx); + h->add(h->rsp, gpr_size * num_args_passed_on_stack); + // restore vector registers + for (int i = static_cast(get_max_vecs_count()) - 1; i >= 0; --i) { + h->uni_vmovups(Vmm(i), h->ptr[h->rsp + i * get_vec_length()]); + } + h->add(h->rsp, (get_max_vecs_count()) * get_vec_length()); + + // restore k registers + if (isa == cpu::x64::avx512_core) { + for (int i = n_k_regs_to_save - 1; i >= 0; --i) { + if (mayiuse(avx512_core)) + h->kmovq(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); + else + h->kmovw(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); + } + h->add(h->rsp, n_k_regs_to_save * k_mask_size); + } + + // restore gpr registers + for (int i = n_gprs_to_save - 1; i >= 0; --i) + h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]); + h->add(h->rsp, n_gprs_to_save * gpr_size); +} + +template +void BrgemmEmitter::emit_isa(const std::vector &in, const std::vector &out) const { + using Vmm = typename dnnl::impl::utils::conditional3::type; + Reg64 input_0(static_cast(in[0])); + Reg64 input_1(static_cast(in[1])); + Reg64 output_0(static_cast(out[0])); + + for (size_t mb = 0; mb < div_up(M, M_blk); mb++) { + const bool is_M_tail = (M - mb * M_blk < M_blk); + + size_t brgIdx0 = getBrgIdx(0, 0, 0); + size_t K0_step0 = brgCtxs0[brgIdx0].K; + size_t K0_step1 = brgCtxs0[brgIdx0].K * brgCtxs0[brgIdx0].LDB; + size_t N0_step0 = brgCtxs0[brgIdx0].N * brg0VnniFactor; + size_t N0_step1 = brgCtxs0[brgIdx0].N; + for (size_t n = 0; n < 2; n++) { + for (size_t k = 0; k < 2; k++) { + size_t mIdx = is_M_tail ? 1 : 0; + auto& brgemmCtx = brgCtxs0[getBrgIdx(mIdx, k, n)]; + + if (brgemmCtx.K != 0 && brgemmCtx.N != 0) { + const size_t in0_offset = (k * K0_step0 + mb * M_blk * brgemmCtx.LDA) * io_data_size[0]; + const size_t in1_offset = (k * K0_step1 + n * N0_step0) * io_data_size[1]; + const size_t out0_offset = (n * N0_step1 + mb * M_blk * brgemmCtx.LDC) * io_data_size[2]; + if (in0_offset != 0) + h->add(input_0, in0_offset); + if (in1_offset != 0) + h->add(input_1, in1_offset); + if (out0_offset != 0) + h->add(output_0, out0_offset); + emit_brgemm_kernel_call(brgKernels0[getBrgIdx(mIdx, k, n)].get(), + 1, + input_0, + input_1, + nullptr, + output_0, + nullptr); + if (in0_offset != 0) + h->sub(input_0, in0_offset); + if (in1_offset != 0) + h->sub(input_1, in1_offset); + if (out0_offset != 0) + h->sub(output_0, out0_offset); + } + } + } + } +} } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp index 1d054833aa48e6..c559f2421f0235 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp @@ -12,6 +12,11 @@ #include "jit_load_store_emitters.hpp" #include "snippets_transformations/op/store_convert.hpp" +// Matmul support: +#include +#include +#include +#include using namespace Xbyak; using ngraph::snippets::AllocatedEmitter; @@ -98,7 +103,7 @@ class KernelEmitter : public jit_container_emitter { // Vector of indices (lenght = input tensor rank) per every input and output that describes in which order // corresponding tensor dimensions are accessed (default: consecutive dense, e.g. 0,1,2,3 for 4D tensor). // Needed to calc i/o offsets. - std::vector> data_access_pattern; + std::vector> data_layout; std::vector> io_shapes = {}; std::vector io_data_size {}; @@ -355,5 +360,50 @@ class StoreConvertEmitter : public MemoryEmitter { size_t count; std::unique_ptr store_emitter = nullptr; }; + +class BrgemmEmitter : public jit_emitter { +public: + BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n); + + size_t get_inputs_num() const override {return 2;} + +private: + void emit_impl(const std::vector& in, + const std::vector& out, + const std::vector& pool, + const std::vector& gpr, + const ov::intel_cpu::emitter_context *emit_context) const override; + + template + void emit_isa(const std::vector &in, const std::vector &out) const; + std::vector io_data_size {}; + struct brgemmCtx { + size_t M, N, K, LDA, LDB, LDC; + dnnl_data_type_t dt_in0, dt_in1; + char palette[64]; + bool is_with_amx; + bool is_with_comp; + float beta; + }; + void initBrgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, bool use_amx) const; + template + void callBrgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, const void* pin0, const void* pin1, void* pout, void* wsp) const; + size_t getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) const; + template + void emit_brgemm_kernel_call(const brgemm_kernel_t *brg_kernel, int bs, + Reg64 addr_A, Reg64 addr_B, + const brgemm_batch_element_t *batch, Reg64 addr_C, void *scratch) const; + + static constexpr size_t BRGEMM_KERNELS_NUM = 8; + static constexpr size_t matmulOptimalM = 32; + brgemmCtx brgCtxs0[BRGEMM_KERNELS_NUM]; + std::unique_ptr brgKernels0[BRGEMM_KERNELS_NUM]; + + size_t M, M_blk, M_tail; + size_t K, K_blk, K_tail; + size_t N, N_blk, N_tail; + size_t brg0VnniFactor; +}; + } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index 6e7b692ce46e4a..d2a8f5381c9174 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -309,7 +309,7 @@ ov::PartialShape Snippet::canonicalizeBody() { output_blocked_shapes.push_back(blockedShape); } - const auto canonicalShape = snippet->canonicalize(output_blocked_shapes, input_blocked_shapes); + const auto& canonicalShape = snippet->canonicalize(output_blocked_shapes, input_blocked_shapes); return canonicalShape; } void Snippet::createPrimitive() { diff --git a/src/plugins/intel_cpu/src/plugin.cpp b/src/plugins/intel_cpu/src/plugin.cpp index 7612784bd522b9..fc33ea556e66ae 100644 --- a/src/plugins/intel_cpu/src/plugin.cpp +++ b/src/plugins/intel_cpu/src/plugin.cpp @@ -655,6 +655,7 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr // they can be tokenized only as a part of complex patterns const bool is_disabled_tokenization = (ov::is_type(n) || ov::is_type(n) || + ov::is_type(n) || ov::is_type(n)); const auto& inputs = n->inputs(); // todo: clarify whether we can evaluate snippets on const paths diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/conv_eltwise.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/conv_eltwise.cpp index bdf0fd38a50136..dcb2f96f2087e5 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/conv_eltwise.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/conv_eltwise.cpp @@ -10,7 +10,7 @@ namespace test { namespace snippets { namespace { - ov::Shape convInputShape {1, 10, 16, 16}; + ov::Shape convInputShape {1, 2, 16, 16}; INSTANTIATE_TEST_SUITE_P(smoke_Snippets_ConvAdd, ConvEltwise, ::testing::Combine( ::testing::Values(convInputShape), diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp new file mode 100644 index 00000000000000..11fb9e9cc2a6fb --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp @@ -0,0 +1,34 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/matmul.hpp" +#include "common_test_utils/test_constants.hpp" + +namespace ov { +namespace test { +namespace snippets { + + +namespace { +std::vector> input_shapes{ + {{2, 1, 3, 5}, {1, 3, 5, 3}}, + {{3, 1, 32, 14}, {1, 2, 14, 32}}, + {{1, 2, 37, 23}, {2, 1, 23, 37}}, + {{1, 1, 37, 23}, {1, 2, 23, 33}}, + {{2, 1, 69, 43}, {1, 1, 43, 49}} +}; +std::vector precisions{element::f32}; +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, MatMul, + ::testing::Combine( + ::testing::ValuesIn(input_shapes), + ::testing::ValuesIn(precisions), + ::testing::Values(3), // Sinh * 2 + MatMu; + ::testing::Values(1), // Tokenized MatMul + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + MatMul::getTestCaseName); + +} // namespace +} // namespace snippets +} // namespace test +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_matmul.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_matmul.cpp new file mode 100644 index 00000000000000..b573b5f36ff330 --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_matmul.cpp @@ -0,0 +1,63 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/transpose_matmul.hpp" +#include "common_test_utils/test_constants.hpp" + +namespace ov { +namespace test { +namespace snippets { + + +namespace { +std::vector precisions{element::f32}; +namespace transpose_zero_input { +std::vector> transpose_input_shapes{ + {{1, 49, 2, 23}, {2, 2, 23, 39}} +}; +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, TransposeMatMul, + ::testing::Combine( + ::testing::ValuesIn(transpose_input_shapes), + ::testing::Values(0), // Transpose on 0th Matmul input + ::testing::ValuesIn(precisions), + ::testing::Values(3), // Sinh * 2 + MatMu; + ::testing::Values(1), // Tokenized MatMul + FusedTranspose + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + TransposeMatMul::getTestCaseName); +} // namespace transpose_zero_input + +namespace transpose_first_input { +std::vector> transpose_input_shapes{ + {{2, 1, 49, 13}, {1, 13, 3, 39}} +}; +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, TransposeMatMul, + ::testing::Combine( + ::testing::ValuesIn(transpose_input_shapes), + ::testing::Values(1), // Transpose on 1st Matmul input + ::testing::ValuesIn(precisions), + ::testing::Values(3), // Sinh * 2 + MatMu; + ::testing::Values(1), // Tokenized MatMul + FusedTranspose + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + TransposeMatMul::getTestCaseName); +} // namespace transpose_first_input + +namespace transpose_output { +std::vector> transpose_input_shapes{ + {{2, 1, 49, 13}, {1, 2, 13, 39}} +}; +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, TransposeMatMul, + ::testing::Combine( + ::testing::ValuesIn(transpose_input_shapes), + ::testing::Values(2), // Transpose on Matmul output + ::testing::ValuesIn(precisions), + ::testing::Values(3), // Sinh * 2 + MatMu; + ::testing::Values(1), // Tokenized MatMul + FusedTranspose + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + TransposeMatMul::getTestCaseName); +} // namespace transpose_output + +} // namespace +} // namespace snippets +} // namespace test +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/tests/unit/ngraph_transformations/snipptes_mark_skipped.cpp b/src/plugins/intel_cpu/tests/unit/ngraph_transformations/snipptes_mark_skipped.cpp index 9aab3ffdfe7a01..33dffb5be79fd9 100644 --- a/src/plugins/intel_cpu/tests/unit/ngraph_transformations/snipptes_mark_skipped.cpp +++ b/src/plugins/intel_cpu/tests/unit/ngraph_transformations/snipptes_mark_skipped.cpp @@ -19,6 +19,12 @@ class SnippetsMarkSkippedTests : public TransformationTestsF { manager.register_pass(); manager.register_pass(); manager.register_pass(); + // + // todo: This is a temporary work-around. remove when MatMul tokenization is supported through general pipeline + manager.get_pass_config()->set_callback( + [](const std::shared_ptr& n) -> bool { + return ov::is_type(n); + }); } }; diff --git a/src/tests/functional/plugin/shared/include/snippets/matmul.hpp b/src/tests/functional/plugin/shared/include/snippets/matmul.hpp new file mode 100644 index 00000000000000..ba213cc0da5597 --- /dev/null +++ b/src/tests/functional/plugin/shared/include/snippets/matmul.hpp @@ -0,0 +1,32 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "shared_test_classes/base/snippets_test_utils.hpp" + +namespace ov { +namespace test { +namespace snippets { + +typedef std::tuple< + std::vector, // Input Shapes + ov::element::Type, // Element type + size_t, // Expected num nodes + size_t, // Expected num subgraphs + std::string // Target Device +> MatMulParams; + +class MatMul : public testing::WithParamInterface, + virtual public ov::test::SnippetsTestsCommon { +public: + static std::string getTestCaseName(testing::TestParamInfo obj); + +protected: + void SetUp() override; +}; + +} // namespace snippets +} // namespace test +} // namespace ov \ No newline at end of file diff --git a/src/tests/functional/plugin/shared/include/snippets/transpose_matmul.hpp b/src/tests/functional/plugin/shared/include/snippets/transpose_matmul.hpp new file mode 100644 index 00000000000000..f949e9df9d5c3b --- /dev/null +++ b/src/tests/functional/plugin/shared/include/snippets/transpose_matmul.hpp @@ -0,0 +1,33 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "shared_test_classes/base/snippets_test_utils.hpp" + +namespace ov { +namespace test { +namespace snippets { + +typedef std::tuple< + std::vector, // Input Shapes + size_t , // Transpose position + ov::element::Type, // Element type + size_t, // Expected num nodes + size_t, // Expected num subgraphs + std::string // Target Device +> TransposeMatMulParams; + +class TransposeMatMul : public testing::WithParamInterface, + virtual public ov::test::SnippetsTestsCommon { +public: + static std::string getTestCaseName(testing::TestParamInfo obj); + +protected: + void SetUp() override; +}; + +} // namespace snippets +} // namespace test +} // namespace ov \ No newline at end of file diff --git a/src/tests/functional/plugin/shared/src/snippets/matmul.cpp b/src/tests/functional/plugin/shared/src/snippets/matmul.cpp new file mode 100644 index 00000000000000..0cbfc85a972e79 --- /dev/null +++ b/src/tests/functional/plugin/shared/src/snippets/matmul.cpp @@ -0,0 +1,54 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_test_utils/common_utils.hpp" +#include "snippets/matmul.hpp" +#include "subgraph_matmul.hpp" +#include "functional_test_utils/skip_tests_config.hpp" +#include "cpp_interfaces/interface/ie_internal_plugin_config.hpp" + +namespace ov { +namespace test { +namespace snippets { + +std::string MatMul::getTestCaseName(testing::TestParamInfo obj) { + std::vector input_shapes; + ov::element::Type elem_type; + std::string targetDevice; + size_t num_nodes, num_subgraphs; + std::tie(input_shapes, elem_type, num_nodes, num_subgraphs, targetDevice) = obj.param; + if (input_shapes.size() != 2) + IE_THROW() << "Invalid input shapes vector size"; + std::ostringstream result; + result << "IS[0]=" << CommonTestUtils::partialShape2str({input_shapes[0]}) << "_"; + result << "IS[1]=" << CommonTestUtils::partialShape2str({input_shapes[1]}) << "_"; + result << "T=" << elem_type << "_"; + result << "#N=" << num_nodes << "_"; + result << "#S=" << num_subgraphs << "_"; + result << "targetDevice=" << targetDevice; + return result.str(); +} + +void MatMul::SetUp() { + std::vector input_shapes; + ov::element::Type elem_type; + std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(dynamic_shapes_to_test_representation(input_shapes)); + + auto f = ov::test::snippets::MatMulSinhFunction(input_shapes); + function = f.getOriginal(); + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, + InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); + } +} + +TEST_P(MatMul, CompareWithRefImpl) { + run(); + validateNumSubgraphs(); +} + +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/tests/functional/plugin/shared/src/snippets/transpose_matmul.cpp b/src/tests/functional/plugin/shared/src/snippets/transpose_matmul.cpp new file mode 100644 index 00000000000000..ed3d057a1ab242 --- /dev/null +++ b/src/tests/functional/plugin/shared/src/snippets/transpose_matmul.cpp @@ -0,0 +1,57 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_test_utils/common_utils.hpp" +#include "snippets/transpose_matmul.hpp" +#include "subgraph_matmul.hpp" +#include "functional_test_utils/skip_tests_config.hpp" +#include "cpp_interfaces/interface/ie_internal_plugin_config.hpp" + +namespace ov { +namespace test { +namespace snippets { + +std::string TransposeMatMul::getTestCaseName(testing::TestParamInfo obj) { + std::vector input_shapes; + size_t transpose_position; + ov::element::Type elem_type; + std::string targetDevice; + size_t num_nodes, num_subgraphs; + std::tie(input_shapes, transpose_position, elem_type, num_nodes, num_subgraphs, targetDevice) = obj.param; + if (input_shapes.size() != 2) + IE_THROW() << "Invalid input shapes vector size"; + std::ostringstream result; + result << "IS[0]=" << CommonTestUtils::partialShape2str({input_shapes[0]}) << "_"; + result << "IS[1]=" << CommonTestUtils::partialShape2str({input_shapes[1]}) << "_"; + result << "Pos=" << transpose_position << "_"; + result << "T=" << elem_type << "_"; + result << "#N=" << num_nodes << "_"; + result << "#S=" << num_subgraphs << "_"; + result << "targetDevice=" << targetDevice; + return result.str(); +} + +void TransposeMatMul::SetUp() { + std::vector input_shapes; + size_t transpose_position; + ov::element::Type elem_type; + std::tie(input_shapes, transpose_position, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(dynamic_shapes_to_test_representation(input_shapes)); + + auto f = ov::test::snippets::Transpose0213MatMulSinhFunction(input_shapes, transpose_position); + function = f.getOriginal(); + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, + InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); + } +} + +TEST_P(TransposeMatMul, CompareWithRefImpl) { + run(); + validateNumSubgraphs(); +} + +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_lowered.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_lowered.hpp index 69027e96452751..7218f192a8dbcf 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_lowered.hpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_lowered.hpp @@ -8,8 +8,9 @@ #include "snippets_helpers.hpp" #include "subgraph_simple.hpp" #include "subgraph_converts.hpp" +#include "subgraph_matmul.hpp" -/* This file provides lowered representations (after the generate() was calles) for some simple functions. +/* This file provides lowered representations (after the generate() was called) for some simple functions. * This is required to test snippets lowering and optimization passes. All the functions are expected to be direct * descendants of SnippetsFunctionCustomizable (defined here) and one of the SnippetsFunctionBase derived classes * (declared in subgraph_simple.hpp). Note that the corresponding SnippetsFunctionBase child should use virtual inheritance @@ -51,6 +52,16 @@ class EltwiseThreeInputsLoweredFunction : public EltwiseThreeInputsFunction { std::vector broadcast_shapes; }; +class Transpose0213MatMulSinhLoweredFunction : public Transpose0213MatMulSinhFunction { +public: + explicit Transpose0213MatMulSinhLoweredFunction(const std::vector& inputShapes, size_t position = 0) : + Transpose0213MatMulSinhFunction(inputShapes, position, false) { + } + +protected: + std::shared_ptr initLowered() const override; +}; + } // namespace snippets } // namespace test } // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp new file mode 100644 index 00000000000000..374d24029bd6e6 --- /dev/null +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp @@ -0,0 +1,64 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "ngraph/ngraph.hpp" +#include "./snippets_helpers.hpp" + +/* This file contains definitions of relatively simple functions (models) that will be used + * to test snippets-specific behavior. All the functions are expected to be direct descendants of + * SnippetsFunctionBase, so their constructors take only one (inputShapes) argument. + */ + +namespace ov { +namespace test { +namespace snippets { +/// Minimal graph to test MatMul support +/// Works because Sinh is not supported by tokenization yet. +/// Tokenized simply by starting subgraph, +// in1 in2 +// Sinh Sinh +// Matmul +// Result +// todo: remove Sinh once "no subgraph after input" limitation is relaxed +class MatMulSinhFunction : public SnippetsFunctionBase { +public: + explicit MatMulSinhFunction(const std::vector& inputShapes) + : SnippetsFunctionBase(inputShapes) { + NGRAPH_CHECK(input_shapes.size() == 2, "Got invalid number of input shapes"); + } +protected: + std::shared_ptr initOriginal() const override; + std::shared_ptr initReference() const override; +}; + +/// Minimal graph to test MatMul+Transpose combinations. Transpose location is specified via the position argument: +/// 0 - before the first MatMul input; 1 - before the second MatMul input; 2 - after the MatMul output. +/// Tokenized simply by starting subgraph, +// in1 in2 +// Sinh Sinh +// Transpose / +// Matmul +// Result +// todo: remove Sinh once "no subgraph after input" limitation is relaxed +class Transpose0213MatMulSinhFunction : public SnippetsFunctionBase { +public: + explicit Transpose0213MatMulSinhFunction(const std::vector& inputShapes, size_t position = 0, + bool insert_guard = true) + : SnippetsFunctionBase(inputShapes), transpose_position(position), insert_guard(insert_guard) { + NGRAPH_CHECK(input_shapes.size() == 2, "Got invalid number of input shapes"); + NGRAPH_CHECK(input_shapes[0].rank().get_length() == 4 && input_shapes[1].rank().get_length() == 4, + "Only rank 4 input shapes are supported by this test"); + NGRAPH_CHECK(transpose_position >=0 && transpose_position <= 2, "Got invalid transpose position"); + } +protected: + std::shared_ptr initOriginal() const override; + size_t transpose_position; + bool insert_guard; // true if Sinh ops should be inserted after inputs +}; + +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp index afea8266be0e04..86d07b912f9ea2 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp @@ -105,6 +105,24 @@ std::shared_ptr EltwiseThreeInputsLoweredFunction::initLowered() cons } return model; } + +std::shared_ptr Transpose0213MatMulSinhLoweredFunction::initLowered() const { + ParameterVector data{std::make_shared(precision, input_shapes[0]), + std::make_shared(precision, input_shapes[1])}; + std::vector layout{0, 2, 1, 3}; + // Note: validity of transpose_position values is checked in Transpose0213MatMulSinhFunction constructor + if (transpose_position <= 1) { + auto& rt_info = data[transpose_position]->get_rt_info(); + rt_info["Layout"] = layout; + } + auto matmul = std::make_shared(data[0], data[1]); + if (transpose_position == 2) { + auto& rt_info = matmul->get_rt_info(); + rt_info["Layout"] = layout; + matmul->validate_and_infer_types(); + } + return std::make_shared(NodeVector{matmul}, data); +} } // namespace snippets } // namespace test } // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp new file mode 100644 index 00000000000000..266593a6ff8624 --- /dev/null +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp @@ -0,0 +1,58 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "subgraph_matmul.hpp" +#include "common_test_utils/data_utils.hpp" +#include + +namespace ov { +namespace test { +namespace snippets { +std::shared_ptr MatMulSinhFunction::initOriginal() const { + auto data0 = std::make_shared(precision, input_shapes[0]); + auto sinh0 = std::make_shared(data0); + auto data1 = std::make_shared(precision, input_shapes[1]); + auto sinh1 = std::make_shared(data1); + auto matmul = std::make_shared(sinh0, sinh1); + return std::make_shared(NodeVector{matmul}, ParameterVector{data0, data1}); +} +std::shared_ptr MatMulSinhFunction::initReference() const { + auto data0 = std::make_shared(precision, input_shapes[0]); + auto sinh0 = std::make_shared(data0); + auto data1 = std::make_shared(precision, input_shapes[1]); + auto sinh1 = std::make_shared(data1); + auto indata0 = std::make_shared(precision, sinh0->get_output_partial_shape(0)); + auto indata1 = std::make_shared(precision, sinh1->get_output_partial_shape(0)); + auto matmul = std::make_shared(NodeVector{sinh0, sinh1}, + std::make_shared(NodeVector{std::make_shared(indata0, indata1)}, + ParameterVector{indata0, indata1})); + return std::make_shared(NodeVector{matmul}, ParameterVector{data0, data1}); +} +std::shared_ptr Transpose0213MatMulSinhFunction::initOriginal() const { + auto data0 = std::make_shared(precision, input_shapes[0]); + auto data0_guarded = insert_guard ? std::make_shared(data0)->output(0) : data0->output(0); + auto data1 = std::make_shared(precision, input_shapes[1]); + auto data1_guarded = insert_guard ? std::make_shared(data1)->output(0) : data1->output(0); + auto const_order = std::make_shared(ov::element::i32, Shape {4}, std::vector{0, 2, 1, 3}); + std::shared_ptr result; + switch (transpose_position) { + case 0: { + auto transpose = std::make_shared(data0_guarded, const_order); + result = std::make_shared(transpose, data1_guarded); + break; + } case 1: { + auto transpose = std::make_shared(data1_guarded, const_order); + result = std::make_shared(data0_guarded, transpose); + break; + } case 2: { + auto matmul = std::make_shared(data0_guarded, data1_guarded); + result = std::make_shared(matmul, const_order); + break; + } + } + return std::make_shared(NodeVector{result}, ParameterVector{data0, data1}); +} +} // namespace snippets +} // namespace test +} // namespace ov \ No newline at end of file diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_simple.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_simple.cpp index 6fa4648a5548a9..d58660a6714eef 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_simple.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_simple.cpp @@ -147,7 +147,9 @@ std::shared_ptr EltwiseMaxNumParamsSinhFunction::initOriginal() const std::shared_ptr MatMulEltwiseBranchesFunction::initOriginal() const { auto data_1 = std::make_shared(precision, input_shapes[0]); auto data_2 = std::make_shared(precision, input_shapes[1]); - auto non_snippet_op = std::make_shared(data_1, data_2); + auto sinh_1 = std::make_shared(data_1); + auto sinh_2 = std::make_shared(data_2); + auto non_snippet_op = std::make_shared(sinh_1, sinh_2); const std::vector const_values = CommonTestUtils::generate_float_numbers(4, -10., 10.); auto mul_const_1 = op::v0::Constant::create(precision, {1}, {const_values[0]}); auto mul_1 = std::make_shared(non_snippet_op, mul_const_1); @@ -170,9 +172,11 @@ std::shared_ptr MatMulEltwiseBranchesFunction::initOriginal() const { std::shared_ptr MatMulEltwiseBranchesFunction::initReference() const { auto data_1 = std::make_shared(precision, input_shapes[0]); auto data_2 = std::make_shared(precision, input_shapes[1]); + auto sinh_1 = std::make_shared(data_1); + auto sinh_2 = std::make_shared(data_2); const std::vector const_values = CommonTestUtils::generate_float_numbers(4, -10., 10.); // snippet inputs - auto non_snippet_op = std::make_shared(data_1, data_2); + auto non_snippet_op = std::make_shared(sinh_1, sinh_2); auto mul_const_1 = std::make_shared(precision, Shape{1}, const_values[0]); auto add_const_1 = std::make_shared(precision, Shape{1}, const_values[1]); auto mul_const_2 = std::make_shared(precision, Shape{1}, const_values[2]);