@@ -61,9 +61,6 @@ realtype getWrmsNorm(
61
61
AmiVector const & x, AmiVector const & xdot, AmiVector const & mask,
62
62
realtype atol, realtype rtol, AmiVector& ewt
63
63
) {
64
- // Depending on what convergence we want to check (xdot, sxdot, xQBdot)
65
- // we need to pass ewt[QB], as xdot and xQBdot have different sizes.
66
-
67
64
// ewt = x
68
65
N_VAbs (const_cast <N_Vector>(x.getNVector ()), ewt.getNVector ());
69
66
// ewt *= rtol
@@ -85,6 +82,10 @@ realtype getWrmsNorm(
85
82
);
86
83
}
87
84
85
+ realtype WRMSComputer::wrms (AmiVector const & x, AmiVector const & x_ref) {
86
+ return getWrmsNorm (x_ref, x, mask_, atol_, rtol_, ewt_);
87
+ }
88
+
88
89
/* *
89
90
* @brief Compute the backward quadratures, which contribute to the
90
91
* gradient (xQB) from the quadrature over the backward state itself (xQ)
@@ -118,8 +119,23 @@ void computeQBfromQ(
118
119
SteadystateProblem::SteadystateProblem (Solver const & solver, Model const & model)
119
120
: delta_(model.nx_solver, solver.getSunContext())
120
121
, delta_old_(model.nx_solver, solver.getSunContext())
121
- , ewt_(model.nx_solver, solver.getSunContext())
122
- , ewtQB_(model.nplist(), solver.getSunContext())
122
+ , wrms_computer_x_(
123
+ model.nx_solver, solver.getSunContext(),
124
+ solver.getAbsoluteToleranceSteadyState(),
125
+ solver.getRelativeToleranceSteadyState(),
126
+ AmiVector (model.get_steadystate_mask(), solver.getSunContext())
127
+ )
128
+ , wrms_computer_xQB_(
129
+ model.nplist(), solver.getSunContext(),
130
+ solver.getAbsoluteToleranceQuadratures(),
131
+ solver.getRelativeToleranceQuadratures(), AmiVector()
132
+ )
133
+ , wrms_computer_sx_(
134
+ model.nx_solver, solver.getSunContext(),
135
+ solver.getAbsoluteToleranceSteadyStateSensi(),
136
+ solver.getRelativeToleranceSteadyStateSensi(),
137
+ AmiVector(model.get_steadystate_mask(), solver.getSunContext())
138
+ )
123
139
, x_old_(model.nx_solver, solver.getSunContext())
124
140
, xdot_(model.nx_solver, solver.getSunContext())
125
141
, sdx_(model.nx_solver, model.nplist(), solver.getSunContext())
@@ -141,12 +157,6 @@ SteadystateProblem::SteadystateProblem(Solver const& solver, Model const& model)
141
157
),
142
158
.state = model.getModelState ()}
143
159
)
144
- , atol_(solver.getAbsoluteToleranceSteadyState())
145
- , rtol_(solver.getRelativeToleranceSteadyState())
146
- , atol_sensi_(solver.getAbsoluteToleranceSteadyStateSensi())
147
- , rtol_sensi_(solver.getRelativeToleranceSteadyStateSensi())
148
- , atol_quad_(solver.getAbsoluteToleranceQuadratures())
149
- , rtol_quad_(solver.getRelativeToleranceQuadratures())
150
160
, newton_solver_(
151
161
NewtonSolver (model, solver.getLinearSolver(), solver.getSunContext())
152
162
)
@@ -607,7 +617,6 @@ bool SteadystateProblem::requires_state_sensitivities(
607
617
608
618
realtype
609
619
SteadystateProblem::getWrms (Model& model, SensitivityMethod sensi_method) {
610
- realtype wrms = INFINITY;
611
620
if (sensi_method == SensitivityMethod::adjoint) {
612
621
if (newton_step_conv_) {
613
622
throw NewtonFailure (
@@ -622,22 +631,18 @@ SteadystateProblem::getWrms(Model& model, SensitivityMethod sensi_method) {
622
631
// to zero at all. So we need xQBdot, hence compute xQB first.
623
632
computeQBfromQ (model, xQ_, xQB_, state_);
624
633
computeQBfromQ (model, xB_, xQBdot_, state_);
625
- wrms = getWrmsNorm (
626
- xQB_, xQBdot_, steadystate_mask_, atol_quad_, rtol_quad_, ewtQB_
627
- );
628
- } else {
629
- // If we're doing a forward simulation (with or without sensitivities),
630
- // get RHS and compute weighted error norm.
631
- if (newton_step_conv_)
632
- getNewtonStep (model);
633
- else
634
- updateRightHandSide (model);
635
- wrms = getWrmsNorm (
636
- state_.x , newton_step_conv_ ? delta_ : xdot_, steadystate_mask_,
637
- atol_, rtol_, ewt_
638
- );
634
+ return wrms_computer_xQB_.wrms (xQBdot_, xQB_);
639
635
}
640
- return wrms;
636
+
637
+ if (newton_step_conv_) {
638
+ getNewtonStep (model);
639
+ return wrms_computer_x_.wrms (delta_, state_.x );
640
+ }
641
+
642
+ // If we're doing a forward simulation (with or without sensitivities),
643
+ // get RHS and compute weighted error norm.
644
+ updateRightHandSide (model);
645
+ return wrms_computer_x_.wrms (xdot_, state_.x );
641
646
}
642
647
643
648
realtype SteadystateProblem::getWrmsFSA (Model& model) {
@@ -655,10 +660,7 @@ realtype SteadystateProblem::getWrmsFSA(Model& model) {
655
660
);
656
661
if (newton_step_conv_)
657
662
newton_solver_.solveLinearSystem (xdot_);
658
- wrms = getWrmsNorm (
659
- state_.sx [ip], xdot_, steadystate_mask_, atol_sensi_, rtol_sensi_,
660
- ewt_
661
- );
663
+ wrms = wrms_computer_sx_.wrms (xdot_, state_.sx [ip]);
662
664
// ideally this function would report the maximum of all wrms over
663
665
// all ip, but for practical purposes we can just report the wrms for
664
666
// the first ip where we know that the convergence threshold is not
@@ -939,4 +941,5 @@ void SteadystateProblem::getNewtonStep(Model& model) {
939
941
newton_solver_.getStep (delta_, model, state_);
940
942
delta_updated_ = true ;
941
943
}
944
+
942
945
} // namespace amici
0 commit comments