From b7de09482589a321d67a854f3a933e91c9306e4a Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Wed, 2 May 2018 09:55:06 +0200 Subject: td solve: free the solver immediately after solving There is no need to keep it around for evaluating the data. --- init.c | 147 ++++++++++++++++++++++++++++++++++++++--------------------------- 1 file changed, 85 insertions(+), 62 deletions(-) diff --git a/init.c b/init.c index c512cca..33e5587 100644 --- a/init.c +++ b/init.c @@ -46,8 +46,6 @@ typedef struct TDPriv { unsigned int basis_order[NB_EQUATIONS][2]; BasisSetContext *basis[NB_EQUATIONS][2]; - NLSolveContext *solver; - ThreadPoolContext *tp; TDLogger logger; @@ -220,37 +218,12 @@ static int teukolsky_init_check_options(TDContext *td) return ret; } - - ret = tdi_nlsolve_context_alloc(&s->solver, ARRAY_ELEMS(basis_sets)); - if (ret < 0) { - tdi_log(&s->logger, 0, "Error allocating the non-linear solver\n"); - return ret; - } - - s->solver->logger = s->logger; - s->solver->tp = s->tp; - s->solver->maxiter = td->max_iter; - s->solver->atol = td->atol; - - memcpy(s->solver->basis, s->basis, sizeof(s->basis)); - memcpy(s->solver->solve_order, s->basis_order, sizeof(s->basis_order)); - -#if HAVE_OPENCL - s->solver->ocl_ctx = s->ocl_ctx; - s->solver->ocl_queue = s->ocl_queue; -#endif - - ret = tdi_nlsolve_context_init(s->solver); - if (ret < 0) { - tdi_log(&s->logger, 0, "Error initializing the non-linear solver\n"); - return ret; - } - return 0; } -static int constraint_eval_alloc(const TDContext *td, double amplitude, - TDConstraintEvalContext **pce) +static int constraint_eval_alloc(const TDContext *td, const unsigned int *nb_coords, + const double * const *coords, + double amplitude, TDConstraintEvalContext **pce) { TDPriv *priv = td->priv; TDConstraintEvalContext *ce; @@ -264,10 +237,10 @@ static int constraint_eval_alloc(const TDContext *td, double amplitude, ce->logger = priv->logger; ce->amplitude = amplitude; - ce->nb_coords[0] = td->nb_coeffs[0]; - ce->nb_coords[1] = td->nb_coeffs[1]; - ce->coords[0] = priv->solver->colloc_grid[0][0]; - ce->coords[1] = priv->solver->colloc_grid[0][1]; + ce->nb_coords[0] = nb_coords[0]; + ce->nb_coords[1] = nb_coords[1]; + ce->coords[0] = coords[0]; + ce->coords[1] = coords[1]; ret = tdi_constraint_eval_init(ce); if (ret < 0) { @@ -280,6 +253,44 @@ static int constraint_eval_alloc(const TDContext *td, double amplitude, return 0; } +static int nlsolve_alloc(const TDContext *td, NLSolveContext **pnl) +{ + TDPriv *s = td->priv; + NLSolveContext *nl; + int ret; + + ret = tdi_nlsolve_context_alloc(&nl, ARRAY_ELEMS(basis_sets)); + if (ret < 0) { + tdi_log(&s->logger, 0, "Error allocating the non-linear solver\n"); + return ret; + } + + nl->logger = s->logger; + nl->tp = s->tp; + nl->maxiter = td->max_iter; + nl->atol = td->atol; + + memcpy(nl->basis, s->basis, sizeof(s->basis)); + memcpy(nl->solve_order, s->basis_order, sizeof(s->basis_order)); + +#if HAVE_OPENCL + nl->ocl_ctx = s->ocl_ctx; + nl->ocl_queue = s->ocl_queue; +#endif + + ret = tdi_nlsolve_context_init(nl); + if (ret < 0) { + tdi_log(&s->logger, 0, "Error initializing the non-linear solver\n"); + goto fail; + } + + *pnl = nl; + return 0; +fail: + tdi_nlsolve_context_free(&nl); + return ret; +} + static int teukolsky_constraint_eval(void *opaque, unsigned int eq_idx, const unsigned int *colloc_grid_order, const double * const *colloc_grid, @@ -295,6 +306,7 @@ static int teukolsky_constraint_eval(void *opaque, unsigned int eq_idx, int td_solve(TDContext *td, double *coeffs_init[3]) { TDPriv *s = td->priv; + NLSolveContext *nl; TDConstraintEvalContext *ce; double a0; int ret; @@ -303,9 +315,13 @@ int td_solve(TDContext *td, double *coeffs_init[3]) if (ret < 0) return ret; - ret = constraint_eval_alloc(td, 0.0, &ce); + ret = nlsolve_alloc(td, &nl); if (ret < 0) - return ret; + goto fail; + + ret = constraint_eval_alloc(td, td->nb_coeffs, nl->colloc_grid[0], 0.0, &ce); + if (ret < 0) + goto fail; if (fabs(td->amplitude) >= ce->a_diverge) { tdi_log(&s->logger, 0, "Amplitude A=%16.16g is above the point A_{max}=%g, no solutions " @@ -325,16 +341,16 @@ int td_solve(TDContext *td, double *coeffs_init[3]) if (td->solution_branch == 0 || coeffs_init) { // direct solve with default (flat space) or user-provided initial guess - ret = constraint_eval_alloc(td, td->amplitude, &ce); + ret = constraint_eval_alloc(td, td->nb_coeffs, nl->colloc_grid[0], td->amplitude, &ce); if (ret < 0) - return ret; + goto fail; - ret = tdi_nlsolve_solve(s->solver, teukolsky_constraint_eval, NULL, + ret = tdi_nlsolve_solve(nl, teukolsky_constraint_eval, NULL, ce, s->coeffs, 0); tdi_constraint_eval_free(&ce); if (ret < 0) { tdi_log(&s->logger, 0, "tdi_nlsolve_solve() failed: %d", ret); - return ret; + goto fail; } } else { // second branch requested and no user-provided initial guess @@ -345,39 +361,39 @@ int td_solve(TDContext *td, double *coeffs_init[3]) double cur_amplitude, new_amplitude, step; - ret = constraint_eval_alloc(td, a0, &ce); + ret = constraint_eval_alloc(td, td->nb_coeffs, nl->colloc_grid[0], a0, &ce); if (ret < 0) - return ret; + goto fail; - ret = tdi_nlsolve_solve(s->solver, teukolsky_constraint_eval, NULL, + ret = tdi_nlsolve_solve(nl, teukolsky_constraint_eval, NULL, ce, s->coeffs, 0); tdi_constraint_eval_free(&ce); if (ret < 0) - return ret; + goto fail; - ret = constraint_eval_alloc(td, a1, &ce); + ret = constraint_eval_alloc(td, td->nb_coeffs, nl->colloc_grid[0], a1, &ce); if (ret < 0) - return ret; - ret = tdi_nlsolve_solve(s->solver, teukolsky_constraint_eval, NULL, + goto fail; + ret = tdi_nlsolve_solve(nl, teukolsky_constraint_eval, NULL, ce, s->coeffs_tmp, 0); tdi_constraint_eval_free(&ce); if (ret < 0) - return ret; + goto fail; cblas_daxpy(N, -1.0, s->coeffs, 1, s->coeffs_tmp, 1); cblas_daxpy(N, -1.0, s->coeffs_tmp, 1, s->coeffs, 1); // obtain solution for a1 in the upper branch - ret = constraint_eval_alloc(td, a1, &ce); + ret = constraint_eval_alloc(td, td->nb_coeffs, nl->colloc_grid[0], a1, &ce); if (ret < 0) - return ret; + goto fail; - ret = tdi_nlsolve_solve(s->solver, teukolsky_constraint_eval, NULL, + ret = tdi_nlsolve_solve(nl, teukolsky_constraint_eval, NULL, ce, s->coeffs, 0); tdi_constraint_eval_free(&ce); if (ret < 0) { tdi_log(&s->logger, 0, "Failed to get into the upper branch\n"); - return ret; + goto fail; } cur_amplitude = a1; @@ -389,23 +405,23 @@ int td_solve(TDContext *td, double *coeffs_init[3]) new_amplitude = td->amplitude; tdi_log(&s->logger, 2, "Trying amplitude %g\n", new_amplitude); - ret = constraint_eval_alloc(td, new_amplitude, &ce); + ret = constraint_eval_alloc(td, td->nb_coeffs, nl->colloc_grid[0], new_amplitude, &ce); if (ret < 0) - return ret; + goto fail; memcpy(s->coeffs_tmp, s->coeffs, sizeof(*s->coeffs) * N); - ret = tdi_nlsolve_solve(s->solver, teukolsky_constraint_eval, NULL, + ret = tdi_nlsolve_solve(nl, teukolsky_constraint_eval, NULL, ce, s->coeffs_tmp, 1); tdi_constraint_eval_free(&ce); if (ret == -EDOM) { step *= 0.5; if (fabs(step) < 1e-5) - return ret; + goto fail; continue; } else if (ret < 0) - return ret; + goto fail; - if (ret <= s->solver->maxiter / 2) + if (ret <= nl->maxiter / 2) step *= 1.75; cur_amplitude = new_amplitude; @@ -413,9 +429,12 @@ int td_solve(TDContext *td, double *coeffs_init[3]) } } finish: - tdi_nlsolve_print_stats(s->solver); + ret = 0; + tdi_nlsolve_print_stats(nl); +fail: + tdi_nlsolve_context_free(&nl); - return 0; + return ret; } static void log_default_callback(const TDContext *td, int level, const char *fmt, va_list vl) @@ -466,7 +485,6 @@ void td_context_free(TDContext **ptd) s = td->priv; - tdi_nlsolve_context_free(&s->solver); tdi_threadpool_free(&s->tp); #if HAVE_OPENCL @@ -577,13 +595,18 @@ int td_eval_krt(const TDContext *td, const unsigned int diff_order[2], double *out) { + static const double dummy_coord = 0.0; + static const double *dummy_coords[2] = { &dummy_coord, &dummy_coord }; + static const unsigned int nb_dummy_coords[2] = { 1, 1 }; + TDConstraintEvalContext *ce; int ret; if (diff_order[0] || diff_order[1]) return -ENOSYS; - ret = constraint_eval_alloc(td, td->amplitude, &ce); + ret = constraint_eval_alloc(td, nb_dummy_coords, dummy_coords, + td->amplitude, &ce); if (ret < 0) return ret; -- cgit v1.2.3