From 975abbee31abe7f8299c3cc4583f76214f699671 Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Mon, 9 Apr 2018 10:13:44 +0200 Subject: nlsolve: faster abort on divergence --- init.c | 10 +++++----- nlsolve.c | 6 +++++- nlsolve.h | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/init.c b/init.c index b1495a6..e90bbb8 100644 --- a/init.c +++ b/init.c @@ -284,7 +284,7 @@ int td_solve(TDContext *td, double *coeffs_init[3]) if (td->solution_branch == 0 || coeffs_init) { // direct solve with default (flat space) or user-provided initial guess ret = tdi_nlsolve_solve(s->solver, teukolsky_constraint_eval, NULL, - &td->amplitude, s->coeffs); + &td->amplitude, s->coeffs, 0); if (ret < 0) { tdi_log(&s->logger, 0, "tdi_nlsolve_solve() failed: %d", ret); return ret; @@ -301,11 +301,11 @@ int td_solve(TDContext *td, double *coeffs_init[3]) int dir; ret = tdi_nlsolve_solve(s->solver, teukolsky_constraint_eval, NULL, - &a0, s->coeffs); + &a0, s->coeffs, 0); if (ret < 0) return ret; ret = tdi_nlsolve_solve(s->solver, teukolsky_constraint_eval, NULL, - &a1, s->coeffs_tmp); + &a1, s->coeffs_tmp, 0); if (ret < 0) return ret; @@ -314,7 +314,7 @@ int td_solve(TDContext *td, double *coeffs_init[3]) // obtain solution for a1 in the upper branch ret = tdi_nlsolve_solve(s->solver, teukolsky_constraint_eval, NULL, - &a1, s->coeffs); + &a1, s->coeffs, 0); if (ret < 0) { tdi_log(&s->logger, 0, "Failed to get into the upper branch\n"); return ret; @@ -335,7 +335,7 @@ int td_solve(TDContext *td, double *coeffs_init[3]) tdi_log(&s->logger, 2, "Trying amplitude %g\n", new_amplitude); memcpy(s->coeffs_tmp, s->coeffs, sizeof(*s->coeffs) * N); ret = tdi_nlsolve_solve(s->solver, teukolsky_constraint_eval, NULL, - &new_amplitude, s->coeffs_tmp); + &new_amplitude, s->coeffs_tmp, 1); if (ret == -EDOM) { inverse_step = 0.5 * inverse_step; if (fabs(inverse_step) < 1e-2) diff --git a/nlsolve.c b/nlsolve.c index 960a8e9..708cd08 100644 --- a/nlsolve.c +++ b/nlsolve.c @@ -100,7 +100,7 @@ struct NLSolvePriv { }; int tdi_nlsolve_solve(NLSolveContext *ctx, NLEqCallback eq_eval, - NLEqJacobianCallback eq_jac_eval, void *opaque, double *coeffs) + NLEqJacobianCallback eq_jac_eval, void *opaque, double *coeffs, int fast_abort) { NLSolvePriv *s = ctx->priv; int64_t start, totaltime_start; @@ -209,6 +209,10 @@ int tdi_nlsolve_solve(NLSolveContext *ctx, NLEqCallback eq_eval, it, s->delta[max_idx], ctx->atol); ret = 0; goto finish; + } else if ((fast_abort && fabs(s->delta[max_idx]) > 1e6) || + s->delta[max_idx] > 1e18) { + tdi_log(&ctx->logger, 2, "max(delta) %g, aborting\n", s->delta[max_idx]); + return -EDOM; } cblas_daxpy(s->solve_order, 1.0, s->delta, 1, coeffs, 1); diff --git a/nlsolve.h b/nlsolve.h index 1d3fae5..f9b69fe 100644 --- a/nlsolve.h +++ b/nlsolve.h @@ -101,7 +101,7 @@ void tdi_nlsolve_context_free(NLSolveContext **ctx); int tdi_nlsolve_solve(NLSolveContext *ctx, NLEqCallback eq_eval, NLEqJacobianCallback eq_jac_eval, - void *opaque, double *coeffs); + void *opaque, double *coeffs, int fast_abort); void tdi_nlsolve_print_stats(NLSolveContext *ctx); -- cgit v1.2.3