-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NPUW] New compute patterns #27618
[NPUW] New compute patterns #27618
Changes from 5 commits
d4cd1e2
8456ced
c821d30
295a5b0
65686c2
bf1c5a9
224c0e3
0c1ee14
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -465,28 +465,20 @@ void Snapshot::earlyRegroup() { | |
break; | ||
} | ||
case PatternType::PATTERN: { | ||
// FIXME: refactor as more patterns are supported | ||
if (isolate.pattern == "RMSNorm") { | ||
rewr.add_matcher<ov::npuw::patterns::compute::RMSNorm>(shared_from_this(), isolate.tag); | ||
handle_patterns = true; | ||
} else if (isolate.pattern == "DQMatMulCWu4") { | ||
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulCWu4>(shared_from_this(), isolate.tag); | ||
handle_patterns = true; | ||
} else if (isolate.pattern == "DQMatMulGQu4") { | ||
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulGQu4>(shared_from_this(), isolate.tag); | ||
handle_patterns = true; | ||
} else if (isolate.pattern == "DQMatMulCWi4") { | ||
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulCWi4>(shared_from_this(), isolate.tag); | ||
handle_patterns = true; | ||
} else if (isolate.pattern == "DQMatMulGQi4") { | ||
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulGQi4>(shared_from_this(), isolate.tag); | ||
handle_patterns = true; | ||
} else if (isolate.pattern == "VocabMatMul") { | ||
rewr.add_matcher<ov::npuw::patterns::compute::VocabMatMul>(shared_from_this(), isolate.tag); | ||
handle_patterns = true; | ||
} else { | ||
LOG_WARN("OPENVINO_NPUW_ISOLATE: unsupported pattern " << isolate.pattern << " is skipped!"); | ||
} | ||
#define HNDL(p) \ | ||
if (isolate.pattern == #p) { \ | ||
rewr.add_matcher<ov::npuw::patterns::compute::p>(shared_from_this(), isolate.tag); \ | ||
handle_patterns = true; \ | ||
} | ||
HNDL(RMSNorm); | ||
HNDL(RMSNorm2); | ||
HNDL(DQMatMulCWu4); | ||
HNDL(DQMatMulGQu4); | ||
HNDL(DQMatMulCWi4); | ||
HNDL(DQMatMulGQi4); | ||
HNDL(DQMatMulConv); | ||
HNDL(VocabMatMul); | ||
#undef HNDL | ||
} | ||
} | ||
} | ||
|
@@ -723,6 +715,20 @@ std::shared_ptr<Repeated> Snapshot::tryMergeTriangles(const std::vector<Group::G | |
return {}; | ||
} | ||
|
||
if (prods.size() < m_ctx.keep_blocks) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can add a property to enable those 2 checks so it doesn't break the default config There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dmatveev thoughts? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've convinced myself that this check is harmless |
||
// In some cases (specifically mixed precision) during MergeUniques() pass we could be left with | ||
// E.g. 10 repeated blocks with tag AAA and 2 repeated blocks with tag BBB | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The missing part for this great comment is to indicate that the structure may look like
So our algorithm won't merge |
||
// TryMergeTriangles() pass checks that producer and consumer have a different tag to be merged further. | ||
// Let's say in our example 10 AAA blocks are finalized and cannot be merged further due to above check. | ||
// However we will proceed to merge 3 BBB blocks with 3 AAA blocks since the tags are different. | ||
// This will create a new tag CCC for the merged blocks and the merge will continue until those 3 blocks | ||
// consume a large amount of legit AAA blocks. | ||
// Later in CleanUpUniques() pass those repeated blocks will be stripped off repeated tag due to the same check | ||
// in this "if". To prevent such cases where we would end up with small number of huge blocks this check was | ||
// introduced. | ||
return {}; | ||
} | ||
|
||
// In this special case we only assume | ||
// our vector of N repeating consumer groups | ||
// 1. has the same size | ||
|
@@ -939,6 +945,20 @@ std::shared_ptr<Repeated> Snapshot::tryMergeRepeating(const std::vector<Group::G | |
} | ||
} | ||
|
||
if (prods.size() < m_ctx.keep_blocks) { | ||
// In some cases (specifically mixed precision) during MergeUniques() pass we could be left with | ||
// E.g. 10 repeated blocks with tag AAA and 2 repeated blocks with tag BBB | ||
// TryMergeRepeating() pass checks that producer and consumer have a different tag to be merged further. | ||
// Let's say in our example 10 AAA blocks are finalized and cannot be merged further due to above check. | ||
// However we will proceed to merge 3 BBB blocks with 3 AAA blocks since the tags are different. | ||
// This will create a new tag CCC for the merged blocks and the merge will continue until those 3 blocks | ||
// consume a large amount of legit AAA blocks. | ||
// Later in CleanUpUniques() pass those repeated blocks will be stripped off repeated tag due to the same check | ||
// in this "if". To prevent such cases where we would end up with small number of huge blocks this check was | ||
// introduced. | ||
Comment on lines
+949
to
+958
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if this is the full copy of the text above, I'd just refer to one - like "see comment in...". |
||
return {}; | ||
} | ||
|
||
std::shared_ptr<Repeated> new_rep = std::make_shared<Repeated>(); | ||
|
||
for (size_t i = 0; i < conss.size(); ++i) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -226,6 +226,51 @@ DQMatMulCWi4::DQMatMulCWi4(const std::shared_ptr<ov::npuw::online::Snapshot>& sn | |
register_matcher(std::make_shared<opp::Matcher>(qmm, "TagDQMatMulCWi4"), std::move(callback)); | ||
} | ||
|
||
// Pattern: | ||
// -> Transpose ------------------------------> | ||
// Param/Const --> Convert(f32) --> Multiply -> Convolution -> Transpose -> | ||
// Param/Const -> (Convert(f32)) -> | ||
|
||
DQMatMulConv::DQMatMulConv(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) { | ||
auto param = opp::any_input(); | ||
auto convert = opp::wrap_type<ov::op::v0::Convert>({param->output(0)}); | ||
auto param2 = opp::any_input(); | ||
auto convert2 = opp::optional<ov::op::v0::Convert>({param2->output(0)}); | ||
auto multiply = opp::wrap_type<ov::op::v1::Multiply>({convert, convert2}); | ||
auto transpose_in = opp::wrap_type<ov::op::v1::Transpose>({opp::any_input(), opp::any_input()}); | ||
auto conv = opp::wrap_type<ov::op::v1::Convolution>({transpose_in, multiply}); | ||
auto transpose_out = opp::wrap_type<ov::op::v1::Transpose>({conv, opp::any_input()}); | ||
|
||
auto node_to_gptr = snapshot->getNodeToGroupMap(); | ||
|
||
// Note: Use [=] to make sure the above objects stay alive in the callback | ||
auto callback = [=](ov::pass::pattern::Matcher& m) { | ||
auto& node_to_output = m.get_pattern_value_map(); | ||
|
||
auto matched_node_param = node_to_output.at(param); | ||
auto matched_node_param2 = node_to_output.at(param2); | ||
|
||
auto matched_node_transpose_in = node_to_output.at(transpose_in).get_node_shared_ptr(); | ||
auto matched_node_transpose_out = node_to_output.at(transpose_out).get_node_shared_ptr(); | ||
auto matched_node_multiply = node_to_output.at(multiply).get_node_shared_ptr(); | ||
auto matched_node_conv = node_to_output.at(conv).get_node_shared_ptr(); | ||
|
||
if ((matched_node_param.get_element_type() == ov::element::i4 || | ||
matched_node_param.get_element_type() == ov::element::i8) && | ||
(matched_node_param2.get_element_type() == ov::element::f32 || | ||
matched_node_param2.get_element_type() == ov::element::f16)) { | ||
// Partitioning ignores Param/Const -> Convert nodes | ||
node_to_gptr->at(matched_node_transpose_in)->isolate(isol_tag); | ||
node_to_gptr->at(matched_node_transpose_out)->isolate(isol_tag); | ||
node_to_gptr->at(matched_node_multiply)->isolate(isol_tag); | ||
node_to_gptr->at(matched_node_conv)->isolate(isol_tag); | ||
} | ||
|
||
return false; // root hasn't changed | ||
}; | ||
register_matcher(std::make_shared<opp::Matcher>(transpose_out, "TagDQMatMulConv"), std::move(callback)); | ||
} | ||
|
||
// This is a case for Raw (f16/f32) MatMul connected directly to the Result. | ||
// | ||
// The following combinations are covered: | ||
|
@@ -327,6 +372,40 @@ RMSNorm::RMSNorm(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, co | |
register_matcher(std::make_shared<opp::Matcher>(multiply2, "TagRMSNorm"), std::move(callback)); | ||
} | ||
|
||
// TODO: visualize | ||
RMSNorm2::RMSNorm2(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) { | ||
auto hadd = opp::wrap_type<ov::op::v1::Add>({opp::any_input(), opp::any_input()}); | ||
auto power = opp::wrap_type<ov::op::v1::Power>({hadd, opp::any_input()}); | ||
auto reduce = opp::wrap_type<ov::op::v1::ReduceSum>({power, opp::any_input()}); | ||
auto sqrt = opp::wrap_type<ov::op::v0::Sqrt>({reduce}); | ||
auto div = opp::wrap_type<ov::op::v1::Divide>({hadd, sqrt}); | ||
auto multiply = opp::wrap_type<ov::op::v1::Multiply>({opp::any_input(), div}); | ||
Comment on lines
+377
to
+382
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How different it is compared to the normal RMSNorm? Can those be combined via optional or choice over add vs reduce, etc? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not possible, the first Add connects to different nodes in the pattern |
||
|
||
auto node_to_gptr = snapshot->getNodeToGroupMap(); | ||
|
||
// Note: Use [=] to make sure the above objects stay alive in the callback | ||
auto callback = [=](ov::pass::pattern::Matcher& m) { | ||
auto& node_to_output = m.get_pattern_value_map(); | ||
|
||
auto matched_hadd = node_to_output.at(hadd).get_node_shared_ptr(); | ||
auto matched_power = node_to_output.at(power).get_node_shared_ptr(); | ||
auto matched_reduce = node_to_output.at(reduce).get_node_shared_ptr(); | ||
auto matched_sqrt = node_to_output.at(sqrt).get_node_shared_ptr(); | ||
auto matched_div = node_to_output.at(div).get_node_shared_ptr(); | ||
auto matched_multiply = node_to_output.at(multiply).get_node_shared_ptr(); | ||
|
||
node_to_gptr->at(matched_hadd)->isolate(isol_tag); | ||
node_to_gptr->at(matched_power)->isolate(isol_tag); | ||
node_to_gptr->at(matched_reduce)->isolate(isol_tag); | ||
node_to_gptr->at(matched_sqrt)->isolate(isol_tag); | ||
node_to_gptr->at(matched_div)->isolate(isol_tag); | ||
node_to_gptr->at(matched_multiply)->isolate(isol_tag); | ||
|
||
return false; // root hasn't changed | ||
}; | ||
register_matcher(std::make_shared<opp::Matcher>(multiply, "TagRMSNorm2"), std::move(callback)); | ||
} | ||
|
||
} // namespace compute | ||
} // namespace patterns | ||
} // namespace npuw | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did we have 1:1 mapping here between the type names (
p
) and their written mnemonics (.pattern
) in the config?