diff --git a/src/inference/src/dev/isync_infer_request.cpp b/src/inference/src/dev/isync_infer_request.cpp index 5f386460ec3f46..c137f562f22441 100644 --- a/src/inference/src/dev/isync_infer_request.cpp +++ b/src/inference/src/dev/isync_infer_request.cpp @@ -121,6 +121,16 @@ const std::shared_ptr& ov::ISyncInferRequest::get_comp } ov::ISyncInferRequest::FoundPort ov::ISyncInferRequest::find_port(const ov::Output& port) const { + // check if the tensor names of target port is a subset of source port's tensor names + auto check_tensor_names = [](const std::unordered_set& source, + const std::unordered_set& target) { + for (auto const& name : target) { + if (source.find(name) == source.end()) + return false; + } + return true; + }; + // This function is hotspot, need optimization. auto check_nodes = [](const ov::Node* node1, const ov::Node* node2) { return node1 == node2 || @@ -143,8 +153,8 @@ ov::ISyncInferRequest::FoundPort ov::ISyncInferRequest::find_port(const ov::Outp ov::ISyncInferRequest::FoundPort::Type type = ov::ISyncInferRequest::FoundPort::Type::INPUT; for (const auto& ports : {get_inputs(), get_outputs()}) { for (size_t i = 0; i < ports.size(); i++) { - if (ports[i].get_index() == port.get_index() && ports[i].get_names() == port.get_names() && - check_nodes(ports[i].get_node(), port.get_node())) { + if (ports[i].get_index() == port.get_index() && check_nodes(ports[i].get_node(), port.get_node()) && + check_tensor_names(ports[i].get_names(), port.get_names())) { std::lock_guard lock(m_cache_mutex); m_cached_ports[port_hash] = {i, type}; return m_cached_ports[port_hash]; diff --git a/src/inference/tests/unit/iplugin_test.cpp b/src/inference/tests/unit/iplugin_test.cpp index f8debebdcc232c..1bfad957263035 100644 --- a/src/inference/tests/unit/iplugin_test.cpp +++ b/src/inference/tests/unit/iplugin_test.cpp @@ -28,7 +28,7 @@ class IPluginTest : public ::testing::Test { std::shared_ptr create_model() { auto param = std::make_shared(ov::element::f32, ov::PartialShape{1, 3, 2, 2}); param->set_friendly_name("Param"); - param->output(0).set_names({"param"}); + param->output(0).set_names({"param", "name1", "name2", "name3"}); auto relu = std::make_shared(param); relu->set_friendly_name("ReLU"); @@ -81,6 +81,42 @@ MATCHER_P(blob_in_map_pointer_is_same, ref_blob, "") { return reinterpret_cast(arg.begin()->second->buffer()) == reinterpret_cast(ref_blob->buffer()); } +TEST_F(IPluginTest, SetTensorWithIncorrectPortNames) { + ov::SoPtr tensor = ov::make_tensor(ov::element::f32, {1, 3, 2, 2}); + auto updated_param = std::make_shared(ov::element::f32, ov::PartialShape{1, 3, 2, 2}); + updated_param->set_friendly_name("Param"); + + updated_param->output(0).set_names({"new_name"}); + EXPECT_THROW(mock_infer_request->set_tensor(updated_param->output(0), tensor), ov::Exception); + + updated_param->output(0).set_names({"param", "new_name"}); + EXPECT_THROW(mock_infer_request->set_tensor(updated_param->output(0), tensor), ov::Exception); + + updated_param->output(0).set_names({"new_name", "name2"}); + EXPECT_THROW(mock_infer_request->set_tensor(updated_param->output(0), tensor), ov::Exception); +} + +TEST_F(IPluginTest, SetTensorWithCorrectPortNames) { + ov::SoPtr tensor = ov::make_tensor(ov::element::f32, {1, 3, 2, 2}); + auto updated_param = std::make_shared(ov::element::f32, ov::PartialShape{1, 3, 2, 2}); + updated_param->set_friendly_name("Param"); + + updated_param->output(0).set_names({"param"}); + EXPECT_NO_THROW(mock_infer_request->set_tensor(updated_param->output(0), tensor)); + + updated_param->output(0).set_names({"name1", "param"}); + EXPECT_NO_THROW(mock_infer_request->set_tensor(updated_param->output(0), tensor)); + + updated_param->output(0).set_names({"name1", "name2"}); + EXPECT_NO_THROW(mock_infer_request->set_tensor(updated_param->output(0), tensor)); + + updated_param->output(0).set_names({"param", "name1", "name2"}); + EXPECT_NO_THROW(mock_infer_request->set_tensor(updated_param->output(0), tensor)); + + updated_param->output(0).set_names({"param", "name1", "name2", "name3"}); + EXPECT_NO_THROW(mock_infer_request->set_tensor(updated_param->output(0), tensor)); +} + TEST_F(IPluginTest, failToSetTensorWithIncorrectPort) { auto incorrect_param = std::make_shared(ov::element::f32, ov::PartialShape{1, 2}); ov::SoPtr tensor = ov::make_tensor(ov::element::f32, {1, 1, 1, 1});