From b8231f6626f0e1c9c01dfce95e242ea71937bfdb Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Mon, 2 Dec 2019 10:23:35 +0100 Subject: nonlin_ode: solve for delta, not the new iterate This is simpler. --- nonlin_ode.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/nonlin_ode.py b/nonlin_ode.py index 4cb5c2e..7584882 100644 --- a/nonlin_ode.py +++ b/nonlin_ode.py @@ -3,7 +3,7 @@ import numpy as np import sys -import series_expansion +import series_expansion as se def _nonlin_solve_1d_iter(prev, grid, basis_vals, Fs, Fs_args): order = grid.shape[0] @@ -28,11 +28,9 @@ def _nonlin_solve_1d_iter(prev, grid, basis_vals, Fs, Fs_args): for diff_order in xrange(N - 1): mat[:, idx] -= basis_vals[diff_order][:, idx] * F_vals[diff_order + 1] - rhs = F_vals[0] - for diff_order in xrange(N - 1): - rhs -= F_vals[diff_order + 1] * prev_vals[diff_order] + rhs = F_vals[0] - prev.eval(grid, N - 1) - return series_expansion.SeriesExpansion(np.linalg.solve(mat, rhs), prev.basis) + return np.linalg.solve(mat, rhs) def nonlin_solve_1d(initial_guess, Fs, args = None, maxiter = 100, atol = 1e-14, grid = None, verbose = True): """ @@ -81,18 +79,19 @@ def nonlin_solve_1d(initial_guess, Fs, args = None, maxiter = 100, atol = 1e-14, basis_val[:, idx] = basis.eval(idx, grid, diff_order) basis_vals.append(basis_val) - solution_old = initial_guess + solution = initial_guess for i in xrange(maxiter): - solution_new = _nonlin_solve_1d_iter(solution_old, grid, basis_vals, Fs, args) + delta = _nonlin_solve_1d_iter(solution, grid, basis_vals, Fs, args) - delta = np.max(np.abs(solution_new.coeffs - solution_old.coeffs)) - if np.isnan(delta): + err = np.max(np.abs(delta)) + if np.isnan(err): raise RuntimeError('nan') + if err < atol: + return solution + solution = se.SeriesExpansion(solution.coeffs + delta, solution.basis) if verbose: - sys.stderr.write('delta: %g, coeffs[0]: %g\n' % (delta, solution_new.coeffs[0])) - if delta < atol: - return solution_new - solution_old = solution_new + sys.stderr.write('delta: %g, coeffs[0]: %g\n' % (err, solution.coeffs[0])) + raise RuntimeError('The horizon finder failed to converge') def nonlin_residual(solution, N, grid, F, args): -- cgit v1.2.3