Skip to content

Commit

Permalink
Add another pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
smirnov-alexey committed Nov 19, 2024
1 parent 8456ced commit c821d30
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ static const std::map<std::string, std::string> ISOL_PRESETS = {{"COMPUTE",
"P:DQMatMulGQi4/compute,P:DQMatMulCWi4/compute,"
"P:DQMatMulConv/compute,"
"P:VocabMatMul/compute,"
"P:RMSNorm/compute"}};
"P:RMSNorm/compute,P:RMSNorm2/compute"}};
}

// For missing declaration warning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,9 @@ void Snapshot::earlyRegroup() {
if (isolate.pattern == "RMSNorm") {
rewr.add_matcher<ov::npuw::patterns::compute::RMSNorm>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "RMSNorm2") {
rewr.add_matcher<ov::npuw::patterns::compute::RMSNorm2>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "DQMatMulCWu4") {
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulCWu4>(shared_from_this(), isolate.tag);
handle_patterns = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ RMSNorm::RMSNorm(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, co

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
std::cout << "RMSNorm MATCHED!" << std::endl;
auto& node_to_output = m.get_pattern_value_map();

auto matched_hadd = node_to_output.at(hadd).get_node_shared_ptr();
Expand All @@ -373,6 +374,40 @@ RMSNorm::RMSNorm(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, co
register_matcher(std::make_shared<opp::Matcher>(multiply2, "TagRMSNorm"), std::move(callback));
}

// TODO: visualize
RMSNorm2::RMSNorm2(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
auto hadd = opp::wrap_type<ov::op::v1::Add>({opp::any_input(), opp::any_input()});
auto power = opp::wrap_type<ov::op::v1::Power>({hadd, opp::any_input()});
auto reduce = opp::wrap_type<ov::op::v1::ReduceSum>({power, opp::any_input()});
auto sqrt = opp::wrap_type<ov::op::v0::Sqrt>({reduce});
auto div = opp::wrap_type<ov::op::v1::Divide>({hadd, sqrt});
auto multiply = opp::wrap_type<ov::op::v1::Multiply>({opp::any_input(), div});

auto node_to_gptr = snapshot->getNodeToGroupMap();

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_hadd = node_to_output.at(hadd).get_node_shared_ptr();
auto matched_power = node_to_output.at(power).get_node_shared_ptr();
auto matched_reduce = node_to_output.at(reduce).get_node_shared_ptr();
auto matched_sqrt = node_to_output.at(sqrt).get_node_shared_ptr();
auto matched_div = node_to_output.at(div).get_node_shared_ptr();
auto matched_multiply = node_to_output.at(multiply).get_node_shared_ptr();

node_to_gptr->at(matched_hadd)->isolate(isol_tag);
node_to_gptr->at(matched_power)->isolate(isol_tag);
node_to_gptr->at(matched_reduce)->isolate(isol_tag);
node_to_gptr->at(matched_sqrt)->isolate(isol_tag);
node_to_gptr->at(matched_div)->isolate(isol_tag);
node_to_gptr->at(matched_multiply)->isolate(isol_tag);

return false; // root hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(multiply, "TagRMSNorm2"), std::move(callback));
}

} // namespace compute
} // namespace patterns
} // namespace npuw
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ class RMSNorm : public ov::pass::MatcherPass {
RMSNorm(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
};

class RMSNorm2 : public ov::pass::MatcherPass {
public:
RMSNorm2(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
};

} // namespace compute
} // namespace patterns
} // namespace npuw
Expand Down

0 comments on commit c821d30

Please sign in to comment.