Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORT GPU implementation #13755

Merged
merged 31 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6289adc
Adding tasks for onnx runtime gpu inference on AMD GPU's
ChSonnabend Aug 15, 2024
09b3a59
Merge branch 'AliceO2Group:dev' into onnxruntime-gpu
ChSonnabend Aug 15, 2024
79ced32
Merge branch 'AliceO2Group:dev' into onnxruntime-gpu
ChSonnabend Aug 16, 2024
e585588
Working header file for ONNX model executions
ChSonnabend Aug 28, 2024
792e7ec
Finally fixing casting issue with Oet::Float16_t and OrtDataType::Flo…
ChSonnabend Sep 16, 2024
f675ddd
Merge branch 'AliceO2Group:dev' into onnxruntime-gpu
ChSonnabend Sep 17, 2024
59ba4d9
Using reinterpret_cast for type conversion
ChSonnabend Sep 17, 2024
ae21472
Modifying test script
ChSonnabend Sep 17, 2024
c9cd4fa
Adding source for float16 implementation
ChSonnabend Sep 17, 2024
f4b74de
Minor update on the 3rd party libraries
ChSonnabend Sep 27, 2024
a910402
Merge branch 'AliceO2Group:dev' into onnxruntime-gpu
ChSonnabend Sep 27, 2024
62eadea
Changing to #ifdef statements for O2 compilation
ChSonnabend Oct 3, 2024
224a137
Please consider the following formatting changes
alibuild Oct 3, 2024
71548d9
Merge pull request #5 from alibuild/alibot-cleanup-13522
ChSonnabend Nov 4, 2024
ced7a4a
Merge branch 'AliceO2Group:dev' into onnxruntime-gpu
ChSonnabend Nov 15, 2024
cbf93d3
Merge branch 'AliceO2Group:dev' into onnxruntime-gpu
ChSonnabend Nov 21, 2024
9c8f167
Changing names to CamelCase and adding test task for onnx model
ChSonnabend Nov 28, 2024
de1ae50
Fixing warning of narrowing conversion
ChSonnabend Nov 28, 2024
a3bf6af
Adding mapping of ORT logging to InfoLogger and disabling telemetry e…
ChSonnabend Nov 29, 2024
43cc7fb
Merge branch 'dev' into onnxruntime-gpu
ChSonnabend Nov 29, 2024
1911707
Merge branch 'AliceO2Group:dev' into onnxruntime-gpu
ChSonnabend Nov 29, 2024
a19e595
Removing old files and adding whitespace
ChSonnabend Nov 29, 2024
9df2dfb
Removing add_subdirectory (duplicate)
ChSonnabend Nov 29, 2024
c4bc6b6
Reformatting to adjust to current dev branch
ChSonnabend Nov 29, 2024
e427f0a
Adding whitespace
ChSonnabend Nov 29, 2024
e436204
Removing test task
ChSonnabend Nov 29, 2024
27fa752
Adding back the white space
ChSonnabend Nov 29, 2024
7b82026
Removing brackets
ChSonnabend Nov 29, 2024
43eb177
Removing curly braces
ChSonnabend Nov 29, 2024
2cc5d3e
Please consider the following formatting changes
alibuild Nov 29, 2024
900e542
Merge pull request #9 from alibuild/alibot-cleanup-13755
ChSonnabend Nov 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion Common/ML/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,21 @@
# granted to it by virtue of its status as an Intergovernmental Organization
# or submit itself to any jurisdiction.

# Pass ORT variables as a preprocessor definition
if(DEFINED ENV{ORT_ROCM_BUILD})
add_compile_definitions(ORT_ROCM_BUILD=$ENV{ORT_ROCM_BUILD})
endif()
if(DEFINED ENV{ORT_CUDA_BUILD})
add_compile_definitions(ORT_CUDA_BUILD=$ENV{ORT_CUDA_BUILD})
endif()
if(DEFINED ENV{ORT_MIGRAPHX_BUILD})
add_compile_definitions(ORT_MIGRAPHX_BUILD=$ENV{ORT_MIGRAPHX_BUILD})
endif()
if(DEFINED ENV{ORT_TENSORRT_BUILD})
add_compile_definitions(ORT_TENSORRT_BUILD=$ENV{ORT_TENSORRT_BUILD})
endif()

o2_add_library(ML
SOURCES src/ort_interface.cxx
SOURCES src/OrtInterface.cxx
TARGETVARNAME targetName
PRIVATE_LINK_LIBRARIES O2::Framework ONNXRuntime::ONNXRuntime)
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
// granted to it by virtue of its status as an Intergovernmental Organization
// or submit itself to any jurisdiction.

/// \file ort_interface.h
/// \file OrtInterface.h
/// \author Christian Sonnabend <[email protected]>
/// \brief A header library for loading ONNX models and inferencing them on CPU and GPU

#ifndef O2_ML_ONNX_INTERFACE_H
#define O2_ML_ONNX_INTERFACE_H
#ifndef O2_ML_ORTINTERFACE_H
#define O2_ML_ORTINTERFACE_H

// C++ and system includes
#include <vector>
Expand Down Expand Up @@ -89,4 +89,4 @@ class OrtModel

} // namespace o2

#endif // O2_ML_ORT_INTERFACE_H
#endif // O2_ML_ORTINTERFACE_H
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
// granted to it by virtue of its status as an Intergovernmental Organization
// or submit itself to any jurisdiction.

/// \file ort_interface.cxx
/// \file OrtInterface.cxx
/// \author Christian Sonnabend <[email protected]>
/// \brief A header library for loading ONNX models and inferencing them on CPU and GPU

#include "ML/ort_interface.h"
#include "ML/OrtInterface.h"
#include "ML/3rdparty/GPUORTFloat16.h"

// ONNX includes
Expand Down Expand Up @@ -50,29 +50,35 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0);
allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0);
intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0);
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 0);
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 2);
enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0);
enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0);

std::string dev_mem_str = "Hip";
#ifdef ORT_ROCM_BUILD
#if defined(ORT_ROCM_BUILD)
#if ORT_ROCM_BUILD == 1
if (device == "ROCM") {
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->sessionOptions, deviceId));
LOG(info) << "(ORT) ROCM execution provider set";
}
#endif
#ifdef ORT_MIGRAPHX_BUILD
#endif
#if defined(ORT_MIGRAPHX_BUILD)
#if ORT_MIGRAPHX_BUILD == 1
if (device == "MIGRAPHX") {
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(pImplOrt->sessionOptions, deviceId));
LOG(info) << "(ORT) MIGraphX execution provider set";
}
#endif
#ifdef ORT_CUDA_BUILD
#endif
#if defined(ORT_CUDA_BUILD)
#if ORT_CUDA_BUILD == 1
if (device == "CUDA") {
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->sessionOptions, deviceId));
LOG(info) << "(ORT) CUDA execution provider set";
dev_mem_str = "Cuda";
}
#endif
#endif

if (allocateDeviceMemory) {
Expand Down Expand Up @@ -106,7 +112,27 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
(pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations));
(pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel));

pImplOrt->env = std::make_shared<Ort::Env>(OrtLoggingLevel(loggingLevel), (optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()));
pImplOrt->env = std::make_shared<Ort::Env>(
OrtLoggingLevel(loggingLevel),
(optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()),
// Integrate ORT logging into Fairlogger
[](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) {
if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
LOG(debug) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else if (severity == ORT_LOGGING_LEVEL_INFO) {
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else if (severity == ORT_LOGGING_LEVEL_WARNING) {
LOG(warning) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else if (severity == ORT_LOGGING_LEVEL_ERROR) {
LOG(error) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else if (severity == ORT_LOGGING_LEVEL_FATAL) {
LOG(fatal) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else {
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
}
},
(void*)3);
(pImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events
pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions);

for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
Expand All @@ -130,16 +156,14 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
[&](const std::string& str) { return str.c_str(); });

// Print names
if (loggingLevel > 1) {
LOG(info) << "Input Nodes:";
for (size_t i = 0; i < mInputNames.size(); i++) {
LOG(info) << "\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]);
}
LOG(info) << "\tInput Nodes:";
for (size_t i = 0; i < mInputNames.size(); i++) {
LOG(info) << "\t\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]);
}

LOG(info) << "Output Nodes:";
for (size_t i = 0; i < mOutputNames.size(); i++) {
LOG(info) << "\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]);
}
LOG(info) << "\tOutput Nodes:";
for (size_t i = 0; i < mOutputNames.size(); i++) {
LOG(info) << "\t\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]);
}
}

Expand Down