From f3f81c4375525923e760a1812a20d16db8af33b5 Mon Sep 17 00:00:00 2001 From: "Zhai, Xuejun" Date: Mon, 19 Feb 2024 10:21:10 +0800 Subject: [PATCH] [AUTO BATCH] Clean legacy name in AUTO BATCH Plugin Signed-off-by: Zhai, Xuejun --- src/plugins/auto_batch/src/plugin.cpp | 13 ++++++++----- src/plugins/auto_batch/src/sync_infer_request.cpp | 4 ++-- .../tests/unit/async_infer_request_test.cpp | 9 +++++---- .../tests/unit/sync_infer_request_test.cpp | 7 ++++--- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/plugins/auto_batch/src/plugin.cpp b/src/plugins/auto_batch/src/plugin.cpp index 4024051f45c602..dfc01663a09257 100644 --- a/src/plugins/auto_batch/src/plugin.cpp +++ b/src/plugins/auto_batch/src/plugin.cpp @@ -172,8 +172,10 @@ std::shared_ptr Plugin::compile_model(const std::shared_ptr< const auto& static_shape = input->get_shape(); if (static_shape[0] != 1) OPENVINO_THROW("Auto-batching does not reshape/re-batch originally batched networks!"); - batched_inputs.insert( - ov::op::util::get_ie_output_name(params[input_id]->output(0))); // batched dim for the input + batched_inputs.insert(params[input_id] + ->output(0) + .get_node_shared_ptr() + ->get_friendly_name()); // batched dim for the input } else { // if the 0-th dim is not for the batch, then we support only the case when NONE dimension is batch for (size_t s = 1; s < shape.size(); s++) @@ -191,8 +193,9 @@ std::shared_ptr Plugin::compile_model(const std::shared_ptr< if (shape[0] != 1) OPENVINO_THROW("Auto-batching does not reshape/re-batch originally batched networks!"); const auto& node = output->input_value(0); - batched_outputs.insert( - ov::op::util::get_ie_output_name(ov::Output(node.get_node(), node.get_index()))); + batched_outputs.insert(ov::Output(node.get_node(), node.get_index()) + .get_node_shared_ptr() + ->get_friendly_name()); } else { // if the 0-th dim is not for the batch, then we support only the case when NONE dimension is batch for (size_t s = 1; s < shape.size(); s++) @@ -269,7 +272,7 @@ std::shared_ptr Plugin::compile_model(const std::shared_ptr< std::map, ov::PartialShape> partial_shapes; for (auto& input : inputs) { auto input_shape = input.get_shape(); - if (batched_inputs.find(ov::op::util::get_ie_output_name(input)) != batched_inputs.end()) { + if (batched_inputs.find(input.get_node_shared_ptr()->get_friendly_name()) != batched_inputs.end()) { input_shape[0] = meta_device.device_batch_size; } partial_shapes.insert({input, ov::PartialShape(input_shape)}); diff --git a/src/plugins/auto_batch/src/sync_infer_request.cpp b/src/plugins/auto_batch/src/sync_infer_request.cpp index 707adedc3b9bad..c3f7de497f3eee 100644 --- a/src/plugins/auto_batch/src/sync_infer_request.cpp +++ b/src/plugins/auto_batch/src/sync_infer_request.cpp @@ -52,7 +52,7 @@ size_t SyncInferRequest::get_batch_size() const { void SyncInferRequest::share_tensors_with_batched_req(const std::set& batched_inputs, const std::set& batched_outputs) { for (const auto& it : get_inputs()) { - auto name = ov::op::util::get_ie_output_name(it); + auto name = it.get_node_shared_ptr()->get_friendly_name(); ov::SoPtr res; auto batched_tensor = m_batched_request_wrapper->_infer_request_batched->get_tensor(it); if (!batched_tensor._so) @@ -66,7 +66,7 @@ void SyncInferRequest::share_tensors_with_batched_req(const std::setinput_value(0)); + auto name = it.get_node_shared_ptr()->input_value(0).get_node_shared_ptr()->get_friendly_name(); ov::SoPtr res; auto batched_tensor = m_batched_request_wrapper->_infer_request_batched->get_tensor(it); if (!batched_tensor._so) diff --git a/src/plugins/auto_batch/tests/unit/async_infer_request_test.cpp b/src/plugins/auto_batch/tests/unit/async_infer_request_test.cpp index a78d71f79b4c58..21cfc9866bce36 100644 --- a/src/plugins/auto_batch/tests/unit/async_infer_request_test.cpp +++ b/src/plugins/auto_batch/tests/unit/async_infer_request_test.cpp @@ -124,7 +124,7 @@ class AutoBatchAsyncInferRequestTest : public ::testing::TestWithParam, ov::PartialShape> partial_shapes; for (auto& input : inputs) { auto input_shape = input.get_shape(); - if (m_batched_inputs.find(ov::op::util::get_ie_output_name(input)) != m_batched_inputs.end()) { + if (m_batched_inputs.find(input.get_node_shared_ptr()->get_friendly_name()) != m_batched_inputs.end()) { input_shape[0] = m_batch_size; } partial_shapes.insert({input, ov::PartialShape(input_shape)}); @@ -229,14 +229,15 @@ class AutoBatchAsyncInferRequestTest : public ::testing::TestWithParam& model, int batch_size) { const auto& params = model->get_parameters(); for (size_t i = 0; i < params.size(); i++) { - m_batched_inputs.insert(ov::op::util::get_ie_output_name(params[i]->output(0))); + m_batched_inputs.insert(params[i]->output(0).get_node_shared_ptr()->get_friendly_name()); } const auto& results = model->get_results(); for (size_t i = 0; i < results.size(); i++) { const auto& output = results[i]; const auto& node = output->input_value(0); - m_batched_outputs.insert( - ov::op::util::get_ie_output_name(ov::Output(node.get_node(), node.get_index()))); + m_batched_outputs.insert(ov::Output(node.get_node(), node.get_index()) + .get_node_shared_ptr() + ->get_friendly_name()); } } }; diff --git a/src/plugins/auto_batch/tests/unit/sync_infer_request_test.cpp b/src/plugins/auto_batch/tests/unit/sync_infer_request_test.cpp index 5b836efd97b0a8..de54d6ceffc3a4 100644 --- a/src/plugins/auto_batch/tests/unit/sync_infer_request_test.cpp +++ b/src/plugins/auto_batch/tests/unit/sync_infer_request_test.cpp @@ -148,14 +148,15 @@ class AutoBatchRequestTest : public ::testing::TestWithParam& model, int batch_size) { const auto& params = model->get_parameters(); for (size_t i = 0; i < params.size(); i++) { - m_batched_inputs.insert(ov::op::util::get_ie_output_name(params[i]->output(0))); + m_batched_inputs.insert(params[i]->output(0).get_node_shared_ptr()->get_friendly_name()); } const auto& results = model->get_results(); for (size_t i = 0; i < results.size(); i++) { const auto& output = results[i]; const auto& node = output->input_value(0); - m_batched_outputs.insert( - ov::op::util::get_ie_output_name(ov::Output(node.get_node(), node.get_index()))); + m_batched_outputs.insert(ov::Output(node.get_node(), node.get_index()) + .get_node_shared_ptr() + ->get_friendly_name()); } } };