diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/compiler.cpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/compiler.cpp index 2951ea353968da..a81beea0b5ab0c 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/compiler.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/compiler.cpp @@ -27,7 +27,7 @@ static const std::map ISOL_PRESETS = {{"COMPUTE", "P:DQMatMulGQi4/compute,P:DQMatMulCWi4/compute," "P:DQMatMulConv/compute," "P:VocabMatMul/compute," - "P:RMSNorm/compute"}}; + "P:RMSNorm/compute,P:RMSNorm2/compute"}}; } // For missing declaration warning diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.cpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.cpp index d8fbad01b1f775..6767e7314f62dc 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.cpp @@ -469,6 +469,9 @@ void Snapshot::earlyRegroup() { if (isolate.pattern == "RMSNorm") { rewr.add_matcher(shared_from_this(), isolate.tag); handle_patterns = true; + } else if (isolate.pattern == "RMSNorm2") { + rewr.add_matcher(shared_from_this(), isolate.tag); + handle_patterns = true; } else if (isolate.pattern == "DQMatMulCWu4") { rewr.add_matcher(shared_from_this(), isolate.tag); handle_patterns = true; diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.cpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.cpp index 14c91b70e279c2..b6da99e8b4d5ea 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.cpp @@ -348,6 +348,7 @@ RMSNorm::RMSNorm(const std::shared_ptr& snapshot, co // Note: Use [=] to make sure the above objects stay alive in the callback auto callback = [=](ov::pass::pattern::Matcher& m) { + std::cout << "RMSNorm MATCHED!" << std::endl; auto& node_to_output = m.get_pattern_value_map(); auto matched_hadd = node_to_output.at(hadd).get_node_shared_ptr(); @@ -373,6 +374,40 @@ RMSNorm::RMSNorm(const std::shared_ptr& snapshot, co register_matcher(std::make_shared(multiply2, "TagRMSNorm"), std::move(callback)); } +// TODO: visualize +RMSNorm2::RMSNorm2(const std::shared_ptr& snapshot, const std::string& isol_tag) { + auto hadd = opp::wrap_type({opp::any_input(), opp::any_input()}); + auto power = opp::wrap_type({hadd, opp::any_input()}); + auto reduce = opp::wrap_type({power, opp::any_input()}); + auto sqrt = opp::wrap_type({reduce}); + auto div = opp::wrap_type({hadd, sqrt}); + auto multiply = opp::wrap_type({opp::any_input(), div}); + + auto node_to_gptr = snapshot->getNodeToGroupMap(); + + // Note: Use [=] to make sure the above objects stay alive in the callback + auto callback = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + + auto matched_hadd = node_to_output.at(hadd).get_node_shared_ptr(); + auto matched_power = node_to_output.at(power).get_node_shared_ptr(); + auto matched_reduce = node_to_output.at(reduce).get_node_shared_ptr(); + auto matched_sqrt = node_to_output.at(sqrt).get_node_shared_ptr(); + auto matched_div = node_to_output.at(div).get_node_shared_ptr(); + auto matched_multiply = node_to_output.at(multiply).get_node_shared_ptr(); + + node_to_gptr->at(matched_hadd)->isolate(isol_tag); + node_to_gptr->at(matched_power)->isolate(isol_tag); + node_to_gptr->at(matched_reduce)->isolate(isol_tag); + node_to_gptr->at(matched_sqrt)->isolate(isol_tag); + node_to_gptr->at(matched_div)->isolate(isol_tag); + node_to_gptr->at(matched_multiply)->isolate(isol_tag); + + return false; // root hasn't changed + }; + register_matcher(std::make_shared(multiply, "TagRMSNorm2"), std::move(callback)); +} + } // namespace compute } // namespace patterns } // namespace npuw diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.hpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.hpp index ab6c177d8ef924..77bc9fb3f90418 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.hpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/compute.hpp @@ -56,6 +56,11 @@ class RMSNorm : public ov::pass::MatcherPass { RMSNorm(const std::shared_ptr& snapshot, const std::string& isol_tag); }; +class RMSNorm2 : public ov::pass::MatcherPass { +public: + RMSNorm2(const std::shared_ptr& snapshot, const std::string& isol_tag); +}; + } // namespace compute } // namespace patterns } // namespace npuw