From 2efd2b196f461983d2563c32d5d330a51cf88e93 Mon Sep 17 00:00:00 2001
From: Georgy Krivoruchko <georgy.krivoruchko@intel.com>
Date: Tue, 8 Oct 2024 18:56:03 +0400
Subject: [PATCH] [ONNX] Added support for dynamic input shapes in
 com.microsoft.MatMulNBits (#26898)

### Details:
- Added option to receive an input with dynamic shape, shape must be
calculated later while shape inference

### Tickets:
 - N/A
---
 .../src/op/com.microsoft/matmulnbits.cpp      | 33 +++++++++++++++----
 1 file changed, 27 insertions(+), 6 deletions(-)

diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp
index 5b8a439933efd1..0db8f11c4f214f 100644
--- a/src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp
+++ b/src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp
@@ -59,8 +59,9 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
                      "Expected rank of quantized weights is 3 [N][n_blocks_per_col][blob_size], got: ",
                      b_quantized.get_partial_shape().rank());
     CHECK_VALID_NODE(node,
-                     a.get_element_type() == ov::element::f16 || a.get_element_type() == ov::element::f32,
-                     "Unsupported input A type, accepted FP16, FP32, got: ",
+                     a.get_element_type() == ov::element::f16 || a.get_element_type() == ov::element::f32 ||
+                         a.get_element_type() == ov::element::dynamic,
+                     "Unsupported input A type, accepted dynamic, FP16, FP32, got: ",
                      a.get_element_type());
     CHECK_VALID_NODE(
         node,
@@ -96,7 +97,9 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
     if (inputs.size() > 5) {
         bias = inputs[5];
         CHECK_VALID_NODE(node,
-                         bias.get_element_type() == a.get_element_type(),
+                         bias.get_element_type() == a.get_element_type() ||
+                             a.get_element_type() == ov::element::dynamic ||
+                             bias.get_element_type() == ov::element::dynamic,
                          "Unsupported input bias type, must be equal to input A type, got: ",
                          bias.get_element_type());
         CHECK_VALID_NODE(node,
@@ -121,17 +124,35 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
         case 2:
             casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size * 4)};
             casted_b = std::make_shared<v0::Constant>(ov::element::u2, casted_b_shape, b_const->get_data_ptr());
-            default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 2);
+            if (a.get_element_type() != ov::element::dynamic) {
+                default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 2);
+            } else {
+                default_zp =
+                    std::make_shared<v1::ConvertLike>(a,
+                                                      std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 2.f));
+            }
             break;
         case 4:
             casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size * 2)};
             casted_b = std::make_shared<v0::Constant>(ov::element::u4, casted_b_shape, b_const->get_data_ptr());
-            default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 8);
+            if (a.get_element_type() != ov::element::dynamic) {
+                default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 8);
+            } else {
+                default_zp =
+                    std::make_shared<v1::ConvertLike>(a,
+                                                      std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 8.f));
+            }
             break;
         case 8:
             casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size)};
             casted_b = op::util::reshape(b_const, casted_b_shape);
-            default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 128);
+            if (a.get_element_type() != ov::element::dynamic) {
+                default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 128);
+            } else {
+                default_zp =
+                    std::make_shared<v1::ConvertLike>(a,
+                                                      std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 128.f));
+            }
             break;
         default:
             FRONT_END_THROW("Unsupported bits count");