Skip to content

Commit

Permalink
Add suggested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
MirceaDan99 committed Dec 12, 2024
1 parent 60a36c4 commit b9f9c3d
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 86 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <vector>

#include "openvino/runtime/shared_buffer.hpp"

namespace intel_npu {

class BlobContainer {
public:
virtual void* get_ptr() = 0;

virtual size_t size() const = 0;

virtual bool release_from_memory() = 0;

virtual ~BlobContainer() = default;
};

class BlobContainerVector : public BlobContainer {
public:
BlobContainerVector(std::vector<uint8_t> blob) : _ownershipBlob(std::move(blob)) {}

void* get_ptr() override {
return reinterpret_cast<void*>(_ownershipBlob.data());
}

size_t size() const override {
return _ownershipBlob.size();
}

bool release_from_memory() override {
_ownershipBlob.clear();
_ownershipBlob.shrink_to_fit();
return true;
}

private:
std::vector<uint8_t> _ownershipBlob;
};

class BlobContainerAlignedBuffer : public BlobContainer {
public:
BlobContainerAlignedBuffer(const std::shared_ptr<ov::AlignedBuffer>& blobSO, size_t offset)
: _ownershipBlob(blobSO),
_offset(offset) {}

void* get_ptr() override {
return _ownershipBlob->get_ptr(_offset);
}

size_t size() const override {
return _ownershipBlob->size();
}

bool release_from_memory() override {
return false;
}

private:
std::shared_ptr<ov::AlignedBuffer> _ownershipBlob;
size_t _offset;
};

} // namespace intel_npu
Original file line number Diff line number Diff line change
Expand Up @@ -8,83 +8,21 @@
#include <mutex>
#include <vector>

#include "intel_npu/common/blob_container.hpp"
#include "intel_npu/network_metadata.hpp"
#include "intel_npu/utils/zero/zero_init.hpp"
#include "intel_npu/utils/zero/zero_utils.hpp"
#include "intel_npu/utils/zero/zero_wrappers.hpp"
#include "openvino/runtime/profiling_info.hpp"
#include "openvino/runtime/shared_buffer.hpp"

namespace intel_npu {

class BlobContainer {
public:
virtual void* get_ptr() {
OPENVINO_THROW("const BlobContainer::get_ptr() method is not implemented!");
}

virtual size_t size() const {
OPENVINO_THROW("BlobContainer::size() method is not implemented!");
}

virtual bool release_from_memory() {
OPENVINO_THROW("BlobContainer::release_from_memory() method is not implemented!");
}

virtual ~BlobContainer() = default;
};

class BlobContainerVector : public BlobContainer {
public:
BlobContainerVector(std::vector<uint8_t> blob) : _ownershipBlob(std::move(blob)) {}

void* get_ptr() override {
return reinterpret_cast<void*>(_ownershipBlob.data());
}

size_t size() const override {
return _ownershipBlob.size();
}

bool release_from_memory() override {
_ownershipBlob.clear();
_ownershipBlob.shrink_to_fit();
return true;
}

private:
std::vector<uint8_t> _ownershipBlob;
};

class BlobContainerAlignedBuffer : public BlobContainer {
public:
BlobContainerAlignedBuffer(const std::shared_ptr<ov::AlignedBuffer>& blobSO, size_t offset)
: _ownershipBlob(blobSO),
_offset(offset) {}

void* get_ptr() override {
return _ownershipBlob->get_ptr(_offset);
}

size_t size() const override {
return _ownershipBlob->size();
}

bool release_from_memory() override {
return false;
}

private:
std::shared_ptr<ov::AlignedBuffer> _ownershipBlob;
size_t _offset;
};

class IGraph : public std::enable_shared_from_this<IGraph> {
public:
IGraph(ze_graph_handle_t handle,
NetworkMetadata metadata,
const Config& config,
std::optional<std::unique_ptr<BlobContainer>> blobPtr);
std::unique_ptr<BlobContainer> blobPtr);

virtual void export_blob(std::ostream& stream) const = 0;

Expand Down
9 changes: 3 additions & 6 deletions src/plugins/intel_npu/src/common/src/igraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@ namespace intel_npu {
IGraph::IGraph(ze_graph_handle_t handle,
NetworkMetadata metadata,
const Config& config,
std::optional<std::unique_ptr<BlobContainer>> blobPtr)
std::unique_ptr<BlobContainer> blobPtr)
: _handle(handle),
_metadata(std::move(metadata)),
_logger("IGraph", config.get<LOG_LEVEL>()) {
if (blobPtr.has_value()) {
_blobPtr = std::move(*blobPtr);
}
}
_logger("IGraph", config.get<LOG_LEVEL>()),
_blobPtr(std::move(blobPtr)) {}

const NetworkMetadata& IGraph::get_metadata() const {
return _metadata;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class DriverGraph final : public IGraph {
ze_graph_handle_t graphHandle,
NetworkMetadata metadata,
const Config& config,
std::optional<std::unique_ptr<BlobContainer>> blob);
std::unique_ptr<BlobContainer> blob);

void export_blob(std::ostream& stream) const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ZeGraphExtWrappers {
const std::string& buildFlags,
const uint32_t& flags) const;

ze_graph_handle_t getGraphHandle(const uint8_t* data, size_t size) const;
ze_graph_handle_t getGraphHandle(const uint8_t& data, size_t size) const;

NetworkMetadata getNetworkMeta(ze_graph_handle_t graphHandle) const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::compile(const std::shared_ptr<con
graphHandle,
std::move(networkMeta),
config,
std::nullopt);
nullptr);
}

std::shared_ptr<IGraph> DriverCompilerAdapter::parse(std::unique_ptr<BlobContainer> blobPtr,
Expand All @@ -209,7 +209,7 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::parse(std::unique_ptr<BlobContain

_logger.debug("parse start");
ze_graph_handle_t graphHandle =
_zeGraphExt->getGraphHandle(reinterpret_cast<const uint8_t*>(blobPtr->get_ptr()), blobPtr->size());
_zeGraphExt->getGraphHandle(*reinterpret_cast<const uint8_t*>(blobPtr->get_ptr()), blobPtr->size());
_logger.debug("parse end");

OV_ITT_TASK_NEXT(PARSE_BLOB, "getNetworkMeta");
Expand All @@ -220,7 +220,7 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::parse(std::unique_ptr<BlobContain
graphHandle,
std::move(networkMeta),
config,
std::optional<std::unique_ptr<BlobContainer>>(std::move(blobPtr)));
std::move(blobPtr));
}

ov::SupportedOpsMap DriverCompilerAdapter::query(const std::shared_ptr<const ov::Model>& model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ DriverGraph::DriverGraph(const std::shared_ptr<ZeGraphExtWrappers>& zeGraphExt,
ze_graph_handle_t graphHandle,
NetworkMetadata metadata,
const Config& config,
std::optional<std::unique_ptr<BlobContainer>> blobPtr)
std::unique_ptr<BlobContainer> blobPtr)
: IGraph(graphHandle, std::move(metadata), config, std::move(blobPtr)),
_zeGraphExt(zeGraphExt),
_zeroInitStruct(zeroInitStruct),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ std::shared_ptr<IGraph> PluginCompilerAdapter::compile(const std::shared_ptr<con
// Depending on the config, we may get an error when trying to get the graph handle from the compiled network
try {
graphHandle =
_zeGraphExt->getGraphHandle(reinterpret_cast<const uint8_t*>(blobPtr->get_ptr()), blobPtr->size());
_zeGraphExt->getGraphHandle(*reinterpret_cast<const uint8_t*>(blobPtr->get_ptr()), blobPtr->size());
} catch (...) {
_logger.info("Failed to obtain the level zero graph handle. Inference requests for this model are not "
"allowed. Only exports are available");
Expand Down Expand Up @@ -122,7 +122,7 @@ std::shared_ptr<IGraph> PluginCompilerAdapter::parse(std::unique_ptr<BlobContain

if (_zeGraphExt) {
graphHandle =
_zeGraphExt->getGraphHandle(reinterpret_cast<const uint8_t*>(blobPtr->get_ptr()), blobPtr->size());
_zeGraphExt->getGraphHandle(*reinterpret_cast<const uint8_t*>(blobPtr->get_ptr()), blobPtr->size());
}

return std::make_shared<PluginGraph>(_zeGraphExt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ PluginGraph::PluginGraph(const std::shared_ptr<ZeGraphExtWrappers>& zeGraphExt,
NetworkMetadata metadata,
std::unique_ptr<BlobContainer> blobPtr,
const Config& config)
: IGraph(graphHandle,
std::move(metadata),
config,
std::optional<std::unique_ptr<BlobContainer>>(std::move(blobPtr))),
: IGraph(graphHandle, std::move(metadata), config, std::move(blobPtr)),
_zeGraphExt(zeGraphExt),
_zeroInitStruct(zeroInitStruct),
_compiler(compiler),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,15 +363,15 @@ ze_graph_handle_t ZeGraphExtWrappers::getGraphHandle(std::pair<size_t, std::shar
return graphHandle;
}

ze_graph_handle_t ZeGraphExtWrappers::getGraphHandle(const uint8_t* blobData, size_t blobSize) const {
ze_graph_handle_t ZeGraphExtWrappers::getGraphHandle(const uint8_t& blobData, size_t blobSize) const {
ze_graph_handle_t graphHandle;

if (blobData == nullptr || blobSize == 0) {
if (blobSize == 0) {
OPENVINO_THROW("Empty blob");
}

ze_graph_desc_t desc =
{ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES, nullptr, ZE_GRAPH_FORMAT_NATIVE, blobSize, blobData, nullptr};
{ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES, nullptr, ZE_GRAPH_FORMAT_NATIVE, blobSize, &blobData, nullptr};

_logger.debug("getGraphHandle - perform pfnCreate");
auto result = _zeroInitStruct->getGraphDdiTable().pfnCreate(_zeroInitStruct->getContext(),
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_npu/src/plugin/src/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ std::shared_ptr<ov::ICompiledModel> Plugin::import_model(std::istream& stream, c
_properties.erase(ov::internal::cached_model_buffer.name());
}

const std::map<std::string, std::string> propertiesMap = any_copy(_properties);
const auto propertiesMap = any_copy(_properties);
auto localConfig = merge_configs(_globalConfig, propertiesMap, OptionMode::RunTime);
_logger.setLevel(localConfig.get<LOG_LEVEL>());
const auto platform = _backends->getCompilationPlatform(localConfig.get<PLATFORM>(), localConfig.get<DEVICE_ID>());
Expand Down

0 comments on commit b9f9c3d

Please sign in to comment.