From ea1aafbdb14a208a30a39dc0a03a5276c8f9d08a Mon Sep 17 00:00:00 2001 From: Corentin Allaire Date: Tue, 25 Jan 2022 14:39:55 +0100 Subject: [PATCH 1/2] small fix --- Examples/Python/src/ExaTrkXTrackFinding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Examples/Python/src/ExaTrkXTrackFinding.cpp b/Examples/Python/src/ExaTrkXTrackFinding.cpp index b32f193d472..37e299c3a90 100644 --- a/Examples/Python/src/ExaTrkXTrackFinding.cpp +++ b/Examples/Python/src/ExaTrkXTrackFinding.cpp @@ -29,7 +29,7 @@ void addExaTrkXTrackFinding(Context& ctx) { auto [m, mex] = ctx.get("main", "examples"); { - using Alg = ActsExamples::TrackFindingAlgorithm; + using Alg = ActsExamples::TrackFindingMLBasedAlgorithm; using Config = Alg::Config; auto alg = From 58b2ba3232bd5cd06b0db8c7df9f9368060622dd Mon Sep 17 00:00:00 2001 From: Corentin Allaire Date: Tue, 25 Jan 2022 16:04:45 +0100 Subject: [PATCH 2/2] add ExaTrkXTrackFinding in the bindings --- .../TrackFindingMLBasedAlgorithm.hpp | 6 ++--- .../src/TrackFindingMLBasedAlgorithm.cpp | 5 +++- Examples/Python/src/ExaTrkXTrackFinding.cpp | 24 ++++++++++++++++++- Examples/Python/src/ModuleEntry.cpp | 1 - Examples/Scripts/Python/ExaTrkX.py | 17 ++++++++----- 5 files changed, 40 insertions(+), 13 deletions(-) diff --git a/Examples/Algorithms/TrackFindingMLBased/include/ActsExamples/TrackFindingMLBased/TrackFindingMLBasedAlgorithm.hpp b/Examples/Algorithms/TrackFindingMLBased/include/ActsExamples/TrackFindingMLBased/TrackFindingMLBasedAlgorithm.hpp index ca7710e11de..8c1cd532073 100644 --- a/Examples/Algorithms/TrackFindingMLBased/include/ActsExamples/TrackFindingMLBased/TrackFindingMLBasedAlgorithm.hpp +++ b/Examples/Algorithms/TrackFindingMLBased/include/ActsExamples/TrackFindingMLBased/TrackFindingMLBasedAlgorithm.hpp @@ -17,8 +17,8 @@ class TrackFindingMLBasedAlgorithm final : public BareAlgorithm { /// Output protoTracks collection. std::string outputProtoTracks; - /// Path to the onnx model - std::string onnxModelDir; + /// ML based track finder + std::shared_ptr trackFinderML; // NOTE the other config parameters for the Exa.TrkX class for now are just initialized as the defaults }; @@ -41,8 +41,6 @@ class TrackFindingMLBasedAlgorithm final : public BareAlgorithm { const Config& config() const { return m_cfg; } private: - ExaTrkXTrackFinding m_exaTrkx; - // configuration Config m_cfg; }; diff --git a/Examples/Algorithms/TrackFindingMLBased/src/TrackFindingMLBasedAlgorithm.cpp b/Examples/Algorithms/TrackFindingMLBased/src/TrackFindingMLBasedAlgorithm.cpp index e893a29a10d..f00c9cd5cca 100644 --- a/Examples/Algorithms/TrackFindingMLBased/src/TrackFindingMLBasedAlgorithm.cpp +++ b/Examples/Algorithms/TrackFindingMLBased/src/TrackFindingMLBasedAlgorithm.cpp @@ -32,6 +32,9 @@ ActsExamples::TrackFindingMLBasedAlgorithm::TrackFindingMLBasedAlgorithm( if (m_cfg.outputProtoTracks.empty()) { throw std::invalid_argument("Missing protoTrack output collection"); } + if (m_cfg.trackFinderML.empty()) { + throw std::invalid_argument("Missing track finder"); + } } ActsExamples::ProcessCode ActsExamples::TrackFindingMLBasedAlgorithm::execute( @@ -62,7 +65,7 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingMLBasedAlgorithm::execute( // ProtoTrackContainer protoTracks; std::vector > trackCandidates; - m_exaTrkx.getTracks(inputValues, spacepointIDs, trackCandidates); + m_cfg.trackFinderML->getTracks(inputValues, spacepointIDs, trackCandidates); std::vector protoTracks; for(auto& x: trackCandidates){ diff --git a/Examples/Python/src/ExaTrkXTrackFinding.cpp b/Examples/Python/src/ExaTrkXTrackFinding.cpp index f97b117a2fc..7c0c81e9a70 100644 --- a/Examples/Python/src/ExaTrkXTrackFinding.cpp +++ b/Examples/Python/src/ExaTrkXTrackFinding.cpp @@ -29,6 +29,28 @@ namespace Acts::Python { void addExaTrkXTrackFinding(Context& ctx) { auto [m, mex] = ctx.get("main", "examples"); + { + using Alg = Acts::Plugin::ExaTrkXTrackFinding; + using Config = Acts::Plugin::ExaTrkXTrackFinding::Config; + + auto alg = + py::class_>( + mex, "ExaTrkXTrackFinding") + .def(py::init(), + py::arg("config")) + .def_property_readonly("config", &Alg::config); + + auto c = py::class_(alg, "Config").def(py::init<>()); + ACTS_PYTHON_STRUCT_BEGIN(c, Config); + ACTS_PYTHON_MEMBER(inputMLModuleDir); + ACTS_PYTHON_MEMBER(spacepointFeatures); + ACTS_PYTHON_MEMBER(embeddingDim); + ACTS_PYTHON_MEMBER(rVal); + ACTS_PYTHON_MEMBER(knnVal); + ACTS_PYTHON_MEMBER(filterCut); + ACTS_PYTHON_STRUCT_END(); + } + { using Alg = ActsExamples::TrackFindingMLBasedAlgorithm; using Config = Alg::Config; @@ -44,7 +66,7 @@ void addExaTrkXTrackFinding(Context& ctx) { ACTS_PYTHON_STRUCT_BEGIN(c, Config); ACTS_PYTHON_MEMBER(inputSpacePoints); ACTS_PYTHON_MEMBER(outputProtoTracks); - ACTS_PYTHON_MEMBER(onnxModelDir); + ACTS_PYTHON_MEMBER(trackFinderML); ACTS_PYTHON_STRUCT_END(); } diff --git a/Examples/Python/src/ModuleEntry.cpp b/Examples/Python/src/ModuleEntry.cpp index 86656712311..b276bdfbe97 100644 --- a/Examples/Python/src/ModuleEntry.cpp +++ b/Examples/Python/src/ModuleEntry.cpp @@ -84,7 +84,6 @@ void addGenerators(Context& ctx); void addTruthTracking(Context& ctx); void addTrackFitting(Context& ctx); void addTrackFinding(Context& ctx); -void addExaTrkXTrackFinding(Context& ctx); void addVertexing(Context& ctx); // Plugins diff --git a/Examples/Scripts/Python/ExaTrkX.py b/Examples/Scripts/Python/ExaTrkX.py index 4447262997b..65680f1c93d 100755 --- a/Examples/Scripts/Python/ExaTrkX.py +++ b/Examples/Scripts/Python/ExaTrkX.py @@ -135,17 +135,22 @@ def runExaTrkX( # It takes all the source links created from truth hit smearing, seeds from # truth particle smearing and source link selection config - onnx_model_dir="/home/xju/ocean/code/Tracking-ML-Exa.TrkX/Pipelines/TrackML_Example/onnx_models" - #ACTS_INFO("ML model dir: " << onnx_model_dir) - + exaTrkxFinding = acts.examples.ExaTrkXTrackFinding( + inputMLModuleDir="/home/xju/ocean/code/Tracking-ML-Exa.TrkX/Pipelines/TrackML_Example/onnx_models", + spacepointFeatures=3, + embeddingDim=8, + rVal=1.6, + knnVal=500, + filterCut=0.21 + ) - trackFinder = acts.examples.TrackFindingMLBasedAlgorithm( + trackFinderAlg = acts.examples.TrackFindingMLBasedAlgorithm( level=acts.logging.INFO, inputSpacePoints="spacepoints", outputProtoTracks="protoTracks", - onnxModelDir=onnx_model_dir + trackFinderML=exaTrkxFinding ) - s.addAlgorithm(trackFinder) + s.addAlgorithm(trackFinderAlg) # Write truth track finding / seeding performance