Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Deduplicate seeds during track finding in Examples #3088

Merged
merged 9 commits into from
Apr 11, 2024
5 changes: 5 additions & 0 deletions CI/physmon/workflows/physmon_ckf_tracking.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3

import tempfile
from pathlib import Path
import shutil
Expand All @@ -22,6 +23,7 @@
SeedFinderOptionsArg,
SeedingAlgorithm,
TruthEstimatedSeedingAlgorithmConfigArg,
CkfConfig,
addCKFTracks,
addAmbiguityResolution,
AmbiguityResolutionConfig,
Expand Down Expand Up @@ -135,6 +137,9 @@ def run_ckf_tracking(truthSmearedSeeded, truthEstimatedSeeded, label):
loc0=(-4.0 * u.mm, 4.0 * u.mm),
nMeasurementsMin=6,
),
CkfConfig(
seedDeduplication=False if truthSmearedSeeded else True,
),
outputDirRoot=tp,
)

Expand Down
7 changes: 5 additions & 2 deletions CI/physmon/workflows/physmon_track_finding_ttbar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3

import tempfile
from pathlib import Path
import shutil
Expand All @@ -12,11 +13,10 @@
from acts.examples.reconstruction import (
addSeeding,
TruthSeedRanges,
ParticleSmearingSigmas,
SeedFinderConfigArg,
SeedFinderOptionsArg,
SeedingAlgorithm,
TruthEstimatedSeedingAlgorithmConfigArg,
CkfConfig,
addCKFTracks,
addAmbiguityResolution,
AmbiguityResolutionConfig,
Expand Down Expand Up @@ -103,6 +103,9 @@
loc0=(-4.0 * u.mm, 4.0 * u.mm),
nMeasurementsMin=6,
),
CkfConfig(
seedDeduplication=True,
),
outputDirRoot=tp,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// This file is part of the Acts project.
//
// Copyright (C) 2020 CERN for the benefit of the Acts project
// Copyright (C) 2020-2024 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
Expand All @@ -24,6 +24,7 @@
#include "Acts/Utilities/TrackHelpers.hpp"
#include "ActsExamples/EventData/IndexSourceLink.hpp"
#include "ActsExamples/EventData/Measurement.hpp"
#include "ActsExamples/EventData/SimSeed.hpp"
#include "ActsExamples/EventData/Track.hpp"
#include "ActsExamples/Framework/DataHandle.hpp"
#include "ActsExamples/Framework/IAlgorithm.hpp"
Expand Down Expand Up @@ -87,6 +88,9 @@ class TrackFindingAlgorithm final : public IAlgorithm {
std::string inputSourceLinks;
/// Input initial track parameter estimates for for each proto track.
std::string inputInitialTrackParameters;
/// Input seeds. These are optional and allow for seed deduplication.
/// The seeds must match the initial track parameters.
std::string inputSeeds;
/// Output find trajectories collection.
std::string outputTracks;

Expand All @@ -99,8 +103,6 @@ class TrackFindingAlgorithm final : public IAlgorithm {
std::shared_ptr<TrackFinderFunction> findTracks;
/// CKF measurement selector config
Acts::MeasurementSelector::Config measurementSelectorCfg;
/// Compute shared hit information
bool computeSharedHits = false;
/// Track selector config
std::optional<std::variant<Acts::TrackSelector::Config,
Acts::TrackSelector::EtaBinnedConfig>>
Expand All @@ -113,6 +115,11 @@ class TrackFindingAlgorithm final : public IAlgorithm {
Acts::TrackExtrapolationStrategy::firstOrLast;
/// Run finding in two directions
bool twoWay = true;
/// Whether to use seed deduplication
/// This is only available if `inputSeeds` is set.
bool seedDeduplication = false;
/// Compute shared hit information
bool computeSharedHits = false;
};

/// Constructor of the track finding algorithm
Expand Down Expand Up @@ -146,13 +153,14 @@ class TrackFindingAlgorithm final : public IAlgorithm {
"InputMeasurements"};
ReadDataHandle<IndexSourceLinkContainer> m_inputSourceLinks{
this, "InputSourceLinks"};

ReadDataHandle<TrackParametersContainer> m_inputInitialTrackParameters{
this, "InputInitialTrackParameters"};
ReadDataHandle<SimSeedContainer> m_inputSeeds{this, "InputSeeds"};

WriteDataHandle<ConstTrackContainer> m_outputTracks{this, "OutputTracks"};

mutable std::atomic<std::size_t> m_nTotalSeeds{0};
mutable std::atomic<std::size_t> m_nDeduplicatedSeeds{0};
mutable std::atomic<std::size_t> m_nFailedSeeds{0};
mutable std::atomic<std::size_t> m_nFailedSmoothing{0};
mutable std::atomic<std::size_t> m_nFailedExtrapolation{0};
Expand Down
172 changes: 151 additions & 21 deletions Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// This file is part of the Acts project.
//
// Copyright (C) 2020 CERN for the benefit of the Acts project
// Copyright (C) 2020-2024 CERN for the benefit of the Acts project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
Expand All @@ -12,6 +12,7 @@
#include "Acts/Definitions/Direction.hpp"
#include "Acts/EventData/MultiTrajectory.hpp"
#include "Acts/EventData/ProxyAccessor.hpp"
#include "Acts/EventData/SourceLink.hpp"
#include "Acts/EventData/TrackContainer.hpp"
#include "Acts/EventData/TrackParameters.hpp"
#include "Acts/EventData/VectorMultiTrajectory.hpp"
Expand All @@ -30,10 +31,13 @@
#include "Acts/TrackFitting/GainMatrixUpdater.hpp"
#include "Acts/TrackFitting/KalmanFitter.hpp"
#include "Acts/Utilities/Delegate.hpp"
#include "Acts/Utilities/Enumerate.hpp"
#include "Acts/Utilities/Logger.hpp"
#include "Acts/Utilities/TrackHelpers.hpp"
#include "ActsExamples/EventData/IndexSourceLink.hpp"
#include "ActsExamples/EventData/Measurement.hpp"
#include "ActsExamples/EventData/MeasurementCalibration.hpp"
#include "ActsExamples/EventData/SimSeed.hpp"
#include "ActsExamples/EventData/Track.hpp"
#include "ActsExamples/Framework/AlgorithmContext.hpp"
#include "ActsExamples/Framework/ProcessCode.hpp"
Expand All @@ -45,13 +49,85 @@
#include <ostream>
#include <stdexcept>
#include <system_error>
#include <unordered_map>
#include <utility>

#include <boost/histogram.hpp>
#include <boost/functional/hash.hpp>

ActsExamples::TrackFindingAlgorithm::TrackFindingAlgorithm(
Config config, Acts::Logging::Level level)
: ActsExamples::IAlgorithm("TrackFindingAlgorithm", level),
namespace ActsExamples {
namespace {

/// Source link indices of the bottom, middle, top measurements.
/// In case of strip seeds only the first source link of the pair is used.
using SeedIdentifier = std::array<Index, 3>;

/// Build a seed identifier from a seed.
///
/// @param seed The seed to build the identifier from.
/// @return The seed identifier.
SeedIdentifier makeSeedIdentifier(const SimSeed& seed) {
SeedIdentifier result;

for (const auto& [i, sp] : Acts::enumerate(seed.sp())) {
const Acts::SourceLink& firstSourceLink = sp->sourceLinks().front();
result.at(i) = firstSourceLink.get<IndexSourceLink>().index();
}

return result;
}

/// Visit all possible seed identifiers of a track.
///
/// @param track The track to visit the seed identifiers of.
/// @param visitor The visitor to call for each seed identifier.
template <typename Visitor>
void visitSeedIdentifiers(const TrackProxy& track, Visitor visitor) {
// first we collect the source link indices of the track states
std::vector<Index> sourceLinkIndices;
sourceLinkIndices.reserve(track.nMeasurements());
for (const auto& trackState : track.trackStatesReversed()) {
if (!trackState.hasUncalibratedSourceLink()) {
continue;
}
const Acts::SourceLink& sourceLink = trackState.getUncalibratedSourceLink();
sourceLinkIndices.push_back(sourceLink.get<IndexSourceLink>().index());
}

// then we iterate over all possible triplets and form seed identifiers
for (std::size_t i = 0; i < sourceLinkIndices.size(); ++i) {
for (std::size_t j = i + 1; j < sourceLinkIndices.size(); ++j) {
for (std::size_t k = j + 1; k < sourceLinkIndices.size(); ++k) {
// Putting them into reverse order (k, j, i) to compensate for the
// `trackStatesReversed` above.
visitor({sourceLinkIndices.at(k), sourceLinkIndices.at(j),
sourceLinkIndices.at(i)});
}
}
}
}
andiwand marked this conversation as resolved.
Show resolved Hide resolved

} // namespace
} // namespace ActsExamples

// Specialize std::hash for SeedIdentifier
// This is required to use SeedIdentifier as a key in an `std::unordered_map`.
template <class T, std::size_t N>
struct std::hash<std::array<T, N>> {
std::size_t operator()(const std::array<T, N>& array) const {
std::hash<T> hasher;
std::size_t result = 0;
for (auto&& element : array) {
boost::hash_combine(result, hasher(element));
}
return result;
}
};
andiwand marked this conversation as resolved.
Show resolved Hide resolved

namespace ActsExamples {

TrackFindingAlgorithm::TrackFindingAlgorithm(Config config,
Acts::Logging::Level level)
: IAlgorithm("TrackFindingAlgorithm", level),
m_cfg(std::move(config)),
m_trackSelector(
m_cfg.trackSelectorCfg.has_value()
Expand All @@ -75,18 +151,34 @@ ActsExamples::TrackFindingAlgorithm::TrackFindingAlgorithm(
throw std::invalid_argument("Missing tracks output collection");
}

if (m_cfg.seedDeduplication && m_cfg.inputSeeds.empty()) {
throw std::invalid_argument(
"Missing seeds input collection. This is "
"required for seed deduplication.");
}

m_inputMeasurements.initialize(m_cfg.inputMeasurements);
m_inputSourceLinks.initialize(m_cfg.inputSourceLinks);
m_inputInitialTrackParameters.initialize(m_cfg.inputInitialTrackParameters);
m_inputSeeds.maybeInitialize(m_cfg.inputSeeds);
m_outputTracks.initialize(m_cfg.outputTracks);
}

ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithm::execute(
const ActsExamples::AlgorithmContext& ctx) const {
ProcessCode TrackFindingAlgorithm::execute(const AlgorithmContext& ctx) const {
// Read input data
const auto& measurements = m_inputMeasurements(ctx);
const auto& sourceLinks = m_inputSourceLinks(ctx);
const auto& initialParameters = m_inputInitialTrackParameters(ctx);
const SimSeedContainer* seeds = nullptr;

if (m_inputSeeds.isInitialized()) {
seeds = &m_inputSeeds(ctx);

if (initialParameters.size() != seeds->size()) {
ACTS_ERROR("Number of initial parameters and seeds do not match. "
<< initialParameters.size() << " != " << seeds->size());
}
}

// Construct a perigee surface as the target surface
auto pSurface = Acts::Surface::makeShared<Acts::PerigeeSurface>(
Expand Down Expand Up @@ -123,7 +215,7 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithm::execute(
secondPropOptions.direction = firstPropOptions.direction.invert();

// Set the CombinatorialKalmanFilter options
ActsExamples::TrackFindingAlgorithm::TrackFinderOptions firstOptions(
TrackFindingAlgorithm::TrackFinderOptions firstOptions(
ctx.geoContext, ctx.magFieldContext, ctx.calibContext, slAccessorDelegate,
extensions, firstPropOptions);

Expand Down Expand Up @@ -162,7 +254,50 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithm::execute(

unsigned int nSeed = 0;

// A map indicating whether a seed has been discovered already
std::unordered_map<SeedIdentifier, bool> discoveredSeeds;

auto addTrack = [&](const TrackProxy& track) {
// flag seeds which are covered by the track
visitSeedIdentifiers(track, [&](const SeedIdentifier& seedIdentifier) {
if (auto it = discoveredSeeds.find(seedIdentifier);
it != discoveredSeeds.end()) {
it->second = true;
}
});

if (m_trackSelector.has_value() && !m_trackSelector->isValidTrack(track)) {
return;
}

auto destProxy = tracks.makeTrack();
// make sure we copy track states!
destProxy.copyFrom(track, true);
};

if (seeds != nullptr && m_cfg.seedDeduplication) {
// Index the seeds for deduplication
for (const auto& seed : *seeds) {
SeedIdentifier seedIdentifier = makeSeedIdentifier(seed);
discoveredSeeds.emplace(seedIdentifier, false);
}
}

for (std::size_t iSeed = 0; iSeed < initialParameters.size(); ++iSeed) {
m_nTotalSeeds++;

if (seeds != nullptr && m_cfg.seedDeduplication) {
const SimSeed& seed = seeds->at(iSeed);
SeedIdentifier seedIdentifier = makeSeedIdentifier(seed);
// check if the seed has been discovered already
if (auto it = discoveredSeeds.find(seedIdentifier);
it != discoveredSeeds.end() && it->second) {
m_nDeduplicatedSeeds++;
ACTS_VERBOSE("Skipping seed " << iSeed << " due to deduplication.");
continue;
}
}

// Clear trackContainerTemp and trackStateContainerTemp
tracksTemp.clear();

Expand All @@ -171,7 +306,7 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithm::execute(

auto firstResult =
(*m_cfg.findTracks)(firstInitialParameters, firstOptions, tracksTemp);
m_nTotalSeeds++;

nSeed++;

if (!firstResult.ok()) {
Expand Down Expand Up @@ -283,11 +418,7 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithm::execute(
continue;
}

if (!m_trackSelector.has_value() ||
m_trackSelector->isValidTrack(trackCandidate)) {
auto destProxy = tracks.makeTrack();
destProxy.copyFrom(trackCandidate, true);
}
addTrack(trackCandidate);

++nSecond;
}
Expand Down Expand Up @@ -315,11 +446,7 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithm::execute(
continue;
}

if (!m_trackSelector.has_value() ||
m_trackSelector->isValidTrack(trackCandidate)) {
auto destProxy = tracks.makeTrack();
destProxy.copyFrom(trackCandidate, true);
}
addTrack(trackCandidate);
}
}
}
Expand All @@ -346,12 +473,13 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithm::execute(
constTrackStateContainer};

m_outputTracks(ctx, std::move(constTracks));
return ActsExamples::ProcessCode::SUCCESS;
return ProcessCode::SUCCESS;
}

ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithm::finalize() {
ProcessCode TrackFindingAlgorithm::finalize() {
ACTS_INFO("TrackFindingAlgorithm statistics:");
ACTS_INFO("- total seeds: " << m_nTotalSeeds);
ACTS_INFO("- deduplicated seeds: " << m_nDeduplicatedSeeds);
ACTS_INFO("- failed seeds: " << m_nFailedSeeds);
ACTS_INFO("- failed smoothing: " << m_nFailedSmoothing);
ACTS_INFO("- failed extrapolation: " << m_nFailedExtrapolation);
Expand All @@ -369,3 +497,5 @@ ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithm::finalize() {
ACTS_DEBUG("Track State memory statistics (averaged):\n" << ss.str());
return ProcessCode::SUCCESS;
}

} // namespace ActsExamples
Loading
Loading