Skip to content

Commit

Permalink
Refactor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
MirceaDan99 committed Sep 12, 2024
1 parent cb55b4b commit a1f4975
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ void LevelZeroCompilerAdapter::release(std::shared_ptr<const NetworkDescription>
void LevelZeroCompilerAdapter::getCompiledNetwork(
std::shared_ptr<const NetworkDescription> networkDescription, std::ostream& stream) {
_logger.info("getCompiledNetwork - using adapter to perform getCompiledNetwork(networkDescription)");
apiAdapter->getCompiledNetwork(std::move(networkDescription), stream);
apiAdapter->getCompiledNetwork(networkDescription, stream);
}

} // namespace driverCompilerAdapter
Expand Down
62 changes: 14 additions & 48 deletions src/plugins/intel_npu/src/compiler/src/zero_compiler_in_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,31 +364,17 @@ void LevelZeroCompilerInDriver<TableExtension>::release(std::shared_ptr<const Ne

class CustomStringBuf : public std::stringbuf {
public:
CustomStringBuf(std::string&& moveStr) {
this->obj = std::move(moveStr);
this->setp(&this->obj[0], &this->obj[0] + this->obj.size());
this->setg(&this->obj[0], &this->obj[0] + 1, &this->obj[0] + this->obj.size());
this->pbump(this->obj.size());
}

std::string str() {
return this->obj;
}
private:
std::string obj;
CustomStringBuf(std::vector<uint8_t>& blob) {
this->setp(reinterpret_cast<char*>(blob.data()), reinterpret_cast<char*>(blob.data()) + blob.size());
this->setg(reinterpret_cast<char*>(blob.data()), reinterpret_cast<char*>(blob.data()) + 1, reinterpret_cast<char*>(blob.data()) + blob.size());
this->pbump(blob.size());
}
};

// new method for export_model(ostringstream) ?
template <typename TableExtension>
void LevelZeroCompilerInDriver<TableExtension>::getCompiledNetwork(
std::shared_ptr<const NetworkDescription> networkDescription, std::ostream& stream) {

std::ostringstream* oStringStreamPtr = dynamic_cast<std::ostringstream*>(&stream);
std::vector<uint8_t> blob;
std::string blobStr;
uint8_t* blobData;
size_t blobSize = -1;

if (networkDescription->metadata.graphHandle != nullptr && networkDescription->compiledNetwork.size() == 0) {
_logger.info("LevelZeroCompilerInDriver getCompiledNetwork get blob from graphHandle");
ze_graph_handle_t graphHandle = static_cast<ze_graph_handle_t>(networkDescription->metadata.graphHandle);
Expand All @@ -405,17 +391,10 @@ void LevelZeroCompilerInDriver<TableExtension>::getCompiledNetwork(
uint64_t(result),
". ",
getLatestBuildError());

if (oStringStreamPtr != nullptr) {
blobStr.resize(blobSize);
blobData = reinterpret_cast<uint8_t*>(&blobStr[0]);
} else {
blob.resize(blobSize);
blobData = blob.data();
}


std::const_pointer_cast<NetworkDescription>(networkDescription)->compiledNetwork.resize(blobSize);
// Get blob data
result = _graphDdiTableExt->pfnGetNativeBinary(graphHandle, &blobSize, blobData);
result = _graphDdiTableExt->pfnGetNativeBinary(graphHandle, &blobSize, std::const_pointer_cast<NetworkDescription>(networkDescription)->compiledNetwork.data());

OPENVINO_ASSERT(result == ZE_RESULT_SUCCESS,
"Failed to compile network. L0 pfnGetNativeBinary get blob data",
Expand All @@ -427,29 +406,16 @@ void LevelZeroCompilerInDriver<TableExtension>::getCompiledNetwork(
". ",
getLatestBuildError());
_logger.info("LevelZeroCompilerInDriver getCompiledNetwork returning blob");
// return blob;
} else {
_logger.info("return the blob from network description");
if (oStringStreamPtr != nullptr) {
// some magic trick here so oStringStreamPtr->str(CustomStringThatWontCopyBuffer(networkDescription->compiledNetwork.data(), networkDescription->compiledNetwork.size());
} else {
blobData = std::const_pointer_cast<NetworkDescription>(networkDescription)->compiledNetwork.data();
blobSize = networkDescription->compiledNetwork.size();
}
// return networkDescription->compiledNetwork;
}
if (oStringStreamPtr == nullptr) {
stream.write(reinterpret_cast<const char*>(blobData), blobSize);

std::ostringstream* oStringStreamPtr = dynamic_cast<std::ostringstream*>(&stream);
if (oStringStreamPtr != nullptr) {
CustomStringBuf customStringBuf(std::const_pointer_cast<NetworkDescription>(networkDescription)->compiledNetwork);
oStringStreamPtr->rdbuf()->swap(customStringBuf);
} else {
{
// Only for CXX17
stream.rdbuf(new CustomStringBuf(std::move(blobStr)));
oStringStreamPtr->rdbuf()->swap(*dynamic_cast<CustomStringBuf*>(stream.rdbuf()));
}
// By CXX20 we may use move semantics and avoid use of CustomStringBuf class
// oStringStreamPtr->str(std::move(blobStr));
std::cout << oStringStreamPtr->str()[1] << std::endl;
std::cout << oStringStreamPtr->str()[2] << std::endl;
stream.write(reinterpret_cast<const char*>(networkDescription->compiledNetwork.data()), networkDescription->compiledNetwork.size());
}
}

Expand Down
23 changes: 12 additions & 11 deletions src/plugins/intel_npu/tools/compile_tool/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,18 +572,19 @@ int main(int argc, char* argv[]) {
outputName = getFileNameFromPath(fileNameNoExt(FLAGS_m)) + ".blob";
}

// std::ofstream outputFile{outputName, std::ios::out | std::ios::binary};
/*std::ofstream outputFile{outputName, std::ios::out | std::ios::binary};
if (!outputFile.is_open()) {
std::cout << "Outputting file " << outputName << " can't be opened for writing" << std::endl;
return EXIT_FAILURE;
} else {
std::cout << "Writing into file - " << outputName << std::endl;
compiledModel.export_model(outputFile);
}*/
std::ostringstream oStringStream;
// if (!outputFile.is_open()) {
// std::cout << "Outputting file " << outputName << " can't be opened for writing" << std::endl;
// return EXIT_FAILURE;
// } else {
// std::cout << "Writing into file - " << outputName << std::endl;
// compiledModel.export_model(outputFile);
compiledModel.export_model(oStringStream);
std::cout << oStringStream.str()[1] << std::endl;
std::cout << oStringStream.str()[2] << std::endl;
// }
compiledModel.export_model(oStringStream);
auto str = oStringStream.str();
std::cout << str[1] << std::endl;
std::cout << str[2] << std::endl;
std::cout << "Done. LoadNetwork time elapsed: " << loadNetworkTimeElapsed.count() << " ms" << std::endl;
} catch (const std::exception& error) {
std::cerr << error.what() << std::endl;
Expand Down

0 comments on commit a1f4975

Please sign in to comment.