Skip to content

Commit

Permalink
Fix the issue caused by the changed tensor names after complex transf…
Browse files Browse the repository at this point in the history
…ormation (openvinotoolkit#23081)

### Details:
- Update logic of port name checking to check if queried port tensor
names is subset instead of op "==".
 - *...*

### Tickets:
 - CVS-122932
  • Loading branch information
yangwang201911 authored Mar 7, 2024
1 parent cc12899 commit 4f8f573
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
14 changes: 12 additions & 2 deletions src/inference/src/dev/isync_infer_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,16 @@ const std::shared_ptr<const ov::ICompiledModel>& ov::ISyncInferRequest::get_comp
}

ov::ISyncInferRequest::FoundPort ov::ISyncInferRequest::find_port(const ov::Output<const ov::Node>& port) const {
// check if the tensor names of target port is a subset of source port's tensor names
auto check_tensor_names = [](const std::unordered_set<std::string>& source,
const std::unordered_set<std::string>& target) {
for (auto const& name : target) {
if (source.find(name) == source.end())
return false;
}
return true;
};

// This function is hotspot, need optimization.
auto check_nodes = [](const ov::Node* node1, const ov::Node* node2) {
return node1 == node2 ||
Expand All @@ -143,8 +153,8 @@ ov::ISyncInferRequest::FoundPort ov::ISyncInferRequest::find_port(const ov::Outp
ov::ISyncInferRequest::FoundPort::Type type = ov::ISyncInferRequest::FoundPort::Type::INPUT;
for (const auto& ports : {get_inputs(), get_outputs()}) {
for (size_t i = 0; i < ports.size(); i++) {
if (ports[i].get_index() == port.get_index() && ports[i].get_names() == port.get_names() &&
check_nodes(ports[i].get_node(), port.get_node())) {
if (ports[i].get_index() == port.get_index() && check_nodes(ports[i].get_node(), port.get_node()) &&
check_tensor_names(ports[i].get_names(), port.get_names())) {
std::lock_guard<std::mutex> lock(m_cache_mutex);
m_cached_ports[port_hash] = {i, type};
return m_cached_ports[port_hash];
Expand Down
38 changes: 37 additions & 1 deletion src/inference/tests/unit/iplugin_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class IPluginTest : public ::testing::Test {
std::shared_ptr<ov::Model> create_model() {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{1, 3, 2, 2});
param->set_friendly_name("Param");
param->output(0).set_names({"param"});
param->output(0).set_names({"param", "name1", "name2", "name3"});

auto relu = std::make_shared<ov::op::v0::Relu>(param);
relu->set_friendly_name("ReLU");
Expand Down Expand Up @@ -81,6 +81,42 @@ MATCHER_P(blob_in_map_pointer_is_same, ref_blob, "") {
return reinterpret_cast<float*>(arg.begin()->second->buffer()) == reinterpret_cast<float*>(ref_blob->buffer());
}

TEST_F(IPluginTest, SetTensorWithIncorrectPortNames) {
ov::SoPtr<ov::ITensor> tensor = ov::make_tensor(ov::element::f32, {1, 3, 2, 2});
auto updated_param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{1, 3, 2, 2});
updated_param->set_friendly_name("Param");

updated_param->output(0).set_names({"new_name"});
EXPECT_THROW(mock_infer_request->set_tensor(updated_param->output(0), tensor), ov::Exception);

updated_param->output(0).set_names({"param", "new_name"});
EXPECT_THROW(mock_infer_request->set_tensor(updated_param->output(0), tensor), ov::Exception);

updated_param->output(0).set_names({"new_name", "name2"});
EXPECT_THROW(mock_infer_request->set_tensor(updated_param->output(0), tensor), ov::Exception);
}

TEST_F(IPluginTest, SetTensorWithCorrectPortNames) {
ov::SoPtr<ov::ITensor> tensor = ov::make_tensor(ov::element::f32, {1, 3, 2, 2});
auto updated_param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{1, 3, 2, 2});
updated_param->set_friendly_name("Param");

updated_param->output(0).set_names({"param"});
EXPECT_NO_THROW(mock_infer_request->set_tensor(updated_param->output(0), tensor));

updated_param->output(0).set_names({"name1", "param"});
EXPECT_NO_THROW(mock_infer_request->set_tensor(updated_param->output(0), tensor));

updated_param->output(0).set_names({"name1", "name2"});
EXPECT_NO_THROW(mock_infer_request->set_tensor(updated_param->output(0), tensor));

updated_param->output(0).set_names({"param", "name1", "name2"});
EXPECT_NO_THROW(mock_infer_request->set_tensor(updated_param->output(0), tensor));

updated_param->output(0).set_names({"param", "name1", "name2", "name3"});
EXPECT_NO_THROW(mock_infer_request->set_tensor(updated_param->output(0), tensor));
}

TEST_F(IPluginTest, failToSetTensorWithIncorrectPort) {
auto incorrect_param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{1, 2});
ov::SoPtr<ov::ITensor> tensor = ov::make_tensor(ov::element::f32, {1, 1, 1, 1});
Expand Down

0 comments on commit 4f8f573

Please sign in to comment.