From 2e03749bb901494f857547cc518ad58eea76a098 Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Tue, 2 Apr 2019 11:05:36 +0200 Subject: mg2d: add API for interpolating an initial guess from a provided grid --- mg2d.c | 93 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 82 insertions(+), 11 deletions(-) (limited to 'mg2d.c') diff --git a/mg2d.c b/mg2d.c index 03a1759..c9070fb 100644 --- a/mg2d.c +++ b/mg2d.c @@ -70,6 +70,8 @@ struct MG2DInternal { NDArray *rhs; NDArray *diff_coeffs[MG2D_DIFF_COEFF_NB]; + GridTransferContext *transfer_init; + int cpuflags; int64_t time_solve; @@ -319,10 +321,25 @@ static int mg_levels_init(MG2DContext *ctx) cur = priv->root; prev = NULL; - mg2di_ndarray_copy(cur->solver->u, priv->u); - mg2di_ndarray_copy(cur->solver->rhs, priv->rhs); - for (int i = 0; i < ARRAY_ELEMS(ctx->diff_coeffs); i++) - mg2di_ndarray_copy(cur->solver->diff_coeffs[i], priv->diff_coeffs[i]); + if (priv->u) { + mg2di_ndarray_copy(cur->solver->u, priv->u); + mg2di_ndarray_free(&priv->u); + ctx->u = cur->solver->u->data; + ctx->u_stride = cur->solver->u->stride[0]; + + mg2di_ndarray_copy(cur->solver->rhs, priv->rhs); + mg2di_ndarray_free(&priv->rhs); + ctx->rhs = cur->solver->rhs->data; + ctx->rhs_stride = cur->solver->rhs->stride[0]; + + for (int i = 0; i < ARRAY_ELEMS(ctx->diff_coeffs); i++) { + mg2di_ndarray_copy(cur->solver->diff_coeffs[i], priv->diff_coeffs[i]); + mg2di_ndarray_free(&priv->diff_coeffs[i]); + ctx->diff_coeffs[i] = cur->solver->diff_coeffs[i]->data; + } + ctx->diff_coeffs_stride = cur->solver->diff_coeffs[0]->stride[0]; + + } while (cur) { if (!prev) { @@ -356,6 +373,7 @@ static int mg_levels_init(MG2DContext *ctx) cur->solver->logger = priv->logger; cur->solver->cpuflags = priv->cpuflags; + cur->solver->tp = priv->tp; cur->solver->fd_stencil = ctx->fd_stencil; if (cur->solver->solver_type == EGS_SOLVER_RELAXATION) { @@ -449,18 +467,12 @@ static int mg_levels_init(MG2DContext *ctx) static int threadpool_init(MG2DContext *ctx) { MG2DInternal *priv = ctx->priv; - MG2DLevel *level = priv->root; int ret; ret = tp_init(&priv->tp, ctx->nb_threads); if (ret < 0) return ret; - while (level) { - level->solver->tp = priv->tp; - level = level->child; - } - return 0; } @@ -519,7 +531,7 @@ int mg2d_solve(MG2DContext *ctx) mg2di_log(&priv->logger, MG2D_LOG_INFO, "converged on iteration %d, residual %g\n", i, res_cur); - mg2di_ndarray_copy(priv->u, s_root->u); + //mg2di_ndarray_copy(priv->u, s_root->u); priv->time_solve += gettime() - time_start; priv->count_solve++; @@ -741,6 +753,8 @@ void mg2d_solver_free(MG2DContext **pctx) tp_free(&ctx->priv->tp); + mg2di_gt_free(&ctx->priv->transfer_init); + mg2di_ndarray_free(&ctx->priv->u); mg2di_ndarray_free(&ctx->priv->rhs); for (int i = 0; i < ARRAY_ELEMS(ctx->priv->diff_coeffs); i++) @@ -856,3 +870,60 @@ unsigned int mg2d_max_fd_stencil(void) { return FD_STENCIL_MAX; } + +int mg2d_init_guess(MG2DContext *ctx, const double *src, + ptrdiff_t src_stride, + const size_t src_size[2], + const double src_step[2]) +{ + MG2DInternal *priv = ctx->priv; + NDArray *a_src; + int ret; + + if (!priv->tp) { + ret = threadpool_init(ctx); + if (ret < 0) + return ret; + } + + if (priv->transfer_init && + (priv->transfer_init->src.size[0] != src_size[0] || + priv->transfer_init->src.size[1] != src_size[1] || + fabs(priv->transfer_init->src.step[0] - src_step[0]) > 1e-15 || + fabs(priv->transfer_init->src.step[1] - src_step[1]) > 1e-15)) { + mg2di_gt_free(&priv->transfer_init); + } + + if (!priv->transfer_init) { + priv->transfer_init = mg2di_gt_alloc(GRID_TRANSFER_LAGRANGE_3); + if (!priv->transfer_init) + return -ENOMEM; + + priv->transfer_init->tp = priv->tp; + priv->transfer_init->cpuflags = priv->cpuflags; + + priv->transfer_init->src.size[0] = src_size[0]; + priv->transfer_init->src.size[1] = src_size[1]; + priv->transfer_init->src.step[0] = src_step[0]; + priv->transfer_init->src.step[1] = src_step[1]; + + priv->transfer_init->dst.size[0] = ctx->domain_size; + priv->transfer_init->dst.size[1] = ctx->domain_size; + priv->transfer_init->dst.step[0] = ctx->step[0]; + priv->transfer_init->dst.step[1] = ctx->step[1]; + + ret = mg2di_gt_init(priv->transfer_init); + if (ret < 0) + return ret; + } + + ret = mg2di_ndarray_wrap(&a_src, 2, src_size, src, + (ptrdiff_t [2]){ src_stride, 1 }); + if (ret < 0) + return ret; + + ret = mg2di_gt_transfer(priv->transfer_init, priv->u ? priv->u : priv->root->solver->u, a_src); + + mg2di_ndarray_free(&a_src); + return ret; +} -- cgit v1.2.3