Skip to content

Commit

Permalink
Replacing NetworkMetadata::findByName with lambda functions
Browse files Browse the repository at this point in the history
  • Loading branch information
razvanapetroaie committed Jul 31, 2024
1 parent 92e79eb commit d9334df
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,6 @@ struct NetworkMetadata final {
*/
void bindRelatedDescriptors();

private:
std::optional<size_t> findByName(const std::vector<IODescriptor>& descriptors, const std::string_view targetName);

}; // namespace intel_npu

/**
Expand Down
50 changes: 24 additions & 26 deletions src/plugins/intel_npu/src/al/src/icompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,6 @@

namespace intel_npu {

std::optional<size_t> NetworkMetadata::findByName(const std::vector<IODescriptor>& descriptors,
const std::string_view targetName) {
for (size_t descriptorIndex = 0; descriptorIndex < descriptors.size(); ++descriptorIndex) {
if (descriptors.at(descriptorIndex).nameFromCompiler == targetName) {
return descriptorIndex;
}
}

return std::nullopt;
}

void NetworkMetadata::bindRelatedDescriptors() {
size_t ioIndex = 0;

Expand All @@ -27,18 +16,24 @@ void NetworkMetadata::bindRelatedDescriptors() {
}

if (input.isStateInput) {
const std::optional<size_t> relatedDescriptorIndex = findByName(outputs, input.nameFromCompiler);

if (relatedDescriptorIndex.has_value()) {
input.relatedDescriptorIndex = relatedDescriptorIndex;
outputs.at(*relatedDescriptorIndex).relatedDescriptorIndex = std::optional(ioIndex);
const auto relatedDescriptorIterator =
std::find_if(outputs.begin(), outputs.end(), [&](const IODescriptor& output) {
return output.isStateOutput && (output.nameFromCompiler == input.nameFromCompiler);
});

if (relatedDescriptorIterator != outputs.end()) {
input.relatedDescriptorIndex = std::distance(outputs.begin(), relatedDescriptorIterator);
outputs.at(*input.relatedDescriptorIndex).relatedDescriptorIndex = ioIndex;
}
} else if (input.isShapeTensor) {
const std::optional<size_t> relatedDescriptorIndex = findByName(inputs, input.nameFromCompiler);

if (relatedDescriptorIndex.has_value() && *relatedDescriptorIndex != ioIndex) {
input.relatedDescriptorIndex = relatedDescriptorIndex;
inputs.at(*relatedDescriptorIndex).relatedDescriptorIndex = std::optional(ioIndex);
const auto relatedDescriptorIterator =
std::find_if(inputs.begin(), inputs.end(), [&](const IODescriptor& candidate) {
return !candidate.isShapeTensor && (candidate.nameFromCompiler == input.nameFromCompiler);
});

if (relatedDescriptorIterator != inputs.end()) {
input.relatedDescriptorIndex = std::distance(inputs.begin(), relatedDescriptorIterator);
inputs.at(*input.relatedDescriptorIndex).relatedDescriptorIndex = ioIndex;
}
}

Expand All @@ -54,11 +49,14 @@ void NetworkMetadata::bindRelatedDescriptors() {
}

if (output.isShapeTensor) {
const std::optional<size_t> relatedDescriptorIndex = findByName(outputs, output.nameFromCompiler);

if (relatedDescriptorIndex.has_value() && *relatedDescriptorIndex != ioIndex) {
output.relatedDescriptorIndex = relatedDescriptorIndex;
outputs.at(*relatedDescriptorIndex).relatedDescriptorIndex = std::optional(ioIndex);
const auto relatedDescriptorIterator =
std::find_if(outputs.begin(), outputs.end(), [&](const IODescriptor& candidate) {
return !candidate.isShapeTensor && (candidate.nameFromCompiler == output.nameFromCompiler);
});

if (relatedDescriptorIterator != outputs.end()) {
output.relatedDescriptorIndex = std::distance(outputs.begin(), relatedDescriptorIterator);
outputs.at(*output.relatedDescriptorIndex).relatedDescriptorIndex = ioIndex;
}
}

Expand Down

0 comments on commit d9334df

Please sign in to comment.