summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2019-12-02 10:23:35 +0100
committerAnton Khirnov <anton@khirnov.net>2019-12-02 10:23:35 +0100
commitb8231f6626f0e1c9c01dfce95e242ea71937bfdb (patch)
treeeb827205ca8e479317af6474abf0416059a4d063
parent52c429db82eb32cea648295a741e49644f9b1b00 (diff)
nonlin_ode: solve for delta, not the new iterate
This is simpler.
-rw-r--r--nonlin_ode.py25
1 files changed, 12 insertions, 13 deletions
diff --git a/nonlin_ode.py b/nonlin_ode.py
index 4cb5c2e..7584882 100644
--- a/nonlin_ode.py
+++ b/nonlin_ode.py
@@ -3,7 +3,7 @@
import numpy as np
import sys
-import series_expansion
+import series_expansion as se
def _nonlin_solve_1d_iter(prev, grid, basis_vals, Fs, Fs_args):
order = grid.shape[0]
@@ -28,11 +28,9 @@ def _nonlin_solve_1d_iter(prev, grid, basis_vals, Fs, Fs_args):
for diff_order in xrange(N - 1):
mat[:, idx] -= basis_vals[diff_order][:, idx] * F_vals[diff_order + 1]
- rhs = F_vals[0]
- for diff_order in xrange(N - 1):
- rhs -= F_vals[diff_order + 1] * prev_vals[diff_order]
+ rhs = F_vals[0] - prev.eval(grid, N - 1)
- return series_expansion.SeriesExpansion(np.linalg.solve(mat, rhs), prev.basis)
+ return np.linalg.solve(mat, rhs)
def nonlin_solve_1d(initial_guess, Fs, args = None, maxiter = 100, atol = 1e-14, grid = None, verbose = True):
"""
@@ -81,18 +79,19 @@ def nonlin_solve_1d(initial_guess, Fs, args = None, maxiter = 100, atol = 1e-14,
basis_val[:, idx] = basis.eval(idx, grid, diff_order)
basis_vals.append(basis_val)
- solution_old = initial_guess
+ solution = initial_guess
for i in xrange(maxiter):
- solution_new = _nonlin_solve_1d_iter(solution_old, grid, basis_vals, Fs, args)
+ delta = _nonlin_solve_1d_iter(solution, grid, basis_vals, Fs, args)
- delta = np.max(np.abs(solution_new.coeffs - solution_old.coeffs))
- if np.isnan(delta):
+ err = np.max(np.abs(delta))
+ if np.isnan(err):
raise RuntimeError('nan')
+ if err < atol:
+ return solution
+ solution = se.SeriesExpansion(solution.coeffs + delta, solution.basis)
if verbose:
- sys.stderr.write('delta: %g, coeffs[0]: %g\n' % (delta, solution_new.coeffs[0]))
- if delta < atol:
- return solution_new
- solution_old = solution_new
+ sys.stderr.write('delta: %g, coeffs[0]: %g\n' % (err, solution.coeffs[0]))
+
raise RuntimeError('The horizon finder failed to converge')
def nonlin_residual(solution, N, grid, F, args):