Skip to content

Commit

Permalink
[AUTO BATCH] Clean legacy name in AUTO BATCH Plugin
Browse files Browse the repository at this point in the history
Signed-off-by: Zhai, Xuejun <[email protected]>
  • Loading branch information
zhaixuejun1993 committed Feb 19, 2024
1 parent 3986f55 commit f3f81c4
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 14 deletions.
13 changes: 8 additions & 5 deletions src/plugins/auto_batch/src/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,10 @@ std::shared_ptr<ov::ICompiledModel> Plugin::compile_model(const std::shared_ptr<
const auto& static_shape = input->get_shape();
if (static_shape[0] != 1)
OPENVINO_THROW("Auto-batching does not reshape/re-batch originally batched networks!");
batched_inputs.insert(
ov::op::util::get_ie_output_name(params[input_id]->output(0))); // batched dim for the input
batched_inputs.insert(params[input_id]
->output(0)
.get_node_shared_ptr()
->get_friendly_name()); // batched dim for the input
} else {
// if the 0-th dim is not for the batch, then we support only the case when NONE dimension is batch
for (size_t s = 1; s < shape.size(); s++)
Expand All @@ -191,8 +193,9 @@ std::shared_ptr<ov::ICompiledModel> Plugin::compile_model(const std::shared_ptr<
if (shape[0] != 1)
OPENVINO_THROW("Auto-batching does not reshape/re-batch originally batched networks!");
const auto& node = output->input_value(0);
batched_outputs.insert(
ov::op::util::get_ie_output_name(ov::Output<const ov::Node>(node.get_node(), node.get_index())));
batched_outputs.insert(ov::Output<const ov::Node>(node.get_node(), node.get_index())
.get_node_shared_ptr()
->get_friendly_name());
} else {
// if the 0-th dim is not for the batch, then we support only the case when NONE dimension is batch
for (size_t s = 1; s < shape.size(); s++)
Expand Down Expand Up @@ -269,7 +272,7 @@ std::shared_ptr<ov::ICompiledModel> Plugin::compile_model(const std::shared_ptr<
std::map<ov::Output<ov::Node>, ov::PartialShape> partial_shapes;
for (auto& input : inputs) {
auto input_shape = input.get_shape();
if (batched_inputs.find(ov::op::util::get_ie_output_name(input)) != batched_inputs.end()) {
if (batched_inputs.find(input.get_node_shared_ptr()->get_friendly_name()) != batched_inputs.end()) {
input_shape[0] = meta_device.device_batch_size;
}
partial_shapes.insert({input, ov::PartialShape(input_shape)});
Expand Down
4 changes: 2 additions & 2 deletions src/plugins/auto_batch/src/sync_infer_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ size_t SyncInferRequest::get_batch_size() const {
void SyncInferRequest::share_tensors_with_batched_req(const std::set<std::string>& batched_inputs,
const std::set<std::string>& batched_outputs) {
for (const auto& it : get_inputs()) {
auto name = ov::op::util::get_ie_output_name(it);
auto name = it.get_node_shared_ptr()->get_friendly_name();
ov::SoPtr<ov::ITensor> res;
auto batched_tensor = m_batched_request_wrapper->_infer_request_batched->get_tensor(it);
if (!batched_tensor._so)
Expand All @@ -66,7 +66,7 @@ void SyncInferRequest::share_tensors_with_batched_req(const std::set<std::string
}

for (const auto& it : get_outputs()) {
auto name = ov::op::util::get_ie_output_name(it.get_node_shared_ptr()->input_value(0));
auto name = it.get_node_shared_ptr()->input_value(0).get_node_shared_ptr()->get_friendly_name();
ov::SoPtr<ov::ITensor> res;
auto batched_tensor = m_batched_request_wrapper->_infer_request_batched->get_tensor(it);
if (!batched_tensor._so)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class AutoBatchAsyncInferRequestTest : public ::testing::TestWithParam<AutoBatch
std::map<ov::Output<ov::Node>, ov::PartialShape> partial_shapes;
for (auto& input : inputs) {
auto input_shape = input.get_shape();
if (m_batched_inputs.find(ov::op::util::get_ie_output_name(input)) != m_batched_inputs.end()) {
if (m_batched_inputs.find(input.get_node_shared_ptr()->get_friendly_name()) != m_batched_inputs.end()) {
input_shape[0] = m_batch_size;
}
partial_shapes.insert({input, ov::PartialShape(input_shape)});
Expand Down Expand Up @@ -229,14 +229,15 @@ class AutoBatchAsyncInferRequestTest : public ::testing::TestWithParam<AutoBatch
void prepare_input(std::shared_ptr<ov::Model>& model, int batch_size) {
const auto& params = model->get_parameters();
for (size_t i = 0; i < params.size(); i++) {
m_batched_inputs.insert(ov::op::util::get_ie_output_name(params[i]->output(0)));
m_batched_inputs.insert(params[i]->output(0).get_node_shared_ptr()->get_friendly_name());
}
const auto& results = model->get_results();
for (size_t i = 0; i < results.size(); i++) {
const auto& output = results[i];
const auto& node = output->input_value(0);
m_batched_outputs.insert(
ov::op::util::get_ie_output_name(ov::Output<const ov::Node>(node.get_node(), node.get_index())));
m_batched_outputs.insert(ov::Output<const ov::Node>(node.get_node(), node.get_index())
.get_node_shared_ptr()
->get_friendly_name());
}
}
};
Expand Down
7 changes: 4 additions & 3 deletions src/plugins/auto_batch/tests/unit/sync_infer_request_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,15 @@ class AutoBatchRequestTest : public ::testing::TestWithParam<AutoBatchRequestTes
void prepare_input(std::shared_ptr<ov::Model>& model, int batch_size) {
const auto& params = model->get_parameters();
for (size_t i = 0; i < params.size(); i++) {
m_batched_inputs.insert(ov::op::util::get_ie_output_name(params[i]->output(0)));
m_batched_inputs.insert(params[i]->output(0).get_node_shared_ptr()->get_friendly_name());
}
const auto& results = model->get_results();
for (size_t i = 0; i < results.size(); i++) {
const auto& output = results[i];
const auto& node = output->input_value(0);
m_batched_outputs.insert(
ov::op::util::get_ie_output_name(ov::Output<const ov::Node>(node.get_node(), node.get_index())));
m_batched_outputs.insert(ov::Output<const ov::Node>(node.get_node(), node.get_index())
.get_node_shared_ptr()
->get_friendly_name());
}
}
};
Expand Down

0 comments on commit f3f81c4

Please sign in to comment.