Skip to content

Commit c6944b6

Browse files
committed
SteadystateProblem: extract NewtonsMethod (cont.)
Move the remainders of Newton's method from `SteadystateProblem` to `NewtonsMethod`. Also, a slight performance improvement through avoiding checking for negative state in case there are no non-negativity constraints. Previously, the state vector checked even if there were no such constraints.
1 parent 6069b9d commit c6944b6

File tree

4 files changed

+310
-182
lines changed

4 files changed

+310
-182
lines changed

include/amici/model.h

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,19 +1459,19 @@ class Model : public AbstractModel, public ModelDimensions {
14591459
* constants / fixed parameters
14601460
* @return Those indices.
14611461
*/
1462-
std::vector<int> const& getReinitializationStateIdxs() const;
1462+
[[nodiscard]] std::vector<int> const& getReinitializationStateIdxs() const;
14631463

14641464
/**
14651465
* @brief getter for dxdotdp (matlab generated)
14661466
* @return dxdotdp
14671467
*/
1468-
AmiVectorArray const& get_dxdotdp() const;
1468+
[[nodiscard]] AmiVectorArray const& get_dxdotdp() const;
14691469

14701470
/**
14711471
* @brief getter for dxdotdp (python generated)
14721472
* @return dxdotdp
14731473
*/
1474-
SUNMatrixWrapper const& get_dxdotdp_full() const;
1474+
[[nodiscard]] SUNMatrixWrapper const& get_dxdotdp_full() const;
14751475

14761476
/**
14771477
* @brief Get trigger times for events that don't require root-finding.
@@ -1480,7 +1480,7 @@ class Model : public AbstractModel, public ModelDimensions {
14801480
* root-finding (i.e. that trigger at predetermined timepoints),
14811481
* in ascending order.
14821482
*/
1483-
virtual std::vector<double> get_trigger_timepoints() const;
1483+
[[nodiscard]] virtual std::vector<double> get_trigger_timepoints() const;
14841484

14851485
/**
14861486
* @brief Get steady-state mask as std::vector.
@@ -1489,7 +1489,7 @@ class Model : public AbstractModel, public ModelDimensions {
14891489
*
14901490
* @return Steady-state mask
14911491
*/
1492-
std::vector<realtype> get_steadystate_mask() const {
1492+
[[nodiscard]] std::vector<realtype> get_steadystate_mask() const {
14931493
return steadystate_mask_;
14941494
};
14951495

@@ -1511,7 +1511,18 @@ class Model : public AbstractModel, public ModelDimensions {
15111511
* @param ie event index
15121512
* @return The corresponding Event object.
15131513
*/
1514-
Event const& get_event(int ie) const { return events_.at(ie); }
1514+
[[nodiscard]] Event const& get_event(int ie) const {
1515+
return events_.at(ie);
1516+
}
1517+
1518+
/**
1519+
* @brief Whether there is at least one state variable for which
1520+
* non-negativity is to be enforced.
1521+
* @return Vector of all events.
1522+
*/
1523+
[[nodiscard]] bool get_any_state_nonnegative() const {
1524+
return any_state_non_negative_;
1525+
}
15151526

15161527
/**
15171528
* Flag indicating whether for

include/amici/steadystateproblem.h

Lines changed: 121 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ class Model;
1616
class BackwardProblem;
1717

1818
/**
19-
* @brief Computes the weighted root mean square norm.
19+
* @brief Computes the weighted root-mean-square norm.
2020
*
21-
* This class is used to compute the weighted root mean square of the residuals
21+
* This class is used to compute the weighted root-mean-square of the residuals
2222
* and maintains its work space to avoid reallocation.
2323
*/
2424
class WRMSComputer {
@@ -41,7 +41,7 @@ class WRMSComputer {
4141
, mask_(mask) {}
4242

4343
/**
44-
* @brief Compute the weighted root mean square of the residuals.
44+
* @brief Compute the weighted root-mean-square of the residuals.
4545
* @param x Vector to compute the WRMS for.
4646
* @param x_ref The reference vector from which to compute the error
4747
* weights.
@@ -66,30 +66,67 @@ class WRMSComputer {
6666
/**
6767
* @brief Implements Newton's method for finding steady states.
6868
*
69-
* TODO: To be extended after further disentangling SteadyStateProblem.
69+
* See also:
70+
* Lines et al. (2019), IFAC-PapersOnLine 52 (26): 32–37.
71+
* https://doi.org/10.1016/j.ifacol.2019.12.232
7072
*/
7173
class NewtonsMethod {
7274
public:
7375
/**
7476
* @brief Constructor.
75-
* @param nx Number of solver states (nx_solver).
77+
* @param model Number of solver states (nx_solver).
78+
* @param solver NewtonSolver instance to compute the Newton step.
79+
* Expected to be correctly initialized.
7680
* @param sunctx A SUNDIALS context for the NVector.
7781
* @param max_steps
7882
* @param damping_factor_mode
7983
* @param damping_factor_lower_bound
84+
* @param check_delta
8085
*/
8186
NewtonsMethod(
82-
int nx, SUNContext sunctx,
87+
gsl::not_null<Model*> model, SUNContext sunctx,
88+
gsl::not_null<NewtonSolver*> solver,
8389
NewtonDampingFactorMode damping_factor_mode,
84-
realtype damping_factor_lower_bound,
85-
int max_steps
86-
)
87-
: max_steps_(max_steps)
88-
, delta_(nx, sunctx)
89-
, delta_old_(nx, sunctx)
90-
, damping_factor_mode_(damping_factor_mode)
91-
, damping_factor_lower_bound_(damping_factor_lower_bound) {}
90+
realtype damping_factor_lower_bound, int max_steps, bool check_delta
91+
);
92+
93+
/**
94+
* @brief Run the Newton solver iterations and checks for convergence
95+
* to steady state.
96+
* @param xdot Time derivative of the state vector `state.x`.
97+
* @param state SimulationState instance containing the current state.
98+
* @param wrms_computer WRMSComputer instance to compute the WRMS norm.
99+
*/
100+
void
101+
run(AmiVector& xdot, SimulationState& state, WRMSComputer& wrms_computer);
102+
103+
/**
104+
* @brief Compute the Newton step for the current state_.x and xdot and
105+
* store it in delta_.
106+
* @param xdot Time derivative of the state vector `state.x`.
107+
* @param state SimulationState instance containing the current state.
108+
*/
109+
void compute_step(AmiVector const& xdot, SimulationState const& state);
110+
111+
/**
112+
* @brief Get the last Newton step.
113+
* @return Newton step
114+
*/
115+
[[nodiscard]] AmiVector const& get_delta() const { return delta_; }
116+
117+
/**
118+
* @brief Get the number of steps taken in the current iteration.
119+
* @return Number of steps taken.
120+
*/
121+
[[nodiscard]] int get_num_steps() const { return i_step; }
122+
123+
/**
124+
* @brief Get the current WRMS norm.
125+
* @return The current WRMS norm.
126+
*/
127+
[[nodiscard]] realtype get_wrms() const { return wrms_; }
92128

129+
private:
93130
/**
94131
* @brief Update the damping factor gamma that determines step size.
95132
*
@@ -101,43 +138,76 @@ class NewtonsMethod {
101138
* dampening (false)
102139
*/
103140

104-
bool updateDampingFactor(bool step_successful, double& gamma) {
105-
if (damping_factor_mode_ != NewtonDampingFactorMode::on)
106-
return true;
107-
108-
if (step_successful) {
109-
gamma = fmin(1.0, 2.0 * gamma);
110-
} else {
111-
gamma /= 4.0;
112-
}
113-
114-
if (gamma < damping_factor_lower_bound_) {
115-
throw NewtonFailure(
116-
AMICI_DAMPING_FACTOR_ERROR,
117-
"Newton solver failed: the damping factor "
118-
"reached its lower bound"
119-
);
120-
}
121-
return step_successful;
122-
}
141+
bool update_damping_factor(bool step_successful, double& gamma);
142+
143+
/**
144+
* @brief Compute the weighted root-mean-square of the residuals.
145+
* @param xdot
146+
* @param state
147+
* @param wrms_computer
148+
* @return WRMS norm.
149+
*/
150+
realtype compute_wrms(
151+
AmiVector const& xdot, SimulationState const& state,
152+
WRMSComputer& wrms_computer
153+
);
154+
155+
/**
156+
* @brief Check for convergence.
157+
*
158+
* Check if NewtonsMethod::wrms_ is below the convergence threshold,
159+
* make the state non-negative if requested, and recompute and check
160+
* the WRMS norm again.
161+
*
162+
* @param xdot
163+
* @param state
164+
* @param wrms_computer
165+
* @return Whether convergence has been reached.
166+
*/
167+
bool has_converged(
168+
AmiVector& xdot, SimulationState& state, WRMSComputer& wrms_computer
169+
);
170+
171+
static constexpr realtype conv_thresh = 1.0;
123172

124-
// TODO: make private after further disentangling SteadyStateProblem
173+
/** Pointer to the model instance. */
174+
gsl::not_null<Model*> model_;
125175

126176
/** Maximum number of iterations. */
127177
int max_steps_{0};
128178

179+
/** damping factor flag */
180+
NewtonDampingFactorMode damping_factor_mode_{NewtonDampingFactorMode::on};
181+
182+
/** damping factor lower bound */
183+
realtype damping_factor_lower_bound_{1e-8};
184+
185+
/**
186+
* Whether to check the Newton step (delta) or the right-hand side (xdot)
187+
* during the convergence check.
188+
*/
189+
bool check_delta_;
190+
191+
/** Pointer to the Newton solver instance to compute the Newton step. */
192+
gsl::not_null<NewtonSolver*> solver_;
193+
129194
/** Newton step (size: nx_solver). */
130195
AmiVector delta_;
131196

132-
/** previous newton step (size: nx_solver). */
197+
/** Previous Newton step (size: nx_solver). */
133198
AmiVector delta_old_;
134199

135-
private:
136-
/** damping factor flag */
137-
NewtonDampingFactorMode damping_factor_mode_{NewtonDampingFactorMode::on};
200+
/** Newton step (size: nx_solver). */
201+
AmiVector x_old_;
138202

139-
/** damping factor lower bound */
140-
realtype damping_factor_lower_bound_{1e-8};
203+
/**
204+
* WRMS norm based on the current state and delta or xdot
205+
* (depending on `check_delta_`).
206+
*/
207+
realtype wrms_ = INFINITY;
208+
209+
/** The current number of Newton iterations. */
210+
int i_step = 0;
141211
};
142212

143213
/**
@@ -152,7 +222,7 @@ class SteadystateProblem {
152222
* @param solver Solver instance
153223
* @param model Model instance
154224
*/
155-
explicit SteadystateProblem(Solver const& solver, Model const& model);
225+
explicit SteadystateProblem(Solver const& solver, Model& model);
156226

157227
/**
158228
* @brief Compute the steady state in the forward case.
@@ -219,7 +289,7 @@ class SteadystateProblem {
219289
}
220290

221291
/**
222-
* @brief Get the CPU time taken to solvethe forward problem.
292+
* @brief Get the CPU time taken to solve the forward problem.
223293
* @return The CPU time in milliseconds.
224294
*/
225295
[[nodiscard]] double getCPUTime() const { return cpu_time_; }
@@ -248,7 +318,7 @@ class SteadystateProblem {
248318

249319
/**
250320
* @brief Get the weighted root mean square of the residuals.
251-
* @return The weighted root mean square of the residuals.
321+
* @return The weighted root-mean-square of the residuals.
252322
*/
253323
[[nodiscard]] realtype getResidualNorm() const { return wrms_; }
254324

@@ -305,7 +375,7 @@ class SteadystateProblem {
305375
/**
306376
* @brief Handle the computation of the steady state.
307377
*
308-
* Throws an AmiException, if no steady state was found.
378+
* Throws an AmiException if no steady state was found.
309379
*
310380
* @param solver Solver instance.
311381
* @param model Model instance.
@@ -398,15 +468,7 @@ class SteadystateProblem {
398468
realtype getWrmsFSA(Model& model);
399469

400470
/**
401-
* @brief Run the Newton solver iterations and checks for convergence
402-
* to steady state.
403-
* @param model Model instance.
404-
* @param newton_retry flag indicating if Newton solver is rerun
405-
*/
406-
void applyNewtonsMethod(Model& model, bool newton_retry);
407-
408-
/**
409-
* @brief Launch forward simulation if Newton solver or linear system solve
471+
* @brief Launch simulation if Newton solver or linear system solve
410472
* fail or are disabled.
411473
* @param solver Solver instance.
412474
* @param model Model instance.
@@ -428,7 +490,7 @@ class SteadystateProblem {
428490
* @param solver Solver instance
429491
* @param model Model instance.
430492
* @param forwardSensis flag switching on integration with FSA
431-
* @param backward flag switching on quadratures computation
493+
* @param backward flag switching on quadrature computation
432494
* @return A unique pointer to the created Solver instance.
433495
*/
434496
std::unique_ptr<Solver> createSteadystateSimSolver(
@@ -447,21 +509,13 @@ class SteadystateProblem {
447509
* @brief Initialize backward computation.
448510
* @param solver Solver instance
449511
* @param model Model instance.
450-
* @param bwd pointer to backward problem
512+
* @param bwd pointer to the backward problem
451513
* @return flag indicating whether backward computation to be carried out
452514
*/
453515
bool initializeBackwardProblem(
454516
Solver const& solver, Model& model, BackwardProblem const* bwd
455517
);
456518

457-
/**
458-
* @brief Ensure state positivity if requested, and repeat the convergence
459-
* check if necessary.
460-
* @param model Model instance.
461-
*/
462-
bool makePositiveAndCheckConvergence(Model& model);
463-
464-
465519
/**
466520
* @brief Update member variables to indicate that state_.x has been
467521
* updated and xdot_, delta_, etc. need to be recomputed.
@@ -482,13 +536,6 @@ class SteadystateProblem {
482536
*/
483537
void updateRightHandSide(Model& model);
484538

485-
/**
486-
* @brief Compute the Newton step for the current state_.x and set the
487-
* corresponding flag to indicate delta_ is up to date.
488-
* @param model Model instance
489-
*/
490-
void getNewtonStep(Model& model);
491-
492539
/** WRMS computer for x */
493540
WRMSComputer wrms_computer_x_;
494541
/** WRMS computer for xQB */
@@ -548,7 +595,10 @@ class SteadystateProblem {
548595
/** Newton's method for finding steady states */
549596
NewtonsMethod newtons_method_;
550597

551-
/** whether newton step should be used for convergence steps */
598+
/**
599+
* Whether the Newton step should be used instead of xdot for convergence
600+
* checks during simulation and Newton's method.
601+
*/
552602
bool newton_step_conv_{false};
553603
/**
554604
* whether sensitivities should be checked for convergence to steady state
@@ -557,10 +607,6 @@ class SteadystateProblem {
557607

558608
/** flag indicating whether xdot_ has been computed for the current state */
559609
bool xdot_updated_{false};
560-
/**
561-
* flag indicating whether delta_ has been computed for the current state
562-
*/
563-
bool delta_updated_{false};
564610
/**
565611
* flag indicating whether simulation sensitivities have been retrieved for
566612
* the current state

python/tests/test_preequilibration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def test_newton_steadystate_check(preeq_fixture):
508508
assert rdatas[newton_check]["status"] == amici.AMICI_SUCCESS
509509

510510
# assert correct results
511-
for variable in ["llh", "sllh", "sx0", "sx_ss", "x_ss"]:
511+
for variable in ["x_ss", "llh", "sx0", "sx_ss", "sllh"]:
512512
assert_allclose(
513513
rdatas[True][variable],
514514
rdatas[False][variable],

0 commit comments

Comments
 (0)