Skip to content

Commit c385f81

Browse files
authored
Refactor BackwardProblem for presimulation (#2780)
Refactor BackwardProblem to implement presimulation including event-handling (#2775). Similar to #2777 for the ForwardProblem: * Extract some state from BackwardProblem * Extract a EventHandlingBwdSimulator that handles simulation and discontinuities * Those will be reused for presimulation shortly * Remove some unused getters from BackwardProblem
1 parent a9e3224 commit c385f81

File tree

4 files changed

+273
-148
lines changed

4 files changed

+273
-148
lines changed

include/amici/backwardproblem.h

Lines changed: 129 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -8,89 +8,108 @@
88
#include <vector>
99

1010
namespace amici {
11-
1211
class ExpData;
1312
class Solver;
1413
class Model;
1514
class ForwardProblem;
1615
class SteadystateProblem;
1716

18-
//! class to solve backwards problems.
19-
/*!
20-
solves the backwards problem for adjoint sensitivity analysis and handles
21-
events and data-points
22-
*/
23-
24-
class BackwardProblem {
25-
public:
26-
/**
27-
* @brief Construct backward problem from forward problem
28-
* @param fwd pointer to corresponding forward problem
29-
*/
30-
explicit BackwardProblem(ForwardProblem& fwd);
31-
17+
/**
18+
* @brief The BwdSimWorkspace class is used to store temporary simulation
19+
* state during backward simulations.
20+
*/
21+
struct BwdSimWorkspace {
3222
/**
33-
* @brief Solve the backward problem.
34-
*
35-
* If adjoint sensitivities are enabled this will also compute
36-
* sensitivities. workForwardProblem must be called before this is
37-
* function is called.
23+
* @brief Constructor
24+
* @param model The model for which to set up the workspace.
25+
* @param solver The solver for which to set up this workspace.
3826
*/
39-
void workBackwardProblem();
27+
BwdSimWorkspace(
28+
gsl::not_null<Model*> model, gsl::not_null<Solver const*> solver
29+
);
4030

41-
/**
42-
* @brief Accessor for current time t
43-
* @return t
44-
*/
45-
realtype gett() const { return t_; }
46-
47-
/**
48-
* @brief Accessor for which
49-
* @return which
50-
*/
51-
int getwhich() const { return which; }
31+
/** The model. */
32+
Model* model_;
5233

53-
/**
54-
* @brief Accessor for pointer to which
55-
* @return which
56-
*/
57-
int* getwhichptr() { return &which; }
34+
/** adjoint state vector */
35+
AmiVector xB_;
36+
/** differential adjoint state vector */
37+
AmiVector dxB_;
38+
/** quadrature state vector */
39+
AmiVector xQB_;
5840

59-
/**
60-
* @brief Accessor for dJydx
61-
* @return dJydx
62-
*/
63-
std::vector<realtype> const& getdJydx() const { return dJydx_; }
41+
/** array of number of found roots for a certain event type */
42+
std::vector<int> nroots_;
43+
/** array containing the time-points of discontinuities*/
44+
std::vector<Discontinuity> discs_;
45+
/** index of the backward problem */
46+
int which = 0;
47+
};
6448

49+
/**
50+
* @brief The EventHandlingBwdSimulator class runs a backward simulation
51+
* and processes events and measurements general.
52+
*/
53+
class EventHandlingBwdSimulator {
54+
public:
6555
/**
66-
* @brief Accessor for xB
67-
* @return xB
56+
* @brief EventHandlingBwdSimulator constructor.
57+
* @param model The model to simulate.
58+
* @param solver The solver to use for the simulation.
59+
* @param ws The workspace to use for the simulation.
6860
*/
69-
AmiVector const& getAdjointState() const { return xB_; }
61+
EventHandlingBwdSimulator(
62+
gsl::not_null<Model*> model, gsl::not_null<Solver*> solver,
63+
gsl::not_null<BwdSimWorkspace*> ws
64+
)
65+
: model_(model)
66+
, solver_(solver)
67+
, ws_(ws) {};
7068

7169
/**
72-
* @brief Accessor for xQB
73-
* @return xQB
70+
* @brief Run the simulation.
71+
*
72+
* It will run the backward simulation from the initial time of this period
73+
* to the final timepoint of this period, handling events
74+
* and data points as they occur.
75+
*
76+
* Expects the model and the solver to be set up, and `ws` to be initialized
77+
* for this period.
78+
*
79+
* @param t_start The initial time of this period.
80+
* @param t_end The final time of this period.
81+
* @param it The index of the timepoint in `timepoints` to start with.
82+
* @param timepoints The output timepoints or measurement timepoints of
83+
* this period. This must contain at least the final timepoint of this
84+
* period.
85+
* @param dJydx State-derivative of data likelihood. Must be non-null if
86+
* there are any data points in this period.
87+
* @param dJzdx State-derivative of event likelihood. Must be non-null if
88+
* the model has any event-observables.
7489
*/
75-
AmiVector const& getAdjointQuadrature() const { return xQB_; }
90+
void
91+
run(realtype t_start, realtype t_end, realtype it,
92+
std::vector<realtype> const& timepoints,
93+
std::vector<realtype> const* dJydx, std::vector<realtype> const* dJzdx);
7694

7795
private:
78-
void handlePostequilibration();
79-
8096
/**
8197
* @brief Execute everything necessary for the handling of events
8298
* for the backward problem
8399
* @param disc The discontinuity to handle
100+
* @param dJzdx State-derivative of event likelihood
84101
*/
85-
void handleEventB(Discontinuity const& disc);
102+
void
103+
handleEventB(Discontinuity const& disc, std::vector<realtype> const* dJzdx);
86104

87105
/**
88106
* @brief Execute everything necessary for the handling of data
89107
* points for the backward problems
90108
*
91109
* @param it index of data point
110+
* @param dJydx State-derivative of data likelihood
92111
*/
93-
void handleDataPointB(int it);
112+
void handleDataPointB(int it, std::vector<realtype> const* dJydx);
94113

95114
/**
96115
* @brief Compute the next timepoint to integrate to.
@@ -103,26 +122,69 @@ class BackwardProblem {
103122
*/
104123
realtype getTnext(int it);
105124

125+
/** The model to simulate. */
126+
Model* model_;
127+
128+
/** The solver to use for the simulation. */
129+
Solver* solver_;
130+
131+
/** The workspace to use for the simulation. */
132+
gsl::not_null<BwdSimWorkspace*> ws_;
133+
134+
/** current time */
135+
realtype t_{0};
136+
};
137+
138+
//! class to solve backwards problems.
139+
/*!
140+
solves the backwards problem for adjoint sensitivity analysis and handles
141+
events and data-points
142+
*/
143+
144+
class BackwardProblem {
145+
public:
146+
/**
147+
* @brief Construct backward problem from forward problem
148+
* @param fwd pointer to corresponding forward problem
149+
*/
150+
explicit BackwardProblem(ForwardProblem& fwd);
151+
152+
/**
153+
* @brief Solve the backward problem.
154+
*
155+
* If adjoint sensitivities are enabled, this will also compute
156+
* sensitivities. workForwardProblem must be called before this function is
157+
* called.
158+
*/
159+
void workBackwardProblem();
160+
161+
/**
162+
* @brief Accessor for xB
163+
* @return xB
164+
*/
165+
[[nodiscard]] AmiVector const& getAdjointState() const { return ws_.xB_; }
166+
167+
/**
168+
* @brief Accessor for xQB
169+
* @return xQB
170+
*/
171+
[[nodiscard]] AmiVector const& getAdjointQuadrature() const {
172+
return ws_.xQB_;
173+
}
174+
175+
private:
176+
void handlePostequilibration();
177+
106178
Model* model_;
107179
Solver* solver_;
108180
ExpData const* edata_;
109181

110182
/** current time */
111183
realtype t_;
112-
/** adjoint state vector */
113-
AmiVector xB_;
114-
/** differential adjoint state vector */
115-
AmiVector dxB_;
116-
/** quadrature state vector */
117-
AmiVector xQB_;
118184
/** sensitivity state vector array */
119185
AmiVectorArray sx0_;
120-
/** array of number of found roots for a certain event type */
121-
std::vector<int> nroots_;
122186
/** array containing the time-points of discontinuities*/
123187
std::vector<Discontinuity> discs_;
124-
/** index of the backward problem */
125-
int which = 0;
126188

127189
/** state derivative of data likelihood */
128190
std::vector<realtype> dJydx_;
@@ -134,8 +196,12 @@ class BackwardProblem {
134196

135197
/** The postequilibration steadystate problem from the forward problem. */
136198
SteadystateProblem* posteq_problem_;
199+
200+
BwdSimWorkspace ws_;
201+
202+
EventHandlingBwdSimulator simulator_;
137203
};
138204

139205
} // namespace amici
140206

141-
#endif // BACKWARDPROBLEM_H
207+
#endif // AMICI_BACKWARDPROBLEM_H

include/amici/forwardproblem.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,16 @@ struct Discontinuity {
6363
std::vector<int> root_info;
6464
};
6565

66+
/**
67+
* @brief Compute the number of roots for each root function from a vector of
68+
* discontinuities.
69+
* @param discs Encountered discontinuities.
70+
* @param ne Number of root functions (ne).
71+
* @param nmaxevents Maximum number of events to track (nmaxevents).
72+
* @return The number of roots for each root function.
73+
*/
74+
std::vector<int> compute_nroots(std::vector<Discontinuity> const& discs, int ne, int nmaxevents);
75+
6676
/**
6777
* @brief The ForwardProblem class groups all functions for solving the
6878
* forward problem.
@@ -127,12 +137,6 @@ class ForwardProblem {
127137
*/
128138
AmiVectorArray const& getStateSensitivity() const { return sx_; }
129139

130-
/**
131-
* @brief Accessor for nroots
132-
* @return nroots
133-
*/
134-
std::vector<int> const& getNumberOfRoots() const { return nroots_; }
135-
136140
/**
137141
* @brief Get information on the discontinuities encountered so far.
138142
* @return The vector of discontinuities.

0 commit comments

Comments
 (0)