Skip to content

Commit

Permalink
[ONNX] Added translator for MatMulNBits from com.microsoft domain (op…
Browse files Browse the repository at this point in the history
…envinotoolkit#26530)

### Details:
 - Added translator for MatMulNBits from com.microsoft domain

### Tickets:
 - CVS-152263
  • Loading branch information
gkrivor authored Sep 27, 2024
1 parent cdf7854 commit 1ef4e3e
Show file tree
Hide file tree
Showing 4 changed files with 450 additions and 0 deletions.
236 changes: 236 additions & 0 deletions src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "cmath"
#include "core/operator_set.hpp"
#include "exceptions.hpp"
#include "openvino/frontend/exception.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/transpose.hpp"
#include "utils/reshape.hpp"

using namespace ov::op;

namespace ov {
namespace frontend {
namespace onnx {
namespace com_microsoft {
namespace opset_1 {
ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
// Original documentation:
// https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulNBits
const auto inputs = node.get_ov_inputs();
FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= 3, "Minimum 3 inputs are required. Got: ", inputs.size());
const auto& a = inputs[0]; // required
ov::Output<ov::Node> b;
const auto& b_quantized = inputs[1]; // required
const auto& scales = inputs[2]; // required
ov::Output<ov::Node> zero_points; // optional, input[3]
ov::Output<ov::Node> group_idx; // optional, input[4]
ov::Output<ov::Node> bias; // optional, input[5]
const auto K = node.get_attribute_value<int64_t>("K"); // required
const auto N = node.get_attribute_value<int64_t>("N"); // required
const auto accuracy_level = node.get_attribute_value<int64_t>("accuracy_level", 0); // optional, default unset(0)
const auto block_size = node.get_attribute_value<int64_t>("block_size"); // required
const auto bits = node.get_attribute_value<int64_t>(
"bits",
4); // required, in docs: number of bits used for weight quantization (default 4)

const uint64_t n_blocks_per_col = (K + block_size - 1) / block_size;
const auto blob_size = static_cast<int64_t>(ceil(block_size * bits / 8));

CHECK_VALID_NODE(node, n_blocks_per_col > 0, "Wrong blocks count: ", n_blocks_per_col);
CHECK_VALID_NODE(node, blob_size > 0, "Wrong blob size: ", blob_size);
// in documentation: ...Input B is a 2D constant Matrix.
CHECK_VALID_NODE(node,
dynamic_cast<v0::Constant*>(b_quantized.get_node()) != nullptr,
"MatMulNBits limitation: accepting only a constant as a B input");
CHECK_VALID_NODE(node,
b_quantized.get_partial_shape().rank() == 3,
"Expected rank of quantized weights is 3 [N][n_blocks_per_col][blob_size], got: ",
b_quantized.get_partial_shape().rank());
CHECK_VALID_NODE(node,
a.get_element_type() == ov::element::f16 || a.get_element_type() == ov::element::f32,
"Unsupported input A type, accepted FP16, FP32, got: ",
a.get_element_type());
CHECK_VALID_NODE(
node,
b_quantized.get_element_type() == ov::element::u8 || b_quantized.get_element_type() == ov::element::i32,
"Unsupported input B type, accepted FP16, FP32, got: ",
b_quantized.get_element_type());

CHECK_VALID_NODE(node,
block_size >= 16 && (block_size % 2 == 0),
"Wrong block size, should be >=16 and be a power of 2, got: ",
block_size);
CHECK_VALID_NODE(node, accuracy_level >= 0 && accuracy_level <= 4, "Unsupported accuracy level: ", accuracy_level);

if (inputs.size() > 3) {
zero_points = inputs[3];
CHECK_VALID_NODE(node,
zero_points.get_element_type() == ov::element::u8 ||
zero_points.get_element_type() == ov::element::i32 ||
zero_points.get_element_type() == ov::element::f32 ||
zero_points.get_element_type() == ov::element::f16,
"Unsupported input zero_points type, accepted U8, I32, FP16, FP32, got: ",
zero_points.get_element_type());
}

if (inputs.size() > 4) {
group_idx = inputs[4];
CHECK_VALID_NODE(node,
group_idx.get_element_type() == ov::element::i32,
"Unsupported input group_idx type, accepted I32, got: ",
group_idx.get_element_type());
}

if (inputs.size() > 5) {
bias = inputs[5];
CHECK_VALID_NODE(node,
bias.get_element_type() == a.get_element_type(),
"Unsupported input bias type, must be equal to input A type, got: ",
bias.get_element_type());
CHECK_VALID_NODE(node,
bias.get_partial_shape() == PartialShape{N},
"Wrong bias shape, expected [",
N,
"], got: ",
bias.get_partial_shape());
}

{
const auto b_const = std::dynamic_pointer_cast<v0::Constant>(b_quantized.get_node_shared_ptr());

ov::Output<ov::Node> casted_b;
ov::Shape casted_b_shape;
ov::Output<ov::Node> default_zp;
// Casting/converting data of source constant.
// For further calculations (sub and/or multiply) we need to reshape it from [N][n_blocks_per_col][blob_size *
// X] to [N * n_blocks_per_col][blob_size * X] (where X is amount of values in 1 byte) because scale and
// zero_point are represented as: ...with shape like: [N * n_blocks_per_col]...
switch (bits) {
case 2:
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size * 4)};
casted_b = std::make_shared<v0::Constant>(ov::element::u2, casted_b_shape, b_const->get_data_ptr());
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 2);
break;
case 4:
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size * 2)};
casted_b = std::make_shared<v0::Constant>(ov::element::u4, casted_b_shape, b_const->get_data_ptr());
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 8);
break;
case 8:
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size)};
casted_b = op::util::reshape(b_const, casted_b_shape);
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 128);
break;
default:
FRONT_END_THROW("Unsupported bits count");
break;
}

// Possible issue with slice implementation, had to move convertion before slice, instead of slicing uint4
// TODO: Ticket
const auto converted_b = std::make_shared<v1::ConvertLike>(casted_b, a);

// TODO: Need to collect performance data in case constant folding is applied. Possible some perf/mem-gap

// Simple case
if (n_blocks_per_col == 1) {
// Removing unused items in case block is bigger than column count
// For example, if data is (uint8)[1,2,3,4,5,6] then block will be (uint8)[1,2,3,4,5,6,0,0,0,0,0,0,0,0,0,0].
// And last zeros are unused.
const auto zero_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 0);
const auto one_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
const auto elements_const =
std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, static_cast<int32_t>(K));
const auto axis_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
const auto slice_b =
std::make_shared<v8::Slice>(converted_b, zero_const, elements_const, one_const, axis_const);

// Transpose matrix
const auto transposed_shape =
std::make_shared<v0::Constant>(ov::element::i64, Shape{2}, std::vector<int64_t>{1, 0});
const auto transposed_b = std::make_shared<v1::Transpose>(slice_b, transposed_shape);

// If no zero-points provided - we generate default, depends on data size
if (!zero_points.get_node_shared_ptr()) {
zero_points = default_zp;
}
const auto sub_b = std::make_shared<v1::Subtract>(transposed_b, zero_points);

// Scaling
const auto scaled_b = std::make_shared<v1::Multiply>(sub_b, scales);

// Adding bias if required
if (!bias.get_node_shared_ptr()) {
b = scaled_b;
} else {
b = std::make_shared<v1::Add>(scaled_b, bias);
}
} else {
// Transpose matrix. Quantized B matrix is transposed and has a shape [N,K].
// To apply further operations on it which operand's shape is [N] we do this
// transpose to have a matrix [K,N]...
const auto transposed_shape =
std::make_shared<v0::Constant>(ov::element::i64, Shape{2}, std::vector<int64_t>{1, 0});
ov::Output<ov::Node> transposed_b = std::make_shared<v1::Transpose>(converted_b, transposed_shape);

// If no zero-points provided - we generate default, depends on data size
if (!zero_points.get_node_shared_ptr()) {
zero_points = default_zp;
}
const auto sub_b = std::make_shared<v1::Subtract>(transposed_b, zero_points);

// Scaling
const auto scaled_b = std::make_shared<v1::Multiply>(sub_b, scales);

// Transpose again to make reshaping and slicing
transposed_b = std::make_shared<v1::Transpose>(scaled_b, transposed_shape);

const auto reshaped_b =
op::util::reshape(transposed_b,
ov::Shape{static_cast<size_t>(casted_b_shape[0] / n_blocks_per_col),
static_cast<size_t>(casted_b_shape[1] * n_blocks_per_col)});

// Removing unused items in case block is bigger than column count (see description for
// Slice above)
const auto zero_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 0);
const auto one_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
const auto elements_const =
std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, static_cast<int32_t>(K));
const auto axis_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
const auto slice_b =
std::make_shared<v8::Slice>(reshaped_b, zero_const, elements_const, one_const, axis_const);

// Adding bias if required
if (!bias.get_node_shared_ptr()) {
return {std::make_shared<v0::MatMul>(a, slice_b, false, true)};
} else {
// Transpose again
transposed_b = std::make_shared<v1::Transpose>(slice_b, transposed_shape);

b = std::make_shared<v1::Add>(transposed_b, bias);
}
}
}

return {std::make_shared<v0::MatMul>(a, b)};
}

ONNX_OP("MatMulNBits", OPSET_SINCE(1), com_microsoft::opset_1::matmulnbits, MICROSOFT_DOMAIN);

} // namespace opset_1
} // namespace com_microsoft
} // namespace onnx
} // namespace frontend
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
ir_version: 3
producer_name: "OpenVINO ONNX Frontend"
producer_version: ""
model_version: 0
graph {
name: "test_matmul_2d"
node {
input: "a"
input: "b_Q4"
input: "b_scales"
output: "c"
op_type: "MatMulNBits"
attribute {
name: "K"
i: 17
type: INT
}
attribute {
name: "N"
i: 3
type: INT
}
attribute {
name: "accuracy_level"
i: 4
type: INT
}
attribute {
name: "bits"
i: 4
type: INT
}
attribute {
name: "block_size"
i: 16
type: INT
}
domain: "com.microsoft"
}
initializer {
dims: 3
dims: 2
dims: 8
data_type: 2
name: "b_Q4"
raw_data: "G\2025`\024G\2025\200\000\000\000\000\000\000\000Fq$X\003Fq$\210\000\000\000\000\000\000\0005`\024G\2025`\024\200\000\000\000\000\000\000\000"
}
initializer {
dims: 6
data_type: 1
name: "b_scales"
raw_data: "\000\000\220\277\000\000\220\277\000\000\220\277\000\000\000\200\000\000\220\277\000\000\000\276"
}
input {
name: "a"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 17
}
}
}
}
}
output {
name: "c"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 7
}
opset_import {
version: 1
}
Loading

0 comments on commit 1ef4e3e

Please sign in to comment.