From 208a3e985b9c82d23fa7c50059fd545b6a1a67f5 Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Mon, 18 Oct 2021 12:22:47 +0200 Subject: nonlin_ode: raise more specific exceptions --- nonlin_ode.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/nonlin_ode.py b/nonlin_ode.py index d96c158..ab977a2 100644 --- a/nonlin_ode.py +++ b/nonlin_ode.py @@ -5,6 +5,14 @@ import sys from . import series_expansion as se +class ConvergenceError(Exception): + pass + +class MaxIterReached(ConvergenceError): + pass +class DivergenceException(ConvergenceError): + pass + def _nonlin_solve_1d_iter(prev, grid, basis_vals, Fs, Fs_args): order = grid.shape[0] N = len(Fs) @@ -61,7 +69,7 @@ def nonlin_solve_1d(initial_guess, Fs, args = None, maxiter = 100, atol = 1e-14, Returns: The solution to the equation (a SeriesExpansion object). Raises: - RuntimeError: If the iteration fails to converge. + MaxIterReached: If the iteration fails to converge. """ N = len(Fs) @@ -84,15 +92,15 @@ def nonlin_solve_1d(initial_guess, Fs, args = None, maxiter = 100, atol = 1e-14, delta = _nonlin_solve_1d_iter(solution, grid, basis_vals, Fs, args) err = np.max(np.abs(delta)) - if np.isnan(err): - raise RuntimeError('nan') + if not np.isfinite(err): + raise DivergenceException('nan') if err < atol: return solution solution = se.SeriesExpansion(solution.coeffs + delta, solution.basis) if verbose: sys.stderr.write('delta: %g, coeffs[0]: %g\n' % (err, solution.coeffs[0])) - raise RuntimeError('The horizon finder failed to converge') + raise MaxIterReached('The horizon finder failed to converge') def nonlin_residual(solution, N, grid, F, args): """ -- cgit v1.2.3