Skip to content

Commit

Permalink
Implement custom std::ostringstream::str() function
Browse files Browse the repository at this point in the history
  • Loading branch information
MirceaDan99 committed Sep 12, 2024
1 parent a1f4975 commit 64926cf
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 14 deletions.
43 changes: 29 additions & 14 deletions src/plugins/intel_npu/src/compiler/src/zero_compiler_in_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// SPDX-License-Identifier: Apache-2.0
//

#include <intel_npu/utils/customstringbuf/customstringbuf.hpp>

#include "zero_compiler_in_driver.hpp"

#include <fstream>
Expand Down Expand Up @@ -362,24 +364,20 @@ void LevelZeroCompilerInDriver<TableExtension>::release(std::shared_ptr<const Ne
_logger.debug("release completed");
}

class CustomStringBuf : public std::stringbuf {
public:
CustomStringBuf(std::vector<uint8_t>& blob) {
this->setp(reinterpret_cast<char*>(blob.data()), reinterpret_cast<char*>(blob.data()) + blob.size());
this->setg(reinterpret_cast<char*>(blob.data()), reinterpret_cast<char*>(blob.data()) + 1, reinterpret_cast<char*>(blob.data()) + blob.size());
this->pbump(blob.size());
}
};

template <typename TableExtension>
void LevelZeroCompilerInDriver<TableExtension>::getCompiledNetwork(
std::shared_ptr<const NetworkDescription> networkDescription, std::ostream& stream) {

std::ostringstream* oStringStreamPtr = dynamic_cast<std::ostringstream*>(&stream);
uint8_t* blobData;
std::string blobStr;

if (networkDescription->metadata.graphHandle != nullptr && networkDescription->compiledNetwork.size() == 0) {
_logger.info("LevelZeroCompilerInDriver getCompiledNetwork get blob from graphHandle");
ze_graph_handle_t graphHandle = static_cast<ze_graph_handle_t>(networkDescription->metadata.graphHandle);

// Get blob size first
size_t blobSize = -1;
auto result = _graphDdiTableExt->pfnGetNativeBinary(graphHandle, &blobSize, nullptr);

OPENVINO_ASSERT(result == ZE_RESULT_SUCCESS,
Expand All @@ -392,9 +390,16 @@ void LevelZeroCompilerInDriver<TableExtension>::getCompiledNetwork(
". ",
getLatestBuildError());

std::const_pointer_cast<NetworkDescription>(networkDescription)->compiledNetwork.resize(blobSize);
if (oStringStreamPtr != nullptr) {
blobStr.resize(blobSize);
blobData = reinterpret_cast<uint8_t*>(&blobStr[0]);
} else {
std::const_pointer_cast<NetworkDescription>(networkDescription)->compiledNetwork.resize(blobSize);
blobData = std::const_pointer_cast<NetworkDescription>(networkDescription)->compiledNetwork.data();
}

// Get blob data
result = _graphDdiTableExt->pfnGetNativeBinary(graphHandle, &blobSize, std::const_pointer_cast<NetworkDescription>(networkDescription)->compiledNetwork.data());
result = _graphDdiTableExt->pfnGetNativeBinary(graphHandle, &blobSize, blobData);

OPENVINO_ASSERT(result == ZE_RESULT_SUCCESS,
"Failed to compile network. L0 pfnGetNativeBinary get blob data",
Expand All @@ -410,10 +415,20 @@ void LevelZeroCompilerInDriver<TableExtension>::getCompiledNetwork(
_logger.info("return the blob from network description");
}

std::ostringstream* oStringStreamPtr = dynamic_cast<std::ostringstream*>(&stream);
if (oStringStreamPtr != nullptr) {
CustomStringBuf customStringBuf(std::const_pointer_cast<NetworkDescription>(networkDescription)->compiledNetwork);
oStringStreamPtr->rdbuf()->swap(customStringBuf);
CustomStringBuf* customStringBufPtr = new CustomStringBuf(std::move(blobStr));
stream.rdbuf(customStringBufPtr);
oStringStreamPtr->rdbuf()->swap(*customStringBufPtr);

int index = stream.xalloc();
stream.pword(index) = customStringBufPtr;
stream.register_callback([](std::ios_base::event evt, std::ios_base& str, int idx){
if (evt == std::ios_base::erase_event)
{
CustomStringBuf* ptr = static_cast<CustomStringBuf*>(str.pword(idx));
delete ptr;
}
}, index);
} else {
stream.write(reinterpret_cast<const char*>(networkDescription->compiledNetwork.data()), networkDescription->compiledNetwork.size());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once

#include <sstream>

namespace intel_npu {

class CustomStringBuf : public std::stringbuf {
public:
CustomStringBuf(std::string&& str) {
this->_str = std::move(str);
this->setp(&this->_str[0], &this->_str[0] + this->_str.size());
this->setg(&this->_str[0], &this->_str[1], &this->_str[0] + this->_str.size());
this->pbump(this->_str.size());
}

std::string&& str() {
return std::move(_str);
}
private:
std::string _str;
};

} // namespace intel_npu

namespace std {

template<>
ostringstream::_Mystr ostringstream::str() const {
intel_npu::CustomStringBuf* customStringBufPtr = dynamic_cast<intel_npu::CustomStringBuf*>(ostream::rdbuf());
if (customStringBufPtr != nullptr) {
return customStringBufPtr->str();
} else {
return this->rdbuf()->str();
}
}

} // namespace std
1 change: 1 addition & 0 deletions src/plugins/intel_npu/tools/compile_tool/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ ov_add_target(ADD_CPPLINT
TYPE EXECUTABLE
NAME ${TARGET_NAME}
ROOT ${CMAKE_CURRENT_SOURCE_DIR}
INCLUDES "${NPU_PLUGIN_SOURCE_DIR}/src/utils/include"
LINK_LIBRARIES
PRIVATE
openvino::runtime
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_npu/tools/compile_tool/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <unordered_map>
#include <vector>

#include <intel_npu/utils/customstringbuf/customstringbuf.hpp>

#include <gflags/gflags.h>

#include "openvino/core/partial_shape.hpp"
Expand Down

0 comments on commit 64926cf

Please sign in to comment.