Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
smirnov-alexey committed Nov 12, 2024
1 parent eb41f1f commit 1cfa0e1
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1799,7 +1799,6 @@ void Partitioner::optimize(const std::string& func_name) {
rewr.add_matcher<ov::npuw::patterns::opt::DQParMMGQ>(std::ref(ctx));
// Convert specific convolutions to matmuls
rewr.add_matcher<ov::npuw::patterns::opt::ConvToMatmul>(std::ref(ctx));
rewr.add_matcher<ov::npuw::patterns::opt::ConvToMatmul2>(std::ref(ctx));
rewr.run_on_model(f._model);

// Move Gather to host, if required
Expand Down
141 changes: 53 additions & 88 deletions src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,23 +193,24 @@ DQMatMulCWi::DQMatMulCWi() {
}

// FROM:
// ???(Act) ------------------------>
// Param(W) -> to(f32) -> Multiply -> MatMul
// Param(S) ------------>
// ???(Act) ------------------------------------------->
// Param(W) -------> Reshape --> to(f32) --> Multiply -> MatMul
// Param/Const(S) -> Reshape -> (to(f32)) ->
//
// TO:
// ???(Act) -> to(f32) ->
// Param(W) -> to(f32) -> MatMul -> Multiply
// Param(S) -> Reshape ----------->
// ???(Act) -> to(f32) -------------------->
// Param(W) -------> Reshape --> to(f32) --> MatMul --> Multiply
// Param/Const(S) -> Reshape -> (to(f32)) -> Reshape ->
//

DQMatMulCWi2::DQMatMulCWi2() {
auto qweight = opp::wrap_type<ov::op::v0::Parameter>();
auto qcoeff = opp::wrap_type<ov::op::v0::Parameter>();
auto qcoeff = opp::any_input();
auto reshape = opp::wrap_type<ov::op::v1::Reshape>({qweight, opp::any_input()});
auto reshape2 = opp::wrap_type<ov::op::v1::Reshape>({qcoeff, opp::any_input()});
auto qcvtw = opp::wrap_type<ov::op::v0::Convert>({reshape});
auto qmuls = opp::wrap_type<ov::op::v1::Multiply>({qcvtw, reshape2});
auto qcvtc = opp::optional<ov::op::v0::Convert>({reshape2->output(0)});
auto qmuls = opp::wrap_type<ov::op::v1::Multiply>({qcvtw, qcvtc});
auto qmmi = opp::any_input();
auto qmm = opp::wrap_type<ov::op::v0::MatMul>({qmmi, qmuls});

Expand All @@ -235,8 +236,6 @@ DQMatMulCWi2::DQMatMulCWi2() {
auto matched_node_muls = node_to_output.at(qmuls).get_node_shared_ptr();
auto matched_node_mmi = node_to_output.at(qmmi).get_node_shared_ptr();

// !!! FIXME !!! everything is f32

// Reconnect MatMul to read from Convert(W) directly.
// Note: ACT is f32 so has to be converted too.
auto new_cvt_act = std::make_shared<ov::op::v0::Convert>(matched_node_mmi, ov::element::f32);
Expand Down Expand Up @@ -1539,11 +1538,23 @@ SliceLastMatmulMultiply::SliceLastMatmulMultiply() {
register_matcher(std::make_shared<opp::Matcher>(res, "SliceLastMatmulMultiply"), std::move(callback));
}

// FROM:
// -> Transpose ------------------------------>
// Param --------> Convert(f32) --> Multiply -> Convolution -> Transpose ->
// Param/Const -> (Convert(f32)) ->
//
// TO:
// ------------------------------------------------------>
// Param -------> Reshape --> Convert(f32) --> Multiply -> MatMul ->
// Param/Const -> Reshape -> (Convert(f32)) ->
//

ConvToMatmul::ConvToMatmul(Context::Ref ctx) {
auto param = opp::wrap_type<ov::op::v0::Parameter>();
auto convert = opp::optional<ov::op::v0::Convert>({param->output(0)});
auto param2 = opp::wrap_type<ov::op::v0::Parameter>();
auto multiply = opp::wrap_type<ov::op::v1::Multiply>({convert, param2});
auto convert = opp::wrap_type<ov::op::v0::Convert>({param->output(0)});
auto param2 = opp::any_input();
auto convert2 = opp::optional<ov::op::v0::Convert>({param2->output(0)});
auto multiply = opp::wrap_type<ov::op::v1::Multiply>({convert, convert2});
auto tr_input = opp::any_input();
auto transpose_in = opp::wrap_type<ov::op::v1::Transpose>({tr_input, opp::any_input()});
auto conv = opp::wrap_type<ov::op::v1::Convolution>({transpose_in, multiply});
Expand All @@ -1556,25 +1567,33 @@ ConvToMatmul::ConvToMatmul(Context::Ref ctx) {
auto matched_node_param = node_to_output.at(param).get_node_shared_ptr();
auto matched_node_param2 = node_to_output.at(param2).get_node_shared_ptr();
auto matched_node_convert = node_to_output.at(convert).get_node_shared_ptr();
auto matched_node_tr_input = node_to_output.at(tr_input).get_node_shared_ptr();
auto matched_node_convert2 = uat::_(node_to_output).at_or_at(convert2, param2);
auto matched_node_tr_input = node_to_output.at(tr_input);
auto matched_node_transpose_in = node_to_output.at(transpose_in).get_node_shared_ptr();
auto matched_node_transpose_out = node_to_output.at(transpose_out).get_node_shared_ptr();
auto matched_node_multiply = node_to_output.at(multiply).get_node_shared_ptr();
auto matched_node_conv = node_to_output.at(conv).get_node_shared_ptr();

auto matched_param = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_param);
auto matched_param2 = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_param2);

const auto& shape = matched_param->get_shape();
const auto& shape2 = matched_param2->get_shape();
const auto& shape = matched_node_param->get_shape();
const auto& shape2 = matched_node_param2->get_shape();
const auto& tr_in_shape = matched_node_transpose_in->input(0).get_shape();
const auto& tr_out_shape = matched_node_transpose_out->output(0).get_shape();

if (matched_param->get_element_type() == ov::element::i4 &&
matched_param2->get_element_type() == ov::element::f32 && shape.size() == 4 && shape2.size() == 4 &&
shape[2] == 1 && shape[3] == 1 && shape2[2] == 1 && shape2[3] == 1 && tr_in_shape.size() == 4 &&
tr_out_shape.size() == 4 && tr_in_shape[0] == 1 && tr_in_shape[1] == 1 && tr_out_shape[0] == 1 &&
tr_out_shape[1] == 1) {
auto check_shape = [](const ov::Shape& shape) {
// last 2 dims are 1
return shape.size() == 4 && shape[2] == 1 && shape[3] == 1;
};

auto check_transpose_shape = [](const ov::Shape& shape) {
// first 2 dims are 1
return shape.size() == 4 && shape[0] == 1 && shape[1] == 1;
};

if (((matched_node_param->get_element_type() == ov::element::i4 && matched_node_param2->get_element_type() == ov::element::f32 && ov::op::util::is_parameter(matched_node_param2)) ||
(matched_node_param->get_element_type() == ov::element::i8 && matched_node_param2->get_element_type() == ov::element::f16 && ov::op::util::is_constant(matched_node_param2))) &&
check_shape(shape) && check_shape(shape2) &&
check_transpose_shape(tr_in_shape) && check_transpose_shape(tr_out_shape)) {
// Add Reshape before Params/Const
auto new_dims = std::vector<std::size_t>{shape[0], shape[1]};
auto new_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, new_dims);
auto new_reshape = std::make_shared<ov::op::v1::Reshape>(matched_node_param, new_const, false);
Expand All @@ -1584,71 +1603,18 @@ ConvToMatmul::ConvToMatmul(Context::Ref ctx) {
auto new_dims2 = std::vector<std::size_t>{shape2[0], shape2[1]};
auto new_const2 = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, new_dims2);
auto new_reshape2 = std::make_shared<ov::op::v1::Reshape>(matched_node_param2, new_const2, false);
matched_node_multiply->input(1).replace_source_output(new_reshape2);
matched_node_multiply->validate_and_infer_types();

auto matmul =
std::make_shared<ov::op::v0::MatMul>(matched_node_tr_input, matched_node_multiply, false, true);

for (auto&& r : matched_node_transpose_out->output(0).get_target_inputs()) {
r.replace_source_output(matmul);
// Connect to Reshape
if (ov::op::util::is_parameter(matched_node_param2)) {
matched_node_multiply->input(1).replace_source_output(new_reshape2);
matched_node_multiply->validate_and_infer_types();
} else { // constant -> convert -> multiply
node_to_output.at(convert2).get_node_shared_ptr()->input(0).replace_source_output(new_reshape2);
node_to_output.at(convert2).get_node_shared_ptr()->validate_and_infer_types();
matched_node_multiply->validate_and_infer_types();
}
return true; // root has changed
}
return false; // root hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(transpose_out, "ConvToMatmul"), std::move(callback));
}

ConvToMatmul2::ConvToMatmul2(Context::Ref ctx) {
auto param = opp::wrap_type<ov::op::v0::Parameter>();
auto convert = opp::optional<ov::op::v0::Convert>({param->output(0)});
auto constant = opp::wrap_type<ov::op::v0::Constant>();
auto convert2 = opp::optional<ov::op::v0::Convert>({constant->output(0)});
auto multiply = opp::wrap_type<ov::op::v1::Multiply>({convert, convert2});
auto tr_input = opp::any_input();
auto transpose_in = opp::wrap_type<ov::op::v1::Transpose>({tr_input, opp::any_input()});
auto conv = opp::wrap_type<ov::op::v1::Convolution>({transpose_in, multiply});
auto transpose_out = opp::wrap_type<ov::op::v1::Transpose>({conv, opp::any_input()});

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_node_param = node_to_output.at(param).get_node_shared_ptr();
auto matched_node_constant = node_to_output.at(constant).get_node_shared_ptr();
auto matched_node_convert = node_to_output.at(convert).get_node_shared_ptr();
auto matched_node_tr_input = node_to_output.at(tr_input).get_node_shared_ptr();
auto matched_node_transpose_in = node_to_output.at(transpose_in).get_node_shared_ptr();
auto matched_node_transpose_out = node_to_output.at(transpose_out).get_node_shared_ptr();
auto matched_node_multiply = node_to_output.at(multiply).get_node_shared_ptr();
auto matched_node_conv = node_to_output.at(conv).get_node_shared_ptr();

auto matched_param = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_param);
auto matched_constant = std::static_pointer_cast<ov::op::v0::Constant>(matched_node_constant);

const auto& shape = matched_param->get_shape();
const auto& shape2 = matched_constant->get_shape();
const auto& tr_in_shape = matched_node_transpose_in->input(0).get_shape();
const auto& tr_out_shape = matched_node_transpose_out->output(0).get_shape();

if (matched_param->get_element_type() == ov::element::i8 &&
matched_constant->get_element_type() == ov::element::f32 && shape.size() == 4 && shape2.size() == 4 &&
shape[2] == 1 && shape[3] == 1 && shape2[2] == 1 && shape2[3] == 1 && tr_in_shape.size() == 4 &&
tr_out_shape.size() == 4 && tr_in_shape[0] == 1 && tr_in_shape[1] == 1 && tr_out_shape[0] == 1 &&
tr_out_shape[1] == 1) {
auto new_dims = std::vector<std::size_t>{shape[0], shape[1]};
auto new_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, new_dims);
auto new_reshape = std::make_shared<ov::op::v1::Reshape>(matched_node_param, new_const, false);
matched_node_convert->input(0).replace_source_output(new_reshape);
matched_node_convert->validate_and_infer_types();

auto new_dims2 = std::vector<std::size_t>{shape2[0], shape2[1]};
auto new_const2 = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, new_dims2);
auto new_reshape2 = std::make_shared<ov::op::v1::Reshape>(matched_constant, new_const2, false);
matched_node_multiply->input(1).replace_source_output(new_reshape2);
matched_node_multiply->validate_and_infer_types();

// Get rid of Transposes
auto matmul =
std::make_shared<ov::op::v0::MatMul>(matched_node_tr_input, matched_node_multiply, false, true);

Expand All @@ -1657,10 +1623,9 @@ ConvToMatmul2::ConvToMatmul2(Context::Ref ctx) {
}
return true; // root has changed
}

return false; // root hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(transpose_out, "ConvToMatmul2"), std::move(callback));
register_matcher(std::make_shared<opp::Matcher>(transpose_out, "ConvToMatmul"), std::move(callback));
}

} // namespace opt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,6 @@ class ConvToMatmul : public ov::pass::MatcherPass {
ConvToMatmul(Context::Ref ctx);
};

class ConvToMatmul2 : public ov::pass::MatcherPass {
public:
ConvToMatmul2(Context::Ref ctx);
};

} // namespace opt
} // namespace patterns
} // namespace npuw
Expand Down

0 comments on commit 1cfa0e1

Please sign in to comment.