Skip to content

Commit

Permalink
Fix crashes when errors occur at output timepoints (#2555)
Browse files Browse the repository at this point in the history
Fixes a bug that lead to program termination if a root-after-reinitialization error (potentially also others) occurred at an output timepoint, because an non-existing/invalid SimulationState for that timepoint was accessed. See #2491 for further details.

Fixes #2491.

Also avoid some unnecessary copying (during which previously the segfault occurred if this bug triggered in non-debug builds).
  • Loading branch information
dweindl authored Oct 20, 2024
1 parent a82f6d4 commit 9f56266
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
17 changes: 15 additions & 2 deletions include/amici/forwardproblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define AMICI_FORWARDPROBLEM_H

#include "amici/defines.h"
#include "amici/edata.h"
#include "amici/misc.h"
#include "amici/model.h"
#include "amici/vector.h"
Expand Down Expand Up @@ -195,7 +196,7 @@ class ForwardProblem {
if (model->getTimepoint(it) == initial_state_.t)
return getInitialSimulationState();
auto map_iter = timepoint_states_.find(model->getTimepoint(it));
assert(map_iter != timepoint_states_.end());
Ensures(map_iter != timepoint_states_.end());
return map_iter->second;
};

Expand Down Expand Up @@ -441,8 +442,20 @@ class FinalStateStorer : public ContextManager {
* @brief destructor, stores simulation state
*/
~FinalStateStorer() {
if (fwd_)
if (fwd_) {
fwd_->final_state_ = fwd_->getSimulationState();
// if there is an associated output timepoint, also store it in
// timepoint_states if it's not present there.
// this may happen if there is an error just at
// (or indistinguishably before) an output timepoint
auto final_time = fwd_->getFinalTime();
auto const timepoints = fwd_->model->getTimepoints();
if (!fwd_->timepoint_states_.count(final_time)
&& std::find(timepoints.cbegin(), timepoints.cend(), final_time)
!= timepoints.cend()) {
fwd_->timepoint_states_[final_time] = fwd_->final_state_;
}
}
}

private:
Expand Down
6 changes: 3 additions & 3 deletions src/rdata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ void ReturnData::processForwardProblem(
if (edata)
initializeObjectiveFunction(model.hasQuadraticLLH());

auto initialState = fwd.getInitialSimulationState();
auto const& initialState = fwd.getInitialSimulationState();
if (initialState.x.getLength() == 0 && model.nx_solver > 0)
return; // if x wasn't set forward problem failed during initialization

Expand All @@ -259,7 +259,7 @@ void ReturnData::processForwardProblem(
realtype tf = fwd.getFinalTime();
for (int it = 0; it < model.nt(); it++) {
if (model.getTimepoint(it) <= tf) {
auto simulation_state = fwd.getSimulationStateTimepoint(it);
auto const simulation_state = fwd.getSimulationStateTimepoint(it);
model.setModelState(simulation_state.state);
getDataOutput(it, model, simulation_state, edata);
} else {
Expand All @@ -273,7 +273,7 @@ void ReturnData::processForwardProblem(
if (nz > 0) {
auto rootidx = fwd.getRootIndexes();
for (int iroot = 0; iroot <= fwd.getEventCounter(); iroot++) {
auto simulation_state = fwd.getSimulationStateEvent(iroot);
auto const simulation_state = fwd.getSimulationStateEvent(iroot);
model.setModelState(simulation_state.state);
getEventOutput(
simulation_state.t, rootidx.at(iroot), model, simulation_state,
Expand Down

0 comments on commit 9f56266

Please sign in to comment.