Skip to content

Commit

Permalink
refactor: unify the GPU device selection ExaTrkX; local variables dec…
Browse files Browse the repository at this point in the history
…laration in FRNN lib (acts-project#2925)

This PR moves `deviceHint` in previous implementations to the Config constructor, and create a `torch::Device` type with some protections to ensure both model and input tensors are loaded to a specific GPU. The base of FRNN repo is changed to avoid declaration of global variables in CUDA codes which causes segmentation fault in run time when running with Triton Inference Server. 

Tagging ExaTrkX aaS people here
@xju2 @ytchoutw @asnaylor @yongbinfeng @y19y19
  • Loading branch information
hrzhao76 authored and Tim Adye committed Jun 27, 2024
1 parent 60d6563 commit 60f704d
Show file tree
Hide file tree
Showing 19 changed files with 137 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,10 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(

// Run the pipeline
const auto trackCandidates = [&]() {
const int deviceHint = -1;
std::lock_guard<std::mutex> lock(m_mutex);

Acts::ExaTrkXTiming timing;
auto res =
m_pipeline.run(features, spacepointIDs, deviceHint, *hook, &timing);
auto res = m_pipeline.run(features, spacepointIDs, *hook, &timing);

m_timing.graphBuildingTime(timing.graphBuildingTime.count());

Expand Down
5 changes: 3 additions & 2 deletions Examples/Python/src/ExaTrkXTrackFinding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ void addExaTrkXTrackFinding(Context &ctx) {
ACTS_PYTHON_MEMBER(embeddingDim);
ACTS_PYTHON_MEMBER(rVal);
ACTS_PYTHON_MEMBER(knnVal);
ACTS_PYTHON_MEMBER(deviceID);
ACTS_PYTHON_STRUCT_END();
}
{
Expand All @@ -93,6 +94,7 @@ void addExaTrkXTrackFinding(Context &ctx) {
ACTS_PYTHON_MEMBER(cut);
ACTS_PYTHON_MEMBER(nChunks);
ACTS_PYTHON_MEMBER(undirected);
ACTS_PYTHON_MEMBER(deviceID);
ACTS_PYTHON_STRUCT_END();
}
{
Expand Down Expand Up @@ -208,8 +210,7 @@ void addExaTrkXTrackFinding(Context &ctx) {
py::arg("graphConstructor"), py::arg("edgeClassifiers"),
py::arg("trackBuilder"), py::arg("level"))
.def("run", &ExaTrkXPipeline::run, py::arg("features"),
py::arg("spacepoints"), py::arg("deviceHint") = -1,
py::arg("hook") = Acts::ExaTrkXHook{},
py::arg("spacepoints"), py::arg("hook") = Acts::ExaTrkXHook{},
py::arg("timing") = nullptr);
}

Expand Down
2 changes: 1 addition & 1 deletion Plugins/ExaTrkX/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ To build the plugin, enable the appropriate CMake options:

```bash
cmake -B <build> -S <source> \
-D ACTS_BUILD_EXATRKX_PLUGIN=ON \
-D ACTS_BUILD_PLUGIN_EXATRKX=ON \
-D ACTS_EXATRKX_ENABLE_TORCH=ON/OFF \
-D ACTS_EXATRKX_ENABLE_ONNX=ON/OFF \
-D ACTS_BUILD_EXAMPLES_EXATRKX=ON \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,24 @@

#include <memory>

#include <torch/script.h>

namespace Acts {

class BoostTrackBuilding final : public Acts::TrackBuildingBase {
public:
BoostTrackBuilding(std::unique_ptr<const Logger> logger)
: m_logger(std::move(logger)) {}
: m_logger(std::move(logger)), m_device(torch::Device(torch::kCPU)) {}

std::vector<std::vector<int>> operator()(std::any nodes, std::any edges,
std::any edge_weights,
std::vector<int> &spacepointIDs,
int deviceHint = -1) override;
std::vector<std::vector<int>> operator()(
std::any nodes, std::any edges, std::any edge_weights,
std::vector<int> &spacepointIDs,
torch::Device device = torch::Device(torch::kCPU)) override;
torch::Device device() const override { return m_device; };

private:
std::unique_ptr<const Acts::Logger> m_logger;
torch::Device m_device;
const auto &logger() const { return *m_logger; }
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,24 @@

#include <memory>

#include <torch/script.h>

namespace Acts {

class CugraphTrackBuilding final : public Acts::TrackBuildingBase {
public:
CugraphTrackBuilding(std::unique_ptr<const Logger> logger)
: m_logger(std::move(logger)) {}
: m_logger(std::move(logger)), m_device(torch::Device(torch::kCPU)) {}

std::vector<std::vector<int>> operator()(std::any nodes, std::any edges,
std::any edge_weights,
std::vector<int> &spacepointIDs,
int deviceHint = -1) override;
std::vector<std::vector<int>> operator()(
std::any nodes, std::any edges, std::any edge_weights,
std::vector<int> &spacepointIDs,
torch::Device device = torch::Device(torch::kCPU)) override;
torch::Device device() const override { return m_device; };

private:
std::unique_ptr<const Acts::Logger> m_logger;
torch::Device m_device;
const auto &logger() const { return *m_logger; }
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class ExaTrkXPipeline {

std::vector<std::vector<int>> run(std::vector<float> &features,
std::vector<int> &spacepointIDs,
int deviceHint = -1,
const ExaTrkXHook &hook = {},
ExaTrkXTiming *timing = nullptr) const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

#include <memory>

#include <torch/script.h>

namespace Ort {
class Env;
class Session;
Expand All @@ -32,16 +34,18 @@ class OnnxEdgeClassifier final : public Acts::EdgeClassificationBase {
~OnnxEdgeClassifier();

std::tuple<std::any, std::any, std::any> operator()(
std::any nodes, std::any edges, int deviceHint = -1) override;
std::any nodes, std::any edges,
torch::Device device = torch::Device(torch::kCPU)) override;

Config config() const { return m_cfg; }
torch::Device device() const override { return m_device; };

private:
std::unique_ptr<const Acts::Logger> m_logger;
const auto &logger() const { return *m_logger; }

Config m_cfg;

torch::Device m_device;
std::unique_ptr<Ort::Env> m_env;
std::unique_ptr<Ort::Session> m_model;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

#include <memory>

#include <torch/script.h>

namespace Ort {
class Env;
class Session;
Expand All @@ -34,11 +36,12 @@ class OnnxMetricLearning final : public Acts::GraphConstructionBase {
OnnxMetricLearning(const Config& cfg, std::unique_ptr<const Logger> logger);
~OnnxMetricLearning();

std::tuple<std::any, std::any> operator()(std::vector<float>& inputValues,
std::size_t numNodes,
int deviceHint = -1) override;
std::tuple<std::any, std::any> operator()(
std::vector<float>& inputValues, std::size_t numNodes,
torch::Device device = torch::Device(torch::kCPU)) override;

Config config() const { return m_cfg; }
torch::Device device() const override { return m_device; };

private:
void buildEdgesWrapper(std::vector<float>& embedFeatures,
Expand All @@ -50,6 +53,7 @@ class OnnxMetricLearning final : public Acts::GraphConstructionBase {
const auto& logger() const { return *m_logger; }

Config m_cfg;
torch::Device m_device;
std::unique_ptr<Ort::Env> m_env;
std::unique_ptr<Ort::Session> m_model;
};
Expand Down
22 changes: 16 additions & 6 deletions Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <any>
#include <vector>

#include <torch/script.h>

namespace Acts {

// TODO maybe replace std::any with some kind of variant<unique_ptr<torch>,
Expand All @@ -25,12 +27,14 @@ class GraphConstructionBase {
/// @param inputValues Flattened input data
/// @param numNodes Number of nodes. inputValues.size() / numNodes
/// then gives the number of features
/// @param deviceHint Which GPU to pick. Not relevant for CPU-only builds
/// @param device Which GPU device to pick. Not relevant for CPU-only builds
///
/// @return (node_tensor, edge_tensore)
virtual std::tuple<std::any, std::any> operator()(
std::vector<float> &inputValues, std::size_t numNodes,
int deviceHint = -1) = 0;
torch::Device device = torch::Device(torch::kCPU)) = 0;

virtual torch::Device device() const = 0;

virtual ~GraphConstructionBase() = default;
};
Expand All @@ -41,11 +45,14 @@ class EdgeClassificationBase {
///
/// @param nodes Node tensor with shape (n_nodes, n_node_features)
/// @param edges Edge-index tensor with shape (2, n_edges)
/// @param deviceHint Which GPU to pick. Not relevant for CPU-only builds
/// @param device Which GPU device to pick. Not relevant for CPU-only builds
///
/// @return (node_tensor, edge_tensor, score_tensor)
virtual std::tuple<std::any, std::any, std::any> operator()(
std::any nodes, std::any edges, int deviceHint = -1) = 0;
std::any nodes, std::any edges,
torch::Device device = torch::Device(torch::kCPU)) = 0;

virtual torch::Device device() const = 0;

virtual ~EdgeClassificationBase() = default;
};
Expand All @@ -58,12 +65,15 @@ class TrackBuildingBase {
/// @param edges Edge-index tensor with shape (2, n_edges)
/// @param edgeWeights Edge-weights of the previous edge classification phase
/// @param spacepointIDs IDs of the nodes (must have size=n_nodes)
/// @param deviceHint Which GPU to pick. Not relevant for CPU-only builds
/// @param device Which GPU device to pick. Not relevant for CPU-only builds
///
/// @return tracks (as vectors of node-IDs)
virtual std::vector<std::vector<int>> operator()(
std::any nodes, std::any edges, std::any edgeWeights,
std::vector<int> &spacepointIDs, int deviceHint = -1) = 0;
std::vector<int> &spacepointIDs,
torch::Device device = torch::Device(torch::kCPU)) = 0;

virtual torch::Device device() const = 0;

virtual ~TrackBuildingBase() = default;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

#include <memory>

#include <torch/script.h>

namespace torch::jit {
class Module;
}
Expand All @@ -31,22 +33,26 @@ class TorchEdgeClassifier final : public Acts::EdgeClassificationBase {
float cut = 0.21;
int nChunks = 1; // NOTE for GNN use 1
bool undirected = false;
int deviceID = 0;
};

TorchEdgeClassifier(const Config &cfg, std::unique_ptr<const Logger> logger);
~TorchEdgeClassifier();

std::tuple<std::any, std::any, std::any> operator()(
std::any nodes, std::any edges, int deviceHint = -1) override;
std::any nodes, std::any edges,
torch::Device device = torch::Device(torch::kCPU)) override;

Config config() const { return m_cfg; }
torch::Device device() const override { return m_device; };

private:
std::unique_ptr<const Acts::Logger> m_logger;
const auto &logger() const { return *m_logger; }

Config m_cfg;
c10::DeviceType m_deviceType;
torch::Device m_device;
std::unique_ptr<torch::jit::Module> m_model;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,26 @@ class TorchMetricLearning final : public Acts::GraphConstructionBase {
float rVal = 1.6;
int knnVal = 500;
bool shuffleDirections = false;
int deviceID = 0; // default is the first GPU if available
};

TorchMetricLearning(const Config &cfg, std::unique_ptr<const Logger> logger);
~TorchMetricLearning();

std::tuple<std::any, std::any> operator()(std::vector<float> &inputValues,
std::size_t numNodes,
int deviceHint = -1) override;
std::tuple<std::any, std::any> operator()(
std::vector<float> &inputValues, std::size_t numNodes,
torch::Device device = torch::Device(torch::kCPU)) override;

Config config() const { return m_cfg; }
torch::Device device() const override { return m_device; };

private:
std::unique_ptr<const Acts::Logger> m_logger;
const auto &logger() const { return *m_logger; }

Config m_cfg;
c10::DeviceType m_deviceType;
torch::Device m_device;
std::unique_ptr<torch::jit::Module> m_model;
};

Expand Down
2 changes: 1 addition & 1 deletion Plugins/ExaTrkX/src/BoostTrackBuilding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace Acts {

std::vector<std::vector<int>> BoostTrackBuilding::operator()(
std::any nodes, std::any edges, std::any weights,
std::vector<int>& spacepointIDs, int) {
std::vector<int>& spacepointIDs, torch::Device) {
ACTS_DEBUG("Start track building");
const auto edgeTensor = std::any_cast<torch::Tensor>(edges).to(torch::kCPU);
const auto edgeWeightTensor =
Expand Down
2 changes: 1 addition & 1 deletion Plugins/ExaTrkX/src/CugraphTrackBuilding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace Acts {

std::vector<std::vector<int>> CugraphTrackBuilding::operator()(
std::any, std::any edges, std::any edge_weights,
std::vector<int> &spacepointIDs, int) {
std::vector<int> &spacepointIDs, torch::Device) {
auto numSpacepoints = spacepointIDs.size();
auto edgesAfterFiltering = std::any_cast<std::vector<std::int64_t>>(edges);
auto numEdgesAfterF = edgesAfterFiltering.size() / 2;
Expand Down
20 changes: 11 additions & 9 deletions Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ ExaTrkXPipeline::ExaTrkXPipeline(

std::vector<std::vector<int>> ExaTrkXPipeline::run(
std::vector<float> &features, std::vector<int> &spacepointIDs,
int deviceHint, const ExaTrkXHook &hook, ExaTrkXTiming *timing) const {
const ExaTrkXHook &hook, ExaTrkXTiming *timing) const {
auto t0 = std::chrono::high_resolution_clock::now();
auto [nodes, edges] =
(*m_graphConstructor)(features, spacepointIDs.size(), deviceHint);
auto [nodes, edges] = (*m_graphConstructor)(features, spacepointIDs.size(),
m_graphConstructor->device());
auto t1 = std::chrono::high_resolution_clock::now();

if (timing != nullptr) {
Expand All @@ -49,12 +49,14 @@ std::vector<std::vector<int>> ExaTrkXPipeline::run(
hook(nodes, edges, {});

std::any edge_weights;
timing->classifierTimes.clear();
if (timing != nullptr) {
timing->classifierTimes.clear();
}

for (auto edgeClassifier : m_edgeClassifiers) {
t0 = std::chrono::high_resolution_clock::now();
auto [newNodes, newEdges, newWeights] =
(*edgeClassifier)(std::move(nodes), std::move(edges), deviceHint);
auto [newNodes, newEdges, newWeights] = (*edgeClassifier)(
std::move(nodes), std::move(edges), edgeClassifier->device());
t1 = std::chrono::high_resolution_clock::now();

if (timing != nullptr) {
Expand All @@ -69,9 +71,9 @@ std::vector<std::vector<int>> ExaTrkXPipeline::run(
}

t0 = std::chrono::high_resolution_clock::now();
auto res =
(*m_trackBuilder)(std::move(nodes), std::move(edges),
std::move(edge_weights), spacepointIDs, deviceHint);
auto res = (*m_trackBuilder)(std::move(nodes), std::move(edges),
std::move(edge_weights), spacepointIDs,
m_trackBuilder->device());
t1 = std::chrono::high_resolution_clock::now();

if (timing != nullptr) {
Expand Down
6 changes: 4 additions & 2 deletions Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ namespace Acts {

OnnxEdgeClassifier::OnnxEdgeClassifier(const Config &cfg,
std::unique_ptr<const Logger> logger)
: m_logger(std::move(logger)), m_cfg(cfg) {
: m_logger(std::move(logger)),
m_cfg(cfg),
m_device(torch::Device(torch::kCPU)) {
m_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
"ExaTrkX - edge classifier");

Expand All @@ -44,7 +46,7 @@ OnnxEdgeClassifier::OnnxEdgeClassifier(const Config &cfg,
OnnxEdgeClassifier::~OnnxEdgeClassifier() {}

std::tuple<std::any, std::any, std::any> OnnxEdgeClassifier::operator()(
std::any inputNodes, std::any inputEdges, int) {
std::any inputNodes, std::any inputEdges, torch::Device) {
Ort::AllocatorWithDefaultOptions allocator;
auto memoryInfo = Ort::MemoryInfo::CreateCpu(
OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
Expand Down
Loading

0 comments on commit 60f704d

Please sign in to comment.