Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPUW] New compute patterns #27618

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ namespace {
static const std::map<std::string, std::string> ISOL_PRESETS = {{"COMPUTE",
"P:DQMatMulGQu4/compute,P:DQMatMulCWu4/compute,"
"P:DQMatMulGQi4/compute,P:DQMatMulCWi4/compute,"
"P:DQMatMulConv/compute,"
"P:VocabMatMul/compute,"
"P:RMSNorm/compute"}};
"P:RMSNorm/compute,P:RMSNorm2/compute"}};
}

// For missing declaration warning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,28 +465,20 @@ void Snapshot::earlyRegroup() {
break;
}
case PatternType::PATTERN: {
// FIXME: refactor as more patterns are supported
if (isolate.pattern == "RMSNorm") {
rewr.add_matcher<ov::npuw::patterns::compute::RMSNorm>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "DQMatMulCWu4") {
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulCWu4>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "DQMatMulGQu4") {
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulGQu4>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "DQMatMulCWi4") {
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulCWi4>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "DQMatMulGQi4") {
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulGQi4>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "VocabMatMul") {
rewr.add_matcher<ov::npuw::patterns::compute::VocabMatMul>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else {
LOG_WARN("OPENVINO_NPUW_ISOLATE: unsupported pattern " << isolate.pattern << " is skipped!");
}
#define HNDL(p) \
if (isolate.pattern == #p) { \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we have 1:1 mapping here between the type names (p) and their written mnemonics (.pattern) in the config?

rewr.add_matcher<ov::npuw::patterns::compute::p>(shared_from_this(), isolate.tag); \
handle_patterns = true; \
}
HNDL(RMSNorm);
HNDL(RMSNorm2);
HNDL(DQMatMulCWu4);
HNDL(DQMatMulGQu4);
HNDL(DQMatMulCWi4);
HNDL(DQMatMulGQi4);
HNDL(DQMatMulConv);
HNDL(VocabMatMul);
#undef HNDL
}
}
}
Expand Down Expand Up @@ -723,6 +715,20 @@ std::shared_ptr<Repeated> Snapshot::tryMergeTriangles(const std::vector<Group::G
return {};
}

if (prods.size() < m_ctx.keep_blocks) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can add a property to enable those 2 checks so it doesn't break the default config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dmatveev thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've convinced myself that this check is harmless

// In some cases (specifically mixed precision) during MergeUniques() pass we could be left with
// E.g. 10 repeated blocks with tag AAA and 2 repeated blocks with tag BBB
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The missing part for this great comment is to indicate that the structure may look like

AAA -> AAA -> AAA -> BBB -> AAA -> AAA -> AAA -> BBB

So our algorithm won't merge AAA + AAA but will keep merging AAA + (AAA)BBB so the RHS will grow enormously

// TryMergeTriangles() pass checks that producer and consumer have a different tag to be merged further.
// Let's say in our example 10 AAA blocks are finalized and cannot be merged further due to above check.
// However we will proceed to merge 3 BBB blocks with 3 AAA blocks since the tags are different.
// This will create a new tag CCC for the merged blocks and the merge will continue until those 3 blocks
// consume a large amount of legit AAA blocks.
// Later in CleanUpUniques() pass those repeated blocks will be stripped off repeated tag due to the same check
// in this "if". To prevent such cases where we would end up with small number of huge blocks this check was
// introduced.
return {};
}

// In this special case we only assume
// our vector of N repeating consumer groups
// 1. has the same size
Expand Down Expand Up @@ -939,6 +945,20 @@ std::shared_ptr<Repeated> Snapshot::tryMergeRepeating(const std::vector<Group::G
}
}

if (prods.size() < m_ctx.keep_blocks) {
// In some cases (specifically mixed precision) during MergeUniques() pass we could be left with
// E.g. 10 repeated blocks with tag AAA and 2 repeated blocks with tag BBB
// TryMergeRepeating() pass checks that producer and consumer have a different tag to be merged further.
// Let's say in our example 10 AAA blocks are finalized and cannot be merged further due to above check.
// However we will proceed to merge 3 BBB blocks with 3 AAA blocks since the tags are different.
// This will create a new tag CCC for the merged blocks and the merge will continue until those 3 blocks
// consume a large amount of legit AAA blocks.
// Later in CleanUpUniques() pass those repeated blocks will be stripped off repeated tag due to the same check
// in this "if". To prevent such cases where we would end up with small number of huge blocks this check was
// introduced.
Comment on lines +949 to +958
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is the full copy of the text above, I'd just refer to one - like "see comment in...".

return {};
}

std::shared_ptr<Repeated> new_rep = std::make_shared<Repeated>();

for (size_t i = 0; i < conss.size(); ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,51 @@ DQMatMulCWi4::DQMatMulCWi4(const std::shared_ptr<ov::npuw::online::Snapshot>& sn
register_matcher(std::make_shared<opp::Matcher>(qmm, "TagDQMatMulCWi4"), std::move(callback));
}

// Pattern:
// -> Transpose ------------------------------>
// Param/Const --> Convert(f32) --> Multiply -> Convolution -> Transpose ->
// Param/Const -> (Convert(f32)) ->

DQMatMulConv::DQMatMulConv(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
auto param = opp::any_input();
auto convert = opp::wrap_type<ov::op::v0::Convert>({param->output(0)});
auto param2 = opp::any_input();
auto convert2 = opp::optional<ov::op::v0::Convert>({param2->output(0)});
auto multiply = opp::wrap_type<ov::op::v1::Multiply>({convert, convert2});
auto transpose_in = opp::wrap_type<ov::op::v1::Transpose>({opp::any_input(), opp::any_input()});
auto conv = opp::wrap_type<ov::op::v1::Convolution>({transpose_in, multiply});
auto transpose_out = opp::wrap_type<ov::op::v1::Transpose>({conv, opp::any_input()});

auto node_to_gptr = snapshot->getNodeToGroupMap();

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_node_param = node_to_output.at(param);
auto matched_node_param2 = node_to_output.at(param2);

auto matched_node_transpose_in = node_to_output.at(transpose_in).get_node_shared_ptr();
auto matched_node_transpose_out = node_to_output.at(transpose_out).get_node_shared_ptr();
auto matched_node_multiply = node_to_output.at(multiply).get_node_shared_ptr();
auto matched_node_conv = node_to_output.at(conv).get_node_shared_ptr();

if ((matched_node_param.get_element_type() == ov::element::i4 ||
matched_node_param.get_element_type() == ov::element::i8) &&
(matched_node_param2.get_element_type() == ov::element::f32 ||
matched_node_param2.get_element_type() == ov::element::f16)) {
// Partitioning ignores Param/Const -> Convert nodes
node_to_gptr->at(matched_node_transpose_in)->isolate(isol_tag);
node_to_gptr->at(matched_node_transpose_out)->isolate(isol_tag);
node_to_gptr->at(matched_node_multiply)->isolate(isol_tag);
node_to_gptr->at(matched_node_conv)->isolate(isol_tag);
}

return false; // root hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(transpose_out, "TagDQMatMulConv"), std::move(callback));
}

// This is a case for Raw (f16/f32) MatMul connected directly to the Result.
//
// The following combinations are covered:
Expand Down Expand Up @@ -327,6 +372,40 @@ RMSNorm::RMSNorm(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, co
register_matcher(std::make_shared<opp::Matcher>(multiply2, "TagRMSNorm"), std::move(callback));
}

// TODO: visualize
RMSNorm2::RMSNorm2(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
auto hadd = opp::wrap_type<ov::op::v1::Add>({opp::any_input(), opp::any_input()});
auto power = opp::wrap_type<ov::op::v1::Power>({hadd, opp::any_input()});
auto reduce = opp::wrap_type<ov::op::v1::ReduceSum>({power, opp::any_input()});
auto sqrt = opp::wrap_type<ov::op::v0::Sqrt>({reduce});
auto div = opp::wrap_type<ov::op::v1::Divide>({hadd, sqrt});
auto multiply = opp::wrap_type<ov::op::v1::Multiply>({opp::any_input(), div});
Comment on lines +377 to +382
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How different it is compared to the normal RMSNorm? Can those be combined via optional or choice over add vs reduce, etc?

Copy link
Contributor Author

@smirnov-alexey smirnov-alexey Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not possible, the first Add connects to different nodes in the pattern


auto node_to_gptr = snapshot->getNodeToGroupMap();

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_hadd = node_to_output.at(hadd).get_node_shared_ptr();
auto matched_power = node_to_output.at(power).get_node_shared_ptr();
auto matched_reduce = node_to_output.at(reduce).get_node_shared_ptr();
auto matched_sqrt = node_to_output.at(sqrt).get_node_shared_ptr();
auto matched_div = node_to_output.at(div).get_node_shared_ptr();
auto matched_multiply = node_to_output.at(multiply).get_node_shared_ptr();

node_to_gptr->at(matched_hadd)->isolate(isol_tag);
node_to_gptr->at(matched_power)->isolate(isol_tag);
node_to_gptr->at(matched_reduce)->isolate(isol_tag);
node_to_gptr->at(matched_sqrt)->isolate(isol_tag);
node_to_gptr->at(matched_div)->isolate(isol_tag);
node_to_gptr->at(matched_multiply)->isolate(isol_tag);

return false; // root hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(multiply, "TagRMSNorm2"), std::move(callback));
}

} // namespace compute
} // namespace patterns
} // namespace npuw
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ class DQMatMulCWi4 : public ov::pass::MatcherPass {
DQMatMulCWi4(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
};

class DQMatMulConv : public ov::pass::MatcherPass {
public:
DQMatMulConv(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
};

class VocabMatMul : public ov::pass::MatcherPass {
public:
VocabMatMul(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
Expand All @@ -51,6 +56,11 @@ class RMSNorm : public ov::pass::MatcherPass {
RMSNorm(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
};

class RMSNorm2 : public ov::pass::MatcherPass {
public:
RMSNorm2(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
};

} // namespace compute
} // namespace patterns
} // namespace npuw
Expand Down
Loading