Skip to content

Commit 0e33004

Browse files
authored
Refactor storage of discontinuities (#2763)
Store together what belongs together.
1 parent 841336b commit 0e33004

File tree

5 files changed

+76
-109
lines changed

5 files changed

+76
-109
lines changed

include/amici/backwardproblem.h

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define AMICI_BACKWARDPROBLEM_H
33

44
#include "amici/defines.h"
5+
#include "amici/forwardproblem.h"
56
#include "amici/vector.h"
67

78
#include <vector>
@@ -80,8 +81,9 @@ class BackwardProblem {
8081
/**
8182
* @brief Execute everything necessary for the handling of events
8283
* for the backward problem
84+
* @param disc The discontinuity to handle
8385
*/
84-
void handleEventB();
86+
void handleEventB(Discontinuity const& disc);
8587

8688
/**
8789
* @brief Execute everything necessary for the handling of data
@@ -114,22 +116,14 @@ class BackwardProblem {
114116
AmiVector dxB_;
115117
/** quadrature state vector */
116118
AmiVector xQB_;
117-
/** array of state vectors at discontinuities*/
118-
std::vector<AmiVector> x_disc_;
119-
/** array of differential state vectors at discontinuities*/
120-
std::vector<AmiVector> xdot_disc_;
121-
/** array of old differential state vectors at discontinuities*/
122-
std::vector<AmiVector> xdot_old_disc_;
123119
/** sensitivity state vector array */
124120
AmiVectorArray sx0_;
125121
/** array of number of found roots for a certain event type */
126122
std::vector<int> nroots_;
127123
/** array containing the time-points of discontinuities*/
128-
std::vector<realtype> discs_;
124+
std::vector<Discontinuity> discs_;
129125
/** index of the backward problem */
130126
int which = 0;
131-
/** array of index which root has been found */
132-
std::vector<std::vector<int>> root_idx_;
133127

134128
/** state derivative of data likelihood */
135129
std::vector<realtype> dJydx_;

include/amici/forwardproblem.h

Lines changed: 53 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,52 @@ class Solver;
1616
class SteadystateProblem;
1717
class FinalStateStorer;
1818

19+
/**
20+
* @brief Data structure to store some state of a simulation at a discontinuity.
21+
*/
22+
struct Discontinuity {
23+
/**
24+
* @brief Constructor.
25+
* @param time
26+
* @param root_info
27+
* @param xdot_pre
28+
* @param x_post
29+
* @param xdot_post
30+
*/
31+
explicit Discontinuity(
32+
realtype time, std::vector<int> const& root_info = std::vector<int>(),
33+
AmiVector const& xdot_pre = AmiVector(),
34+
AmiVector const& x_post = AmiVector(),
35+
AmiVector const& xdot_post = AmiVector()
36+
)
37+
: time(time)
38+
, x_post(x_post)
39+
, xdot_post(xdot_post)
40+
, xdot_pre(xdot_pre)
41+
, root_info(root_info) {}
42+
43+
/** Time of discontinuity. */
44+
realtype time;
45+
46+
/** Post-event state vector (dimension nx). */
47+
AmiVector x_post;
48+
49+
/** Post-event differential state vectors (dimension nx). */
50+
AmiVector xdot_post;
51+
52+
/** Pre-event differential state vectors (dimension nx). */
53+
AmiVector xdot_pre;
54+
55+
/**
56+
* @brief Array of flags indicating which root has been found.
57+
*
58+
* Array of length nr (ne) with the indices of the user functions gi found
59+
* to have a root. For i = 0, . . . ,nr 1 or -1 if gi has a root, and = 0
60+
* if not. See CVodeGetRootInfo for details.
61+
*/
62+
std::vector<int> root_info;
63+
};
64+
1965
/**
2066
* @brief The ForwardProblem class groups all functions for solving the
2167
* forward problem.
@@ -80,48 +126,18 @@ class ForwardProblem {
80126
*/
81127
AmiVectorArray const& getStateSensitivity() const { return sx_; }
82128

83-
/**
84-
* @brief Accessor for x_disc
85-
* @return x_disc
86-
*/
87-
std::vector<AmiVector> const& getStatesAtDiscontinuities() const {
88-
return x_disc_;
89-
}
90-
91-
/**
92-
* @brief Accessor for xdot_disc
93-
* @return xdot_disc
94-
*/
95-
std::vector<AmiVector> const& getRHSAtDiscontinuities() const {
96-
return xdot_disc_;
97-
}
98-
99-
/**
100-
* @brief Accessor for xdot_old_disc
101-
* @return xdot_old_disc
102-
*/
103-
std::vector<AmiVector> const& getRHSBeforeDiscontinuities() const {
104-
return xdot_old_disc_;
105-
}
106-
107129
/**
108130
* @brief Accessor for nroots
109131
* @return nroots
110132
*/
111133
std::vector<int> const& getNumberOfRoots() const { return nroots_; }
112134

113135
/**
114-
* @brief Accessor for discs
115-
* @return discs
116-
*/
117-
std::vector<realtype> const& getDiscontinuities() const { return discs_; }
118-
119-
/**
120-
* @brief Accessor for rootidx
121-
* @return rootidx
136+
* @brief Get information on the discontinuities encountered so far.
137+
* @return The vector of discontinuities.
122138
*/
123-
std::vector<std::vector<int>> const& getRootIndexes() const {
124-
return root_idx_;
139+
std::vector<Discontinuity> const& getDiscontinuities() const {
140+
return discs_;
125141
}
126142

127143
/**
@@ -294,7 +310,7 @@ class ForwardProblem {
294310
*/
295311
void fillEvents(int nmaxevent) {
296312
if (checkEventsToFill(nmaxevent)) {
297-
discs_.push_back(t_);
313+
discs_.emplace_back(t_);
298314
storeEvent();
299315
}
300316
}
@@ -305,11 +321,6 @@ class ForwardProblem {
305321
*/
306322
SimulationState getSimulationState();
307323

308-
/** array of index vectors (dimension ne) indicating whether the respective
309-
* root has been detected for all so far encountered discontinuities,
310-
* extended as needed (dimension: dynamic) */
311-
std::vector<std::vector<int>> root_idx_;
312-
313324
/** array of number of found roots for a certain event type
314325
* (dimension: ne) */
315326
std::vector<int> nroots_;
@@ -321,30 +332,8 @@ class ForwardProblem {
321332
* (dimension: ne) */
322333
std::vector<realtype> rval_tmp_;
323334

324-
/** array containing the time-points of discontinuities
325-
* (dimension: dynamic) */
326-
std::vector<realtype> discs_;
327-
328-
/**
329-
* array of post-event state vectors (dimension nx) for all so far
330-
* encountered discontinuities, (extended as needed;
331-
* after event processing, same dimension as discs_)
332-
*/
333-
std::vector<AmiVector> x_disc_;
334-
335-
/**
336-
* array of post-event differential state vectors (dimension nx) for all so
337-
* far encountered discontinuities, (extended as needed; after event
338-
* processing, same dimension as discs_)
339-
*/
340-
std::vector<AmiVector> xdot_disc_;
341-
342-
/**
343-
* array of old (pre-event) differential state vectors (dimension nx) for
344-
* all so far encountered discontinuities, (extended as needed; after event
345-
* processing, same dimension as discs_)
346-
*/
347-
std::vector<AmiVector> xdot_old_disc_;
335+
/** Discontinuities encountered so far (dimension: dynamic) */
336+
std::vector<Discontinuity> discs_;
348337

349338
/** Events that are waiting to be handled at the current timepoint. */
350339
EventQueue pending_events_;

src/backwardproblem.cpp

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,9 @@ BackwardProblem::BackwardProblem(
2020
, xB_(fwd.model->nx_solver, solver_->getSunContext())
2121
, dxB_(fwd.model->nx_solver, solver_->getSunContext())
2222
, xQB_(fwd.model->nJ * fwd.model->nplist(), solver_->getSunContext())
23-
, x_disc_(fwd.getStatesAtDiscontinuities())
24-
, xdot_disc_(fwd.getRHSAtDiscontinuities())
25-
, xdot_old_disc_(fwd.getRHSBeforeDiscontinuities())
2623
, sx0_(fwd.getStateSensitivity())
2724
, nroots_(fwd.getNumberOfRoots())
2825
, discs_(fwd.getDiscontinuities())
29-
, root_idx_(fwd.getRootIndexes())
3026
, dJydx_(fwd.getDJydx())
3127
, dJzdx_(fwd.getDJzdx()) {
3228
/* complement dJydx from postequilibration. This shouldn't overwrite
@@ -95,9 +91,9 @@ void BackwardProblem::workBackwardProblem() {
9591
}
9692

9793
/* handle discontinuity */
98-
if (!discs_.empty() && tnext == discs_.back()) {
94+
if (!discs_.empty() && tnext == discs_.back().time) {
95+
handleEventB(discs_.back());
9996
discs_.pop_back();
100-
handleEventB();
10197
}
10298

10399
/* handle data-point */
@@ -128,30 +124,18 @@ void BackwardProblem::workBackwardProblem() {
128124
}
129125
}
130126

131-
void BackwardProblem::handleEventB() {
132-
auto rootidx = root_idx_.back();
133-
this->root_idx_.pop_back();
134-
135-
auto x_disc = this->x_disc_.back();
136-
this->x_disc_.pop_back();
137-
138-
auto xdot_disc = this->xdot_disc_.back();
139-
this->xdot_disc_.pop_back();
140-
141-
auto xdot_old_disc = this->xdot_old_disc_.back();
142-
this->xdot_old_disc_.pop_back();
143-
127+
void BackwardProblem::handleEventB(Discontinuity const& disc) {
144128
for (int ie = 0; ie < model_->ne; ie++) {
145129

146-
if (rootidx[ie] == 0) {
130+
if (disc.root_info[ie] == 0) {
147131
continue;
148132
}
149133

150134
model_->addAdjointQuadratureEventUpdate(
151-
xQB_, ie, t_, x_disc, xB_, xdot_disc, xdot_old_disc
135+
xQB_, ie, t_, disc.x_post, xB_, disc.xdot_post, disc.xdot_pre
152136
);
153137
model_->addAdjointStateEventUpdate(
154-
xB_, ie, t_, x_disc, xdot_disc, xdot_old_disc
138+
xB_, ie, t_, disc.x_post, disc.xdot_post, disc.xdot_pre
155139
);
156140

157141
if (model_->nz > 0) {
@@ -167,7 +151,7 @@ void BackwardProblem::handleEventB() {
167151
nroots_[ie]--;
168152
}
169153

170-
model_->updateHeavisideB(rootidx.data());
154+
model_->updateHeavisideB(disc.root_info.data());
171155
}
172156

173157
void BackwardProblem::handleDataPointB(int const it) {
@@ -196,8 +180,8 @@ realtype BackwardProblem::getTnext(int const it) {
196180
}
197181

198182
if (!discs_.empty()
199-
&& (it < 0 || discs_.back() > model_->getTimepoint(it))) {
200-
double tdisc = discs_.back();
183+
&& (it < 0 || discs_.back().time > model_->getTimepoint(it))) {
184+
double tdisc = discs_.back().time;
201185
return tdisc;
202186
}
203187

src/forwardproblem.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,7 @@ void ForwardProblem::handleEvent(
242242

243243
// store timepoint at which the event occurred, the root function
244244
// values, and the direction of any zero crossings of the root function
245-
discs_.push_back(t_);
246-
root_idx_.push_back(roots_found_);
245+
discs_.emplace_back(t_, roots_found_);
247246
rval_tmp_ = rootvals_;
248247

249248
if (model->nz > 0)
@@ -261,9 +260,9 @@ void ForwardProblem::handleEvent(
261260
// that did not trigger a secondary event
262261
auto store_post_event_info = [this]() {
263262
if (solver->computingASA()) {
264-
// store x to compute jump in discontinuity
265-
x_disc_.push_back(x_);
266-
xdot_disc_.push_back(xdot_);
263+
// store updated x to compute jump in discontinuity
264+
discs_.back().x_post = x_;
265+
discs_.back().xdot_post = xdot_;
267266
}
268267
};
269268

@@ -344,7 +343,7 @@ void ForwardProblem::storeEvent() {
344343
for (int ie = 0; ie < model->ne; ie++) {
345344
roots_found_.at(ie) = (nroots_.at(ie) < model->nMaxEvent()) ? 1 : 0;
346345
}
347-
root_idx_.push_back(roots_found_);
346+
discs_.back().root_info = roots_found_;
348347
}
349348

350349
if (getRootCounter() < getEventCounter()) {
@@ -409,7 +408,7 @@ void ForwardProblem::store_pre_event_state(bool seflag, bool initial_event) {
409408
std::ranges::fill(stau_, 0.0);
410409
}
411410
} else if (solver->computingASA()) {
412-
xdot_old_disc_.push_back(xdot_old_);
411+
discs_.back().xdot_pre = xdot_old_;
413412
}
414413
}
415414

src/rdata.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,13 +291,14 @@ void ReturnData::processForwardProblem(
291291

292292
// process event data
293293
if (nz > 0) {
294-
auto rootidx = fwd.getRootIndexes();
294+
auto const& discontinuities = fwd.getDiscontinuities();
295+
Expects(static_cast<int>(discontinuities.size()) == fwd.getEventCounter() + 1);
295296
for (int iroot = 0; iroot <= fwd.getEventCounter(); iroot++) {
296297
auto const simulation_state = fwd.getSimulationStateEvent(iroot);
297298
model.setModelState(simulation_state.state);
298299
getEventOutput(
299-
simulation_state.t, rootidx.at(iroot), model, simulation_state,
300-
edata
300+
simulation_state.t, discontinuities.at(iroot).root_info, model,
301+
simulation_state, edata
301302
);
302303
}
303304
}

0 commit comments

Comments
 (0)