Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: Refactor CKF branch stopper to allow stop and keep tracks #3102

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions Core/include/Acts/TrackFinding/CombinatorialKalmanFilter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,21 @@ struct CombinatorialKalmanFilterExtensions {
using candidate_container_t =
typename std::vector<typename traj_t::TrackStateProxy>;

enum class BranchStopperResult {
andiwand marked this conversation as resolved.
Show resolved Hide resolved
Continue,
StopAndDrop,
StopAndKeep,
};

using Calibrator = typename KalmanFitterExtensions<traj_t>::Calibrator;
using Updater = typename KalmanFitterExtensions<traj_t>::Updater;
using MeasurementSelector =
Delegate<Result<std::pair<typename candidate_container_t::iterator,
typename candidate_container_t::iterator>>(
candidate_container_t& trackStates, bool&, const Logger&)>;
using BranchStopper = Delegate<bool(const CombinatorialKalmanFilterTipState&,
typename traj_t::TrackStateProxy&)>;
using BranchStopper =
Delegate<BranchStopperResult(const CombinatorialKalmanFilterTipState&,
typename traj_t::TrackStateProxy&)>;

/// The Calibrator is a dedicated calibration algorithm that allows to
/// calibrate measurements using track information, this could be e.g. sagging
Expand Down Expand Up @@ -112,10 +119,10 @@ struct CombinatorialKalmanFilterExtensions {

/// Default branch stopper which will never stop
/// @return false
static bool voidBranchStopper(
static BranchStopperResult voidBranchStopper(
const CombinatorialKalmanFilterTipState& /*tipState*/,
typename traj_t::TrackStateProxy& /*trackState*/) {
return false;
return BranchStopperResult::Continue;
}
};

Expand Down Expand Up @@ -694,13 +701,23 @@ class CombinatorialKalmanFilter {
auto nonSourcelinkState =
result.fittedStates->getTrackState(currentTip);

using BranchStopperResult =
typename CombinatorialKalmanFilterExtensions<
traj_t>::BranchStopperResult;
BranchStopperResult branchStopperResult =
m_extensions.branchStopper(tipState, nonSourcelinkState);

// Check the branch
if (!m_extensions.branchStopper(tipState, nonSourcelinkState)) {
if (branchStopperResult == BranchStopperResult::Continue) {
// Remember the active tip and its state
result.activeTips.emplace_back(currentTip, tipState);
} else {
// No branch on this surface
nBranchesOnSurface = 0;

if (branchStopperResult == BranchStopperResult::StopAndKeep) {
result.lastMeasurementIndices.emplace_back(currentTip);
}
}
}

Expand Down Expand Up @@ -904,13 +921,10 @@ class CombinatorialKalmanFilter {
tipState.nMeasurements++;
}

// Check if need to stop this branch
if (!m_extensions.branchStopper(tipState, trackState)) {
andiwand marked this conversation as resolved.
Show resolved Hide resolved
// Put tipstate back into active tips to continue with it
result.activeTips.emplace_back(currentTip, tipState);
// Record the number of branches on surface
nBranchesOnSurface++;
}
// Put tipstate back into active tips to continue with it
result.activeTips.emplace_back(currentTip, tipState);
// Record the number of branches on surface
nBranchesOnSurface++;
}
return Result<void>::success();
}
Expand Down
38 changes: 13 additions & 25 deletions Examples/Algorithms/TrackFinding/src/TrackFindingAlgorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,18 @@ class BranchStopper {
using Config =
std::optional<std::variant<Acts::TrackSelector::Config,
Acts::TrackSelector::EtaBinnedConfig>>;
using BranchStopperResult = Acts::CombinatorialKalmanFilterExtensions<
Acts::VectorMultiTrajectory>::BranchStopperResult;

mutable std::atomic<std::size_t> m_nStoppedBranches{0};

explicit BranchStopper(const Config& config) : m_config(config) {}

bool operator()(
BranchStopperResult operator()(
const Acts::CombinatorialKalmanFilterTipState& tipState,
Acts::VectorMultiTrajectory::TrackStateProxy& trackState) const {
if (!m_config.has_value()) {
return false;
return BranchStopperResult::Continue;
}

const Acts::TrackSelector::Config* singleConfig = std::visit(
Expand All @@ -207,35 +209,21 @@ class BranchStopper {

if (singleConfig == nullptr) {
++m_nStoppedBranches;
return true;
return BranchStopperResult::StopAndDrop;
}

// Continue if the number of holes is below the maximum
if (tipState.nHoles <= singleConfig->maxHoles) {
return false;
}

// Continue if the number of outliers is below the maximum
if (tipState.nOutliers <= singleConfig->maxOutliers) {
return false;
}
bool enoughMeasurements =
tipState.nMeasurements >= singleConfig->minMeasurements;
bool tooManyHoles = tipState.nHoles > singleConfig->maxHoles;
bool tooManyOutliers = tipState.nOutliers > singleConfig->maxOutliers;

// If there are not enough measurements but more holes than allowed we stop
if (tipState.nMeasurements < singleConfig->minMeasurements) {
if (tooManyHoles || tooManyOutliers) {
++m_nStoppedBranches;
return true;
return enoughMeasurements ? BranchStopperResult::StopAndKeep
: BranchStopperResult::StopAndDrop;
}

// Getting another measurement guarantees that the holes are in the middle
// of the track
if (trackState.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
++m_nStoppedBranches;
return true;
}

// We cannot be sure if the holes are just at the end of the track so we
// have to keep going
return false;
return BranchStopperResult::Continue;
}

private:
Expand Down
Loading