Skip to content

Commit fa9cf30

Browse files
Improve consistent initial conditions (#36)
1 parent 5d9149a commit fa9cf30

File tree

1 file changed

+29
-31
lines changed

1 file changed

+29
-31
lines changed

scipy_dae/integrate/_dae/common.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,6 @@ def __call__(self, t):
126126
return ys, yps
127127

128128

129-
# TODO: Compare this with
130-
# - ddassl.f by Petzold
131-
# - epsode.f by Bryne and Hindmarsh
132129
def select_initial_step(t0, y0, yp0, t_bound, rtol, atol, max_step):
133130
"""Empirically select a good initial step.
134131
@@ -160,18 +157,28 @@ def select_initial_step(t0, y0, yp0, t_bound, rtol, atol, max_step):
160157
161158
References
162159
----------
163-
.. [1] TODO: Find a reference.
160+
.. [1] L. F. Shampine, "Starting an ODE solver", November 1977.
164161
"""
165-
min_step = 0.0
162+
safety = 0.8
163+
min_step = 16 * EPS * abs(t0)
166164
threshold = atol / rtol
167165
hspan = abs(t_bound - t0)
168166

169-
# compute an initial step size h using yp = y'(t0)
167+
# compute scaling
170168
wt = np.maximum(np.abs(y0), threshold)
171-
rh = 1.25 * np.linalg.norm(yp0 / wt, np.inf) / np.sqrt(rtol)
169+
170+
# error
171+
e = np.linalg.norm(yp0 / wt, np.inf)
172+
173+
# reciprocal step size
174+
rh = e / np.sqrt(rtol) / safety
175+
176+
# compute an initial step size
172177
h_abs = min(max_step, hspan)
173178
if h_abs * rh > 1:
174179
h_abs = 1 / rh
180+
181+
# ensure h_abs >= min_step
175182
h_abs = max(h_abs, min_step)
176183
return h_abs
177184

@@ -184,7 +191,7 @@ def consistent_initial_conditions(fun, t0, y0, yp0, jac=None, fixed_y0=None,
184191
185192
References
186193
----------
187-
.. [1] L. F. Shampine, "Solving 0 = F(t, y(t), y(t)) in Matlab", Journal
194+
.. [1] L. F. Shampine, "Solving 0 = F(t, y(t), y'(t)) in Matlab", Journal
188195
of Numerical Mathematics, vol. 10, no. 4, 2002, pp. 291-310.
189196
"""
190197
n = len(y0)
@@ -220,8 +227,8 @@ def fun_composite(t, z):
220227
if not (isinstance(rtol, float) and rtol > 0):
221228
raise ValueError("Relative tolerance must be a positive scalar.")
222229

223-
if rtol < 100 * np.finfo(float).eps:
224-
rtol = 100 * np.finfo(float).eps
230+
if rtol < 100 * EPS:
231+
rtol = 100 * EPS
225232
print(f"Relative tolerance increased to {rtol}")
226233

227234
if np.any(np.array(atol) <= 0):
@@ -235,29 +242,13 @@ def fun_composite(t, z):
235242
Jy, Jyp = jac(t0, y0, yp0)
236243

237244
scale_f = atol + np.abs(f) * rtol
238-
# z0 = np.concatenate([y0, yp0])
239-
# scale_z = atol + np.abs(z0) * rtol
240-
# dz_norm_old = None
241-
# rate_z = None
242-
# tol = max(10 * EPS / rtol, min(0.03, rtol ** 0.5))
243245

244246
for _ in range(newton_maxiter):
245247
for _ in range(chord_iter):
246248
dy, dyp = solve_underdetermined_system(f, Jy, Jyp, free_y, free_yp)
247249
y0 += dy
248250
yp0 += dyp
249251

250-
# dz = np.concatenate([dy, dyp])
251-
# with np.errstate(divide='ignore'):
252-
# dz_norm = norm(dz / scale_z)
253-
# if dz_norm_old is not None:
254-
# rate_z = dz_norm / dz_norm_old
255-
256-
# if (dz_norm == 0 or (rate_z is not None and rate_z / (1 - rate_z) * dz_norm < safety * tol)):
257-
# return y0, yp0, f
258-
259-
# dz_norm_old = dz_norm
260-
261252
f = fun(t0, y0, yp0, *args)
262253
error = norm(f / scale_f)
263254
if error < safety:
@@ -271,7 +262,7 @@ def fun_composite(t, z):
271262
def qrank(A):
272263
"""Compute QR-decomposition with column pivoting of A and estimate the rank."""
273264
Q, R, p = qr(A, pivoting=True)
274-
tol = max(A.shape) * np.finfo(float).eps * abs(R[0, 0])
265+
tol = max(A.shape) * EPS * abs(R[0, 0])
275266
rank = np.sum(abs(np.diag(R)) > tol)
276267
return rank, Q, R, p
277268

@@ -353,10 +344,17 @@ def solve_underdetermined_system(f, Jy, Jyp, free_y, free_yp):
353344
# [S21, S22] [w1] = d2
354345
# [w2]
355346
# using column pivoting QR-decomposition
356-
w_ = np.zeros(RS.shape[1])
357-
w_[:rankS] = solve_triangular(RS[:rankS, :rankS], (QS.T @ d2[:rankS]))
358-
w = np.zeros_like(w_)
359-
w[pS] = w_
347+
# [RS11, RS12] [v1] = [c1]
348+
# [ 0, 0] [v2] [c2]
349+
# with v2 = 0 this gives
350+
# RS11 @ v1 = c1
351+
c = QS.T @ d2
352+
v = np.zeros(RS.shape[1])
353+
v[:rankS] = solve_triangular(RS[:rankS, :rankS], c[:rankS])
354+
355+
# apply permutation
356+
w = np.zeros_like(v)
357+
w[pS] = v
360358

361359
# set w2' = 0 and solve the remaining system
362360
# [R11] w1' = d1 - [S11, S12] [w1]

0 commit comments

Comments
 (0)