Skip to content

Commit b8e3c80

Browse files
authored
Extract WRMS computation and buffers (#2761)
Also fix an issue caused by incorrectly using the state-mask for computing the WRMS for `xQB`, which should only be used for `x`.
1 parent 0e33004 commit b8e3c80

File tree

2 files changed

+88
-48
lines changed

2 files changed

+88
-48
lines changed

include/amici/steadystateproblem.h

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,54 @@ class Solver;
1515
class Model;
1616
class BackwardProblem;
1717

18+
/**
19+
* @brief Computes the weighted root mean square norm.
20+
*
21+
* This class is used to compute the weighted root mean square of the residuals
22+
* and maintains its work space to avoid reallocation.
23+
*/
24+
class WRMSComputer {
25+
public:
26+
/**
27+
* @brief Constructor.
28+
* @param n The length of the vectors for which to compute the WRMS.
29+
* @param sunctx A SUNDIALS context for the NVector.
30+
* @param atol Absolute tolerance to compute error weights.
31+
* @param rtol Relative tolerance to compute error weights.
32+
* @param mask Mask for entries to include in the WRMS norm.
33+
* Positive value: include; non-positive value: exclude; empty: include all.
34+
*/
35+
WRMSComputer(
36+
int n, SUNContext sunctx, realtype atol, realtype rtol, AmiVector mask
37+
)
38+
: ewt_(n, sunctx)
39+
, rtol_(rtol)
40+
, atol_(atol)
41+
, mask_(mask) {}
42+
43+
/**
44+
* @brief Compute the weighted root mean square of the residuals.
45+
* @param x Vector to compute the WRMS for.
46+
* @param x_ref The reference vector from which to compute the error
47+
* weights.
48+
* @return The WRMS norm.
49+
*/
50+
realtype wrms(AmiVector const& x, AmiVector const& x_ref);
51+
52+
private:
53+
/** Error weights for the residuals. */
54+
AmiVector ewt_;
55+
/** Relative tolerance to compute error weights. */
56+
realtype rtol_;
57+
/** Absolute tolerance to compute error weights. */
58+
realtype atol_;
59+
/**
60+
* Mask for entries to include in the WRMS norm.
61+
* Positive value: include; non-positive value: exclude; empty: include all.
62+
*/
63+
AmiVector mask_;
64+
};
65+
1866
/**
1967
* @brief The SteadystateProblem class solves a steady-state problem using
2068
* Newton's method and falls back to integration on failure.
@@ -374,10 +422,12 @@ class SteadystateProblem {
374422
AmiVector delta_;
375423
/** previous newton step (size: nx_solver). */
376424
AmiVector delta_old_;
377-
/** error weights for solver state, dimension nx_solver */
378-
AmiVector ewt_;
379-
/** error weights for backward quadratures, dimension nplist() */
380-
AmiVector ewtQB_;
425+
/** WRMS computer for x */
426+
WRMSComputer wrms_computer_x_;
427+
/** WRMS computer for xQB */
428+
WRMSComputer wrms_computer_xQB_;
429+
/** WRMS computer for sx */
430+
WRMSComputer wrms_computer_sx_;
381431
/** old state vector */
382432
AmiVector x_old_;
383433
/** time derivative state vector */
@@ -430,19 +480,6 @@ class SteadystateProblem {
430480
*/
431481
std::vector<SteadyStateStatus> steady_state_status_;
432482

433-
/** absolute tolerance for convergence check (state)*/
434-
realtype atol_{NAN};
435-
/** relative tolerance for convergence check (state)*/
436-
realtype rtol_{NAN};
437-
/** absolute tolerance for convergence check (state sensi)*/
438-
realtype atol_sensi_{NAN};
439-
/** relative tolerance for convergence check (state sensi)*/
440-
realtype rtol_sensi_{NAN};
441-
/** absolute tolerance for convergence check (quadratures)*/
442-
realtype atol_quad_{NAN};
443-
/** relative tolerance for convergence check (quadratures)*/
444-
realtype rtol_quad_{NAN};
445-
446483
/** Newton solver */
447484
NewtonSolver newton_solver_;
448485

src/steadystateproblem.cpp

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,6 @@ realtype getWrmsNorm(
6161
AmiVector const& x, AmiVector const& xdot, AmiVector const& mask,
6262
realtype atol, realtype rtol, AmiVector& ewt
6363
) {
64-
// Depending on what convergence we want to check (xdot, sxdot, xQBdot)
65-
// we need to pass ewt[QB], as xdot and xQBdot have different sizes.
66-
6764
// ewt = x
6865
N_VAbs(const_cast<N_Vector>(x.getNVector()), ewt.getNVector());
6966
// ewt *= rtol
@@ -85,6 +82,10 @@ realtype getWrmsNorm(
8582
);
8683
}
8784

85+
realtype WRMSComputer::wrms(AmiVector const& x, AmiVector const& x_ref) {
86+
return getWrmsNorm(x_ref, x, mask_, atol_, rtol_, ewt_);
87+
}
88+
8889
/**
8990
* @brief Compute the backward quadratures, which contribute to the
9091
* gradient (xQB) from the quadrature over the backward state itself (xQ)
@@ -118,8 +119,23 @@ void computeQBfromQ(
118119
SteadystateProblem::SteadystateProblem(Solver const& solver, Model const& model)
119120
: delta_(model.nx_solver, solver.getSunContext())
120121
, delta_old_(model.nx_solver, solver.getSunContext())
121-
, ewt_(model.nx_solver, solver.getSunContext())
122-
, ewtQB_(model.nplist(), solver.getSunContext())
122+
, wrms_computer_x_(
123+
model.nx_solver, solver.getSunContext(),
124+
solver.getAbsoluteToleranceSteadyState(),
125+
solver.getRelativeToleranceSteadyState(),
126+
AmiVector(model.get_steadystate_mask(), solver.getSunContext())
127+
)
128+
, wrms_computer_xQB_(
129+
model.nplist(), solver.getSunContext(),
130+
solver.getAbsoluteToleranceQuadratures(),
131+
solver.getRelativeToleranceQuadratures(), AmiVector()
132+
)
133+
, wrms_computer_sx_(
134+
model.nx_solver, solver.getSunContext(),
135+
solver.getAbsoluteToleranceSteadyStateSensi(),
136+
solver.getRelativeToleranceSteadyStateSensi(),
137+
AmiVector(model.get_steadystate_mask(), solver.getSunContext())
138+
)
123139
, x_old_(model.nx_solver, solver.getSunContext())
124140
, xdot_(model.nx_solver, solver.getSunContext())
125141
, sdx_(model.nx_solver, model.nplist(), solver.getSunContext())
@@ -141,12 +157,6 @@ SteadystateProblem::SteadystateProblem(Solver const& solver, Model const& model)
141157
),
142158
.state = model.getModelState()}
143159
)
144-
, atol_(solver.getAbsoluteToleranceSteadyState())
145-
, rtol_(solver.getRelativeToleranceSteadyState())
146-
, atol_sensi_(solver.getAbsoluteToleranceSteadyStateSensi())
147-
, rtol_sensi_(solver.getRelativeToleranceSteadyStateSensi())
148-
, atol_quad_(solver.getAbsoluteToleranceQuadratures())
149-
, rtol_quad_(solver.getRelativeToleranceQuadratures())
150160
, newton_solver_(
151161
NewtonSolver(model, solver.getLinearSolver(), solver.getSunContext())
152162
)
@@ -607,7 +617,6 @@ bool SteadystateProblem::requires_state_sensitivities(
607617

608618
realtype
609619
SteadystateProblem::getWrms(Model& model, SensitivityMethod sensi_method) {
610-
realtype wrms = INFINITY;
611620
if (sensi_method == SensitivityMethod::adjoint) {
612621
if (newton_step_conv_) {
613622
throw NewtonFailure(
@@ -622,22 +631,18 @@ SteadystateProblem::getWrms(Model& model, SensitivityMethod sensi_method) {
622631
// to zero at all. So we need xQBdot, hence compute xQB first.
623632
computeQBfromQ(model, xQ_, xQB_, state_);
624633
computeQBfromQ(model, xB_, xQBdot_, state_);
625-
wrms = getWrmsNorm(
626-
xQB_, xQBdot_, steadystate_mask_, atol_quad_, rtol_quad_, ewtQB_
627-
);
628-
} else {
629-
// If we're doing a forward simulation (with or without sensitivities),
630-
// get RHS and compute weighted error norm.
631-
if (newton_step_conv_)
632-
getNewtonStep(model);
633-
else
634-
updateRightHandSide(model);
635-
wrms = getWrmsNorm(
636-
state_.x, newton_step_conv_ ? delta_ : xdot_, steadystate_mask_,
637-
atol_, rtol_, ewt_
638-
);
634+
return wrms_computer_xQB_.wrms(xQBdot_, xQB_);
639635
}
640-
return wrms;
636+
637+
if (newton_step_conv_) {
638+
getNewtonStep(model);
639+
return wrms_computer_x_.wrms(delta_, state_.x);
640+
}
641+
642+
// If we're doing a forward simulation (with or without sensitivities),
643+
// get RHS and compute weighted error norm.
644+
updateRightHandSide(model);
645+
return wrms_computer_x_.wrms(xdot_, state_.x);
641646
}
642647

643648
realtype SteadystateProblem::getWrmsFSA(Model& model) {
@@ -655,10 +660,7 @@ realtype SteadystateProblem::getWrmsFSA(Model& model) {
655660
);
656661
if (newton_step_conv_)
657662
newton_solver_.solveLinearSystem(xdot_);
658-
wrms = getWrmsNorm(
659-
state_.sx[ip], xdot_, steadystate_mask_, atol_sensi_, rtol_sensi_,
660-
ewt_
661-
);
663+
wrms = wrms_computer_sx_.wrms(xdot_, state_.sx[ip]);
662664
// ideally this function would report the maximum of all wrms over
663665
// all ip, but for practical purposes we can just report the wrms for
664666
// the first ip where we know that the convergence threshold is not
@@ -939,4 +941,5 @@ void SteadystateProblem::getNewtonStep(Model& model) {
939941
newton_solver_.getStep(delta_, model, state_);
940942
delta_updated_ = true;
941943
}
944+
942945
} // namespace amici

0 commit comments

Comments
 (0)