Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: unify the GPU device selection ExaTrkX; local variables declaration in FRNN lib #2925

Merged
merged 13 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,10 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(

// Run the pipeline
const auto trackCandidates = [&]() {
const int deviceHint = -1;
std::lock_guard<std::mutex> lock(m_mutex);

Acts::ExaTrkXTiming timing;
auto res =
m_pipeline.run(features, spacepointIDs, deviceHint, *hook, &timing);
auto res = m_pipeline.run(features, spacepointIDs, *hook, &timing);

m_timing.graphBuildingTime(timing.graphBuildingTime.count());

Expand Down
5 changes: 3 additions & 2 deletions Examples/Python/src/ExaTrkXTrackFinding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ void addExaTrkXTrackFinding(Context &ctx) {
ACTS_PYTHON_MEMBER(embeddingDim);
ACTS_PYTHON_MEMBER(rVal);
ACTS_PYTHON_MEMBER(knnVal);
ACTS_PYTHON_MEMBER(deviceID);
ACTS_PYTHON_STRUCT_END();
}
{
Expand All @@ -93,6 +94,7 @@ void addExaTrkXTrackFinding(Context &ctx) {
ACTS_PYTHON_MEMBER(cut);
ACTS_PYTHON_MEMBER(nChunks);
ACTS_PYTHON_MEMBER(undirected);
ACTS_PYTHON_MEMBER(deviceID);
ACTS_PYTHON_STRUCT_END();
}
{
Expand Down Expand Up @@ -208,8 +210,7 @@ void addExaTrkXTrackFinding(Context &ctx) {
py::arg("graphConstructor"), py::arg("edgeClassifiers"),
py::arg("trackBuilder"), py::arg("level"))
.def("run", &ExaTrkXPipeline::run, py::arg("features"),
py::arg("spacepoints"), py::arg("deviceHint") = -1,
py::arg("hook") = Acts::ExaTrkXHook{},
py::arg("spacepoints"), py::arg("hook") = Acts::ExaTrkXHook{},
py::arg("timing") = nullptr);
}

Expand Down
2 changes: 1 addition & 1 deletion Plugins/ExaTrkX/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ To build the plugin, enable the appropriate CMake options:

```bash
cmake -B <build> -S <source> \
-D ACTS_BUILD_EXATRKX_PLUGIN=ON \
-D ACTS_BUILD_PLUGIN_EXATRKX=ON \
-D ACTS_EXATRKX_ENABLE_TORCH=ON/OFF \
-D ACTS_EXATRKX_ENABLE_ONNX=ON/OFF \
-D ACTS_BUILD_EXAMPLES_EXATRKX=ON \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,24 @@

#include <memory>

#include <torch/script.h>

namespace Acts {

class BoostTrackBuilding final : public Acts::TrackBuildingBase {
public:
BoostTrackBuilding(std::unique_ptr<const Logger> logger)
: m_logger(std::move(logger)) {}
: m_logger(std::move(logger)), m_device(torch::Device(torch::kCPU)) {}

std::vector<std::vector<int>> operator()(std::any nodes, std::any edges,
std::any edge_weights,
std::vector<int> &spacepointIDs,
int deviceHint = -1) override;
torch::Device device) override;
torch::Device device() const override { return m_device; };

private:
std::unique_ptr<const Acts::Logger> m_logger;
torch::Device m_device;
const auto &logger() const { return *m_logger; }
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,24 @@

#include <memory>

#include <torch/script.h>

namespace Acts {

class CugraphTrackBuilding final : public Acts::TrackBuildingBase {
public:
CugraphTrackBuilding(std::unique_ptr<const Logger> logger)
: m_logger(std::move(logger)) {}
: m_logger(std::move(logger)), m_device(torch::Device(torch::kCPU)) {}

std::vector<std::vector<int>> operator()(std::any nodes, std::any edges,
std::any edge_weights,
std::vector<int> &spacepointIDs,
int deviceHint = -1) override;
torch::Device device) override;
torch::Device device() const override { return m_device; };

private:
std::unique_ptr<const Acts::Logger> m_logger;
torch::Device m_device;
const auto &logger() const { return *m_logger; }
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class ExaTrkXPipeline {

std::vector<std::vector<int>> run(std::vector<float> &features,
std::vector<int> &spacepointIDs,
int deviceHint = -1,
const ExaTrkXHook &hook = {},
ExaTrkXTiming *timing = nullptr) const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

#include <memory>

#include <torch/script.h>

namespace Ort {
class Env;
class Session;
Expand All @@ -32,16 +34,18 @@ class OnnxEdgeClassifier final : public Acts::EdgeClassificationBase {
~OnnxEdgeClassifier();

std::tuple<std::any, std::any, std::any> operator()(
std::any nodes, std::any edges, int deviceHint = -1) override;
std::any nodes, std::any edges,
torch::Device device = torch::Device(torch::kCPU)) override;

Config config() const { return m_cfg; }
torch::Device device() const override { return m_device; };

private:
std::unique_ptr<const Acts::Logger> m_logger;
const auto &logger() const { return *m_logger; }

Config m_cfg;

torch::Device m_device;
std::unique_ptr<Ort::Env> m_env;
std::unique_ptr<Ort::Session> m_model;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

#include <memory>

#include <torch/script.h>

namespace Ort {
class Env;
class Session;
Expand All @@ -34,11 +36,12 @@ class OnnxMetricLearning final : public Acts::GraphConstructionBase {
OnnxMetricLearning(const Config& cfg, std::unique_ptr<const Logger> logger);
~OnnxMetricLearning();

std::tuple<std::any, std::any> operator()(std::vector<float>& inputValues,
std::size_t numNodes,
int deviceHint = -1) override;
std::tuple<std::any, std::any> operator()(
std::vector<float>& inputValues, std::size_t numNodes,
torch::Device device = torch::Device(torch::kCPU)) override;

Config config() const { return m_cfg; }
torch::Device device() const override { return m_device; };

private:
void buildEdgesWrapper(std::vector<float>& embedFeatures,
Expand All @@ -49,6 +52,7 @@ class OnnxMetricLearning final : public Acts::GraphConstructionBase {
const auto& logger() const { return *m_logger; }

Config m_cfg;
torch::Device m_device;
std::unique_ptr<Ort::Env> m_env;
std::unique_ptr<Ort::Session> m_model;
};
Expand Down
20 changes: 14 additions & 6 deletions Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <any>
#include <vector>

#include <torch/script.h>

namespace Acts {

// TODO maybe replace std::any with some kind of variant<unique_ptr<torch>,
Expand All @@ -25,12 +27,14 @@ class GraphConstructionBase {
/// @param inputValues Flattened input data
/// @param numNodes Number of nodes. inputValues.size() / numNodes
/// then gives the number of features
/// @param deviceHint Which GPU to pick. Not relevant for CPU-only builds
/// @param device Which GPU device to pick. Not relevant for CPU-only builds
///
/// @return (node_tensor, edge_tensore)
virtual std::tuple<std::any, std::any> operator()(
std::vector<float> &inputValues, std::size_t numNodes,
int deviceHint = -1) = 0;
torch::Device device) = 0;

virtual torch::Device device() const = 0;

virtual ~GraphConstructionBase() = default;
};
Expand All @@ -41,11 +45,13 @@ class EdgeClassificationBase {
///
/// @param nodes Node tensor with shape (n_nodes, n_node_features)
/// @param edges Edge-index tensor with shape (2, n_edges)
/// @param deviceHint Which GPU to pick. Not relevant for CPU-only builds
/// @param device Which GPU device to pick. Not relevant for CPU-only builds
///
/// @return (node_tensor, edge_tensor, score_tensor)
virtual std::tuple<std::any, std::any, std::any> operator()(
std::any nodes, std::any edges, int deviceHint = -1) = 0;
std::any nodes, std::any edges, torch::Device device) = 0;

virtual torch::Device device() const = 0;

virtual ~EdgeClassificationBase() = default;
};
Expand All @@ -58,12 +64,14 @@ class TrackBuildingBase {
/// @param edges Edge-index tensor with shape (2, n_edges)
/// @param edgeWeights Edge-weights of the previous edge classification phase
/// @param spacepointIDs IDs of the nodes (must have size=n_nodes)
/// @param deviceHint Which GPU to pick. Not relevant for CPU-only builds
/// @param device Which GPU device to pick. Not relevant for CPU-only builds
///
/// @return tracks (as vectors of node-IDs)
virtual std::vector<std::vector<int>> operator()(
std::any nodes, std::any edges, std::any edgeWeights,
std::vector<int> &spacepointIDs, int deviceHint = -1) = 0;
std::vector<int> &spacepointIDs, torch::Device device) = 0;

virtual torch::Device device() const = 0;

virtual ~TrackBuildingBase() = default;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

#include <memory>

#include <torch/script.h>

namespace torch::jit {
class Module;
}
Expand All @@ -31,22 +33,25 @@ class TorchEdgeClassifier final : public Acts::EdgeClassificationBase {
float cut = 0.21;
int nChunks = 1; // NOTE for GNN use 1
bool undirected = false;
int deviceID = 0;
};

TorchEdgeClassifier(const Config &cfg, std::unique_ptr<const Logger> logger);
~TorchEdgeClassifier();

std::tuple<std::any, std::any, std::any> operator()(
std::any nodes, std::any edges, int deviceHint = -1) override;
std::any nodes, std::any edges, torch::Device device) override;

Config config() const { return m_cfg; }
torch::Device device() const override { return m_device; };

private:
std::unique_ptr<const Acts::Logger> m_logger;
const auto &logger() const { return *m_logger; }

Config m_cfg;
c10::DeviceType m_deviceType;
torch::Device m_device;
std::unique_ptr<torch::jit::Module> m_model;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,26 @@ class TorchMetricLearning final : public Acts::GraphConstructionBase {
float rVal = 1.6;
int knnVal = 500;
bool shuffleDirections = false;
int deviceID = 0; // default is the first GPU if available
};

TorchMetricLearning(const Config &cfg, std::unique_ptr<const Logger> logger);
~TorchMetricLearning();

std::tuple<std::any, std::any> operator()(std::vector<float> &inputValues,
std::size_t numNodes,
int deviceHint = -1) override;
torch::Device device) override;

Config config() const { return m_cfg; }
torch::Device device() const override { return m_device; };

private:
std::unique_ptr<const Acts::Logger> m_logger;
const auto &logger() const { return *m_logger; }

Config m_cfg;
c10::DeviceType m_deviceType;
torch::Device m_device;
std::unique_ptr<torch::jit::Module> m_model;
};

Expand Down
2 changes: 1 addition & 1 deletion Plugins/ExaTrkX/src/BoostTrackBuilding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace Acts {

std::vector<std::vector<int>> BoostTrackBuilding::operator()(
std::any nodes, std::any edges, std::any weights,
std::vector<int>& spacepointIDs, int) {
std::vector<int>& spacepointIDs, torch::Device) {
ACTS_DEBUG("Start track building");
const auto edgeTensor = std::any_cast<torch::Tensor>(edges).to(torch::kCPU);
const auto edgeWeightTensor =
Expand Down
2 changes: 1 addition & 1 deletion Plugins/ExaTrkX/src/CugraphTrackBuilding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace Acts {

std::vector<std::vector<int>> CugraphTrackBuilding::operator()(
std::any, std::any edges, std::any edge_weights,
std::vector<int> &spacepointIDs, int) {
std::vector<int> &spacepointIDs, torch::Device) {
auto numSpacepoints = spacepointIDs.size();
auto edgesAfterFiltering = std::any_cast<std::vector<int64_t>>(edges);
auto numEdgesAfterF = edgesAfterFiltering.size() / 2;
Expand Down
20 changes: 11 additions & 9 deletions Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ ExaTrkXPipeline::ExaTrkXPipeline(

std::vector<std::vector<int>> ExaTrkXPipeline::run(
std::vector<float> &features, std::vector<int> &spacepointIDs,
int deviceHint, const ExaTrkXHook &hook, ExaTrkXTiming *timing) const {
const ExaTrkXHook &hook, ExaTrkXTiming *timing) const {
auto t0 = std::chrono::high_resolution_clock::now();
auto [nodes, edges] =
(*m_graphConstructor)(features, spacepointIDs.size(), deviceHint);
auto [nodes, edges] = (*m_graphConstructor)(features, spacepointIDs.size(),
m_graphConstructor->device());
auto t1 = std::chrono::high_resolution_clock::now();

if (timing != nullptr) {
Expand All @@ -47,12 +47,14 @@ std::vector<std::vector<int>> ExaTrkXPipeline::run(
hook(nodes, edges, {});

std::any edge_weights;
timing->classifierTimes.clear();
if (timing != nullptr) {
timing->classifierTimes.clear();
}

for (auto edgeClassifier : m_edgeClassifiers) {
t0 = std::chrono::high_resolution_clock::now();
auto [newNodes, newEdges, newWeights] =
(*edgeClassifier)(std::move(nodes), std::move(edges), deviceHint);
auto [newNodes, newEdges, newWeights] = (*edgeClassifier)(
std::move(nodes), std::move(edges), edgeClassifier->device());
t1 = std::chrono::high_resolution_clock::now();

if (timing != nullptr) {
Expand All @@ -67,9 +69,9 @@ std::vector<std::vector<int>> ExaTrkXPipeline::run(
}

t0 = std::chrono::high_resolution_clock::now();
auto res =
(*m_trackBuilder)(std::move(nodes), std::move(edges),
std::move(edge_weights), spacepointIDs, deviceHint);
auto res = (*m_trackBuilder)(std::move(nodes), std::move(edges),
std::move(edge_weights), spacepointIDs,
m_trackBuilder->device());
t1 = std::chrono::high_resolution_clock::now();

if (timing != nullptr) {
Expand Down
6 changes: 4 additions & 2 deletions Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ namespace Acts {

OnnxEdgeClassifier::OnnxEdgeClassifier(const Config &cfg,
std::unique_ptr<const Logger> logger)
: m_logger(std::move(logger)), m_cfg(cfg) {
: m_logger(std::move(logger)),
m_cfg(cfg),
m_device(torch::Device(torch::kCPU)) {
m_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
"ExaTrkX - edge classifier");

Expand All @@ -44,7 +46,7 @@ OnnxEdgeClassifier::OnnxEdgeClassifier(const Config &cfg,
OnnxEdgeClassifier::~OnnxEdgeClassifier() {}

std::tuple<std::any, std::any, std::any> OnnxEdgeClassifier::operator()(
std::any inputNodes, std::any inputEdges, int) {
std::any inputNodes, std::any inputEdges, torch::Device) {
Ort::AllocatorWithDefaultOptions allocator;
auto memoryInfo = Ort::MemoryInfo::CreateCpu(
OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
Expand Down
6 changes: 4 additions & 2 deletions Plugins/ExaTrkX/src/OnnxMetricLearning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ namespace Acts {

OnnxMetricLearning::OnnxMetricLearning(const Config& cfg,
std::unique_ptr<const Logger> logger)
: m_logger(std::move(logger)), m_cfg(cfg) {
: m_logger(std::move(logger)),
m_cfg(cfg),
m_device(torch::Device(torch::kCPU)) {
m_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
"ExaTrkX - metric learning");

Expand Down Expand Up @@ -57,7 +59,7 @@ void OnnxMetricLearning::buildEdgesWrapper(std::vector<float>& embedFeatures,
}

std::tuple<std::any, std::any> OnnxMetricLearning::operator()(
std::vector<float>& inputValues, std::size_t, int) {
std::vector<float>& inputValues, std::size_t, torch::Device) {
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved
Ort::AllocatorWithDefaultOptions allocator;
auto memoryInfo = Ort::MemoryInfo::CreateCpu(
OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
Expand Down
Loading
Loading