From 2efd2b196f461983d2563c32d5d330a51cf88e93 Mon Sep 17 00:00:00 2001 From: Georgy Krivoruchko <georgy.krivoruchko@intel.com> Date: Tue, 8 Oct 2024 18:56:03 +0400 Subject: [PATCH] [ONNX] Added support for dynamic input shapes in com.microsoft.MatMulNBits (#26898) ### Details: - Added option to receive an input with dynamic shape, shape must be calculated later while shape inference ### Tickets: - N/A --- .../src/op/com.microsoft/matmulnbits.cpp | 33 +++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp index 5b8a439933efd1..0db8f11c4f214f 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp @@ -59,8 +59,9 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) { "Expected rank of quantized weights is 3 [N][n_blocks_per_col][blob_size], got: ", b_quantized.get_partial_shape().rank()); CHECK_VALID_NODE(node, - a.get_element_type() == ov::element::f16 || a.get_element_type() == ov::element::f32, - "Unsupported input A type, accepted FP16, FP32, got: ", + a.get_element_type() == ov::element::f16 || a.get_element_type() == ov::element::f32 || + a.get_element_type() == ov::element::dynamic, + "Unsupported input A type, accepted dynamic, FP16, FP32, got: ", a.get_element_type()); CHECK_VALID_NODE( node, @@ -96,7 +97,9 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) { if (inputs.size() > 5) { bias = inputs[5]; CHECK_VALID_NODE(node, - bias.get_element_type() == a.get_element_type(), + bias.get_element_type() == a.get_element_type() || + a.get_element_type() == ov::element::dynamic || + bias.get_element_type() == ov::element::dynamic, "Unsupported input bias type, must be equal to input A type, got: ", bias.get_element_type()); CHECK_VALID_NODE(node, @@ -121,17 +124,35 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) { case 2: casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size * 4)}; casted_b = std::make_shared<v0::Constant>(ov::element::u2, casted_b_shape, b_const->get_data_ptr()); - default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 2); + if (a.get_element_type() != ov::element::dynamic) { + default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 2); + } else { + default_zp = + std::make_shared<v1::ConvertLike>(a, + std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 2.f)); + } break; case 4: casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size * 2)}; casted_b = std::make_shared<v0::Constant>(ov::element::u4, casted_b_shape, b_const->get_data_ptr()); - default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 8); + if (a.get_element_type() != ov::element::dynamic) { + default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 8); + } else { + default_zp = + std::make_shared<v1::ConvertLike>(a, + std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 8.f)); + } break; case 8: casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size)}; casted_b = op::util::reshape(b_const, casted_b_shape); - default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 128); + if (a.get_element_type() != ov::element::dynamic) { + default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 128); + } else { + default_zp = + std::make_shared<v1::ConvertLike>(a, + std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 128.f)); + } break; default: FRONT_END_THROW("Unsupported bits count");