From 046080f8565ee03422d9e42fd7509ca36483f166 Mon Sep 17 00:00:00 2001 From: Dmitry Matveev Date: Fri, 27 Sep 2024 14:34:09 +0100 Subject: [PATCH] NPUW: Head/tail optimizations (#26633) ### Details: - This PR enables WS and DQ for head/tail - Some efficiency problems were fixed by introducing a host-side Gather (ON by default, can be turned OFF for evaluation purposes) ### Tickets: - E-139867 --- .../al/include/intel_npu/al/config/npuw.hpp | 1 + .../al/include/npuw_private_properties.hpp | 8 + .../intel_npu/src/al/src/config/npuw.cpp | 1 + .../src/plugin/npuw/compiled_model.cpp | 13 + .../src/plugin/npuw/compiled_model.hpp | 2 + .../plugin/npuw/just_sync_infer_request.cpp | 9 + .../plugin/npuw/partitioning/partitioning.cpp | 133 ++++- .../plugin/npuw/partitioning/partitioning.hpp | 8 + .../plugin/npuw/partitioning/patterns/opt.cpp | 563 +++++++++++++++++- .../plugin/npuw/partitioning/patterns/opt.hpp | 70 +++ .../intel_npu/src/plugin/npuw/util.cpp | 136 ++++- .../intel_npu/src/plugin/npuw/util.hpp | 2 + 12 files changed, 902 insertions(+), 44 deletions(-) diff --git a/src/plugins/intel_npu/src/al/include/intel_npu/al/config/npuw.hpp b/src/plugins/intel_npu/src/al/include/intel_npu/al/config/npuw.hpp index ca80ebfd074265..76a3b23259f1f5 100644 --- a/src/plugins/intel_npu/src/al/include/intel_npu/al/config/npuw.hpp +++ b/src/plugins/intel_npu/src/al/include/intel_npu/al/config/npuw.hpp @@ -43,6 +43,7 @@ DEFINE_OPT(NPUW_FOLD, bool, false, npuw::partitioning::fold, CompileTime); DEFINE_OPT(NPUW_CWAI, bool, false, npuw::partitioning::cwai, CompileTime); DEFINE_OPT(NPUW_DQ, bool, false, npuw::partitioning::dyn_quant, CompileTime); DEFINE_OPT(NPUW_PMM, std::string, "2", npuw::partitioning::par_matmul_merge_dims, CompileTime); +DEFINE_OPT(NPUW_HOST_GATHER, bool, true, npuw::partitioning::host_gather, CompileTime); DEFINE_OPT(NPUW_DCOFF_TYPE, std::string, "", npuw::partitioning::dcoff_type, CompileTime); DEFINE_OPT(NPUW_DCOFF_SCALE, bool, false, npuw::partitioning::dcoff_with_scale, CompileTime); DEFINE_OPT(NPUW_FUNCALL_FOR_ALL, bool, false, npuw::partitioning::funcall_for_all, CompileTime); diff --git a/src/plugins/intel_npu/src/al/include/npuw_private_properties.hpp b/src/plugins/intel_npu/src/al/include/npuw_private_properties.hpp index 27e88efb1c5be5..039c3abdf67561 100644 --- a/src/plugins/intel_npu/src/al/include/npuw_private_properties.hpp +++ b/src/plugins/intel_npu/src/al/include/npuw_private_properties.hpp @@ -177,6 +177,14 @@ static constexpr ov::Property dyn_quant{"NPUW_DQ"}; */ static constexpr ov::Property par_matmul_merge_dims{"NPUW_PMM"}; +/** + * @brief + * Type: boolean + * When applicable, do embedding gather on host. + * Default value: true. + */ +static constexpr ov::Property host_gather{"NPUW_HOST_GATHER"}; + /** * @brief * Type: std::string. diff --git a/src/plugins/intel_npu/src/al/src/config/npuw.cpp b/src/plugins/intel_npu/src/al/src/config/npuw.cpp index 32c21ec43c668c..798b5344c4ea62 100644 --- a/src/plugins/intel_npu/src/al/src/config/npuw.cpp +++ b/src/plugins/intel_npu/src/al/src/config/npuw.cpp @@ -28,6 +28,7 @@ void intel_npu::registerNPUWOptions(OptionsDesc& desc) { desc.add(); desc.add(); desc.add(); + desc.add(); desc.add(); desc.add(); desc.add(); diff --git a/src/plugins/intel_npu/src/plugin/npuw/compiled_model.cpp b/src/plugins/intel_npu/src/plugin/npuw/compiled_model.cpp index eea817b52614bd..43cb5ec1aef931 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/compiled_model.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/compiled_model.cpp @@ -18,6 +18,7 @@ #include "openvino/runtime/internal_properties.hpp" #include "openvino/runtime/properties.hpp" #include "openvino/util/common_util.hpp" +#include "partitioning/patterns/opt.hpp" #include "plugin.hpp" #include "util.hpp" @@ -135,6 +136,16 @@ ov::npuw::CompiledModel::CompiledModel(const std::shared_ptr& model, // FIXME: Find a better place to call this transformation ov::pass::ConvertPrecision(ov::element::bf16, ov::element::f16).run_on_model(model); + if (m_cfg.get<::intel_npu::NPUW_FOLD>() && m_cfg.get<::intel_npu::NPUW_FUNCALL_FOR_ALL>()) { + // If there's folding enabled AND non-repeating graphs are forced to be + // functions, do extra lifting for gather (if any) + ov::pass::GraphRewrite rewr; + rewr.add_matcher(); + rewr.add_matcher(); + rewr.add_matcher(); + rewr.run_on_model(model); + } + auto partitioning = getPartitioning(model, m_cfg); m_total_stat.gflops = partitioning.total_gflops; m_total_stat.ops = partitioning.total_ops; @@ -271,6 +282,7 @@ ov::npuw::CompiledModel::CompiledModel(const std::shared_ptr& model, m_compiled_submodels[id].replaced_by = compiled_fcn_iter->second; LOG_INFO("Subgraph[" << id << "] is a function call to [" << compiled_fcn_iter->second << "]"); } + m_compiled_submodels[id].host_gather = subgraph._host_gather; m_compiled_submodels[id].param_base = fcn_template._param_offset; m_compiled_submodels[id].closure = subgraph._closure; m_compiled_submodels[id].scales = subgraph._scales; @@ -799,6 +811,7 @@ void ov::npuw::CompiledModel::implement_properties() { BIND(npuw::partitioning::cwai, NPUW_CWAI), BIND(npuw::partitioning::dyn_quant, NPUW_DQ), BIND(npuw::partitioning::par_matmul_merge_dims, NPUW_PMM), + BIND(npuw::partitioning::host_gather, NPUW_HOST_GATHER), BIND(npuw::partitioning::funcall_for_all, NPUW_FUNCALL_FOR_ALL), BIND(npuw::partitioning::dcoff_type, NPUW_DCOFF_TYPE), BIND(npuw::partitioning::dcoff_with_scale, NPUW_DCOFF_SCALE), diff --git a/src/plugins/intel_npu/src/plugin/npuw/compiled_model.hpp b/src/plugins/intel_npu/src/plugin/npuw/compiled_model.hpp index c2e9a9a6235dd3..1ddaf3f543eaa8 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/compiled_model.hpp +++ b/src/plugins/intel_npu/src/plugin/npuw/compiled_model.hpp @@ -118,6 +118,8 @@ class CompiledModel : public ov::ICompiledModel { // FIXME: This is a 1:1 copy of the ov::npuw::Subgraph structure // w.r.t. function calls + Subgraph::Gather host_gather; + std::size_t param_base = 0; std::vector closure; std::vector scales; diff --git a/src/plugins/intel_npu/src/plugin/npuw/just_sync_infer_request.cpp b/src/plugins/intel_npu/src/plugin/npuw/just_sync_infer_request.cpp index ba1f56b060e0c4..6638fbcbe12a57 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/just_sync_infer_request.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/just_sync_infer_request.cpp @@ -381,6 +381,15 @@ void ov::npuw::JustInferRequest::bind_global_parameters(std::size_t idx) { it.first->copy_to(dst._ptr); }); + // Run host-side gather, if required + if (comp_model_desc.host_gather.dst_idx != -1) { + auto& dst = comp_model_desc.closure[comp_model_desc.host_gather.dst_idx - comp_model_desc.param_base]; + const auto& vocab = comp_model_desc.closure[comp_model_desc.host_gather.src_idx - comp_model_desc.param_base]; + const auto& lport = comp_model_desc.compiled_model->inputs()[comp_model_desc.host_gather.idx_idx]; + const auto lookup = subr->get_tensor(lport); + ov::npuw::util::gather(ov::get_tensor_impl(vocab), lookup, ov::get_tensor_impl(dst)); + } + LOG_DEBUG("Done"); } diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp index cf82694e0601b7..22dfc6e103f719 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp @@ -14,6 +14,7 @@ #include "openvino/op/slice.hpp" #include "openvino/op/util/op_types.hpp" #include "openvino/pass/validate.hpp" +#include "openvino/runtime/make_tensor.hpp" #include "openvino/util/common_util.hpp" #include "openvino/util/xml_parse_utils.hpp" #include "patterns/dcoff.hpp" @@ -1565,16 +1566,59 @@ void Partitioner::optimize(const std::string& func_name) { ov::npuw::Function& f = P.functions.at(func_name); auto& func_group = all_functions.at(func_name); + auto do_permute = [&](ov::npuw::patterns::opt::Context& ctx) { + for (auto&& p : ctx.closures_to_permute) { + auto param_idx = f._model->get_parameter_index(p.first); + auto closure_idx = param_idx - f._param_offset; + ov::parallel_for(func_group.refs.size(), [&](std::size_t f_idx) { + auto& funcall = func_group.refs[f_idx].get(); + ov::npuw::util::permute(funcall._closure[closure_idx], p.second); + }); + } + }; + auto do_cvtf16 = [&](ov::npuw::patterns::opt::Context& ctx) { + for (auto&& p : ctx.closures_to_f16) { + auto param_idx = f._model->get_parameter_index(p); + auto closure_idx = param_idx - f._param_offset; + ov::parallel_for(func_group.refs.size(), [&](std::size_t f_idx) { + auto& funcall = func_group.refs[f_idx].get(); + ov::npuw::util::to_f16(funcall._closure[closure_idx]); + }); + } + }; + // Regardless of DQ setting, run this first { ov::npuw::patterns::opt::Context ctx; ctx.pmm_dims = cfg.get<::intel_npu::NPUW_PMM>(); + + // Run Head/Tail passes + ov::pass::GraphRewrite rewr; + rewr.add_matcher(std::ref(ctx)); + rewr.add_matcher(std::ref(ctx)); + rewr.add_matcher(std::ref(ctx)); + // NB: This pass is disabled for reason! It doesn't make things better + // rewr.add_matcher(std::ref(ctx)); + rewr.add_matcher(std::ref(ctx)); + rewr.add_matcher(std::ref(ctx)); + rewr.run_on_model(f._model); + + // Move Gather to host, if required + if (cfg.get<::intel_npu::NPUW_HOST_GATHER>()) { + ov::pass::GraphRewrite rewr2; + rewr2.add_matcher(std::ref(ctx)); + rewr2.add_matcher(std::ref(ctx)); + rewr2.run_on_model(f._model); + } + + // Run parallel matmul merge mergeParallelMatMuls(f._model, ctx); - // Concatenate closures for "concatenated" parameters ov::ParameterVector new_params; std::vector to_remove; std::set to_remove_idx; + + // Concatenate closures for "concatenated" parameters for (auto&& p : ctx.params_to_concat) { new_params.push_back(p.first); const auto& params_to_concat = p.second.first; @@ -1596,6 +1640,59 @@ void Partitioner::optimize(const std::string& func_name) { funcall._closure.push_back(ov::npuw::util::concat(to_concat, axis)); }); } + + // Unpack closures in compile time, where requested + for (auto&& p : ctx.params_to_unpack) { + const auto& tensor_to_unpack = p.second; + auto w_idx = f._model->get_parameter_index(tensor_to_unpack.w); + auto z_idx = f._model->get_parameter_index(tensor_to_unpack.z); + auto s_idx = f._model->get_parameter_index(tensor_to_unpack.s); + + new_params.push_back(p.first); + to_remove.push_back(tensor_to_unpack.w); + to_remove.push_back(tensor_to_unpack.s); + to_remove_idx.insert(w_idx); + to_remove_idx.insert(s_idx); + + if (tensor_to_unpack.z) { + to_remove.push_back(tensor_to_unpack.z); + to_remove_idx.insert(z_idx); + } + + ov::parallel_for(func_group.refs.size(), [&](std::size_t f_idx) { + auto& funcall = func_group.refs[f_idx].get(); + ov::Tensor cw = funcall._closure[w_idx - f._param_offset]; + ov::Tensor cz = z_idx != -1 ? funcall._closure[z_idx - f._param_offset] : ov::Tensor{}; + ov::Tensor cs = funcall._closure[s_idx - f._param_offset]; + ov::Tensor dst(p.first->get_element_type(), p.first->get_shape()); + + const auto& gti = ov::get_tensor_impl; + if (cw && cz && cs) { + ov::npuw::util::unpack(gti(cw), gti(cz), gti(cs), gti(dst)); + } else if (cw && cs) { + ov::npuw::util::unpack(gti(cw), gti(cs), gti(dst)); + } else { + NPUW_ASSERT(false && "Unsupported combination"); + } + funcall._closure.push_back(std::move(dst)); + }); + } + + // Convert parameters to f16 where required + do_cvtf16(ctx); + + // Host-side gather, pt 1. Add new parameters first + if (ctx.params_to_gather) { + auto& params_to_gather = *ctx.params_to_gather; + new_params.push_back(params_to_gather.pnew); + for (auto&& funcall : func_group.refs) { + auto new_elem_type = params_to_gather.pnew->get_element_type(); + auto new_shape = params_to_gather.pnew->get_shape(); + funcall.get()._closure.push_back(ov::Tensor(new_elem_type, new_shape)); + } + } + + // Add all new parameters introduced by this change f._model->add_parameters(new_params); // Remove parameters and closures that were concatenated @@ -1613,7 +1710,19 @@ void Partitioner::optimize(const std::string& func_name) { for (auto&& now_remove : to_remove) { f._model->remove_parameter(now_remove); } + f._model->validate_nodes_and_infer_types(); + + // Host-side gather, pt. 2: Write the gather mappings to funcall + if (ctx.params_to_gather) { + auto& params_to_gather = *ctx.params_to_gather; + auto gather_dst_id = f._model->get_parameter_index(params_to_gather.pnew); + auto gather_src_id = f._model->get_parameter_index(params_to_gather.pold); + auto gather_idx_id = f._model->get_parameter_index(params_to_gather.pids); + for (auto&& funcall : func_group.refs) { + funcall.get()._host_gather = ov::npuw::Subgraph::Gather{gather_dst_id, gather_src_id, gather_idx_id}; + } + } } if (!cfg.get<::intel_npu::NPUW_DQ>()) { @@ -1625,6 +1734,7 @@ void Partitioner::optimize(const std::string& func_name) { LOG_VERB("Optimize function " << func_name << " in model " << model->get_friendly_name() << "..."); LOG_BLOCK(); + // Run "dynamic quantization" ov::npuw::patterns::opt::Context ctx; ov::pass::GraphRewrite rewr; rewr.add_matcher(); @@ -1635,25 +1745,8 @@ void Partitioner::optimize(const std::string& func_name) { rewr.run_on_model(f._model); ov::pass::Validate().run_on_model(f._model); - // Permute tensors where required - for (auto&& p : ctx.closures_to_permute) { - auto param_idx = f._model->get_parameter_index(p.first); - auto closure_idx = param_idx - f._param_offset; - ov::parallel_for(func_group.refs.size(), [&](std::size_t f_idx) { - auto& funcall = func_group.refs[f_idx].get(); - ov::npuw::util::permute(funcall._closure[closure_idx], p.second); - }); - } - - // Convert tensors where required - for (auto&& p : ctx.closures_to_f16) { - auto param_idx = f._model->get_parameter_index(p); - auto closure_idx = param_idx - f._param_offset; - ov::parallel_for(func_group.refs.size(), [&](std::size_t f_idx) { - auto& funcall = func_group.refs[f_idx].get(); - ov::npuw::util::to_f16(funcall._closure[closure_idx]); - }); - } + do_permute(ctx); + do_cvtf16(ctx); LOG_VERB("Done"); } diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.hpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.hpp index ae8e16546d4272..35c4eacfeffe8b 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.hpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.hpp @@ -41,6 +41,14 @@ struct Subgraph { std::vector _scales; // Scale coeffs for manual unpacking std::vector _zerops; // Zero points for manual unpacking + struct Gather { + // NB.: int64_t is strange but it is used by OV to refer to parameters + int64_t dst_idx = -1; + int64_t src_idx = -1; + int64_t idx_idx = -1; + }; + Gather _host_gather; + using Ref = std::reference_wrapper; }; diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp index 7fab6298bc989f..4ec72e02260884 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp @@ -7,17 +7,17 @@ #include "../../logging.hpp" #include "../../util.hpp" #include "openvino/op/add.hpp" +#include "openvino/op/broadcast.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/convert.hpp" -#include "openvino/op/divide.hpp" +#include "openvino/op/gather.hpp" #include "openvino/op/matmul.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/reduce_sum.hpp" #include "openvino/op/reshape.hpp" #include "openvino/op/slice.hpp" #include "openvino/op/split.hpp" -#include "openvino/op/squeeze.hpp" -#include "openvino/op/unsqueeze.hpp" +#include "openvino/op/subtract.hpp" #include "openvino/op/util/op_types.hpp" #include "openvino/pass/pattern/op/label.hpp" // any_input #include "openvino/pass/pattern/op/optional.hpp" @@ -35,6 +35,9 @@ void Context::permute(PPtr orig_param, const Context::Axes& order) { void Context::to_f16(PPtr orig_param) { closures_to_f16.insert(orig_param); + + orig_param->set_element_type(ov::element::f16); + orig_param->validate_and_infer_types(); } void Context::register_parallel_matmul(O multiply, std::size_t axis, DQParMM&& mm) { @@ -66,6 +69,50 @@ Context::PPtr Context::concat(ov::ParameterVector&& v, std::size_t dim) { return new_param; } +Context::PPtr Context::unpack(Context::PPtr w, Context::PPtr z, Context::PPtr s, ov::element::Type type) { + // FIXME: Assume CW only + NPUW_ASSERT(w->get_shape().size() == 2); + NPUW_ASSERT(z->get_shape().size() == 2); + NPUW_ASSERT(s->get_shape().size() == 2); + auto new_param = std::make_shared(type, w->get_shape()); + params_to_unpack[new_param] = {w, z, s}; + return new_param; +} + +Context::PPtr Context::unpack(Context::PPtr w, Context::PPtr s, ov::element::Type type) { + const auto w_shape = w->get_shape(); + const auto s_shape = s->get_shape(); + + Context::PPtr new_param; + if (w_shape.size() == 3 && s_shape.size() == 3) { + // Assume already reshaped tensor (as it does with unpack) + ov::Shape new_shape = {w_shape[0], w_shape[1] * w_shape[2]}; + new_param = std::make_shared(type, new_shape); + } else if (w_shape.size() == 2 && s_shape.size() == 2) { + new_param = std::make_shared(type, w_shape); + } else { + NPUW_ASSERT(false && "Yet unsupported combination"); + } + + NPUW_ASSERT(new_param); + params_to_unpack[new_param] = {w, {}, s}; + return new_param; +} + +Context::PPtr Context::host_gather(Context::PPtr w, Context::PPtr ids) { + const auto w_shape = w->get_shape(); + const auto ids_shape = ids->get_shape(); + + NPUW_ASSERT(w_shape.size() == 2); + NPUW_ASSERT(ids_shape.size() == 2); + NPUW_ASSERT(ids_shape[0] == 1); + + ov::Shape new_shape = {1, ids_shape[1], w_shape[1]}; + auto new_param = std::make_shared(w->get_element_type(), new_shape); + params_to_gather = Gather{new_param, w, ids}; + return new_param; +} + namespace opp = ov::pass::pattern; // FROM: @@ -207,8 +254,6 @@ DQMatMulGQi::DQMatMulGQi(Context::Ref ctx) { ctx.get().permute(matched_qweight, {0, 2, 1}); // Mark S closure to be lowered fo f16 - matched_qcoeff->set_element_type(ov::element::f16); - matched_qcoeff->validate_and_infer_types(); ctx.get().to_f16(matched_qcoeff); // Reshape the Act to group format @@ -614,6 +659,8 @@ DQMatMulGQ2iP::DQMatMulGQ2iP(Context::Ref ctx) { register_matcher(std::make_shared(qmm, "OptDQMatMulGQ2iP"), std::move(callback)); } +//////////////////////////////////////////////////////////////////////////////// +// Parallel matmuls // Identifies this pattern // // Multiply -----------------------------------> MatMul @@ -656,10 +703,6 @@ DQParMMGQ::DQParMMGQ(Context::Ref ctx) { } void mergeParallelMatMuls(const std::shared_ptr& m, Context& ctx) { - ov::pass::GraphRewrite rewr; - rewr.add_matcher(std::ref(ctx)); - rewr.run_on_model(m); - for (auto&& mul_to_mms : ctx.par_dq_mms) { auto& parallel_matmuls = mul_to_mms.second; if (parallel_matmuls.size() < 2) { @@ -747,6 +790,508 @@ void mergeParallelMatMuls(const std::shared_ptr& m, Context& ctx) { } } +//////////////////////////////////////////////////////////////////////////////// +// Head/tail (Gather + Vocab) + +// Identify a Gather+DQ Asym CW MatMul pattern, lift Gather up +// Note: this pattern is applied on the full model before any partitioning +DQLiftGatherAsymCW::DQLiftGatherAsymCW() { + auto qweight = opp::wrap_type(); + auto qzerop = opp::wrap_type(); + auto qcoeff = opp::wrap_type(); + auto qcvtw = opp::wrap_type({qweight}); + auto qcvtz = opp::wrap_type({qzerop}); + auto qsubz = opp::wrap_type({qcvtw, qcvtz}); + auto qmuls = opp::wrap_type({qsubz, qcoeff}); + auto qcvtm = opp::wrap_type({qmuls}); + + auto pids = opp::wrap_type(); + auto cvtids = opp::wrap_type({pids}); + auto gather = opp::wrap_type({qcvtm, cvtids, opp::any_input()}); + + // Note: Use [=] to make sure the above objects stay alive in the callback + auto callback = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + + // Create new gathers on W, Z, and S respectively + auto matched_out_w = node_to_output.at(qweight); + auto matched_out_z = node_to_output.at(qzerop); + auto matched_out_s = node_to_output.at(qcoeff); + auto matched_out_ids = node_to_output.at(cvtids); + auto matched_out_gather = node_to_output.at(gather); + + // Replicate the compute part + auto gather_c = std::make_shared(ov::element::i32, ov::Shape{}, 0); + auto new_g_w = std::make_shared(matched_out_w, matched_out_ids, gather_c); + auto new_g_z = std::make_shared(matched_out_z, matched_out_ids, gather_c); + auto new_g_s = std::make_shared(matched_out_s, matched_out_ids, gather_c); + + auto new_cvt_w = std::make_shared(new_g_w, ov::element::f16); + auto new_cvt_z = std::make_shared(new_g_z, ov::element::f16); + auto new_sub = std::make_shared(new_cvt_w, new_cvt_z); + auto new_mul = std::make_shared(new_sub, new_g_s); + auto new_out = std::make_shared(new_mul, ov::element::f32); + + // Reconnect old gather readers to the new Multiply + for (auto&& r : matched_out_gather.get_target_inputs()) { + r.replace_source_output(new_out); + } + return true; // root was changed + }; + register_matcher(std::make_shared(gather, "DQGatherAsymCW"), std::move(callback)); +} + +// Identify a Gather+DQ Sym CW MatMul pattern, lift Gather up +// Note: this pattern is applied on the full model before any partitioning +DQLiftGatherSymCW::DQLiftGatherSymCW() { + auto qweight = opp::wrap_type(); + auto qcoeff = opp::wrap_type(); + auto qcvtw = opp::wrap_type({qweight}); + auto qmuls = opp::wrap_type({qcvtw, qcoeff}); + auto qcvtm = opp::wrap_type({qmuls}); + + auto pids = opp::wrap_type(); + auto cvtids = opp::wrap_type({pids}); + auto gather = opp::wrap_type({qcvtm, cvtids, opp::any_input()}); + + // Note: Use [=] to make sure the above objects stay alive in the callback + auto callback = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + + auto matched_out_w = node_to_output.at(qweight); + auto matched_out_s = node_to_output.at(qcoeff); + auto matched_out_ids = node_to_output.at(cvtids); + auto matched_out_gather = node_to_output.at(gather); + + // Create new gathers on W and S, connect respectively + auto new_cvt_w = std::make_shared(matched_out_w, ov::element::f16); + auto gather_c = std::make_shared(ov::element::i32, ov::Shape{}, 0); + auto new_g_w = std::make_shared(new_cvt_w, matched_out_ids, gather_c); + auto new_g_s = std::make_shared(matched_out_s, matched_out_ids, gather_c); + auto new_mul = std::make_shared(new_g_w, new_g_s); + auto new_out = std::make_shared(new_mul, ov::element::f32); + + // Reconnect old gather readers to the new Multiply + for (auto&& r : matched_out_gather.get_target_inputs()) { + r.replace_source_output(new_out); + } + return true; // root was changed + }; + register_matcher(std::make_shared(gather, "DQGatherSymCW"), std::move(callback)); +} + +// Identify a Gather+DQ Sym GQ MatMul pattern, lift Gather up +// Note(1): this pattern is applied on the full model before any partitioning +// Note(2): here's a difference, the new lifted Gathers stay behind Convert(W) & Convert(S) +DQLiftGatherSymGQ::DQLiftGatherSymGQ() { + auto qweight = opp::wrap_type(); + auto qcoeff = opp::wrap_type(); + auto qcvtw = opp::wrap_type({qweight}); + auto qmuls = opp::wrap_type({qcvtw, qcoeff}); + auto qreshp = opp::wrap_type({qmuls, opp::any_input()}); + auto qcvtm = opp::wrap_type({qreshp}); + + auto pids = opp::wrap_type(); + auto cvtids = opp::wrap_type({pids}); + auto gather = opp::wrap_type({qcvtm, cvtids, opp::any_input()}); + + // Note: Use [=] to make sure the above objects stay alive in the callback + auto callback = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + + // Create new gathers on W and S respectively + auto matched_out_w = node_to_output.at(qweight); + auto matched_out_s = node_to_output.at(qcoeff); + auto matched_out_ids = node_to_output.at(cvtids); + auto matched_out_gather = node_to_output.at(gather); + + auto matched_gather_shape = matched_out_gather.get_shape(); + + // Replicate the compute part + auto new_cvt_w = std::make_shared(matched_out_w, ov::element::f16); + + auto gather_c = std::make_shared(ov::element::i32, ov::Shape{}, 0); + auto new_g_w = std::make_shared(new_cvt_w, matched_out_ids, gather_c); + auto new_g_s = std::make_shared(matched_out_s, matched_out_ids, gather_c); + + auto new_mul = std::make_shared(new_g_w, new_g_s); + + auto new_rshp_c = std::make_shared(ov::element::i32, + ov::Shape{matched_gather_shape.size()}, + matched_gather_shape); + auto new_reshape = std::make_shared(new_mul, new_rshp_c, false); + + auto new_out = std::make_shared(new_reshape, ov::element::f32); + + // Reconnect old gather readers to the new Multiply + for (auto&& r : matched_out_gather.get_target_inputs()) { + r.replace_source_output(new_out); + } + return true; // root was changed + }; + register_matcher(std::make_shared(gather, "DQGatherSymGQ"), std::move(callback)); +} + +// This is a companion to DQLiftGatherAsymCW step. This pass runs if +// the respective block (mainly, a head) was turned a function +// (e.g. with FUNCALL_FOR_ALL) As in this case the DQDictMatMulCWu +// compile-time converts asymmetric MM to fp16, do the same thing here +DQUnpackDictGatherCWu::DQUnpackDictGatherCWu(Context::Ref ctx) { + auto pids = opp::wrap_type(); + auto cvtids = opp::wrap_type({pids}); + + auto qweight = opp::wrap_type(); + auto qzerop = opp::wrap_type(); + auto qcoeff = opp::wrap_type(); + auto qgthrw = opp::wrap_type({qweight, cvtids, opp::any_input()}); + auto qgthrz = opp::wrap_type({qzerop, cvtids, opp::any_input()}); + auto qgthrs = opp::wrap_type({qcoeff, cvtids, opp::any_input()}); + + auto qcvtw = opp::wrap_type({qgthrw}); + auto qcvtz = opp::wrap_type({qgthrz}); + auto qsubz = opp::wrap_type({qcvtw, qcvtz}); + auto qmuls = opp::wrap_type({qsubz, qgthrs}); + auto qcvtm = opp::wrap_type({qmuls}); + + // Note: Use [=] to make sure the above objects stay alive in the callback + auto callback = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + + auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr(); + auto matched_node_qzerop = node_to_output.at(qzerop).get_node_shared_ptr(); + auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr(); + auto matched_out_ids = node_to_output.at(cvtids); + auto matched_node_cvt = node_to_output.at(qcvtm).get_node_shared_ptr(); + + auto matched_qweight = std::static_pointer_cast(matched_node_qweight); + auto matched_qzerop = std::static_pointer_cast(matched_node_qzerop); + auto matched_qcoeff = std::static_pointer_cast(matched_node_qcoeff); + + // Strip down the DQ subgraph, replace the original Q-ed closure tensor with unpacked fp16 + auto new_wi = ctx.get().unpack(matched_qweight, matched_qzerop, matched_qcoeff, ov::element::f16); + auto gather_c = std::make_shared(ov::element::i32, ov::Shape{}, 0); + auto new_g = std::make_shared(new_wi, matched_out_ids, gather_c); + + matched_node_cvt->input(0).replace_source_output(new_g); + + return true; // root has changed + }; + register_matcher(std::make_shared(qcvtm, "DQDictGatherCWu"), std::move(callback)); +} + +// This is a follow-up to DQLiftGatherSymGQ step, which happens if the respective +// block (mainly, a head) was turned a function (e.g. with FUNCALL_FOR_ALL) +DQUnpackDictGatherGQi::DQUnpackDictGatherGQi(Context::Ref ctx) { + auto pids = opp::wrap_type(); + auto cvtids = opp::wrap_type({pids}); + + auto qweight = opp::wrap_type(); + auto qcoeff = opp::wrap_type(); + auto qgthrw = opp::wrap_type({qweight, cvtids, opp::any_input()}); + auto qgthrs = opp::wrap_type({qcoeff, cvtids, opp::any_input()}); + + auto qcvtw = opp::wrap_type({qgthrw}); + auto qmuls = opp::wrap_type({qcvtw, qgthrs}); + auto qrshp = opp::wrap_type({qmuls, opp::any_input()}); + auto qcvtm = opp::wrap_type({qrshp}); + + // Note: Use [=] to make sure the above objects stay alive in the callback + auto callback = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + + auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr(); + auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr(); + auto matched_out_ids = node_to_output.at(cvtids); + auto matched_node_cvt = node_to_output.at(qcvtm).get_node_shared_ptr(); + + auto matched_qweight = std::static_pointer_cast(matched_node_qweight); + auto matched_qcoeff = std::static_pointer_cast(matched_node_qcoeff); + + // Strip down the DQ subgraph, replace the original Q-ed closure tensor with unpacked fp16 + auto new_wi = ctx.get().unpack(matched_qweight, matched_qcoeff, ov::element::f16); + + auto gather_c = std::make_shared(ov::element::i32, ov::Shape{}, 0); + auto new_g = std::make_shared(new_wi, matched_out_ids, gather_c); + matched_node_cvt->input(0).replace_source_output(new_g); + + return true; // root has changed + }; + register_matcher(std::make_shared(qcvtm, "DQDictGatherCWu"), std::move(callback)); +} + +// Identify the case* where the FP16/32 vocab tensor is gathered with +// input_ids and the embedding size is high. In this case, substitute +// gather with a host-side op. Lower vocab tensor to f16. +// * - This case normally happens as a result of other +// * - DictGather-related transformations +HostGather::HostGather(Context::Ref ctx) { + auto pids = opp::wrap_type(); + auto cvtids = opp::wrap_type({pids}); + + auto qweight = opp::wrap_type(); + auto qgthrw = opp::wrap_type({qweight, cvtids, opp::any_input()}); + + auto callback = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + auto out_shape = node_to_output.at(qgthrw).get_shape(); + auto matched_out_qweight = node_to_output.at(qweight); + auto qweight_type = matched_out_qweight.get_element_type(); + + auto matched_out_gather = node_to_output.at(qgthrw); + + auto sole_reader = [](ov::Output out) { + const auto readers = out.get_target_inputs(); + NPUW_ASSERT(readers.size() >= 1); + return readers.begin()->get_node(); + }; + + if (out_shape.back() >= 2048 && (qweight_type == ov::element::f16 || qweight_type == ov::element::f32) && + (matched_out_gather.get_target_inputs().size() > 1 || + ov::is_type(sole_reader(matched_out_gather)))) { + auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr(); + auto matched_node_ids = node_to_output.at(pids).get_node_shared_ptr(); + auto matched_out_gthr = node_to_output.at(qgthrw); + auto matched_qweight = std::static_pointer_cast(matched_node_qweight); + auto matched_ids = std::static_pointer_cast(matched_node_ids); + + if (qweight_type == ov::element::f32) { + ctx.get().to_f16(matched_qweight); + } + auto new_param = ctx.get().host_gather(matched_qweight, matched_ids); + std::shared_ptr new_cvt; + if (qweight_type == ov::element::f16) { + new_cvt = new_param; + } else { + new_cvt = std::make_shared(new_param, ov::element::f32); + } + NPUW_ASSERT(new_cvt); + for (auto&& r : matched_out_gthr.get_target_inputs()) { + r.replace_source_output(new_cvt); + } + return true; // Root has changed + } + return false; // Root hasn't changed (yet) + }; + register_matcher(std::make_shared(qgthrw, "HostGather"), std::move(callback)); +} + +// Identify the case* where the gather is applied on a compressed +// (symmetric) vocab tensor. Both CW and GQ paths are supported. +// +// FIXME: This may be inefficient: 4x-es the memory consumption +// due to i4-to-fp16 conversion. +HostGatherDQ::HostGatherDQ(Context::Ref ctx) { + auto pids = opp::wrap_type(); + auto cvtids = opp::wrap_type({pids}); + + auto qweight = opp::wrap_type(); + auto qcvtw = opp::wrap_type({qweight}); + auto qcoeff = opp::wrap_type(); + + auto qgthrw = opp::wrap_type({qcvtw, cvtids, opp::any_input()}); + auto qgthrc = opp::wrap_type({qcoeff, cvtids, opp::any_input()}); + auto qmul = opp::wrap_type({qgthrw, qgthrc}); + + auto callback = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + auto matched_out_mul = node_to_output.at(qmul); + auto out_shape = matched_out_mul.get_shape(); + + if (out_shape.size() != 3 && out_shape.size() != 4) { + return false; + } + + // shape=3 == CW model, 1 x N x Hs + // shape=4 == GQ model, 1 x G x(N/G) x Hs + // were Hs = hidden size, G is # of groups, N is the prompt size. + auto out_len = out_shape.size() == 3 ? out_shape[2] : out_shape[2] * out_shape[3]; + + auto matched_out_qweight = node_to_output.at(qweight); + auto qweight_type = matched_out_qweight.get_element_type(); + + if (out_len >= 2048 && qweight_type == ov::element::i4) { + auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr(); + auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr(); + auto matched_node_ids = node_to_output.at(pids).get_node_shared_ptr(); + + auto matched_qweight = std::static_pointer_cast(matched_node_qweight); + auto matched_qcoeff = std::static_pointer_cast(matched_node_qcoeff); + auto matched_ids = std::static_pointer_cast(matched_node_ids); + + auto fp16vocab = ctx.get().unpack(matched_qweight, matched_qcoeff, ov::element::f16); + auto new_param = ctx.get().host_gather(fp16vocab, matched_ids); + for (auto&& r : matched_out_mul.get_target_inputs()) { + r.replace_source_output(new_param); + } + return true; // Root has changed + } + return false; // Root hasn't changed (yet) + }; + register_matcher(std::make_shared(qmul, "HostGatherDQ"), std::move(callback)); +} + +// FROM: +// Param(W) -> to(f16) -> +// Param(Z) -> to(f16) -> Subtract +// Param(S) ---------------------> Multiply -> to(f32) -> MatMul -> Result +// ???(Act) --------------------------------------------> +// +// TO: +// Param(W) ------------> +// ???(Act) -> to(f16) -> MatMul -> to(f32) -> Result + +DQUnpackDictMatMulCWu::DQUnpackDictMatMulCWu(Context::Ref ctx) { + auto qweight = opp::wrap_type(); + auto qcoeff = opp::wrap_type(); + auto qzerop = opp::wrap_type(); + auto qcvtw = opp::wrap_type({qweight}); + auto qcvtz = opp::wrap_type({qzerop}); + auto qsub = opp::wrap_type({qcvtw, qcvtz}); + auto qmuls = opp::wrap_type({qsub, qcoeff}); + auto qcvtm = opp::wrap_type({qmuls}); + auto qmmi = opp::any_input(); + auto qmm = opp::wrap_type({qmmi, qcvtm}); + auto qres = opp::wrap_type({qmm}); + + // Note: Use [=] to make sure the above objects stay alive in the callback + auto callback = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + + auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr(); + auto matched_node_qzerop = node_to_output.at(qzerop).get_node_shared_ptr(); + auto matched_node_cvtw = node_to_output.at(qcvtw).get_node_shared_ptr(); + auto matched_node_cvtz = node_to_output.at(qcvtz).get_node_shared_ptr(); + auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr(); + auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr(); + auto matched_mmi = node_to_output.at(qmmi); + auto matched_node_res = node_to_output.at(qres).get_node_shared_ptr(); + + auto matched_qweight = std::static_pointer_cast(matched_node_qweight); + auto matched_qzerop = std::static_pointer_cast(matched_node_qzerop); + auto matched_qcoeff = std::static_pointer_cast(matched_node_qcoeff); + auto matched_matmul = std::static_pointer_cast(matched_node_matmul); + auto matched_result = std::static_pointer_cast(matched_node_res); + + auto qcoeff_shape = matched_qcoeff->output(0).get_shape(); + auto qzerop_shape = matched_qzerop->output(0).get_shape(); + auto act_shape = matched_mmi.get_shape(); + + if (ov::element::u8 == matched_qweight->get_element_type() && qcoeff_shape[1] == 1 && + !matched_matmul->get_transpose_a() && matched_matmul->get_transpose_b()) { + auto new_cvt_a = std::make_shared(matched_mmi, ov::element::f16); + + auto new_wi = ctx.get().unpack(matched_qweight, matched_qzerop, matched_qcoeff, ov::element::f16); + auto new_mm = std::make_shared(new_cvt_a, new_wi, false, true); + auto new_out = std::make_shared(new_mm, ov::element::f32); + + matched_result->input(0).replace_source_output(new_out); + } + return false; // root has changed (yet) + }; + register_matcher(std::make_shared(qres, "OptDQDictMatMulCWu"), std::move(callback)); +} + +// FROM: +// Param(W) -> to(f16) -> +// Param(S) ------------> Multiply -> Reshape -> to(f32) -> MatMul -> Result +// ???(Act) ----------------------------------------------> +// +// TO: +// Param(W) ------------> +// ???(Act) -> to(f16) -> MatMul -> to(f32) -> Result +// NB: This pass only worsens the performance so is disabled +DQUnpackDictMatMulGQi::DQUnpackDictMatMulGQi(Context::Ref ctx) { + auto qweight = opp::wrap_type(); + auto qcoeff = opp::wrap_type(); + auto qcvtw = opp::wrap_type({qweight}); + auto qmuls = opp::wrap_type({qcvtw, qcoeff}); + auto qreshp = opp::wrap_type({qmuls, opp::any_input()}); + auto qcvtm = opp::wrap_type({qreshp}); + auto qcvtr = opp::wrap_type({qreshp}); + auto qmmi = opp::any_input(); + auto qmm = opp::wrap_type({qmmi, qcvtr}); + auto qres = opp::wrap_type({qmm}); + + // Note: Use [=] to make sure the above objects stay alive in the callback + auto callback = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + + auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr(); + auto matched_node_cvtw = node_to_output.at(qcvtw).get_node_shared_ptr(); + auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr(); + auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr(); + auto matched_mmi = node_to_output.at(qmmi); + auto matched_node_res = node_to_output.at(qres).get_node_shared_ptr(); + + auto matched_qweight = std::static_pointer_cast(matched_node_qweight); + auto matched_qcoeff = std::static_pointer_cast(matched_node_qcoeff); + auto matched_matmul = std::static_pointer_cast(matched_node_matmul); + auto matched_result = std::static_pointer_cast(matched_node_res); + + auto qcoeff_shape = matched_qcoeff->output(0).get_shape(); + auto act_shape = matched_mmi.get_shape(); + + if (ov::element::i4 == matched_qweight->get_element_type() && qcoeff_shape.size() == 3) { + auto new_cvt_a = std::make_shared(matched_mmi, ov::element::f16); + + auto new_wi = ctx.get().unpack(matched_qweight, matched_qcoeff, ov::element::f16); + auto new_mm = std::make_shared(new_cvt_a, + new_wi, + matched_matmul->get_transpose_a(), + matched_matmul->get_transpose_b()); + auto new_out = std::make_shared(new_mm, ov::element::f32); + + matched_result->input(0).replace_source_output(new_out); + } + return false; // root has changed (yet) + }; + register_matcher(std::make_shared(qres, "OptDQDictMatMulGQi"), std::move(callback)); +} + +// FROM: +// Param(W):f32 -> +// ???(Act) -----> MatMul -> Result +// +// TO: +// Param(W):f16 --------> +// ???(Act) -> to(f16) -> MatMul -> to(f32) -> Result +// NB: This pass only worsens the performance so is disabled +CompressDictMatMulf32::CompressDictMatMulf32(Context::Ref ctx) { + auto weight = opp::wrap_type(); + auto mmi = opp::any_input(); + auto mm = opp::wrap_type({mmi, weight}); + auto res = opp::wrap_type({mm}); + + // Note: Use [=] to make sure the above objects stay alive in the callback + auto callback = [=](ov::pass::pattern::Matcher& m) { + auto& node_to_output = m.get_pattern_value_map(); + + auto matched_node_weight = node_to_output.at(weight).get_node_shared_ptr(); + auto matched_node_matmul = node_to_output.at(mm).get_node_shared_ptr(); + auto matched_mmi = node_to_output.at(mmi); + auto matched_node_res = node_to_output.at(res).get_node_shared_ptr(); + + auto matched_weight = std::static_pointer_cast(matched_node_weight); + auto matched_matmul = std::static_pointer_cast(matched_node_matmul); + auto matched_result = std::static_pointer_cast(matched_node_res); + + if (ov::element::f32 == matched_weight->get_element_type()) { + auto new_cvt_a = std::make_shared(matched_mmi, ov::element::f16); + + ctx.get().to_f16(matched_weight); + auto new_mm = std::make_shared(new_cvt_a, + matched_weight, + matched_matmul->get_transpose_a(), + matched_matmul->get_transpose_b()); + auto new_out = std::make_shared(new_mm, ov::element::f32); + + matched_result->input(0).replace_source_output(new_out); + } + return false; // root has changed (yet) + }; + register_matcher(std::make_shared(res, "OptCompressDictMatMulf32"), std::move(callback)); +} + } // namespace opt } // namespace patterns } // namespace npuw diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.hpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.hpp index b51b32df23f2a2..530d0a52cc515f 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.hpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.hpp @@ -6,6 +6,7 @@ #include #include +#include #include #include "openvino/openvino.hpp" @@ -48,6 +49,19 @@ struct Context { std::map> params_to_concat; PPtr concat(ov::ParameterVector&& v, std::size_t dim); + struct DQUnpack { + PPtr w, z, s; + }; + std::map params_to_unpack; + PPtr unpack(PPtr w, PPtr z, PPtr s, ov::element::Type type); + PPtr unpack(PPtr w, PPtr s, ov::element::Type type); + + struct Gather { + PPtr pnew, pold, pids; + }; + std::optional params_to_gather; + PPtr host_gather(PPtr w, PPtr ids); + using Ref = std::reference_wrapper; }; @@ -78,6 +92,62 @@ class DQParMMGQ : public ov::pass::MatcherPass { void mergeParallelMatMuls(const std::shared_ptr& m, Context& ctx); +// Gather-related passes + +class DQLiftGatherAsymCW : public ov::pass::MatcherPass { +public: + DQLiftGatherAsymCW(); +}; + +class DQLiftGatherSymCW : public ov::pass::MatcherPass { +public: + DQLiftGatherSymCW(); +}; + +class DQLiftGatherSymGQ : public ov::pass::MatcherPass { +public: + DQLiftGatherSymGQ(); +}; + +// Head vocab unpacks + +class DQUnpackDictGatherCWu : public ov::pass::MatcherPass { +public: + DQUnpackDictGatherCWu(Context::Ref ctx); +}; + +class DQUnpackDictGatherGQi : public ov::pass::MatcherPass { +public: + DQUnpackDictGatherGQi(Context::Ref ctx); +}; + +class HostGather : public ov::pass::MatcherPass { +public: + HostGather(Context::Ref ctx); +}; + +class HostGatherDQ : public ov::pass::MatcherPass { +public: + HostGatherDQ(Context::Ref ctx); +}; + +// Tail vocab unpacks + +class DQUnpackDictMatMulCWu : public ov::pass::MatcherPass { +public: + DQUnpackDictMatMulCWu(Context::Ref ctx); +}; + +class DQUnpackDictMatMulGQi : public ov::pass::MatcherPass { +public: + DQUnpackDictMatMulGQi(Context::Ref ctx); +}; + +class CompressDictMatMulf32 : public ov::pass::MatcherPass { +public: + CompressDictMatMulf32(Context::Ref ctx); +}; + } // namespace opt } // namespace patterns } // namespace npuw diff --git a/src/plugins/intel_npu/src/plugin/npuw/util.cpp b/src/plugins/intel_npu/src/plugin/npuw/util.cpp index 9f7a404e0d9f4d..c67b0c62d0889c 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/util.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/util.cpp @@ -135,6 +135,14 @@ inline __m128i avx2_u8tof16_lo(__m128i vu8, __m256 z, __m256 s) { return avx2_u8tof16_hi(vu8h, z, s); } +inline __m128i avx2_u8tof16(__m128i vi8, __m256 z, __m256 s) { + __m256i i32vec = _mm256_cvtepu8_epi32(vi8); // extend: 8 x i8 -> 8 x i32 [256b of 256b] + __m256 f32vec = _mm256_cvtepi32_ps(i32vec); // convert: 8 x i32 -> 8 x f32 [256b of 256b] + __m256 f32sub = _mm256_sub_ps(f32vec, z); // subtract: 8 x f32 -> 8 x f32 [256b of 256b] + __m256 f32scl = _mm256_mul_ps(f32sub, s); // scale: 8 x f32 -> 8 x f32 [256b of 256b] + return _mm256_cvtps_ph(f32scl, _MM_FROUND_TO_NEAREST_INT); // convert: 8 x f32 -> 8 x f16 [128b] +} + // NOTE: This routine implements the NEW ORDER inline void avx2_u4tof16(__m256i vinput, __m128i vout[8], __m256 zvalVec, __m256 svalVec[8]) { // vinput - 64 x u4 elements - 256 bits @@ -1232,6 +1240,56 @@ void unpack_i8f16(const ov::SoPtr& from, } // sindex } +void unpack_u8f16(const ov::SoPtr& from, + const ov::SoPtr& zerop, + const ov::SoPtr& scale, + const ov::SoPtr& to, + const ov::npuw::util::UnpackOptions& _options) { + NPUW_ASSERT(from->is_continuous()); + NPUW_ASSERT(zerop->is_continuous()); + NPUW_ASSERT(scale->is_continuous()); + NPUW_ASSERT(to->is_continuous()); + NPUW_ASSERT(from->get_size() == to->get_size()); + NPUW_ASSERT(from->get_size() % 8 == 0); + NPUW_ASSERT(scale->get_shape()[0] == from->get_shape()[0]); + NPUW_ASSERT(scale->get_shape()[1] == 1); + NPUW_ASSERT(zerop->get_shape()[0] == from->get_shape()[0]); + NPUW_ASSERT(zerop->get_shape()[1] == 1); + + const auto scale_elem_type = scale->get_element_type(); + NPUW_ASSERT(scale_elem_type == ov::element::f32 || scale_elem_type == ov::element::f16); + + const auto zerop_elem_type = zerop->get_element_type(); + NPUW_ASSERT(zerop_elem_type == ov::element::u8); + + constexpr std::size_t VECSIZE = 8; + + const std::size_t total = from->get_size(); + const std::size_t stotal = scale->get_size(); + uint8_t const* pSrc = from->data(); + uint8_t const* pZrp = zerop->data(); + int8_t const* pScl = static_cast(scale->data()); + int16_t* pDst = static_cast(to->data()); + + for (std::size_t sindex = 0u; sindex < stotal; sindex++) { + __m256 svec = avx2_load_scale(pScl, scale_elem_type); + __m128i u8zp = _mm_set1_epi8(*pZrp); // bcast: 8 x u8 + __m256i u32zp = _mm256_cvtepu8_epi32(u8zp); // i32 zero point + __m256 f32zp = _mm256_cvtepi32_ps(u32zp); // f32 zero point + for (std::size_t index = 0u; index < (total / stotal); index += VECSIZE) { + __m128i const* pSrcV = reinterpret_cast(pSrc); + __m128i* pDstV = reinterpret_cast<__m128i*>(pDst); + __m128i u8in = _mm_loadl_epi64(pSrcV); // load: 8 x u8 + __m128i f16vec = avx2_u8tof16(u8in, f32zp, svec); // convert & scale + _mm_store_si128(pDstV, f16vec); // store: 8 x f16 + pSrc += VECSIZE; + pDst += VECSIZE; + } // index + pScl += scale_elem_type.size(); + pZrp++; + } // sindex +} + } // namespace void ov::npuw::util::unpack(const ov::SoPtr& from, @@ -1298,10 +1356,17 @@ void ov::npuw::util::unpack(const ov::SoPtr& from, const auto type_scale = scale->get_element_type(); const auto type_to = to->get_element_type(); - NPUW_ASSERT(type_from == ov::element::u4); - NPUW_ASSERT(type_zerop == ov::element::u4 || type_zerop == ov::element::f16 || type_zerop == ov::element::f32); - NPUW_ASSERT(type_scale == ov::element::f16 || type_scale == ov::element::f32); - NPUW_ASSERT(type_to == ov::element::f16); + if (type_from == ov::element::u4) { + NPUW_ASSERT(type_zerop == ov::element::u4 || type_zerop == ov::element::f16 || type_zerop == ov::element::f32); + NPUW_ASSERT(type_scale == ov::element::f16 || type_scale == ov::element::f32); + NPUW_ASSERT(type_to == ov::element::f16); + } else if (type_from == ov::element::u8) { + NPUW_ASSERT(type_zerop == ov::element::u8); + NPUW_ASSERT(type_scale == ov::element::f16); + NPUW_ASSERT(type_to == ov::element::f16); + } else { + NPUW_ASSERT(false && "Unsupported combination"); + } // This function determines the appropriate unpacking strategy for tensor multiplication // based on the 'scale' shape and 'from' shape. @@ -1325,20 +1390,61 @@ void ov::npuw::util::unpack(const ov::SoPtr& from, const auto& from_shape = from->get_shape(); const auto& scale_shape = scale->get_shape(); - if (scale_shape.size() == 3 && scale_shape[0] == from_shape[0] && scale_shape[1] == 1 && - scale_shape[2] == from_shape[2]) { - unpack_u4f16_z(from, zerop, scale, to, unpack_options); - } else if (scale_shape.size() == 3 && scale_shape[0] == from_shape[0] && scale_shape[1] == from_shape[1] && - scale_shape[2] == 1) { - if (zerop->get_size() == 1) { + if (type_from == ov::element::u4) { + if (scale_shape.size() == 3 && scale_shape[0] == from_shape[0] && scale_shape[1] == 1 && + scale_shape[2] == from_shape[2]) { + unpack_u4f16_z(from, zerop, scale, to, unpack_options); + } else if (scale_shape.size() == 3 && scale_shape[0] == from_shape[0] && scale_shape[1] == from_shape[1] && + scale_shape[2] == 1) { + if (zerop->get_size() == 1) { + unpack_u4f16(from, zerop, scale, to, unpack_options); + } else { + unpack_u4f16_asymm_zp(from, zerop, scale, to, unpack_options); + } + } else if (scale_shape.size() == 2 && scale_shape[0] == from_shape[0] && scale_shape[1] == 1) { unpack_u4f16(from, zerop, scale, to, unpack_options); } else { - unpack_u4f16_asymm_zp(from, zerop, scale, to, unpack_options); + NPUW_ASSERT(false); } - } else if (scale_shape.size() == 2 && scale_shape[0] == from_shape[0] && scale_shape[1] == 1) { - unpack_u4f16(from, zerop, scale, to, unpack_options); - } else { - NPUW_ASSERT(false); + } else if (type_from == ov::element::u8) { + // Only support CW for now + if (scale_shape.size() == 2 && scale_shape[0] == from_shape[0] && scale_shape[1] == 1) { + unpack_u8f16(from, zerop, scale, to, unpack_options); + } else { + NPUW_ASSERT(false); + } + } +} + +void ov::npuw::util::gather(const ov::SoPtr& src, + const ov::SoPtr& idx, + const ov::SoPtr& dst) { + const auto src_type = src->get_element_type(); + const auto dst_type = dst->get_element_type(); + NPUW_ASSERT(idx->get_element_type() == ov::element::i64); + NPUW_ASSERT(src_type == ov::element::f16 || src_type == ov::element::f32); + NPUW_ASSERT(src_type == dst_type); + + const auto idx_shape = idx->get_shape(); + NPUW_ASSERT(idx_shape.size() == 2); + NPUW_ASSERT(idx_shape[0] == 1); + + const auto src_shape = src->get_shape(); + NPUW_ASSERT(src_shape.size() == 2); + + const auto dst_shape = dst->get_shape(); + NPUW_ASSERT(dst_shape.size() == 3); + NPUW_ASSERT(src_shape[1] == dst_shape[2]); + + const int64_t* pIdx = idx->data(); + const uint8_t* pSrc = static_cast(src->data()); + uint8_t* pDst = static_cast(dst->data()); + + for (std::size_t r = 0; r < idx_shape[1]; r++) { + auto srcRowIdx = pIdx[r]; + auto pSrcRow = pSrc + src_shape[1] * srcRowIdx * src_type.size(); + std::copy_n(pSrcRow, src_shape[1] * src_type.size(), pDst); + pDst += dst_shape[2] * dst_type.size(); } } diff --git a/src/plugins/intel_npu/src/plugin/npuw/util.hpp b/src/plugins/intel_npu/src/plugin/npuw/util.hpp index 3a935d0a656b17..6012ce0e587352 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/util.hpp +++ b/src/plugins/intel_npu/src/plugin/npuw/util.hpp @@ -50,6 +50,8 @@ void unpack(const ov::SoPtr& from, const ov::SoPtr& to, const UnpackOptions& unpack_options = UnpackOptions{true, 16, false}); +void gather(const ov::SoPtr& src, const ov::SoPtr& idx, const ov::SoPtr& dst); + void to_f32(const ov::Tensor& in, ov::Tensor& out); void to_f16(ov::Tensor& t); void transpose(ov::Tensor& t);