summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mg2d.c53
1 files changed, 41 insertions, 12 deletions
diff --git a/mg2d.c b/mg2d.c
index 8d86e2d..8c9a15b 100644
--- a/mg2d.c
+++ b/mg2d.c
@@ -372,6 +372,14 @@ finish:
return 0;
}
+static void array_copy(double *dst, ptrdiff_t dst_stride,
+ const double *src, ptrdiff_t src_stride,
+ size_t linesize, size_t nb_lines)
+{
+ for (size_t i = 0; i < nb_lines; i++)
+ memcpy(dst + i * dst_stride, src + i * src_stride, linesize * sizeof(*dst));
+}
+
static void bnd_zero(MG2DBoundary *bdst, size_t nb_rows, size_t domain_size)
{
for (size_t i = 0; i < nb_rows; i++) {
@@ -435,6 +443,15 @@ static int mg_levels_init(MG2DContext *ctx)
cur = priv->root;
prev = NULL;
+
+ array_copy(cur->solver->u, cur->solver->u_stride, ctx->u, ctx->u_stride, ctx->domain_size, ctx->domain_size);
+ array_copy(cur->solver->rhs, cur->solver->rhs_stride, ctx->rhs, ctx->rhs_stride, ctx->domain_size, ctx->domain_size);
+ for (int i = 0; i < ARRAY_ELEMS(ctx->diff_coeffs); i++) {
+ array_copy(cur->solver->diff_coeffs[i], cur->solver->diff_coeffs_stride,
+ ctx->diff_coeffs[i], ctx->diff_coeffs_stride,
+ ctx->domain_size, ctx->domain_size);
+ }
+
while (cur) {
if (!prev) {
cur->solver->step[0] = ctx->step[0];
@@ -545,6 +562,8 @@ int mg2d_solve(MG2DContext *ctx)
mg2di_log(&priv->logger, MG2D_LOG_INFO, "converged on iteration %d, residual %g\n",
i, res_cur);
+ array_copy(ctx->u, ctx->u_stride, s_root->u, s_root->u_stride, ctx->domain_size, ctx->domain_size);
+
priv->time_solve += gettime() - time_start;
priv->count_solve++;
@@ -699,6 +718,23 @@ MG2DContext *mg2d_solver_alloc(size_t domain_size)
}
}
+ ctx->u = calloc(SQR(domain_size), sizeof(*ctx->u));
+ if (!ctx->u)
+ goto fail;
+ ctx->u_stride = domain_size;
+
+ ctx->rhs = calloc(SQR(domain_size), sizeof(*ctx->rhs));
+ if (!ctx->rhs)
+ goto fail;
+ ctx->rhs_stride = domain_size;
+
+ for (int i = 0; i < ARRAY_ELEMS(ctx->diff_coeffs); i++) {
+ ctx->diff_coeffs[i] = calloc(SQR(domain_size), sizeof(*ctx->diff_coeffs[i]));
+ if (!ctx->diff_coeffs[i])
+ goto fail;
+ }
+ ctx->diff_coeffs_stride = domain_size;
+
ret = mg_levels_alloc(ctx, domain_size);
if (ret < 0)
goto fail;
@@ -714,18 +750,6 @@ MG2DContext *mg2d_solver_alloc(size_t domain_size)
ctx->log_level = MG2D_LOG_INFO;
ctx->nb_threads = 1;
- ctx->u = priv->root->solver->u;
- ctx->u_stride = priv->root->solver->u_stride;
- /* initialize the initial guess to zero */
- memset(ctx->u, 0, sizeof(*ctx->u) * ctx->u_stride * ctx->domain_size);
-
- ctx->rhs = priv->root->solver->rhs;
- ctx->rhs_stride = priv->root->solver->rhs_stride;
-
- for (int i = 0; i < ARRAY_ELEMS(ctx->diff_coeffs); i++)
- ctx->diff_coeffs[i] = priv->root->solver->diff_coeffs[i];
- ctx->diff_coeffs_stride = priv->root->solver->diff_coeffs_stride;
-
return ctx;
fail:
mg2d_solver_free(&ctx);
@@ -751,6 +775,11 @@ void mg2d_solver_free(MG2DContext **pctx)
tp_free(&ctx->priv->tp);
free(ctx->priv);
+ free(ctx->u);
+ free(ctx->rhs);
+ for (int i = 0; i < ARRAY_ELEMS(ctx->diff_coeffs); i++)
+ free(ctx->diff_coeffs[i]);
+
free(ctx);
*pctx = NULL;
}