Skip to content

Commit

Permalink
Fix passes with DQ_FULL:NO
Browse files Browse the repository at this point in the history
  • Loading branch information
smirnov-alexey committed Nov 21, 2024
1 parent e8c4185 commit 4ff6bec
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 27 deletions.
102 changes: 78 additions & 24 deletions src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ DQMatMulCWi::DQMatMulCWi(Context::Ref ctx) {
auto matched_node_muls_out = uat::_(node_to_output).at_or_at(qcvtm, qmuls);

if (!ctx.get().mm_dq_full) {
const auto& matm_mul_out_shape = matched_matmul->get_output_shape(0);
const auto& matm_mul_in_shape = matched_matmul->get_input_shape(1);
NPUW_ASSERT(matm_mul_out_shape.back() == matm_mul_in_shape.front());
NPUW_ASSERT(matched_matmul->get_transpose_b());
return false; // root hasn't changed
}

Expand Down Expand Up @@ -243,6 +247,7 @@ DQMatMulGQi::DQMatMulGQi(Context::Ref ctx) {
auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr();
auto matched_node_qreshp = node_to_output.at(qreshp).get_node_shared_ptr();
auto matched_out_mmi = node_to_output.at(qmmi);

auto matched_qweight = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_qweight);
Expand All @@ -259,18 +264,40 @@ DQMatMulGQi::DQMatMulGQi(Context::Ref ctx) {
act_shape.size() == 3 && act_shape[1] == 1 && // single-token case
qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[1] == 1 && qcoeff_shape[2] == qweight_shape[2] &&
!matched_matmul->get_transpose_a() && !matched_matmul->get_transpose_b()) {
if (!ctx.get().mm_dq_full) {
// Transpose weight and coeff
ov::Shape tw_shape = {qweight_shape[2], qweight_shape[0], qweight_shape[1]};
matched_qweight->set_partial_shape(tw_shape);
matched_qweight->validate_and_infer_types();
ctx.get().permute(matched_qweight, {2, 0, 1});

ov::Shape tc_shape = {qcoeff_shape[2], qcoeff_shape[0], qcoeff_shape[1]};
matched_qcoeff->set_partial_shape(tc_shape);
matched_qcoeff->validate_and_infer_types();
ctx.get().permute(matched_qcoeff, {2, 0, 1});

// Change Reshape's shape
std::vector<std::size_t> transposed_shape = {qweight_shape[2], qweight_shape[0] * qweight_shape[1]};
auto transposed_shape_c =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, transposed_shape);
matched_node_qreshp->input(1).replace_source_output(transposed_shape_c);
matched_node_qreshp->validate_and_infer_types();

matched_matmul->set_transpose_b(true);
matched_matmul->validate_and_infer_types();

const auto& matm_mul_out_shape = matched_matmul->get_output_shape(0);
const auto& matm_mul_in_shape = matched_matmul->get_input_shape(1);
NPUW_ASSERT(matm_mul_out_shape.back() == matm_mul_in_shape.front());
return false; // root hasn't changed
}

// Mark W closure to transpose, and transpose the respective parameter
ov::Shape tw_shape = {qweight_shape[0], qweight_shape[2], qweight_shape[1]};
matched_qweight->set_partial_shape(tw_shape);
matched_qweight->validate_and_infer_types();
ctx.get().permute(matched_qweight, {0, 2, 1});

if (!ctx.get().mm_dq_full) {
// Only transpose MatMul
matched_matmul->set_transpose_b(true);
return false; // root hasn't changed
}

// Mark S closure to be lowered fo f16
ctx.get().to_f16(matched_qcoeff);

Expand Down Expand Up @@ -380,6 +407,14 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) {
act_shape.size() == 3 && act_shape[0] == 1 && act_shape[1] == 1 && qcoeff_shape[0] == qweight_shape[0] &&
qcoeff_shape[2] == 1 && qcoeff_shape[1] == qweight_shape[1] && !matched_matmul->get_transpose_a() &&
matched_matmul->get_transpose_b()) {
if (!ctx.get().mm_dq_full) {
const auto& matm_mul_out_shape = matched_matmul->get_output_shape(0);
const auto& matm_mul_in_shape = matched_matmul->get_input_shape(1);
NPUW_ASSERT(matm_mul_out_shape.back() == matm_mul_in_shape.front());
NPUW_ASSERT(matched_matmul->get_transpose_b());
return false; // root hasn't changed
}

// Mark W closure to transpose, and transpose the respective parameter
ctx.get().permute(matched_qweight, {1, 0, 2});

Expand All @@ -394,12 +429,6 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) {
matched_qcoeff->set_partial_shape(ts_shape);
matched_qcoeff->validate_and_infer_types();

if (!ctx.get().mm_dq_full) {
// Only transpose MatMul
matched_matmul->set_transpose_b(false);
return false; // root hasn't changed
}

// Reshape the Act to group format
const auto NSPLIT = qweight_shape[1];
std::vector<std::size_t> rshp_act_v = {NSPLIT, 1, act_shape[2] / NSPLIT};
Expand Down Expand Up @@ -490,6 +519,7 @@ DQMatMulGQiP::DQMatMulGQiP(Context::Ref ctx) {
auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr();
auto matched_node_qreshp = node_to_output.at(qreshp).get_node_shared_ptr();
auto matched_out_mmi = node_to_output.at(qmmi);

auto matched_qweight = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_qweight);
Expand All @@ -505,18 +535,40 @@ DQMatMulGQiP::DQMatMulGQiP(Context::Ref ctx) {
act_shape.size() == 3 && act_shape[1] > 1 && // multi-token case
qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[1] == 1 && qcoeff_shape[2] == qweight_shape[2] &&
!matched_matmul->get_transpose_a() && !matched_matmul->get_transpose_b()) {
if (!ctx.get().mm_dq_full) {
// Transpose weight and coeff
ov::Shape tw_shape = {qweight_shape[2], qweight_shape[0], qweight_shape[1]};
matched_qweight->set_partial_shape(tw_shape);
matched_qweight->validate_and_infer_types();
ctx.get().permute(matched_qweight, {2, 0, 1});

ov::Shape tc_shape = {qcoeff_shape[2], qcoeff_shape[0], qcoeff_shape[1]};
matched_qcoeff->set_partial_shape(tc_shape);
matched_qcoeff->validate_and_infer_types();
ctx.get().permute(matched_qcoeff, {2, 0, 1});

// Change Reshape's shape
std::vector<std::size_t> transposed_shape = {qweight_shape[2], qweight_shape[0] * qweight_shape[1]};
auto transposed_shape_c =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, transposed_shape);
matched_node_qreshp->input(1).replace_source_output(transposed_shape_c);
matched_node_qreshp->validate_and_infer_types();

matched_matmul->set_transpose_b(true);
matched_matmul->validate_and_infer_types();

const auto& matm_mul_out_shape = matched_matmul->get_output_shape(0);
const auto& matm_mul_in_shape = matched_matmul->get_input_shape(1);
NPUW_ASSERT(matm_mul_out_shape.back() == matm_mul_in_shape.front());
return false; // root hasn't changed
}

// Mark W closure to transpose, and transpose the respective parameter
ov::Shape tw_shape = {qweight_shape[0], qweight_shape[2], qweight_shape[1]};
matched_qweight->set_partial_shape(tw_shape);
matched_qweight->validate_and_infer_types();
ctx.get().permute(matched_qweight, {0, 2, 1});

if (!ctx.get().mm_dq_full) {
// Only transpose MatMul
matched_matmul->set_transpose_b(true);
return false; // root hasn't changed
}

// Mark S closure to be lowered fo f16
matched_qcoeff->set_element_type(ov::element::f16);
matched_qcoeff->validate_and_infer_types();
Expand Down Expand Up @@ -628,6 +680,14 @@ DQMatMulGQ2iP::DQMatMulGQ2iP(Context::Ref ctx) {
act_shape.size() == 3 && just_one(act_shape[0], act_shape[1]) && // multi-token case
qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[1] == qweight_shape[1] && qcoeff_shape[2] == 1 &&
!matched_matmul->get_transpose_a() && matched_matmul->get_transpose_b()) {
if (!ctx.get().mm_dq_full) {
const auto& matm_mul_out_shape = matched_matmul->get_output_shape(0);
const auto& matm_mul_in_shape = matched_matmul->get_input_shape(1);
NPUW_ASSERT(matm_mul_out_shape.back() == matm_mul_in_shape.front());
NPUW_ASSERT(matched_matmul->get_transpose_b());
return false; // root hasn't changed
}

// Mark W closure to transpose, and transpose the respective parameter
ov::Shape tw_shape = {qweight_shape[1], qweight_shape[0], qweight_shape[2]};
matched_qweight->set_partial_shape(tw_shape);
Expand All @@ -641,12 +701,6 @@ DQMatMulGQ2iP::DQMatMulGQ2iP(Context::Ref ctx) {
matched_qcoeff->set_partial_shape(ts_shape);
matched_qcoeff->validate_and_infer_types();

if (!ctx.get().mm_dq_full) {
// Only transpose MatMul
matched_matmul->set_transpose_b(false);
return false; // root hasn't changed
}

// Select proper activation shape
std::size_t act_dim = act_shape[0] > act_shape[1] ? 0 : 1;

Expand Down
28 changes: 25 additions & 3 deletions src/plugins/intel_npu/src/plugin/npuw/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,13 @@ inline uint8_t tread_4b(const ov::Tensor& t, std::size_t r, std::size_t c, std::
return hi4(*telem);
}

inline float tread_f32(const ov::Tensor& t, std::size_t r, std::size_t c, std::size_t COLS) {
const float* tdata = static_cast<float*>(t.data());
const float* trow = tdata + r * COLS;
const float* telem = trow + c;
return *telem;
}

inline void twrite_4b(ov::Tensor& t, uint8_t value, std::size_t r, std::size_t c, std::size_t COLS) {
uint8_t* tdata = static_cast<uint8_t*>(t.data());
uint8_t* trow = tdata + r * COLS / 2;
Expand All @@ -450,10 +457,17 @@ inline void twrite_4b(ov::Tensor& t, uint8_t value, std::size_t r, std::size_t c
}
}

inline void twrite_f32(ov::Tensor& t, float value, std::size_t r, std::size_t c, std::size_t COLS) {
float* tdata = static_cast<float*>(t.data());
float* trow = tdata + r * COLS;
float* telem = trow + c;
*telem = value;
}

ov::Tensor ov::npuw::util::transpose(const ov::Tensor& t) {
ov::Shape shape = t.get_shape();
NPUW_ASSERT(shape.size() == 3); // Yes, so far only transpose 3D tensors
NPUW_ASSERT(t.get_element_type() == ov::element::i4);
NPUW_ASSERT(t.get_element_type() == ov::element::i4 || t.get_element_type() == ov::element::f32);

ov::Shape tshape = {shape[2], shape[0], shape[1]};
ov::Tensor tnew(t.get_element_type(), tshape);
Expand All @@ -462,8 +476,16 @@ ov::Tensor ov::npuw::util::transpose(const ov::Tensor& t) {
const auto IN_COLS = shape[2];
for (std::size_t i = 0; i < IN_ROWS; i++) {
for (std::size_t j = 0; j < IN_COLS; j++) {
uint8_t value = tread_4b(t, i, j, IN_COLS);
twrite_4b(tnew, value, j, i, IN_ROWS);
switch (t.get_element_type()) {
case ov::element::i4:
twrite_4b(tnew, tread_4b(t, i, j, IN_COLS), j, i, IN_ROWS);
break;
case ov::element::f32:
twrite_f32(tnew, tread_f32(t, i, j, IN_COLS), j, i, IN_ROWS);
break;
default:
NPUW_ASSERT("Element type is not supported yet");
}
}
}
return tnew;
Expand Down

0 comments on commit 4ff6bec

Please sign in to comment.