@@ -126,9 +126,6 @@ def __call__(self, t):
126
126
return ys , yps
127
127
128
128
129
- # TODO: Compare this with
130
- # - ddassl.f by Petzold
131
- # - epsode.f by Bryne and Hindmarsh
132
129
def select_initial_step (t0 , y0 , yp0 , t_bound , rtol , atol , max_step ):
133
130
"""Empirically select a good initial step.
134
131
@@ -160,18 +157,28 @@ def select_initial_step(t0, y0, yp0, t_bound, rtol, atol, max_step):
160
157
161
158
References
162
159
----------
163
- .. [1] TODO: Find a reference .
160
+ .. [1] L. F. Shampine, "Starting an ODE solver", November 1977 .
164
161
"""
165
- min_step = 0.0
162
+ safety = 0.8
163
+ min_step = 16 * EPS * abs (t0 )
166
164
threshold = atol / rtol
167
165
hspan = abs (t_bound - t0 )
168
166
169
- # compute an initial step size h using yp = y'(t0)
167
+ # compute scaling
170
168
wt = np .maximum (np .abs (y0 ), threshold )
171
- rh = 1.25 * np .linalg .norm (yp0 / wt , np .inf ) / np .sqrt (rtol )
169
+
170
+ # error
171
+ e = np .linalg .norm (yp0 / wt , np .inf )
172
+
173
+ # reciprocal step size
174
+ rh = e / np .sqrt (rtol ) / safety
175
+
176
+ # compute an initial step size
172
177
h_abs = min (max_step , hspan )
173
178
if h_abs * rh > 1 :
174
179
h_abs = 1 / rh
180
+
181
+ # ensure h_abs >= min_step
175
182
h_abs = max (h_abs , min_step )
176
183
return h_abs
177
184
@@ -184,7 +191,7 @@ def consistent_initial_conditions(fun, t0, y0, yp0, jac=None, fixed_y0=None,
184
191
185
192
References
186
193
----------
187
- .. [1] L. F. Shampine, "Solving 0 = F(t, y(t), y′ (t)) in Matlab", Journal
194
+ .. [1] L. F. Shampine, "Solving 0 = F(t, y(t), y' (t)) in Matlab", Journal
188
195
of Numerical Mathematics, vol. 10, no. 4, 2002, pp. 291-310.
189
196
"""
190
197
n = len (y0 )
@@ -220,8 +227,8 @@ def fun_composite(t, z):
220
227
if not (isinstance (rtol , float ) and rtol > 0 ):
221
228
raise ValueError ("Relative tolerance must be a positive scalar." )
222
229
223
- if rtol < 100 * np . finfo ( float ). eps :
224
- rtol = 100 * np . finfo ( float ). eps
230
+ if rtol < 100 * EPS :
231
+ rtol = 100 * EPS
225
232
print (f"Relative tolerance increased to { rtol } " )
226
233
227
234
if np .any (np .array (atol ) <= 0 ):
@@ -235,29 +242,13 @@ def fun_composite(t, z):
235
242
Jy , Jyp = jac (t0 , y0 , yp0 )
236
243
237
244
scale_f = atol + np .abs (f ) * rtol
238
- # z0 = np.concatenate([y0, yp0])
239
- # scale_z = atol + np.abs(z0) * rtol
240
- # dz_norm_old = None
241
- # rate_z = None
242
- # tol = max(10 * EPS / rtol, min(0.03, rtol ** 0.5))
243
245
244
246
for _ in range (newton_maxiter ):
245
247
for _ in range (chord_iter ):
246
248
dy , dyp = solve_underdetermined_system (f , Jy , Jyp , free_y , free_yp )
247
249
y0 += dy
248
250
yp0 += dyp
249
251
250
- # dz = np.concatenate([dy, dyp])
251
- # with np.errstate(divide='ignore'):
252
- # dz_norm = norm(dz / scale_z)
253
- # if dz_norm_old is not None:
254
- # rate_z = dz_norm / dz_norm_old
255
-
256
- # if (dz_norm == 0 or (rate_z is not None and rate_z / (1 - rate_z) * dz_norm < safety * tol)):
257
- # return y0, yp0, f
258
-
259
- # dz_norm_old = dz_norm
260
-
261
252
f = fun (t0 , y0 , yp0 , * args )
262
253
error = norm (f / scale_f )
263
254
if error < safety :
@@ -271,7 +262,7 @@ def fun_composite(t, z):
271
262
def qrank (A ):
272
263
"""Compute QR-decomposition with column pivoting of A and estimate the rank."""
273
264
Q , R , p = qr (A , pivoting = True )
274
- tol = max (A .shape ) * np . finfo ( float ). eps * abs (R [0 , 0 ])
265
+ tol = max (A .shape ) * EPS * abs (R [0 , 0 ])
275
266
rank = np .sum (abs (np .diag (R )) > tol )
276
267
return rank , Q , R , p
277
268
@@ -353,10 +344,17 @@ def solve_underdetermined_system(f, Jy, Jyp, free_y, free_yp):
353
344
# [S21, S22] [w1] = d2
354
345
# [w2]
355
346
# using column pivoting QR-decomposition
356
- w_ = np .zeros (RS .shape [1 ])
357
- w_ [:rankS ] = solve_triangular (RS [:rankS , :rankS ], (QS .T @ d2 [:rankS ]))
358
- w = np .zeros_like (w_ )
359
- w [pS ] = w_
347
+ # [RS11, RS12] [v1] = [c1]
348
+ # [ 0, 0] [v2] [c2]
349
+ # with v2 = 0 this gives
350
+ # RS11 @ v1 = c1
351
+ c = QS .T @ d2
352
+ v = np .zeros (RS .shape [1 ])
353
+ v [:rankS ] = solve_triangular (RS [:rankS , :rankS ], c [:rankS ])
354
+
355
+ # apply permutation
356
+ w = np .zeros_like (v )
357
+ w [pS ] = v
360
358
361
359
# set w2' = 0 and solve the remaining system
362
360
# [R11] w1' = d1 - [S11, S12] [w1]
0 commit comments