From caef0ab2d2e862e20dcb475d6b887a77886839a2 Mon Sep 17 00:00:00 2001 From: Alexey Smirnov Date: Mon, 25 Nov 2024 16:59:52 +0000 Subject: [PATCH] [NPUW] New compute patterns (#27618) --- .../npuw/partitioning/online/compiler.cpp | 3 +- .../npuw/partitioning/online/snapshot.cpp | 64 +++++++++------ .../npuw/partitioning/patterns/compute.cpp | 79 +++++++++++++++++++ .../npuw/partitioning/patterns/compute.hpp | 10 +++ 4 files changed, 133 insertions(+), 23 deletions(-) diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/compiler.cpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/compiler.cpp index 89a0e0d2da9b23..a81beea0b5ab0c 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/compiler.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/compiler.cpp @@ -25,8 +25,9 @@ namespace { static const std::map ISOL_PRESETS = {{"COMPUTE", "P:DQMatMulGQu4/compute,P:DQMatMulCWu4/compute," "P:DQMatMulGQi4/compute,P:DQMatMulCWi4/compute," + "P:DQMatMulConv/compute," "P:VocabMatMul/compute," - "P:RMSNorm/compute"}}; + "P:RMSNorm/compute,P:RMSNorm2/compute"}}; } // For missing declaration warning diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.cpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.cpp index c8a27c47665021..f1ef604033481d 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.cpp @@ -465,28 +465,20 @@ void Snapshot::earlyRegroup() { break; } case PatternType::PATTERN: { - // FIXME: refactor as more patterns are supported - if (isolate.pattern == "RMSNorm") { - rewr.add_matcher(shared_from_this(), isolate.tag); - handle_patterns = true; - } else if (isolate.pattern == "DQMatMulCWu4") { - rewr.add_matcher(shared_from_this(), isolate.tag); - handle_patterns = true; - } else if (isolate.pattern == "DQMatMulGQu4") { - rewr.add_matcher(shared_from_this(), isolate.tag); - handle_patterns = true; - } else if (isolate.pattern == "DQMatMulCWi4") { - rewr.add_matcher(shared_from_this(), isolate.tag); - handle_patterns = true; - } else if (isolate.pattern == "DQMatMulGQi4") { - rewr.add_matcher(shared_from_this(), isolate.tag); - handle_patterns = true; - } else if (isolate.pattern == "VocabMatMul") { - rewr.add_matcher(shared_from_this(), isolate.tag); - handle_patterns = true; - } else { - LOG_WARN("OPENVINO_NPUW_ISOLATE: unsupported pattern " << isolate.pattern << " is skipped!"); - } +#define HNDL(p) \ + if (isolate.pattern == #p) { \ + rewr.add_matcher(shared_from_this(), isolate.tag); \ + handle_patterns = true; \ + } + HNDL(RMSNorm); + HNDL(RMSNorm2); + HNDL(DQMatMulCWu4); + HNDL(DQMatMulGQu4); + HNDL(DQMatMulCWi4); + HNDL(DQMatMulGQi4); + HNDL(DQMatMulConv); + HNDL(VocabMatMul); +#undef HNDL } } } @@ -723,6 +715,20 @@ std::shared_ptr Snapshot::tryMergeTriangles(const std::vector Snapshot::tryMergeRepeating(const std::vector new_rep = std::make_shared(); for (size_t i = 0; i < conss.size(); ++i) { diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.cpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.cpp index b082d67037db7d..d39c2363b1cd64 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.cpp @@ -226,6 +226,51 @@ DQMatMulCWi4::DQMatMulCWi4(const std::shared_ptr& sn register_matcher(std::make_shared(qmm, "TagDQMatMulCWi4"), std::move(callback)); } +// Pattern: +// -> Transpose ------------------------------> +// Param/Const --> Convert(f32) --> Multiply -> Convolution -> Transpose -> +// Param/Const -> (Convert(f32)) -> + +DQMatMulConv::DQMatMulConv(const std::shared_ptr& snapshot, const std::string& isol_tag) { + auto param = opp::any_input(); + auto convert = opp::wrap_type({param->output(0)}); + auto param2 = opp::any_input(); + auto convert2 = opp::optional({param2->output(0)}); + auto multiply = opp::wrap_type({convert, convert2}); + auto transpose_in = opp::wrap_type({opp::any_input(), opp::any_input()}); + auto conv = opp::wrap_type({transpose_in, multiply}); + auto transpose_out = opp::wrap_type({conv, opp::any_input()}); + + auto node_to_gptr = snapshot->getNodeToGroupMap(); + + // Note: Use [=] to make sure the above objects stay alive in the callback + auto callback = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + + auto matched_node_param = node_to_output.at(param); + auto matched_node_param2 = node_to_output.at(param2); + + auto matched_node_transpose_in = node_to_output.at(transpose_in).get_node_shared_ptr(); + auto matched_node_transpose_out = node_to_output.at(transpose_out).get_node_shared_ptr(); + auto matched_node_multiply = node_to_output.at(multiply).get_node_shared_ptr(); + auto matched_node_conv = node_to_output.at(conv).get_node_shared_ptr(); + + if ((matched_node_param.get_element_type() == ov::element::i4 || + matched_node_param.get_element_type() == ov::element::i8) && + (matched_node_param2.get_element_type() == ov::element::f32 || + matched_node_param2.get_element_type() == ov::element::f16)) { + // Partitioning ignores Param/Const -> Convert nodes + node_to_gptr->at(matched_node_transpose_in)->isolate(isol_tag); + node_to_gptr->at(matched_node_transpose_out)->isolate(isol_tag); + node_to_gptr->at(matched_node_multiply)->isolate(isol_tag); + node_to_gptr->at(matched_node_conv)->isolate(isol_tag); + } + + return false; // root hasn't changed + }; + register_matcher(std::make_shared(transpose_out, "TagDQMatMulConv"), std::move(callback)); +} + // This is a case for Raw (f16/f32) MatMul connected directly to the Result. // // The following combinations are covered: @@ -327,6 +372,40 @@ RMSNorm::RMSNorm(const std::shared_ptr& snapshot, co register_matcher(std::make_shared(multiply2, "TagRMSNorm"), std::move(callback)); } +// TODO: visualize +RMSNorm2::RMSNorm2(const std::shared_ptr& snapshot, const std::string& isol_tag) { + auto hadd = opp::wrap_type({opp::any_input(), opp::any_input()}); + auto power = opp::wrap_type({hadd, opp::any_input()}); + auto reduce = opp::wrap_type({power, opp::any_input()}); + auto sqrt = opp::wrap_type({reduce}); + auto div = opp::wrap_type({hadd, sqrt}); + auto multiply = opp::wrap_type({opp::any_input(), div}); + + auto node_to_gptr = snapshot->getNodeToGroupMap(); + + // Note: Use [=] to make sure the above objects stay alive in the callback + auto callback = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + + auto matched_hadd = node_to_output.at(hadd).get_node_shared_ptr(); + auto matched_power = node_to_output.at(power).get_node_shared_ptr(); + auto matched_reduce = node_to_output.at(reduce).get_node_shared_ptr(); + auto matched_sqrt = node_to_output.at(sqrt).get_node_shared_ptr(); + auto matched_div = node_to_output.at(div).get_node_shared_ptr(); + auto matched_multiply = node_to_output.at(multiply).get_node_shared_ptr(); + + node_to_gptr->at(matched_hadd)->isolate(isol_tag); + node_to_gptr->at(matched_power)->isolate(isol_tag); + node_to_gptr->at(matched_reduce)->isolate(isol_tag); + node_to_gptr->at(matched_sqrt)->isolate(isol_tag); + node_to_gptr->at(matched_div)->isolate(isol_tag); + node_to_gptr->at(matched_multiply)->isolate(isol_tag); + + return false; // root hasn't changed + }; + register_matcher(std::make_shared(multiply, "TagRMSNorm2"), std::move(callback)); +} + } // namespace compute } // namespace patterns } // namespace npuw diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.hpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.hpp index faa2fe3f0f9578..77bc9fb3f90418 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.hpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.hpp @@ -41,6 +41,11 @@ class DQMatMulCWi4 : public ov::pass::MatcherPass { DQMatMulCWi4(const std::shared_ptr& snapshot, const std::string& isol_tag); }; +class DQMatMulConv : public ov::pass::MatcherPass { +public: + DQMatMulConv(const std::shared_ptr& snapshot, const std::string& isol_tag); +}; + class VocabMatMul : public ov::pass::MatcherPass { public: VocabMatMul(const std::shared_ptr& snapshot, const std::string& isol_tag); @@ -51,6 +56,11 @@ class RMSNorm : public ov::pass::MatcherPass { RMSNorm(const std::shared_ptr& snapshot, const std::string& isol_tag); }; +class RMSNorm2 : public ov::pass::MatcherPass { +public: + RMSNorm2(const std::shared_ptr& snapshot, const std::string& isol_tag); +}; + } // namespace compute } // namespace patterns } // namespace npuw