Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add corentin's changes #2

Merged
merged 3 commits into from
Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class TrackFindingMLBasedAlgorithm final : public BareAlgorithm {
/// Output protoTracks collection.
std::string outputProtoTracks;

/// Path to the onnx model
std::string onnxModelDir;
/// ML based track finder
std::shared_ptr<ExaTrkXTrackFinding> trackFinderML;

// NOTE the other config parameters for the Exa.TrkX class for now are just initialized as the defaults
};
Expand All @@ -41,8 +41,6 @@ class TrackFindingMLBasedAlgorithm final : public BareAlgorithm {
const Config& config() const { return m_cfg; }

private:
ExaTrkXTrackFinding m_exaTrkx;

// configuration
Config m_cfg;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ ActsExamples::TrackFindingMLBasedAlgorithm::TrackFindingMLBasedAlgorithm(
if (m_cfg.outputProtoTracks.empty()) {
throw std::invalid_argument("Missing protoTrack output collection");
}
if (m_cfg.trackFinderML.empty()) {
throw std::invalid_argument("Missing track finder");
}
}

ActsExamples::ProcessCode ActsExamples::TrackFindingMLBasedAlgorithm::execute(
Expand Down Expand Up @@ -62,7 +65,7 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingMLBasedAlgorithm::execute(

// ProtoTrackContainer protoTracks;
std::vector<std::vector<uint32_t> > trackCandidates;
m_exaTrkx.getTracks(inputValues, spacepointIDs, trackCandidates);
m_cfg.trackFinderML->getTracks(inputValues, spacepointIDs, trackCandidates);

std::vector<ProtoTrack> protoTracks;
for(auto& x: trackCandidates){
Expand Down
24 changes: 23 additions & 1 deletion Examples/Python/src/ExaTrkXTrackFinding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,28 @@ namespace Acts::Python {
void addExaTrkXTrackFinding(Context& ctx) {
auto [m, mex] = ctx.get("main", "examples");

{
using Alg = Acts::Plugin::ExaTrkXTrackFinding;
using Config = Acts::Plugin::ExaTrkXTrackFinding::Config;

auto alg =
py::class_<Alg, std::shared_ptr<Alg>>(
mex, "ExaTrkXTrackFinding")
.def(py::init<const Config&>(),
py::arg("config"))
.def_property_readonly("config", &Alg::config);

auto c = py::class_<Config>(alg, "Config").def(py::init<>());
ACTS_PYTHON_STRUCT_BEGIN(c, Config);
ACTS_PYTHON_MEMBER(inputMLModuleDir);
ACTS_PYTHON_MEMBER(spacepointFeatures);
ACTS_PYTHON_MEMBER(embeddingDim);
ACTS_PYTHON_MEMBER(rVal);
ACTS_PYTHON_MEMBER(knnVal);
ACTS_PYTHON_MEMBER(filterCut);
ACTS_PYTHON_STRUCT_END();
}

{
using Alg = ActsExamples::TrackFindingMLBasedAlgorithm;
using Config = Alg::Config;
Expand All @@ -44,7 +66,7 @@ void addExaTrkXTrackFinding(Context& ctx) {
ACTS_PYTHON_STRUCT_BEGIN(c, Config);
ACTS_PYTHON_MEMBER(inputSpacePoints);
ACTS_PYTHON_MEMBER(outputProtoTracks);
ACTS_PYTHON_MEMBER(onnxModelDir);
ACTS_PYTHON_MEMBER(trackFinderML);
ACTS_PYTHON_STRUCT_END();
}

Expand Down
1 change: 0 additions & 1 deletion Examples/Python/src/ModuleEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ void addGenerators(Context& ctx);
void addTruthTracking(Context& ctx);
void addTrackFitting(Context& ctx);
void addTrackFinding(Context& ctx);
void addExaTrkXTrackFinding(Context& ctx);
void addVertexing(Context& ctx);

// Plugins
Expand Down
17 changes: 11 additions & 6 deletions Examples/Scripts/Python/ExaTrkX.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,22 @@ def runExaTrkX(
# It takes all the source links created from truth hit smearing, seeds from
# truth particle smearing and source link selection config

onnx_model_dir="/home/xju/ocean/code/Tracking-ML-Exa.TrkX/Pipelines/TrackML_Example/onnx_models"
#ACTS_INFO("ML model dir: " << onnx_model_dir)

exaTrkxFinding = acts.examples.ExaTrkXTrackFinding(
inputMLModuleDir="/home/xju/ocean/code/Tracking-ML-Exa.TrkX/Pipelines/TrackML_Example/onnx_models",
spacepointFeatures=3,
embeddingDim=8,
rVal=1.6,
knnVal=500,
filterCut=0.21
)

trackFinder = acts.examples.TrackFindingMLBasedAlgorithm(
trackFinderAlg = acts.examples.TrackFindingMLBasedAlgorithm(
level=acts.logging.INFO,
inputSpacePoints="spacepoints",
outputProtoTracks="protoTracks",
onnxModelDir=onnx_model_dir
trackFinderML=exaTrkxFinding
)
s.addAlgorithm(trackFinder)
s.addAlgorithm(trackFinderAlg)


# Write truth track finding / seeding performance
Expand Down