diff options
Diffstat (limited to 'mg2d.c')
-rw-r--r-- | mg2d.c | 53 |
1 files changed, 41 insertions, 12 deletions
@@ -372,6 +372,14 @@ finish: return 0; } +static void array_copy(double *dst, ptrdiff_t dst_stride, + const double *src, ptrdiff_t src_stride, + size_t linesize, size_t nb_lines) +{ + for (size_t i = 0; i < nb_lines; i++) + memcpy(dst + i * dst_stride, src + i * src_stride, linesize * sizeof(*dst)); +} + static void bnd_zero(MG2DBoundary *bdst, size_t nb_rows, size_t domain_size) { for (size_t i = 0; i < nb_rows; i++) { @@ -435,6 +443,15 @@ static int mg_levels_init(MG2DContext *ctx) cur = priv->root; prev = NULL; + + array_copy(cur->solver->u, cur->solver->u_stride, ctx->u, ctx->u_stride, ctx->domain_size, ctx->domain_size); + array_copy(cur->solver->rhs, cur->solver->rhs_stride, ctx->rhs, ctx->rhs_stride, ctx->domain_size, ctx->domain_size); + for (int i = 0; i < ARRAY_ELEMS(ctx->diff_coeffs); i++) { + array_copy(cur->solver->diff_coeffs[i], cur->solver->diff_coeffs_stride, + ctx->diff_coeffs[i], ctx->diff_coeffs_stride, + ctx->domain_size, ctx->domain_size); + } + while (cur) { if (!prev) { cur->solver->step[0] = ctx->step[0]; @@ -545,6 +562,8 @@ int mg2d_solve(MG2DContext *ctx) mg2di_log(&priv->logger, MG2D_LOG_INFO, "converged on iteration %d, residual %g\n", i, res_cur); + array_copy(ctx->u, ctx->u_stride, s_root->u, s_root->u_stride, ctx->domain_size, ctx->domain_size); + priv->time_solve += gettime() - time_start; priv->count_solve++; @@ -699,6 +718,23 @@ MG2DContext *mg2d_solver_alloc(size_t domain_size) } } + ctx->u = calloc(SQR(domain_size), sizeof(*ctx->u)); + if (!ctx->u) + goto fail; + ctx->u_stride = domain_size; + + ctx->rhs = calloc(SQR(domain_size), sizeof(*ctx->rhs)); + if (!ctx->rhs) + goto fail; + ctx->rhs_stride = domain_size; + + for (int i = 0; i < ARRAY_ELEMS(ctx->diff_coeffs); i++) { + ctx->diff_coeffs[i] = calloc(SQR(domain_size), sizeof(*ctx->diff_coeffs[i])); + if (!ctx->diff_coeffs[i]) + goto fail; + } + ctx->diff_coeffs_stride = domain_size; + ret = mg_levels_alloc(ctx, domain_size); if (ret < 0) goto fail; @@ -714,18 +750,6 @@ MG2DContext *mg2d_solver_alloc(size_t domain_size) ctx->log_level = MG2D_LOG_INFO; ctx->nb_threads = 1; - ctx->u = priv->root->solver->u; - ctx->u_stride = priv->root->solver->u_stride; - /* initialize the initial guess to zero */ - memset(ctx->u, 0, sizeof(*ctx->u) * ctx->u_stride * ctx->domain_size); - - ctx->rhs = priv->root->solver->rhs; - ctx->rhs_stride = priv->root->solver->rhs_stride; - - for (int i = 0; i < ARRAY_ELEMS(ctx->diff_coeffs); i++) - ctx->diff_coeffs[i] = priv->root->solver->diff_coeffs[i]; - ctx->diff_coeffs_stride = priv->root->solver->diff_coeffs_stride; - return ctx; fail: mg2d_solver_free(&ctx); @@ -751,6 +775,11 @@ void mg2d_solver_free(MG2DContext **pctx) tp_free(&ctx->priv->tp); free(ctx->priv); + free(ctx->u); + free(ctx->rhs); + for (int i = 0; i < ARRAY_ELEMS(ctx->diff_coeffs); i++) + free(ctx->diff_coeffs[i]); + free(ctx); *pctx = NULL; } |