Skip to content

Commit

Permalink
Add suggested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
MirceaDan99 committed Dec 12, 2024
1 parent 60a36c4 commit 0c2e868
Show file tree
Hide file tree
Showing 10 changed files with 18 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,83 +8,21 @@
#include <mutex>
#include <vector>

#include "intel_npu/common/blob_container.hpp"
#include "intel_npu/network_metadata.hpp"
#include "intel_npu/utils/zero/zero_init.hpp"
#include "intel_npu/utils/zero/zero_utils.hpp"
#include "intel_npu/utils/zero/zero_wrappers.hpp"
#include "openvino/runtime/profiling_info.hpp"
#include "openvino/runtime/shared_buffer.hpp"

namespace intel_npu {

class BlobContainer {
public:
virtual void* get_ptr() {
OPENVINO_THROW("const BlobContainer::get_ptr() method is not implemented!");
}

virtual size_t size() const {
OPENVINO_THROW("BlobContainer::size() method is not implemented!");
}

virtual bool release_from_memory() {
OPENVINO_THROW("BlobContainer::release_from_memory() method is not implemented!");
}

virtual ~BlobContainer() = default;
};

class BlobContainerVector : public BlobContainer {
public:
BlobContainerVector(std::vector<uint8_t> blob) : _ownershipBlob(std::move(blob)) {}

void* get_ptr() override {
return reinterpret_cast<void*>(_ownershipBlob.data());
}

size_t size() const override {
return _ownershipBlob.size();
}

bool release_from_memory() override {
_ownershipBlob.clear();
_ownershipBlob.shrink_to_fit();
return true;
}

private:
std::vector<uint8_t> _ownershipBlob;
};

class BlobContainerAlignedBuffer : public BlobContainer {
public:
BlobContainerAlignedBuffer(const std::shared_ptr<ov::AlignedBuffer>& blobSO, size_t offset)
: _ownershipBlob(blobSO),
_offset(offset) {}

void* get_ptr() override {
return _ownershipBlob->get_ptr(_offset);
}

size_t size() const override {
return _ownershipBlob->size();
}

bool release_from_memory() override {
return false;
}

private:
std::shared_ptr<ov::AlignedBuffer> _ownershipBlob;
size_t _offset;
};

class IGraph : public std::enable_shared_from_this<IGraph> {
public:
IGraph(ze_graph_handle_t handle,
NetworkMetadata metadata,
const Config& config,
std::optional<std::unique_ptr<BlobContainer>> blobPtr);
std::unique_ptr<BlobContainer> blobPtr);

virtual void export_blob(std::ostream& stream) const = 0;

Expand Down
9 changes: 3 additions & 6 deletions src/plugins/intel_npu/src/common/src/igraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@ namespace intel_npu {
IGraph::IGraph(ze_graph_handle_t handle,
NetworkMetadata metadata,
const Config& config,
std::optional<std::unique_ptr<BlobContainer>> blobPtr)
std::unique_ptr<BlobContainer> blobPtr)
: _handle(handle),
_metadata(std::move(metadata)),
_logger("IGraph", config.get<LOG_LEVEL>()) {
if (blobPtr.has_value()) {
_blobPtr = std::move(*blobPtr);
}
}
_logger("IGraph", config.get<LOG_LEVEL>()),
_blobPtr(std::move(blobPtr)) {}

const NetworkMetadata& IGraph::get_metadata() const {
return _metadata;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class DriverGraph final : public IGraph {
ze_graph_handle_t graphHandle,
NetworkMetadata metadata,
const Config& config,
std::optional<std::unique_ptr<BlobContainer>> blob);
std::unique_ptr<BlobContainer> blob);

void export_blob(std::ostream& stream) const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ZeGraphExtWrappers {
const std::string& buildFlags,
const uint32_t& flags) const;

ze_graph_handle_t getGraphHandle(const uint8_t* data, size_t size) const;
ze_graph_handle_t getGraphHandle(const uint8_t& data, size_t size) const;

NetworkMetadata getNetworkMeta(ze_graph_handle_t graphHandle) const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::compile(const std::shared_ptr<con
graphHandle,
std::move(networkMeta),
config,
std::nullopt);
nullptr);
}

std::shared_ptr<IGraph> DriverCompilerAdapter::parse(std::unique_ptr<BlobContainer> blobPtr,
Expand All @@ -209,7 +209,7 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::parse(std::unique_ptr<BlobContain

_logger.debug("parse start");
ze_graph_handle_t graphHandle =
_zeGraphExt->getGraphHandle(reinterpret_cast<const uint8_t*>(blobPtr->get_ptr()), blobPtr->size());
_zeGraphExt->getGraphHandle(*reinterpret_cast<const uint8_t*>(blobPtr->get_ptr()), blobPtr->size());
_logger.debug("parse end");

OV_ITT_TASK_NEXT(PARSE_BLOB, "getNetworkMeta");
Expand All @@ -220,7 +220,7 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::parse(std::unique_ptr<BlobContain
graphHandle,
std::move(networkMeta),
config,
std::optional<std::unique_ptr<BlobContainer>>(std::move(blobPtr)));
std::move(blobPtr));
}

ov::SupportedOpsMap DriverCompilerAdapter::query(const std::shared_ptr<const ov::Model>& model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ DriverGraph::DriverGraph(const std::shared_ptr<ZeGraphExtWrappers>& zeGraphExt,
ze_graph_handle_t graphHandle,
NetworkMetadata metadata,
const Config& config,
std::optional<std::unique_ptr<BlobContainer>> blobPtr)
std::unique_ptr<BlobContainer> blobPtr)
: IGraph(graphHandle, std::move(metadata), config, std::move(blobPtr)),
_zeGraphExt(zeGraphExt),
_zeroInitStruct(zeroInitStruct),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ std::shared_ptr<IGraph> PluginCompilerAdapter::compile(const std::shared_ptr<con
// Depending on the config, we may get an error when trying to get the graph handle from the compiled network
try {
graphHandle =
_zeGraphExt->getGraphHandle(reinterpret_cast<const uint8_t*>(blobPtr->get_ptr()), blobPtr->size());
_zeGraphExt->getGraphHandle(*reinterpret_cast<const uint8_t*>(blobPtr->get_ptr()), blobPtr->size());
} catch (...) {
_logger.info("Failed to obtain the level zero graph handle. Inference requests for this model are not "
"allowed. Only exports are available");
Expand Down Expand Up @@ -122,7 +122,7 @@ std::shared_ptr<IGraph> PluginCompilerAdapter::parse(std::unique_ptr<BlobContain

if (_zeGraphExt) {
graphHandle =
_zeGraphExt->getGraphHandle(reinterpret_cast<const uint8_t*>(blobPtr->get_ptr()), blobPtr->size());
_zeGraphExt->getGraphHandle(*reinterpret_cast<const uint8_t*>(blobPtr->get_ptr()), blobPtr->size());
}

return std::make_shared<PluginGraph>(_zeGraphExt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ PluginGraph::PluginGraph(const std::shared_ptr<ZeGraphExtWrappers>& zeGraphExt,
NetworkMetadata metadata,
std::unique_ptr<BlobContainer> blobPtr,
const Config& config)
: IGraph(graphHandle,
std::move(metadata),
config,
std::optional<std::unique_ptr<BlobContainer>>(std::move(blobPtr))),
: IGraph(graphHandle, std::move(metadata), config, std::move(blobPtr)),
_zeGraphExt(zeGraphExt),
_zeroInitStruct(zeroInitStruct),
_compiler(compiler),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,15 +363,15 @@ ze_graph_handle_t ZeGraphExtWrappers::getGraphHandle(std::pair<size_t, std::shar
return graphHandle;
}

ze_graph_handle_t ZeGraphExtWrappers::getGraphHandle(const uint8_t* blobData, size_t blobSize) const {
ze_graph_handle_t ZeGraphExtWrappers::getGraphHandle(const uint8_t& blobData, size_t blobSize) const {
ze_graph_handle_t graphHandle;

if (blobData == nullptr || blobSize == 0) {
if (blobSize == 0) {
OPENVINO_THROW("Empty blob");
}

ze_graph_desc_t desc =
{ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES, nullptr, ZE_GRAPH_FORMAT_NATIVE, blobSize, blobData, nullptr};
{ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES, nullptr, ZE_GRAPH_FORMAT_NATIVE, blobSize, &blobData, nullptr};

_logger.debug("getGraphHandle - perform pfnCreate");
auto result = _zeroInitStruct->getGraphDdiTable().pfnCreate(_zeroInitStruct->getContext(),
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_npu/src/plugin/src/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ std::shared_ptr<ov::ICompiledModel> Plugin::import_model(std::istream& stream, c
_properties.erase(ov::internal::cached_model_buffer.name());
}

const std::map<std::string, std::string> propertiesMap = any_copy(_properties);
const auto propertiesMap = any_copy(_properties);
auto localConfig = merge_configs(_globalConfig, propertiesMap, OptionMode::RunTime);
_logger.setLevel(localConfig.get<LOG_LEVEL>());
const auto platform = _backends->getCompilationPlatform(localConfig.get<PLATFORM>(), localConfig.get<DEVICE_ID>());
Expand Down

0 comments on commit 0c2e868

Please sign in to comment.