Skip to content

Commit 42a7473

Browse files
authored
Enable presimulation with adjoints (#2786)
Enable presimulation with adjoint sensitivities. * Use previously introduced `EventHandlingBwdSimulator` for processing events also during presimulation * Make `EventHandlingBwdSimulator::run` work if there are no datapoints * Adapt setup/reinitialization logic
1 parent 0723885 commit 42a7473

File tree

6 files changed

+169
-96
lines changed

6 files changed

+169
-96
lines changed

include/amici/backwardproblem.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,9 @@ class BackwardProblem {
197197
/** The postequilibration steadystate problem from the forward problem. */
198198
SteadystateProblem* posteq_problem_;
199199

200+
/** Presimulation results */
201+
PeriodResult presim_result;
202+
200203
BwdSimWorkspace ws_;
201204

202205
EventHandlingBwdSimulator simulator_;

include/amici/forwardproblem.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,14 @@ class ForwardProblem {
502502
return nullptr;
503503
}
504504

505+
/**
506+
* @brief Get the presimulation results.
507+
* @return Presimulation results.
508+
*/
509+
PeriodResult const& get_presimulation_result() const {
510+
return pre_simulator_.result;
511+
}
512+
505513
/** pointer to model instance */
506514
Model* model;
507515

python/tests/test_preequilibration.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -299,44 +299,6 @@ def test_parameter_in_expdata(preeq_fixture):
299299
)
300300

301301

302-
def test_raise_presimulation_with_adjoints(preeq_fixture):
303-
"""Test simulation failures with adjoin+presimulation"""
304-
305-
(
306-
model,
307-
solver,
308-
edata,
309-
edata_preeq,
310-
edata_presim,
311-
edata_sim,
312-
pscales,
313-
plists,
314-
) = preeq_fixture
315-
316-
# preequilibration and presimulation with adjoints:
317-
# this needs to fail unless we remove presimulation
318-
solver.setSensitivityMethod(amici.SensitivityMethod.adjoint)
319-
320-
rdata = amici.runAmiciSimulation(model, solver, edata)
321-
assert rdata["status"] == amici.AMICI_ERROR
322-
323-
# add postequilibration
324-
y = edata.getObservedData()
325-
stdy = edata.getObservedDataStdDev()
326-
ts = np.hstack([*edata.getTimepoints(), np.inf])
327-
edata.setTimepoints(ts)
328-
edata.setObservedData(np.hstack([y, y[0]]))
329-
edata.setObservedDataStdDev(np.hstack([stdy, stdy[0]]))
330-
331-
# remove presimulation
332-
edata.t_presim = 0
333-
edata.fixedParametersPresimulation = ()
334-
335-
# no presim any more, this should work
336-
rdata = amici.runAmiciSimulation(model, solver, edata)
337-
assert rdata["status"] == amici.AMICI_SUCCESS
338-
339-
340302
def test_equilibration_methods_with_adjoints(preeq_fixture):
341303
"""Test different combinations of equilibration and simulation
342304
sensitivity methods"""

python/tests/test_sbml_import.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,9 @@ def test_presimulation(sbml_example_presimulation_module):
251251
check_derivatives(model, solver, edata, epsilon=1e-4)
252252

253253

254+
@pytest.mark.filterwarnings(
255+
"ignore:Adjoint sensitivity analysis for models with discontinuous "
256+
)
254257
def test_presimulation_events(tempdir):
255258
"""Test that events are handled during presimulation."""
256259

@@ -306,7 +309,8 @@ def test_presimulation_events(tempdir):
306309

307310
for sensi_method in (
308311
amici.SensitivityMethod.forward,
309-
# TODO ASA
312+
# FIXME: test with adjoints. currently there is some CVodeF issue
313+
# that fails forward simulation with adjoint sensitivities
310314
# amici.SensitivityMethod.adjoint,
311315
):
312316
solver.setSensitivityMethod(sensi_method)
@@ -336,6 +340,87 @@ def test_presimulation_events(tempdir):
336340
)
337341

338342

343+
@pytest.mark.filterwarnings(
344+
"ignore:Adjoint sensitivity analysis for models with discontinuous "
345+
)
346+
def test_presimulation_events_and_sensitivities(tempdir):
347+
"""Test that presimulation with adjoint sensitivities works
348+
and test that events are handled during presimulation."""
349+
350+
from amici.antimony_import import antimony2amici
351+
352+
model_name = "test_presim_events2"
353+
antimony2amici(
354+
"""
355+
some_time = time
356+
some_time' = 1
357+
bolus = 1
358+
359+
k_pre = 3
360+
k_main = 2
361+
xx = 0
362+
xx' = piecewise(k_pre, time < 0, k_main)
363+
364+
# this will trigger twice, once in presimulation
365+
# and once in the main simulation
366+
at time >= -1 , t0=false: some_time = some_time + bolus
367+
""",
368+
model_name=model_name,
369+
output_dir=tempdir,
370+
)
371+
372+
model_module = import_model_module(model_name, tempdir)
373+
374+
model = model_module.get_model()
375+
model.setTimepoints([0, 1, 2])
376+
edata = amici.ExpData(model)
377+
edata.t_presim = 2
378+
solver = model.getSolver()
379+
380+
# generate artificial data
381+
rdata = amici.runAmiciSimulation(model, solver, edata)
382+
edata_tmp = amici.ExpData(rdata, 1, 0)
383+
edata.setTimepoints(np.array(edata_tmp.getTimepoints()) + 0.1)
384+
edata.setObservedData(edata_tmp.getObservedData())
385+
edata.setObservedDataStdDev(edata_tmp.getObservedDataStdDev())
386+
387+
solver.setSensitivityOrder(amici.SensitivityOrder.first)
388+
389+
for sensi_method in (
390+
amici.SensitivityMethod.forward,
391+
amici.SensitivityMethod.adjoint,
392+
):
393+
solver.setSensitivityMethod(sensi_method)
394+
rdata = amici.runAmiciSimulation(model, solver, edata)
395+
396+
assert rdata.status == amici.AMICI_SUCCESS
397+
assert_allclose(
398+
rdata.by_id("some_time"), np.array([0, 1, 2]) + 2.1, atol=1e-14
399+
)
400+
401+
if sensi_method == amici.SensitivityMethod.forward:
402+
model.requireSensitivitiesForAllParameters()
403+
else:
404+
# FIXME ASA with events:
405+
# https://github.com/AMICI-dev/AMICI/pull/1539
406+
model.setParameterList(
407+
[
408+
i
409+
for i, p in enumerate(model.getParameterIds())
410+
if p != "bolus"
411+
]
412+
)
413+
414+
check_derivatives(
415+
model,
416+
solver,
417+
edata=edata,
418+
atol=1e-6,
419+
rtol=1e-6,
420+
epsilon=1e-8,
421+
)
422+
423+
339424
def test_steadystate_simulation(model_steadystate_module):
340425
model = model_steadystate_module.getModel()
341426
model.setTimepoints(np.linspace(0, 60, 60))

src/backwardproblem.cpp

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ BackwardProblem::BackwardProblem(ForwardProblem& fwd)
2020
, dJzdx_(fwd.getDJzdx())
2121
, preeq_problem_(fwd.getPreequilibrationProblem())
2222
, posteq_problem_(fwd.getPostequilibrationProblem())
23+
, presim_result(fwd.get_presimulation_result())
2324
, ws_(model_, solver_)
2425
, simulator_(model_, solver_, &ws_) {}
2526

@@ -34,6 +35,8 @@ void BackwardProblem::workBackwardProblem() {
3435

3536
handlePostequilibration();
3637

38+
// handle main simulation
39+
3740
// If we have posteq, infinity timepoints were already treated
3841
int it = model_->nt() - 1;
3942
while (it >= 0 && std::isinf(model_->getTimepoint(it))) {
@@ -53,8 +56,13 @@ void BackwardProblem::workBackwardProblem() {
5356
ConditionContext cc(
5457
model_, edata_, FixedParameterContext::presimulation
5558
);
56-
solver_->runB(model_->t0() - edata_->t_presim);
57-
solver_->writeSolutionB(&t_, ws_.xB_, ws_.dxB_, ws_.xQB_, ws_.which);
59+
ws_.discs_ = presim_result.discs;
60+
ws_.nroots_
61+
= compute_nroots(ws_.discs_, model_->ne, model_->nMaxEvent());
62+
simulator_.run(
63+
model_->t0(), model_->t0() - edata_->t_presim, -1, {}, &dJydx_,
64+
&dJzdx_
65+
);
5866
}
5967

6068
// handle pre-equilibration
@@ -178,7 +186,8 @@ void EventHandlingBwdSimulator::run(
178186

179187
t_ = t_start;
180188

181-
if ((it >= 0 || !ws_->discs_.empty()) && timepoints[it] > t_end) {
189+
// datapoint at t_start?
190+
if (it >= 0 && timepoints[it] == t_start) {
182191
handleDataPointB(it, dJydx);
183192
solver_->setupB(
184193
&ws_->which, timepoints[it], model_, ws_->xB_, ws_->dxB_, ws_->xQB_
@@ -187,34 +196,40 @@ void EventHandlingBwdSimulator::run(
187196
// as it is not called in handleDataPointB
188197
solver_->storeDiagnosisB(ws_->which);
189198
--it;
199+
} else {
200+
// no data points, only discontinuities, just set up the solver
201+
// (e.g., during presimulation)
202+
solver_->setupB(
203+
&ws_->which, t_start, model_, ws_->xB_, ws_->dxB_, ws_->xQB_
204+
);
205+
}
190206

191-
while (it >= 0 || !ws_->discs_.empty()) {
192-
// check if next timepoint is a discontinuity or a data-point
193-
double tnext = getTnext(it);
194-
195-
if (tnext < t_) {
196-
solver_->runB(tnext);
197-
solver_->writeSolutionB(
198-
&t_, ws_->xB_, ws_->dxB_, ws_->xQB_, ws_->which
199-
);
200-
}
207+
while (it >= 0 || !ws_->discs_.empty()) {
208+
// check if next timepoint is a discontinuity or a data-point
209+
double tnext = getTnext(it);
201210

202-
// handle discontinuity
203-
if (!ws_->discs_.empty() && tnext == ws_->discs_.back().time) {
204-
handleEventB(ws_->discs_.back(), dJzdx);
205-
ws_->discs_.pop_back();
206-
}
211+
if (tnext < t_) {
212+
solver_->runB(tnext);
213+
solver_->writeSolutionB(
214+
&t_, ws_->xB_, ws_->dxB_, ws_->xQB_, ws_->which
215+
);
216+
}
207217

208-
// handle data-point
209-
if (it >= 0 && tnext == timepoints[it]) {
210-
handleDataPointB(it, dJydx);
211-
it--;
212-
}
218+
// handle discontinuity
219+
if (!ws_->discs_.empty() && tnext == ws_->discs_.back().time) {
220+
handleEventB(ws_->discs_.back(), dJzdx);
221+
ws_->discs_.pop_back();
222+
}
213223

214-
// reinitialize state
215-
solver_->reInitB(ws_->which, t_, ws_->xB_, ws_->dxB_);
216-
solver_->quadReInitB(ws_->which, ws_->xQB_);
224+
// handle data-point
225+
if (it >= 0 && tnext == timepoints[it]) {
226+
handleDataPointB(it, dJydx);
227+
it--;
217228
}
229+
230+
// reinitialize state
231+
solver_->reInitB(ws_->which, t_, ws_->xB_, ws_->dxB_);
232+
solver_->quadReInitB(ws_->which, ws_->xQB_);
218233
}
219234

220235
// we still need to integrate from first datapoint to t_start

src/forwardproblem.cpp

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -181,39 +181,35 @@ void ForwardProblem::handlePresimulation() {
181181
if (!uses_presimulation_)
182182
return;
183183

184-
if (solver->computingASA()) {
185-
throw AmiException(
186-
"Presimulation with adjoint sensitivities"
187-
" is currently not implemented."
188-
);
189-
}
184+
// Are there dedicated condition preequilibration parameters provided?
185+
ConditionContext cond(model, edata, FixedParameterContext::presimulation);
190186

191-
{
192-
// Are there dedicated condition preequilibration parameters provided?
193-
ConditionContext cond(
194-
model, edata, FixedParameterContext::presimulation
187+
// If we need to reinitialize solver states, this won't work yet.
188+
if (model->nx_reinit() > 0)
189+
throw AmiException(
190+
"Adjoint presimulation with reinitialization of "
191+
"non-constant states is not yet implemented. Stopping."
195192
);
196193

197-
// compute initial time and setup solver for (pre-)simulation
198-
t_ = model->t0() - edata->t_presim;
199-
200-
// if preequilibration was done, model was already initialized
201-
if (!preequilibrated_) {
202-
model->initialize(
203-
t_, ws_.x, ws_.dx, ws_.sx, ws_.sdx,
204-
solver->getSensitivityOrder() >= SensitivityOrder::first,
205-
ws_.roots_found
206-
);
207-
} else if (model->ne) {
208-
model->initEvents(t_, ws_.x, ws_.dx, ws_.roots_found);
209-
}
210-
solver->setup(t_, model, ws_.x, ws_.dx, ws_.sx, ws_.sdx);
211-
solver->updateAndReinitStatesAndSensitivities(model);
194+
// compute initial time and setup solver for (pre-)simulation
195+
t_ = model->t0() - edata->t_presim;
212196

213-
std::vector<realtype> const timepoints{model->t0()};
214-
pre_simulator_.run(t_, edata, timepoints);
215-
solver->writeSolution(&t_, ws_.x, ws_.dx, ws_.sx, ws_.dx);
197+
// if preequilibration was done, model was already initialized
198+
if (!preequilibrated_) {
199+
model->initialize(
200+
t_, ws_.x, ws_.dx, ws_.sx, ws_.sdx,
201+
solver->getSensitivityOrder() >= SensitivityOrder::first,
202+
ws_.roots_found
203+
);
204+
} else if (model->ne) {
205+
model->initEvents(t_, ws_.x, ws_.dx, ws_.roots_found);
216206
}
207+
solver->setup(t_, model, ws_.x, ws_.dx, ws_.sx, ws_.sdx);
208+
solver->updateAndReinitStatesAndSensitivities(model);
209+
210+
std::vector<realtype> const timepoints{model->t0()};
211+
pre_simulator_.run(t_, edata, timepoints);
212+
solver->writeSolution(&t_, ws_.x, ws_.dx, ws_.sx, ws_.dx);
217213
}
218214

219215
void ForwardProblem::handleMainSimulation() {
@@ -236,7 +232,10 @@ void ForwardProblem::handleMainSimulation() {
236232

237233
t_ = model->t0();
238234

239-
solver->setup(t_, model, ws_.x, ws_.dx, ws_.sx, ws_.sdx);
235+
// in case of presimulation, the solver was set up already
236+
if (!uses_presimulation_) {
237+
solver->setup(t_, model, ws_.x, ws_.dx, ws_.sx, ws_.sdx);
238+
}
240239

241240
if (preequilibrated_ || uses_presimulation_) {
242241
// Reset the time and re-initialize events for the main simulation
@@ -278,10 +277,11 @@ void EventHandlingSimulator::handle_event(
278277

279278
if (!initial_event && t_ == ws_->tlastroot) {
280279
throw AmiException(
281-
"AMICI is stuck in an event, as the initial "
280+
"AMICI is stuck in an event at time %g, as the initial "
282281
"step-size after the event is too small. "
283282
"To fix this, increase absolute and relative "
284-
"tolerances!"
283+
"tolerances!",
284+
t_
285285
);
286286
}
287287
ws_->tlastroot = t_;

0 commit comments

Comments
 (0)