8
8
#include < vector>
9
9
10
10
namespace amici {
11
-
12
11
class ExpData ;
13
12
class Solver ;
14
13
class Model ;
15
14
class ForwardProblem ;
16
15
class SteadystateProblem ;
17
16
18
- // ! class to solve backwards problems.
19
- /* !
20
- solves the backwards problem for adjoint sensitivity analysis and handles
21
- events and data-points
22
- */
23
-
24
- class BackwardProblem {
25
- public:
26
- /* *
27
- * @brief Construct backward problem from forward problem
28
- * @param fwd pointer to corresponding forward problem
29
- */
30
- explicit BackwardProblem (ForwardProblem& fwd);
31
-
17
+ /* *
18
+ * @brief The BwdSimWorkspace class is used to store temporary simulation
19
+ * state during backward simulations.
20
+ */
21
+ struct BwdSimWorkspace {
32
22
/* *
33
- * @brief Solve the backward problem.
34
- *
35
- * If adjoint sensitivities are enabled this will also compute
36
- * sensitivities. workForwardProblem must be called before this is
37
- * function is called.
23
+ * @brief Constructor
24
+ * @param model The model for which to set up the workspace.
25
+ * @param solver The solver for which to set up this workspace.
38
26
*/
39
- void workBackwardProblem ();
27
+ BwdSimWorkspace (
28
+ gsl::not_null<Model*> model, gsl::not_null<Solver const *> solver
29
+ );
40
30
41
- /* *
42
- * @brief Accessor for current time t
43
- * @return t
44
- */
45
- realtype gett () const { return t_; }
46
-
47
- /* *
48
- * @brief Accessor for which
49
- * @return which
50
- */
51
- int getwhich () const { return which; }
31
+ /* * The model. */
32
+ Model* model_;
52
33
53
- /* *
54
- * @brief Accessor for pointer to which
55
- * @return which
56
- */
57
- int * getwhichptr () { return &which; }
34
+ /* * adjoint state vector */
35
+ AmiVector xB_;
36
+ /* * differential adjoint state vector */
37
+ AmiVector dxB_;
38
+ /* * quadrature state vector */
39
+ AmiVector xQB_;
58
40
59
- /* *
60
- * @brief Accessor for dJydx
61
- * @return dJydx
62
- */
63
- std::vector<realtype> const & getdJydx () const { return dJydx_; }
41
+ /* * array of number of found roots for a certain event type */
42
+ std::vector<int > nroots_;
43
+ /* * array containing the time-points of discontinuities*/
44
+ std::vector<Discontinuity> discs_;
45
+ /* * index of the backward problem */
46
+ int which = 0 ;
47
+ };
64
48
49
+ /* *
50
+ * @brief The EventHandlingBwdSimulator class runs a backward simulation
51
+ * and processes events and measurements general.
52
+ */
53
+ class EventHandlingBwdSimulator {
54
+ public:
65
55
/* *
66
- * @brief Accessor for xB
67
- * @return xB
56
+ * @brief EventHandlingBwdSimulator constructor.
57
+ * @param model The model to simulate.
58
+ * @param solver The solver to use for the simulation.
59
+ * @param ws The workspace to use for the simulation.
68
60
*/
69
- AmiVector const & getAdjointState () const { return xB_; }
61
+ EventHandlingBwdSimulator (
62
+ gsl::not_null<Model*> model, gsl::not_null<Solver*> solver,
63
+ gsl::not_null<BwdSimWorkspace*> ws
64
+ )
65
+ : model_(model)
66
+ , solver_(solver)
67
+ , ws_(ws) {};
70
68
71
69
/* *
72
- * @brief Accessor for xQB
73
- * @return xQB
70
+ * @brief Run the simulation.
71
+ *
72
+ * It will run the backward simulation from the initial time of this period
73
+ * to the final timepoint of this period, handling events
74
+ * and data points as they occur.
75
+ *
76
+ * Expects the model and the solver to be set up, and `ws` to be initialized
77
+ * for this period.
78
+ *
79
+ * @param t_start The initial time of this period.
80
+ * @param t_end The final time of this period.
81
+ * @param it The index of the timepoint in `timepoints` to start with.
82
+ * @param timepoints The output timepoints or measurement timepoints of
83
+ * this period. This must contain at least the final timepoint of this
84
+ * period.
85
+ * @param dJydx State-derivative of data likelihood. Must be non-null if
86
+ * there are any data points in this period.
87
+ * @param dJzdx State-derivative of event likelihood. Must be non-null if
88
+ * the model has any event-observables.
74
89
*/
75
- AmiVector const & getAdjointQuadrature () const { return xQB_; }
90
+ void
91
+ run (realtype t_start, realtype t_end, realtype it,
92
+ std::vector<realtype> const & timepoints,
93
+ std::vector<realtype> const * dJydx, std::vector<realtype> const * dJzdx);
76
94
77
95
private:
78
- void handlePostequilibration ();
79
-
80
96
/* *
81
97
* @brief Execute everything necessary for the handling of events
82
98
* for the backward problem
83
99
* @param disc The discontinuity to handle
100
+ * @param dJzdx State-derivative of event likelihood
84
101
*/
85
- void handleEventB (Discontinuity const & disc);
102
+ void
103
+ handleEventB (Discontinuity const & disc, std::vector<realtype> const * dJzdx);
86
104
87
105
/* *
88
106
* @brief Execute everything necessary for the handling of data
89
107
* points for the backward problems
90
108
*
91
109
* @param it index of data point
110
+ * @param dJydx State-derivative of data likelihood
92
111
*/
93
- void handleDataPointB (int it);
112
+ void handleDataPointB (int it, std::vector<realtype> const * dJydx );
94
113
95
114
/* *
96
115
* @brief Compute the next timepoint to integrate to.
@@ -103,26 +122,69 @@ class BackwardProblem {
103
122
*/
104
123
realtype getTnext (int it);
105
124
125
+ /* * The model to simulate. */
126
+ Model* model_;
127
+
128
+ /* * The solver to use for the simulation. */
129
+ Solver* solver_;
130
+
131
+ /* * The workspace to use for the simulation. */
132
+ gsl::not_null<BwdSimWorkspace*> ws_;
133
+
134
+ /* * current time */
135
+ realtype t_{0 };
136
+ };
137
+
138
+ // ! class to solve backwards problems.
139
+ /* !
140
+ solves the backwards problem for adjoint sensitivity analysis and handles
141
+ events and data-points
142
+ */
143
+
144
+ class BackwardProblem {
145
+ public:
146
+ /* *
147
+ * @brief Construct backward problem from forward problem
148
+ * @param fwd pointer to corresponding forward problem
149
+ */
150
+ explicit BackwardProblem (ForwardProblem& fwd);
151
+
152
+ /* *
153
+ * @brief Solve the backward problem.
154
+ *
155
+ * If adjoint sensitivities are enabled, this will also compute
156
+ * sensitivities. workForwardProblem must be called before this function is
157
+ * called.
158
+ */
159
+ void workBackwardProblem ();
160
+
161
+ /* *
162
+ * @brief Accessor for xB
163
+ * @return xB
164
+ */
165
+ [[nodiscard]] AmiVector const & getAdjointState () const { return ws_.xB_ ; }
166
+
167
+ /* *
168
+ * @brief Accessor for xQB
169
+ * @return xQB
170
+ */
171
+ [[nodiscard]] AmiVector const & getAdjointQuadrature () const {
172
+ return ws_.xQB_ ;
173
+ }
174
+
175
+ private:
176
+ void handlePostequilibration ();
177
+
106
178
Model* model_;
107
179
Solver* solver_;
108
180
ExpData const * edata_;
109
181
110
182
/* * current time */
111
183
realtype t_;
112
- /* * adjoint state vector */
113
- AmiVector xB_;
114
- /* * differential adjoint state vector */
115
- AmiVector dxB_;
116
- /* * quadrature state vector */
117
- AmiVector xQB_;
118
184
/* * sensitivity state vector array */
119
185
AmiVectorArray sx0_;
120
- /* * array of number of found roots for a certain event type */
121
- std::vector<int > nroots_;
122
186
/* * array containing the time-points of discontinuities*/
123
187
std::vector<Discontinuity> discs_;
124
- /* * index of the backward problem */
125
- int which = 0 ;
126
188
127
189
/* * state derivative of data likelihood */
128
190
std::vector<realtype> dJydx_;
@@ -134,8 +196,12 @@ class BackwardProblem {
134
196
135
197
/* * The postequilibration steadystate problem from the forward problem. */
136
198
SteadystateProblem* posteq_problem_;
199
+
200
+ BwdSimWorkspace ws_;
201
+
202
+ EventHandlingBwdSimulator simulator_;
137
203
};
138
204
139
205
} // namespace amici
140
206
141
- #endif // BACKWARDPROBLEM_H
207
+ #endif // AMICI_BACKWARDPROBLEM_H
0 commit comments