From 93a8454522ed39cad4257396092e47c12117dbfb Mon Sep 17 00:00:00 2001 From: hrzhao76 Date: Tue, 23 Jan 2024 16:22:35 -0800 Subject: [PATCH 1/8] feat: ExaTrkX ensure the model are loaded to the same GPU as the input --- .../src/TrackFindingAlgorithmExaTrkX.cpp | 4 +--- Examples/Python/src/ExaTrkXTrackFinding.cpp | 4 +++- .../Plugins/ExaTrkX/BoostTrackBuilding.hpp | 8 ++++++-- .../Plugins/ExaTrkX/CugraphTrackBuilding.hpp | 8 ++++++-- .../Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp | 1 - .../Plugins/ExaTrkX/OnnxEdgeClassifier.hpp | 4 +++- .../Plugins/ExaTrkX/OnnxMetricLearning.hpp | 4 +++- .../include/Acts/Plugins/ExaTrkX/Stages.hpp | 20 +++++++++++++------ .../Plugins/ExaTrkX/TorchEdgeClassifier.hpp | 7 ++++++- .../Plugins/ExaTrkX/TorchMetricLearning.hpp | 5 ++++- Plugins/ExaTrkX/src/BoostTrackBuilding.cpp | 2 +- Plugins/ExaTrkX/src/CugraphTrackBuilding.cpp | 2 +- Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp | 20 ++++++++++--------- Plugins/ExaTrkX/src/OnnxMetricLearning.cpp | 2 +- Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp | 17 ++++++++++++---- Plugins/ExaTrkX/src/TorchMetricLearning.cpp | 18 +++++++++++++---- 16 files changed, 87 insertions(+), 39 deletions(-) diff --git a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp index f87d9714b6f..6eff9246ffc 100644 --- a/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp +++ b/Examples/Algorithms/TrackFindingExaTrkX/src/TrackFindingAlgorithmExaTrkX.cpp @@ -250,12 +250,10 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute( // Run the pipeline const auto trackCandidates = [&]() { - const int deviceHint = -1; std::lock_guard lock(m_mutex); Acts::ExaTrkXTiming timing; - auto res = - m_pipeline.run(features, spacepointIDs, deviceHint, *hook, &timing); + auto res = m_pipeline.run(features, spacepointIDs, *hook, &timing); m_timing.graphBuildingTime(timing.graphBuildingTime.count()); diff --git a/Examples/Python/src/ExaTrkXTrackFinding.cpp b/Examples/Python/src/ExaTrkXTrackFinding.cpp index c167ce54f67..bef42f04e9f 100644 --- a/Examples/Python/src/ExaTrkXTrackFinding.cpp +++ b/Examples/Python/src/ExaTrkXTrackFinding.cpp @@ -70,6 +70,7 @@ void addExaTrkXTrackFinding(Context &ctx) { ACTS_PYTHON_MEMBER(embeddingDim); ACTS_PYTHON_MEMBER(rVal); ACTS_PYTHON_MEMBER(knnVal); + ACTS_PYTHON_MEMBER(deviceID); ACTS_PYTHON_STRUCT_END(); } { @@ -93,6 +94,7 @@ void addExaTrkXTrackFinding(Context &ctx) { ACTS_PYTHON_MEMBER(cut); ACTS_PYTHON_MEMBER(nChunks); ACTS_PYTHON_MEMBER(undirected); + ACTS_PYTHON_MEMBER(deviceID); ACTS_PYTHON_STRUCT_END(); } { @@ -208,7 +210,7 @@ void addExaTrkXTrackFinding(Context &ctx) { py::arg("graphConstructor"), py::arg("edgeClassifiers"), py::arg("trackBuilder"), py::arg("level")) .def("run", &ExaTrkXPipeline::run, py::arg("features"), - py::arg("spacepoints"), py::arg("deviceHint") = -1, + py::arg("spacepoints"), py::arg("hook") = Acts::ExaTrkXHook{}, py::arg("timing") = nullptr); } diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp index 6311f4e91cc..a82203c5d7c 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp @@ -13,20 +13,24 @@ #include +#include + namespace Acts { class BoostTrackBuilding final : public Acts::TrackBuildingBase { public: BoostTrackBuilding(std::unique_ptr logger) - : m_logger(std::move(logger)) {} + : m_logger(std::move(logger)), m_device(torch::Device(torch::kCPU)) {} std::vector> operator()(std::any nodes, std::any edges, std::any edge_weights, std::vector &spacepointIDs, - int deviceHint = -1) override; + torch::Device device) override; + torch::Device device() const override { return m_device; }; private: std::unique_ptr m_logger; + torch::Device m_device; const auto &logger() const { return *m_logger; } }; diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/CugraphTrackBuilding.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/CugraphTrackBuilding.hpp index 76779016549..40a29bcb4eb 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/CugraphTrackBuilding.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/CugraphTrackBuilding.hpp @@ -13,20 +13,24 @@ #include +#include + namespace Acts { class CugraphTrackBuilding final : public Acts::TrackBuildingBase { public: CugraphTrackBuilding(std::unique_ptr logger) - : m_logger(std::move(logger)) {} + : m_logger(std::move(logger)), m_device(torch::Device(torch::kCPU)) {} std::vector> operator()(std::any nodes, std::any edges, std::any edge_weights, std::vector &spacepointIDs, - int deviceHint = -1) override; + torch::Device device) override; + torch::Device device() const override { return m_device; }; private: std::unique_ptr m_logger; + torch::Device m_device; const auto &logger() const { return *m_logger; } }; diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp index 9a830bdf073..9196ccf9421 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp @@ -46,7 +46,6 @@ class ExaTrkXPipeline { std::vector> run(std::vector &features, std::vector &spacepointIDs, - int deviceHint = -1, const ExaTrkXHook &hook = {}, ExaTrkXTiming *timing = nullptr) const; diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp index cb1d84f2713..1827698e85a 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp @@ -13,6 +13,8 @@ #include +#include + namespace Ort { class Env; class Session; @@ -32,7 +34,7 @@ class OnnxEdgeClassifier final : public Acts::EdgeClassificationBase { ~OnnxEdgeClassifier(); std::tuple operator()( - std::any nodes, std::any edges, int deviceHint = -1) override; + std::any nodes, std::any edges, torch::Device device) override; Config config() const { return m_cfg; } diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp index 26af6fb3619..501c29acc72 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp @@ -13,6 +13,8 @@ #include +#include + namespace Ort { class Env; class Session; @@ -36,7 +38,7 @@ class OnnxMetricLearning final : public Acts::GraphConstructionBase { std::tuple operator()(std::vector& inputValues, std::size_t numNodes, - int deviceHint = -1) override; + torch::Device device) override; Config config() const { return m_cfg; } diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp index 048f56bfd3c..f2ba7aeb02e 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp @@ -11,6 +11,8 @@ #include #include +#include + namespace Acts { // TODO maybe replace std::any with some kind of variant, @@ -25,12 +27,14 @@ class GraphConstructionBase { /// @param inputValues Flattened input data /// @param numNodes Number of nodes. inputValues.size() / numNodes /// then gives the number of features - /// @param deviceHint Which GPU to pick. Not relevant for CPU-only builds + /// @param device Which GPU device to pick. Not relevant for CPU-only builds /// /// @return (node_tensor, edge_tensore) virtual std::tuple operator()( std::vector &inputValues, std::size_t numNodes, - int deviceHint = -1) = 0; + torch::Device device) = 0; + + virtual torch::Device device() const = 0; virtual ~GraphConstructionBase() = default; }; @@ -41,11 +45,13 @@ class EdgeClassificationBase { /// /// @param nodes Node tensor with shape (n_nodes, n_node_features) /// @param edges Edge-index tensor with shape (2, n_edges) - /// @param deviceHint Which GPU to pick. Not relevant for CPU-only builds + /// @param device Which GPU device to pick. Not relevant for CPU-only builds /// /// @return (node_tensor, edge_tensor, score_tensor) virtual std::tuple operator()( - std::any nodes, std::any edges, int deviceHint = -1) = 0; + std::any nodes, std::any edges, torch::Device device) = 0; + + virtual torch::Device device() const = 0; virtual ~EdgeClassificationBase() = default; }; @@ -58,12 +64,14 @@ class TrackBuildingBase { /// @param edges Edge-index tensor with shape (2, n_edges) /// @param edgeWeights Edge-weights of the previous edge classification phase /// @param spacepointIDs IDs of the nodes (must have size=n_nodes) - /// @param deviceHint Which GPU to pick. Not relevant for CPU-only builds + /// @param device Which GPU device to pick. Not relevant for CPU-only builds /// /// @return tracks (as vectors of node-IDs) virtual std::vector> operator()( std::any nodes, std::any edges, std::any edgeWeights, - std::vector &spacepointIDs, int deviceHint = -1) = 0; + std::vector &spacepointIDs, torch::Device device) = 0; + + virtual torch::Device device() const = 0; virtual ~TrackBuildingBase() = default; }; diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp index bfa7c9054f4..fa8dcc4f957 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp @@ -13,6 +13,8 @@ #include +#include + namespace torch::jit { class Module; } @@ -31,15 +33,17 @@ class TorchEdgeClassifier final : public Acts::EdgeClassificationBase { float cut = 0.21; int nChunks = 1; // NOTE for GNN use 1 bool undirected = false; + int deviceID = 0; }; TorchEdgeClassifier(const Config &cfg, std::unique_ptr logger); ~TorchEdgeClassifier(); std::tuple operator()( - std::any nodes, std::any edges, int deviceHint = -1) override; + std::any nodes, std::any edges, torch::Device device) override; Config config() const { return m_cfg; } + torch::Device device() const override { return m_device; }; private: std::unique_ptr m_logger; @@ -47,6 +51,7 @@ class TorchEdgeClassifier final : public Acts::EdgeClassificationBase { Config m_cfg; c10::DeviceType m_deviceType; + torch::Device m_device; std::unique_ptr m_model; }; diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp index 886687787cf..44a9f121ecf 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp @@ -32,6 +32,7 @@ class TorchMetricLearning final : public Acts::GraphConstructionBase { float rVal = 1.6; int knnVal = 500; bool shuffleDirections = false; + int deviceID = 0; }; TorchMetricLearning(const Config &cfg, std::unique_ptr logger); @@ -39,9 +40,10 @@ class TorchMetricLearning final : public Acts::GraphConstructionBase { std::tuple operator()(std::vector &inputValues, std::size_t numNodes, - int deviceHint = -1) override; + torch::Device device) override; Config config() const { return m_cfg; } + torch::Device device() const override { return m_device; }; private: std::unique_ptr m_logger; @@ -49,6 +51,7 @@ class TorchMetricLearning final : public Acts::GraphConstructionBase { Config m_cfg; c10::DeviceType m_deviceType; + torch::Device m_device; std::unique_ptr m_model; }; diff --git a/Plugins/ExaTrkX/src/BoostTrackBuilding.cpp b/Plugins/ExaTrkX/src/BoostTrackBuilding.cpp index 01e80930960..d67a8321fe7 100644 --- a/Plugins/ExaTrkX/src/BoostTrackBuilding.cpp +++ b/Plugins/ExaTrkX/src/BoostTrackBuilding.cpp @@ -47,7 +47,7 @@ namespace Acts { std::vector> BoostTrackBuilding::operator()( std::any nodes, std::any edges, std::any weights, - std::vector& spacepointIDs, int) { + std::vector& spacepointIDs, torch::Device) { ACTS_DEBUG("Start track building"); const auto edgeTensor = std::any_cast(edges).to(torch::kCPU); const auto edgeWeightTensor = diff --git a/Plugins/ExaTrkX/src/CugraphTrackBuilding.cpp b/Plugins/ExaTrkX/src/CugraphTrackBuilding.cpp index 062701341ba..a0eaafd2d81 100644 --- a/Plugins/ExaTrkX/src/CugraphTrackBuilding.cpp +++ b/Plugins/ExaTrkX/src/CugraphTrackBuilding.cpp @@ -18,7 +18,7 @@ namespace Acts { std::vector> CugraphTrackBuilding::operator()( std::any, std::any edges, std::any edge_weights, - std::vector &spacepointIDs, int) { + std::vector &spacepointIDs, torch::Device) { auto numSpacepoints = spacepointIDs.size(); auto edgesAfterFiltering = std::any_cast>(edges); auto numEdgesAfterF = edgesAfterFiltering.size() / 2; diff --git a/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp b/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp index 3f8e88150f0..01b9af2b59e 100644 --- a/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp +++ b/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp @@ -34,10 +34,10 @@ ExaTrkXPipeline::ExaTrkXPipeline( std::vector> ExaTrkXPipeline::run( std::vector &features, std::vector &spacepointIDs, - int deviceHint, const ExaTrkXHook &hook, ExaTrkXTiming *timing) const { + const ExaTrkXHook &hook, ExaTrkXTiming *timing) const { auto t0 = std::chrono::high_resolution_clock::now(); - auto [nodes, edges] = - (*m_graphConstructor)(features, spacepointIDs.size(), deviceHint); + auto [nodes, edges] = (*m_graphConstructor)(features, spacepointIDs.size(), + m_graphConstructor->device()); auto t1 = std::chrono::high_resolution_clock::now(); if (timing != nullptr) { @@ -47,12 +47,14 @@ std::vector> ExaTrkXPipeline::run( hook(nodes, edges, {}); std::any edge_weights; - timing->classifierTimes.clear(); + if (timing != nullptr) { + timing->classifierTimes.clear(); + } for (auto edgeClassifier : m_edgeClassifiers) { t0 = std::chrono::high_resolution_clock::now(); - auto [newNodes, newEdges, newWeights] = - (*edgeClassifier)(std::move(nodes), std::move(edges), deviceHint); + auto [newNodes, newEdges, newWeights] = (*edgeClassifier)( + std::move(nodes), std::move(edges), edgeClassifier->device()); t1 = std::chrono::high_resolution_clock::now(); if (timing != nullptr) { @@ -67,9 +69,9 @@ std::vector> ExaTrkXPipeline::run( } t0 = std::chrono::high_resolution_clock::now(); - auto res = - (*m_trackBuilder)(std::move(nodes), std::move(edges), - std::move(edge_weights), spacepointIDs, deviceHint); + auto res = (*m_trackBuilder)(std::move(nodes), std::move(edges), + std::move(edge_weights), spacepointIDs, + m_trackBuilder->device()); t1 = std::chrono::high_resolution_clock::now(); if (timing != nullptr) { diff --git a/Plugins/ExaTrkX/src/OnnxMetricLearning.cpp b/Plugins/ExaTrkX/src/OnnxMetricLearning.cpp index e3b8ebb722a..bc6f28e3f7e 100644 --- a/Plugins/ExaTrkX/src/OnnxMetricLearning.cpp +++ b/Plugins/ExaTrkX/src/OnnxMetricLearning.cpp @@ -57,7 +57,7 @@ void OnnxMetricLearning::buildEdgesWrapper(std::vector& embedFeatures, } std::tuple OnnxMetricLearning::operator()( - std::vector& inputValues, std::size_t, int) { + std::vector& inputValues, std::size_t, torch::Device) { Ort::AllocatorWithDefaultOptions allocator; auto memoryInfo = Ort::MemoryInfo::CreateCpu( OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); diff --git a/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp b/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp index d398fef998a..fe7068c87fb 100644 --- a/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp +++ b/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp @@ -19,9 +19,19 @@ namespace Acts { TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg, std::unique_ptr _logger) - : m_logger(std::move(_logger)), m_cfg(cfg) { + : m_logger(std::move(_logger)), + m_cfg(cfg), + m_device(torch::Device(torch::kCPU)) { c10::InferenceMode guard(true); m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU; + if (m_deviceType == torch::kCUDA && cfg.deviceID >= 0 && + static_cast(cfg.deviceID) < torch::cuda::device_count()) { + ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used."); + m_device = torch::Device(torch::kCUDA, cfg.deviceID); + } else { + ACTS_ERROR("GPU device " << cfg.deviceID + << " not available. Using CPU instead."); + } ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "." << TORCH_VERSION_MINOR << "." << TORCH_VERSION_PATCH); @@ -33,7 +43,7 @@ TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg, try { m_model = std::make_unique(); - *m_model = torch::jit::load(m_cfg.modelPath.c_str(), m_deviceType); + *m_model = torch::jit::load(m_cfg.modelPath.c_str(), m_device); m_model->eval(); } catch (const c10::Error& e) { throw std::invalid_argument("Failed to load models: " + e.msg()); @@ -43,10 +53,9 @@ TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg, TorchEdgeClassifier::~TorchEdgeClassifier() {} std::tuple TorchEdgeClassifier::operator()( - std::any inputNodes, std::any inputEdges, int deviceHint) { + std::any inputNodes, std::any inputEdges, torch::Device device) { ACTS_DEBUG("Start edge classification"); c10::InferenceMode guard(true); - const torch::Device device(m_deviceType, deviceHint); auto nodes = std::any_cast(inputNodes).to(device); auto edgeList = std::any_cast(inputEdges).to(device); diff --git a/Plugins/ExaTrkX/src/TorchMetricLearning.cpp b/Plugins/ExaTrkX/src/TorchMetricLearning.cpp index 6ffda448587..1be8c0b1f86 100644 --- a/Plugins/ExaTrkX/src/TorchMetricLearning.cpp +++ b/Plugins/ExaTrkX/src/TorchMetricLearning.cpp @@ -22,9 +22,19 @@ namespace Acts { TorchMetricLearning::TorchMetricLearning(const Config &cfg, std::unique_ptr _logger) - : m_logger(std::move(_logger)), m_cfg(cfg) { + : m_logger(std::move(_logger)), + m_cfg(cfg), + m_device(torch::Device(torch::kCPU)) { c10::InferenceMode guard(true); m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU; + if (m_deviceType == torch::kCUDA && cfg.deviceID >= 0 && + static_cast(cfg.deviceID) < torch::cuda::device_count()) { + ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used."); + m_device = torch::Device(torch::kCUDA, cfg.deviceID); + } else { + ACTS_ERROR("GPU device " << cfg.deviceID + << " not available. Using CPU instead."); + } ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "." << TORCH_VERSION_MINOR << "." << TORCH_VERSION_PATCH); @@ -36,7 +46,7 @@ TorchMetricLearning::TorchMetricLearning(const Config &cfg, try { m_model = std::make_unique(); - *m_model = torch::jit::load(m_cfg.modelPath, m_deviceType); + *m_model = torch::jit::load(m_cfg.modelPath, m_device); m_model->eval(); } catch (const c10::Error &e) { throw std::invalid_argument("Failed to load models: " + e.msg()); @@ -46,10 +56,10 @@ TorchMetricLearning::TorchMetricLearning(const Config &cfg, TorchMetricLearning::~TorchMetricLearning() {} std::tuple TorchMetricLearning::operator()( - std::vector &inputValues, std::size_t numNodes, int deviceHint) { + std::vector &inputValues, std::size_t numNodes, + torch::Device device) { ACTS_DEBUG("Start graph construction"); c10::InferenceMode guard(true); - const torch::Device device(m_deviceType, deviceHint); const int64_t numAllFeatures = inputValues.size() / numNodes; From 55c01e033251ea57dda96bc3c1d32365a09d12a8 Mon Sep 17 00:00:00 2001 From: hrzhao76 Date: Tue, 23 Jan 2024 17:24:37 -0800 Subject: [PATCH 2/8] add device guard; ci format --- Examples/Python/src/ExaTrkXTrackFinding.cpp | 3 +-- .../include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp | 2 +- Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp | 4 +++- Plugins/ExaTrkX/src/TorchMetricLearning.cpp | 6 ++++-- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/Examples/Python/src/ExaTrkXTrackFinding.cpp b/Examples/Python/src/ExaTrkXTrackFinding.cpp index bef42f04e9f..7c2685b9bd9 100644 --- a/Examples/Python/src/ExaTrkXTrackFinding.cpp +++ b/Examples/Python/src/ExaTrkXTrackFinding.cpp @@ -210,8 +210,7 @@ void addExaTrkXTrackFinding(Context &ctx) { py::arg("graphConstructor"), py::arg("edgeClassifiers"), py::arg("trackBuilder"), py::arg("level")) .def("run", &ExaTrkXPipeline::run, py::arg("features"), - py::arg("spacepoints"), - py::arg("hook") = Acts::ExaTrkXHook{}, + py::arg("spacepoints"), py::arg("hook") = Acts::ExaTrkXHook{}, py::arg("timing") = nullptr); } diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp index 44a9f121ecf..1acef0fda69 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp @@ -32,7 +32,7 @@ class TorchMetricLearning final : public Acts::GraphConstructionBase { float rVal = 1.6; int knnVal = 500; bool shuffleDirections = false; - int deviceID = 0; + int deviceID = 0; // default is the first GPU if available }; TorchMetricLearning(const Config &cfg, std::unique_ptr logger); diff --git a/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp b/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp index fe7068c87fb..df24dd6b4d1 100644 --- a/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp +++ b/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp @@ -8,6 +8,7 @@ #include "Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp" +#include #include #include @@ -25,7 +26,7 @@ TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg, c10::InferenceMode guard(true); m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU; if (m_deviceType == torch::kCUDA && cfg.deviceID >= 0 && - static_cast(cfg.deviceID) < torch::cuda::device_count()) { + static_cast(cfg.deviceID) < torch::cuda::device_count()) { ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used."); m_device = torch::Device(torch::kCUDA, cfg.deviceID); } else { @@ -56,6 +57,7 @@ std::tuple TorchEdgeClassifier::operator()( std::any inputNodes, std::any inputEdges, torch::Device device) { ACTS_DEBUG("Start edge classification"); c10::InferenceMode guard(true); + c10::cuda::CUDAGuard device_guard(device.index()); auto nodes = std::any_cast(inputNodes).to(device); auto edgeList = std::any_cast(inputEdges).to(device); diff --git a/Plugins/ExaTrkX/src/TorchMetricLearning.cpp b/Plugins/ExaTrkX/src/TorchMetricLearning.cpp index 1be8c0b1f86..ef4dd5f62bf 100644 --- a/Plugins/ExaTrkX/src/TorchMetricLearning.cpp +++ b/Plugins/ExaTrkX/src/TorchMetricLearning.cpp @@ -11,6 +11,7 @@ #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp" #include "Acts/Plugins/ExaTrkX/detail/buildEdges.hpp" +#include #include #include @@ -28,7 +29,7 @@ TorchMetricLearning::TorchMetricLearning(const Config &cfg, c10::InferenceMode guard(true); m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU; if (m_deviceType == torch::kCUDA && cfg.deviceID >= 0 && - static_cast(cfg.deviceID) < torch::cuda::device_count()) { + static_cast(cfg.deviceID) < torch::cuda::device_count()) { ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used."); m_device = torch::Device(torch::kCUDA, cfg.deviceID); } else { @@ -56,10 +57,11 @@ TorchMetricLearning::TorchMetricLearning(const Config &cfg, TorchMetricLearning::~TorchMetricLearning() {} std::tuple TorchMetricLearning::operator()( - std::vector &inputValues, std::size_t numNodes, + std::vector &inputValues, std::std::size_t numNodes, torch::Device device) { ACTS_DEBUG("Start graph construction"); c10::InferenceMode guard(true); + c10::cuda::CUDAGuard device_guard(device.index()); const int64_t numAllFeatures = inputValues.size() / numNodes; From 39ee6e57240c759b38ebca6a0380a6cf03cabe9f Mon Sep 17 00:00:00 2001 From: hrzhao76 Date: Mon, 5 Feb 2024 03:38:16 -0800 Subject: [PATCH 3/8] typo fixed; change frnn global declaration --- Plugins/ExaTrkX/src/TorchMetricLearning.cpp | 2 +- thirdparty/FRNN/CMakeLists.txt | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Plugins/ExaTrkX/src/TorchMetricLearning.cpp b/Plugins/ExaTrkX/src/TorchMetricLearning.cpp index ef4dd5f62bf..6a609948c74 100644 --- a/Plugins/ExaTrkX/src/TorchMetricLearning.cpp +++ b/Plugins/ExaTrkX/src/TorchMetricLearning.cpp @@ -57,7 +57,7 @@ TorchMetricLearning::TorchMetricLearning(const Config &cfg, TorchMetricLearning::~TorchMetricLearning() {} std::tuple TorchMetricLearning::operator()( - std::vector &inputValues, std::std::size_t numNodes, + std::vector &inputValues, std::size_t numNodes, torch::Device device) { ACTS_DEBUG("Start graph construction"); c10::InferenceMode guard(true); diff --git a/thirdparty/FRNN/CMakeLists.txt b/thirdparty/FRNN/CMakeLists.txt index 9aa70a450aa..1f8269f8b25 100644 --- a/thirdparty/FRNN/CMakeLists.txt +++ b/thirdparty/FRNN/CMakeLists.txt @@ -11,9 +11,9 @@ include(FetchContent) message(STATUS "Building FRNN as part of the ACTS project") -set(ACTS_FRNN_GIT_REPOSITORY "https://github.com/lxxue/FRNN" +set(ACTS_FRNN_GIT_REPOSITORY "https://github.com/hrzhao76/FRNN/" CACHE STRING "Git repository to take FRNN from") -set(ACTS_FRNN_GIT_TAG "3e370d8d9073d4e130363faf87d2370598b5fbf2" +set(ACTS_FRNN_GIT_TAG "5f8a48b0022300cd2863119f5646a5f31373e0c8" CACHE STRING "Version of FRNN to build") mark_as_advanced(ACTS_FRNN_GIT_REPOSITORY ACTS_FRNN_GIT_TAG) From b6e52ef8c12911e13b1761539b9999e179b3a92e Mon Sep 17 00:00:00 2001 From: hrzhao76 Date: Tue, 20 Feb 2024 07:23:08 -0800 Subject: [PATCH 4/8] choose proper ACTS log for Torch; adapt ONNX to the new base class; fix a typo in READEM --- Plugins/ExaTrkX/README.md | 2 +- .../include/Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp | 6 ++++-- .../include/Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp | 8 +++++--- Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp | 6 ++++-- Plugins/ExaTrkX/src/OnnxMetricLearning.cpp | 4 +++- Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp | 4 ++-- Plugins/ExaTrkX/src/TorchMetricLearning.cpp | 4 ++-- 7 files changed, 21 insertions(+), 13 deletions(-) diff --git a/Plugins/ExaTrkX/README.md b/Plugins/ExaTrkX/README.md index 9db6678f4ac..8a7944bd7f2 100644 --- a/Plugins/ExaTrkX/README.md +++ b/Plugins/ExaTrkX/README.md @@ -8,7 +8,7 @@ To build the plugin, enable the appropriate CMake options: ```bash cmake -B -S \ - -D ACTS_BUILD_EXATRKX_PLUGIN=ON \ + -D ACTS_BUILD_PLUGIN_EXATRKX=ON \ -D ACTS_EXATRKX_ENABLE_TORCH=ON/OFF \ -D ACTS_EXATRKX_ENABLE_ONNX=ON/OFF \ -D ACTS_BUILD_EXAMPLES_EXATRKX=ON \ diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp index 1827698e85a..c506f3414a0 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp @@ -34,16 +34,18 @@ class OnnxEdgeClassifier final : public Acts::EdgeClassificationBase { ~OnnxEdgeClassifier(); std::tuple operator()( - std::any nodes, std::any edges, torch::Device device) override; + std::any nodes, std::any edges, + torch::Device device = torch::Device(torch::kCPU)) override; Config config() const { return m_cfg; } + torch::Device device() const override { return m_device; }; private: std::unique_ptr m_logger; const auto &logger() const { return *m_logger; } Config m_cfg; - + torch::Device m_device; std::unique_ptr m_env; std::unique_ptr m_model; diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp index 501c29acc72..b389da1181c 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp @@ -36,11 +36,12 @@ class OnnxMetricLearning final : public Acts::GraphConstructionBase { OnnxMetricLearning(const Config& cfg, std::unique_ptr logger); ~OnnxMetricLearning(); - std::tuple operator()(std::vector& inputValues, - std::size_t numNodes, - torch::Device device) override; + std::tuple operator()( + std::vector& inputValues, std::size_t numNodes, + torch::Device device = torch::Device(torch::kCPU)) override; Config config() const { return m_cfg; } + torch::Device device() const override { return m_device; }; private: void buildEdgesWrapper(std::vector& embedFeatures, @@ -51,6 +52,7 @@ class OnnxMetricLearning final : public Acts::GraphConstructionBase { const auto& logger() const { return *m_logger; } Config m_cfg; + torch::Device m_device; std::unique_ptr m_env; std::unique_ptr m_model; }; diff --git a/Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp b/Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp index ccdaff82fd9..e2326a06d2c 100644 --- a/Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp +++ b/Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp @@ -19,7 +19,9 @@ namespace Acts { OnnxEdgeClassifier::OnnxEdgeClassifier(const Config &cfg, std::unique_ptr logger) - : m_logger(std::move(logger)), m_cfg(cfg) { + : m_logger(std::move(logger)), + m_cfg(cfg), + m_device(torch::Device(torch::kCPU)) { m_env = std::make_unique(ORT_LOGGING_LEVEL_WARNING, "ExaTrkX - edge classifier"); @@ -44,7 +46,7 @@ OnnxEdgeClassifier::OnnxEdgeClassifier(const Config &cfg, OnnxEdgeClassifier::~OnnxEdgeClassifier() {} std::tuple OnnxEdgeClassifier::operator()( - std::any inputNodes, std::any inputEdges, int) { + std::any inputNodes, std::any inputEdges, torch::Device) { Ort::AllocatorWithDefaultOptions allocator; auto memoryInfo = Ort::MemoryInfo::CreateCpu( OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); diff --git a/Plugins/ExaTrkX/src/OnnxMetricLearning.cpp b/Plugins/ExaTrkX/src/OnnxMetricLearning.cpp index bc6f28e3f7e..0f4addfe197 100644 --- a/Plugins/ExaTrkX/src/OnnxMetricLearning.cpp +++ b/Plugins/ExaTrkX/src/OnnxMetricLearning.cpp @@ -19,7 +19,9 @@ namespace Acts { OnnxMetricLearning::OnnxMetricLearning(const Config& cfg, std::unique_ptr logger) - : m_logger(std::move(logger)), m_cfg(cfg) { + : m_logger(std::move(logger)), + m_cfg(cfg), + m_device(torch::Device(torch::kCPU)) { m_env = std::make_unique(ORT_LOGGING_LEVEL_WARNING, "ExaTrkX - metric learning"); diff --git a/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp b/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp index df24dd6b4d1..add85298864 100644 --- a/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp +++ b/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp @@ -30,8 +30,8 @@ TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg, ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used."); m_device = torch::Device(torch::kCUDA, cfg.deviceID); } else { - ACTS_ERROR("GPU device " << cfg.deviceID - << " not available. Using CPU instead."); + ACTS_WARNING("GPU device " << cfg.deviceID + << " not available. Using CPU instead."); } ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "." << TORCH_VERSION_MINOR << "." diff --git a/Plugins/ExaTrkX/src/TorchMetricLearning.cpp b/Plugins/ExaTrkX/src/TorchMetricLearning.cpp index 6a609948c74..7b69f3ef882 100644 --- a/Plugins/ExaTrkX/src/TorchMetricLearning.cpp +++ b/Plugins/ExaTrkX/src/TorchMetricLearning.cpp @@ -33,8 +33,8 @@ TorchMetricLearning::TorchMetricLearning(const Config &cfg, ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used."); m_device = torch::Device(torch::kCUDA, cfg.deviceID); } else { - ACTS_ERROR("GPU device " << cfg.deviceID - << " not available. Using CPU instead."); + ACTS_WARNING("GPU device " << cfg.deviceID + << " not available. Using CPU instead."); } ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "." << TORCH_VERSION_MINOR << "." From 49a628a8f20eecb823d1428fadd6d6f61fb0307b Mon Sep 17 00:00:00 2001 From: hrzhao76 Date: Sun, 3 Mar 2024 23:58:00 -0800 Subject: [PATCH 5/8] set default params to fix ci bridge --- .../include/Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp | 8 ++++---- .../include/Acts/Plugins/ExaTrkX/CugraphTrackBuilding.hpp | 8 ++++---- Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp | 8 +++++--- .../include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp | 3 ++- .../include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp | 6 +++--- 5 files changed, 18 insertions(+), 15 deletions(-) diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp index a82203c5d7c..9157c58a385 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp @@ -22,10 +22,10 @@ class BoostTrackBuilding final : public Acts::TrackBuildingBase { BoostTrackBuilding(std::unique_ptr logger) : m_logger(std::move(logger)), m_device(torch::Device(torch::kCPU)) {} - std::vector> operator()(std::any nodes, std::any edges, - std::any edge_weights, - std::vector &spacepointIDs, - torch::Device device) override; + std::vector> operator()( + std::any nodes, std::any edges, std::any edge_weights, + std::vector &spacepointIDs, + torch::Device device = torch::Device(torch::kCPU)) override; torch::Device device() const override { return m_device; }; private: diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/CugraphTrackBuilding.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/CugraphTrackBuilding.hpp index 40a29bcb4eb..06d8d0bbeba 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/CugraphTrackBuilding.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/CugraphTrackBuilding.hpp @@ -22,10 +22,10 @@ class CugraphTrackBuilding final : public Acts::TrackBuildingBase { CugraphTrackBuilding(std::unique_ptr logger) : m_logger(std::move(logger)), m_device(torch::Device(torch::kCPU)) {} - std::vector> operator()(std::any nodes, std::any edges, - std::any edge_weights, - std::vector &spacepointIDs, - torch::Device device) override; + std::vector> operator()( + std::any nodes, std::any edges, std::any edge_weights, + std::vector &spacepointIDs, + torch::Device device = torch::Device(torch::kCPU)) override; torch::Device device() const override { return m_device; }; private: diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp index f2ba7aeb02e..274b8c0f494 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp @@ -32,7 +32,7 @@ class GraphConstructionBase { /// @return (node_tensor, edge_tensore) virtual std::tuple operator()( std::vector &inputValues, std::size_t numNodes, - torch::Device device) = 0; + torch::Device device = torch::Device(torch::kCPU)) = 0; virtual torch::Device device() const = 0; @@ -49,7 +49,8 @@ class EdgeClassificationBase { /// /// @return (node_tensor, edge_tensor, score_tensor) virtual std::tuple operator()( - std::any nodes, std::any edges, torch::Device device) = 0; + std::any nodes, std::any edges, + torch::Device device = torch::Device(torch::kCPU)) = 0; virtual torch::Device device() const = 0; @@ -69,7 +70,8 @@ class TrackBuildingBase { /// @return tracks (as vectors of node-IDs) virtual std::vector> operator()( std::any nodes, std::any edges, std::any edgeWeights, - std::vector &spacepointIDs, torch::Device device) = 0; + std::vector &spacepointIDs, + torch::Device device = torch::Device(torch::kCPU)) = 0; virtual torch::Device device() const = 0; diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp index fa8dcc4f957..159e989a606 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp @@ -40,7 +40,8 @@ class TorchEdgeClassifier final : public Acts::EdgeClassificationBase { ~TorchEdgeClassifier(); std::tuple operator()( - std::any nodes, std::any edges, torch::Device device) override; + std::any nodes, std::any edges, + torch::Device device = torch::Device(torch::kCPU)) override; Config config() const { return m_cfg; } torch::Device device() const override { return m_device; }; diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp index 1acef0fda69..779b744ce17 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp @@ -38,9 +38,9 @@ class TorchMetricLearning final : public Acts::GraphConstructionBase { TorchMetricLearning(const Config &cfg, std::unique_ptr logger); ~TorchMetricLearning(); - std::tuple operator()(std::vector &inputValues, - std::size_t numNodes, - torch::Device device) override; + std::tuple operator()( + std::vector &inputValues, std::size_t numNodes, + torch::Device device = torch::Device(torch::kCPU)) override; Config config() const { return m_cfg; } torch::Device device() const override { return m_device; }; From 235ca35043fac1aa18fedcc4300e4980346aa70c Mon Sep 17 00:00:00 2001 From: hrzhao76 Date: Fri, 21 Jun 2024 03:49:54 -0700 Subject: [PATCH 6/8] improve the device selection logic --- Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp | 16 ++++++++++------ Plugins/ExaTrkX/src/TorchMetricLearning.cpp | 17 +++++++++++------ 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp b/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp index add85298864..9ada978ac66 100644 --- a/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp +++ b/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp @@ -25,14 +25,18 @@ TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg, m_device(torch::Device(torch::kCPU)) { c10::InferenceMode guard(true); m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU; - if (m_deviceType == torch::kCUDA && cfg.deviceID >= 0 && - static_cast(cfg.deviceID) < torch::cuda::device_count()) { - ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used."); - m_device = torch::Device(torch::kCUDA, cfg.deviceID); + if (m_deviceType == torch::kCPU) { + ACTS_INFO("Running on CPU..."); } else { - ACTS_WARNING("GPU device " << cfg.deviceID - << " not available. Using CPU instead."); + if (cfg.deviceID >= 0 && + static_cast(cfg.deviceID) < torch::cuda::device_count()) { + ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used."); + m_device = torch::Device(torch::kCUDA, cfg.deviceID); + } else { + ACTS_FATAL("GPU device " << cfg.deviceID << " not available. "); + } } + ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "." << TORCH_VERSION_MINOR << "." << TORCH_VERSION_PATCH); diff --git a/Plugins/ExaTrkX/src/TorchMetricLearning.cpp b/Plugins/ExaTrkX/src/TorchMetricLearning.cpp index 7b69f3ef882..c2a13bba0cb 100644 --- a/Plugins/ExaTrkX/src/TorchMetricLearning.cpp +++ b/Plugins/ExaTrkX/src/TorchMetricLearning.cpp @@ -28,14 +28,19 @@ TorchMetricLearning::TorchMetricLearning(const Config &cfg, m_device(torch::Device(torch::kCPU)) { c10::InferenceMode guard(true); m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU; - if (m_deviceType == torch::kCUDA && cfg.deviceID >= 0 && - static_cast(cfg.deviceID) < torch::cuda::device_count()) { - ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used."); - m_device = torch::Device(torch::kCUDA, cfg.deviceID); + + if (m_deviceType == torch::kCPU) { + ACTS_INFO("Running on CPU..."); } else { - ACTS_WARNING("GPU device " << cfg.deviceID - << " not available. Using CPU instead."); + if (cfg.deviceID >= 0 && + static_cast(cfg.deviceID) < torch::cuda::device_count()) { + ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used."); + m_device = torch::Device(torch::kCUDA, cfg.deviceID); + } else { + ACTS_FATAL("GPU device " << cfg.deviceID << " not available. "); + } } + ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "." << TORCH_VERSION_MINOR << "." << TORCH_VERSION_PATCH); From 021dea9ff0ac7aab61b21d97c4550b86dce21203 Mon Sep 17 00:00:00 2001 From: hrzhao76 Date: Fri, 21 Jun 2024 06:27:57 -0700 Subject: [PATCH 7/8] no CUDAGuard if device is kCPU --- Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp | 11 ++++++++--- Plugins/ExaTrkX/src/TorchMetricLearning.cpp | 11 ++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp b/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp index 9ada978ac66..c5af9917cf0 100644 --- a/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp +++ b/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp @@ -26,14 +26,15 @@ TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg, c10::InferenceMode guard(true); m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU; if (m_deviceType == torch::kCPU) { - ACTS_INFO("Running on CPU..."); + ACTS_DEBUG("Running on CPU..."); } else { if (cfg.deviceID >= 0 && static_cast(cfg.deviceID) < torch::cuda::device_count()) { ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used."); m_device = torch::Device(torch::kCUDA, cfg.deviceID); } else { - ACTS_FATAL("GPU device " << cfg.deviceID << " not available. "); + ACTS_WARNING("GPU device " << cfg.deviceID + << " not available, falling back to CPU."); } } @@ -61,7 +62,11 @@ std::tuple TorchEdgeClassifier::operator()( std::any inputNodes, std::any inputEdges, torch::Device device) { ACTS_DEBUG("Start edge classification"); c10::InferenceMode guard(true); - c10::cuda::CUDAGuard device_guard(device.index()); + + // add a protection to avoid calling for kCPU + if (device.is_cuda()) { + c10::cuda::CUDAGuard device_guard(device.index()); + } auto nodes = std::any_cast(inputNodes).to(device); auto edgeList = std::any_cast(inputEdges).to(device); diff --git a/Plugins/ExaTrkX/src/TorchMetricLearning.cpp b/Plugins/ExaTrkX/src/TorchMetricLearning.cpp index c2a13bba0cb..ec2e251f5ff 100644 --- a/Plugins/ExaTrkX/src/TorchMetricLearning.cpp +++ b/Plugins/ExaTrkX/src/TorchMetricLearning.cpp @@ -30,14 +30,15 @@ TorchMetricLearning::TorchMetricLearning(const Config &cfg, m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU; if (m_deviceType == torch::kCPU) { - ACTS_INFO("Running on CPU..."); + ACTS_DEBUG("Running on CPU..."); } else { if (cfg.deviceID >= 0 && static_cast(cfg.deviceID) < torch::cuda::device_count()) { ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used."); m_device = torch::Device(torch::kCUDA, cfg.deviceID); } else { - ACTS_FATAL("GPU device " << cfg.deviceID << " not available. "); + ACTS_WARNING("GPU device " << cfg.deviceID + << " not available, falling back to CPU."); } } @@ -66,7 +67,11 @@ std::tuple TorchMetricLearning::operator()( torch::Device device) { ACTS_DEBUG("Start graph construction"); c10::InferenceMode guard(true); - c10::cuda::CUDAGuard device_guard(device.index()); + + // add a protection to avoid calling for kCPU + if (device.is_cuda()) { + c10::cuda::CUDAGuard device_guard(device.index()); + } const int64_t numAllFeatures = inputValues.size() / numNodes; From d3805488a5e383680b02cdf02fe5f514b248ece8 Mon Sep 17 00:00:00 2001 From: hrzhao76 Date: Fri, 21 Jun 2024 06:38:09 -0700 Subject: [PATCH 8/8] update the FRNN lib --- cmake/ActsExternSources.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/ActsExternSources.cmake b/cmake/ActsExternSources.cmake index 10bfbe13a3e..3aef0b52b25 100644 --- a/cmake/ActsExternSources.cmake +++ b/cmake/ActsExternSources.cmake @@ -27,7 +27,7 @@ set( ACTS_DFELIBS_SOURCE mark_as_advanced( ACTS_DFELIBS_SOURCE ) set( ACTS_FRNN_SOURCE - "GIT_REPOSITORY;https://github.com/lxxue/FRNN;GIT_TAG;3e370d8d9073d4e130363faf87d2370598b5fbf2" CACHE STRING "Source to take FRNN from") + "GIT_REPOSITORY;https://github.com/hrzhao76/FRNN/;GIT_TAG;5f8a48b0022300cd2863119f5646a5f31373e0c8" CACHE STRING "Source to take FRNN from") mark_as_advanced( ACTS_FRNN_SOURCE ) set( ACTS_GEOMODEL_SOURCE