Skip to content

Commit

Permalink
[ONNX] MatMulInteger (openvinotoolkit#7825)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkozykowski authored Oct 6, 2021
1 parent 9add27f commit 659daf6
Show file tree
Hide file tree
Showing 17 changed files with 1,331 additions and 16 deletions.
55 changes: 55 additions & 0 deletions ngraph/frontend/onnx/frontend/src/op/matmul_integer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "op/matmul_integer.hpp"

#include <cstddef>
#include <memory>
#include <vector>

#include "default_opset.hpp"

namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
OutputVector matmul_integer(const Node& node) {
const OutputVector& inputs = node.get_ng_inputs();

const auto& A = inputs.at(0);
const auto& B = inputs.at(1);
const auto& A_zero_point =
(inputs.size() > 2) ? inputs.at(2) : ngraph::op::Constant::create(ngraph::element::i32, {1}, {0});
const auto& B_zero_point =
(inputs.size() > 3) ? inputs.at(3) : ngraph::op::Constant::create(ngraph::element::i32, {1}, {0});

const auto& converted_A = std::make_shared<default_opset::Convert>(A, element::i32);
const auto& converted_B = std::make_shared<default_opset::Convert>(B, element::i32);

const auto& converted_A_zero_point = std::make_shared<default_opset::Convert>(A_zero_point, element::i32);
const auto& converted_B_zero_point = std::make_shared<default_opset::Convert>(B_zero_point, element::i32);

const auto& A_zero_point_rank = A_zero_point.get_partial_shape().rank();

Output<ngraph::Node> shifted_A;
if (A_zero_point_rank.is_static() && A_zero_point_rank.get_length() == 1) {
const auto& one_node = ngraph::op::Constant::create(ngraph::element::i32, {1}, {1});
const auto& reshaped_A_zero_point =
std::make_shared<default_opset::Unsqueeze>(converted_A_zero_point, one_node);

shifted_A = std::make_shared<default_opset::Subtract>(converted_A, reshaped_A_zero_point);
} else {
shifted_A = std::make_shared<default_opset::Subtract>(converted_A, converted_A_zero_point);
}

const auto& shifted_B = std::make_shared<default_opset::Subtract>(converted_B, converted_B_zero_point);

const auto& result = std::make_shared<default_opset::MatMul>(shifted_A, shifted_B);

return {result};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
24 changes: 24 additions & 0 deletions ngraph/frontend/onnx/frontend/src/op/matmul_integer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "ngraph/node.hpp"
#include "onnx_import/core/node.hpp"

namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
/// \brief Performs ONNX MatMulInteger operation.
///
/// \param node The ONNX node object representing this operation.
///
/// \return The vector containing Ngraph nodes producing output of ONNX quantizied
/// matrix multiplication integer operation.
OutputVector matmul_integer(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
2 changes: 2 additions & 0 deletions ngraph/frontend/onnx/frontend/src/ops_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
#include "op/lrn.hpp"
#include "op/lstm.hpp"
#include "op/matmul.hpp"
#include "op/matmul_integer.hpp"
#include "op/max.hpp"
#include "op/max_pool.hpp"
#include "op/mean.hpp"
Expand Down Expand Up @@ -352,6 +353,7 @@ OperatorsBridge::OperatorsBridge() {
REGISTER_OPERATOR("LpNormalization", 1, lp_norm);
REGISTER_OPERATOR("LRN", 1, lrn);
REGISTER_OPERATOR("LSTM", 1, lstm);
REGISTER_OPERATOR("MatMulInteger", 1, matmul_integer);
REGISTER_OPERATOR("MatMul", 1, matmul);
REGISTER_OPERATOR("MaxPool", 1, max_pool);
REGISTER_OPERATOR("Max", 1, max);
Expand Down
91 changes: 91 additions & 0 deletions ngraph/test/models/onnx/matmul_integer.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "A"
input: "B"
input: "a_zero_point"
input: "b_zero_point"
output: "Y"
op_type: "MatMulInteger"
}
name: "MatMulInt"
input {
name: "A"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 4
}
dim {
dim_value: 3
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "a_zero_point"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 1
}
}
}
}
}
input {
name: "b_zero_point"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 4
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
domain: ""
version: 13
}
83 changes: 83 additions & 0 deletions ngraph/test/models/onnx/matmul_integer_2d_x_3d.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "A"
input: "B"
input: "a_zero_point"
output: "Y"
op_type: "MatMulInteger"
}
name: "MatMulInt"
input {
name: "A"
type {
tensor_type {
elem_type: 3
shape {
dim {
dim_value: 2
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 3
shape {
dim {
dim_value: 2
}
dim {
dim_value: 4
}
dim {
dim_value: 3
}
}
}
}
}
input {
name: "a_zero_point"
type {
tensor_type {
elem_type: 3
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
domain: ""
version: 10
}
106 changes: 106 additions & 0 deletions ngraph/test/models/onnx/matmul_integer_3d.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "A"
input: "B"
input: "a_zero_point"
input: "b_zero_point"
output: "Y"
op_type: "MatMulInteger"
}
name: "MatMulInt"
input {
name: "A"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 3
shape {
dim {
dim_value: 2
}
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "a_zero_point"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 1
}
}
}
}
}
input {
name: "b_zero_point"
type {
tensor_type {
elem_type: 3
shape {
dim {
dim_value: 2
}
dim {
dim_value: 1
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
domain: ""
version: 10
}
Loading

0 comments on commit 659daf6

Please sign in to comment.