Skip to content

Commit

Permalink
Introduce DQ_FULL property
Browse files Browse the repository at this point in the history
  • Loading branch information
smirnov-alexey committed Nov 21, 2024
1 parent 9f7ad50 commit e8c4185
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ DEFINE_OPT(NPUW_PLAN, std::string, "", npuw::partitioning::plan, CompileTime);
DEFINE_OPT(NPUW_FOLD, bool, false, npuw::partitioning::fold, CompileTime);
DEFINE_OPT(NPUW_CWAI, bool, false, npuw::partitioning::cwai, CompileTime);
DEFINE_OPT(NPUW_DQ, bool, false, npuw::partitioning::dyn_quant, CompileTime);
DEFINE_OPT(NPUW_DQ_FULL, bool, true, npuw::partitioning::dyn_quant_full, CompileTime);
DEFINE_OPT(NPUW_PMM, std::string, "2", npuw::partitioning::par_matmul_merge_dims, CompileTime);
DEFINE_OPT(NPUW_SLICE_OUT, bool, false, npuw::partitioning::slice_out, CompileTime);
DEFINE_OPT(NPUW_HOST_GATHER, bool, true, npuw::partitioning::host_gather, CompileTime);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ static constexpr ov::Property<bool> cwai{"NPUW_CWAI"};
*/
static constexpr ov::Property<bool> dyn_quant{"NPUW_DQ"};

/**
* @brief
* Type: bool.
* Apply multiply shuffle and matmul unroll during dynamic quantization transformations.
* Default value: true.
*/
static constexpr ov::Property<bool> dyn_quant_full{"NPUW_DQ_FULL"};

/**
* @brief
* Type: string.
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_npu/src/al/src/config/npuw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void intel_npu::registerNPUWOptions(OptionsDesc& desc) {
desc.add<NPUW_FOLD>();
desc.add<NPUW_CWAI>();
desc.add<NPUW_DQ>();
desc.add<NPUW_DQ_FULL>();
desc.add<NPUW_PMM>();
desc.add<NPUW_SLICE_OUT>();
desc.add<NPUW_SPATIAL>();
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_npu/src/plugin/npuw/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,7 @@ void ov::npuw::CompiledModel::implement_properties() {
BIND(npuw::partitioning::fold, NPUW_FOLD),
BIND(npuw::partitioning::cwai, NPUW_CWAI),
BIND(npuw::partitioning::dyn_quant, NPUW_DQ),
BIND(npuw::partitioning::dyn_quant_full, NPUW_DQ_FULL),
BIND(npuw::partitioning::par_matmul_merge_dims, NPUW_PMM),
BIND(npuw::partitioning::slice_out, NPUW_SLICE_OUT),
BIND(npuw::partitioning::spatial, NPUW_SPATIAL),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1956,9 +1956,10 @@ void Partitioner::optimize(const std::string& func_name) {
// Run "dynamic quantization"
ov::npuw::patterns::opt::Context ctx;
ctx.is_spatial = f._spatial.has_value();
ctx.mm_dq_full = cfg.get<::intel_npu::NPUW_DQ_FULL>();

ov::pass::GraphRewrite rewr;
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulCWi>();
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulCWi>(std::ref(ctx));
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulGQi>(std::ref(ctx));
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulGQ2i>(std::ref(ctx));
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulGQiP>(std::ref(ctx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ namespace uat = ov::npuw::util::at;
// Param/Const(S) -> (Reshape) -> (to(f32)) -> Reshape -->
//

DQMatMulCWi::DQMatMulCWi() {
DQMatMulCWi::DQMatMulCWi(Context::Ref ctx) {
auto qweight = opp::wrap_type<ov::op::v0::Parameter>();
auto qcoeff = opp::any_input();
auto reshapew = opp::optional<ov::op::v1::Reshape>({qweight, opp::any_input()});
Expand Down Expand Up @@ -161,6 +161,10 @@ DQMatMulCWi::DQMatMulCWi() {
auto matched_node_qcoeff_out = uat::_(node_to_output).at_or_at_or_at(qcvtc, reshapec, qcoeff);
auto matched_node_muls_out = uat::_(node_to_output).at_or_at(qcvtm, qmuls);

if (!ctx.get().mm_dq_full) {
return false; // root hasn't changed
}

// Reconnect MatMul to read from Convert(W) directly.
// Note: ACT has to be converted too.
auto cvt_prec = matched_node_cvtw->output(0).get_element_type();
Expand Down Expand Up @@ -261,6 +265,12 @@ DQMatMulGQi::DQMatMulGQi(Context::Ref ctx) {
matched_qweight->validate_and_infer_types();
ctx.get().permute(matched_qweight, {0, 2, 1});

if (!ctx.get().mm_dq_full) {
// Only transpose MatMul
matched_matmul->set_transpose_b(true);
return false; // root hasn't changed
}

// Mark S closure to be lowered fo f16
ctx.get().to_f16(matched_qcoeff);

Expand Down Expand Up @@ -384,6 +394,12 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) {
matched_qcoeff->set_partial_shape(ts_shape);
matched_qcoeff->validate_and_infer_types();

if (!ctx.get().mm_dq_full) {
// Only transpose MatMul
matched_matmul->set_transpose_b(false);
return false; // root hasn't changed
}

// Reshape the Act to group format
const auto NSPLIT = qweight_shape[1];
std::vector<std::size_t> rshp_act_v = {NSPLIT, 1, act_shape[2] / NSPLIT};
Expand Down Expand Up @@ -495,6 +511,12 @@ DQMatMulGQiP::DQMatMulGQiP(Context::Ref ctx) {
matched_qweight->validate_and_infer_types();
ctx.get().permute(matched_qweight, {0, 2, 1});

if (!ctx.get().mm_dq_full) {
// Only transpose MatMul
matched_matmul->set_transpose_b(true);
return false; // root hasn't changed
}

// Mark S closure to be lowered fo f16
matched_qcoeff->set_element_type(ov::element::f16);
matched_qcoeff->validate_and_infer_types();
Expand Down Expand Up @@ -619,6 +641,12 @@ DQMatMulGQ2iP::DQMatMulGQ2iP(Context::Ref ctx) {
matched_qcoeff->set_partial_shape(ts_shape);
matched_qcoeff->validate_and_infer_types();

if (!ctx.get().mm_dq_full) {
// Only transpose MatMul
matched_matmul->set_transpose_b(false);
return false; // root hasn't changed
}

// Select proper activation shape
std::size_t act_dim = act_shape[0] > act_shape[1] ? 0 : 1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,10 @@ namespace npuw {
namespace patterns {
namespace opt {

class DQMatMulCWi : public ov::pass::MatcherPass {
public:
DQMatMulCWi();
};

struct Context {
std::string pmm_dims;
bool is_spatial = false;
bool mm_dq_full = true;

using PPtr = std::shared_ptr<ov::op::v0::Parameter>;
using NPtr = std::shared_ptr<ov::Node>;
Expand Down Expand Up @@ -66,6 +62,11 @@ struct Context {
using Ref = std::reference_wrapper<Context>;
};

class DQMatMulCWi : public ov::pass::MatcherPass {
public:
DQMatMulCWi(Context::Ref ctx);
};

class DQMatMulGQi : public ov::pass::MatcherPass {
public:
explicit DQMatMulGQi(Context::Ref ctx);
Expand Down

0 comments on commit e8c4185

Please sign in to comment.