From 1145d43ffef5568e04fb79a4f73ecd69f2f09ce2 Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Sun, 5 May 2019 15:45:10 +0200 Subject: egs: allow the same context to be used for both relaxation and exact solves --- ell_grid_solve.c | 267 ++++++++++++++++++++++++++++--------------------------- 1 file changed, 136 insertions(+), 131 deletions(-) (limited to 'ell_grid_solve.c') diff --git a/ell_grid_solve.c b/ell_grid_solve.c index cefb231..ac5aa28 100644 --- a/ell_grid_solve.c +++ b/ell_grid_solve.c @@ -55,6 +55,7 @@ typedef struct EGSExactInternal { size_t N; size_t N_ghosts; + int arrays_alloced; double *mat; double *mat_f; double *rhs; @@ -93,10 +94,8 @@ struct EGSInternal { ResidualCalcContext *rescalc; - union { - EGSRelaxInternal r; - EGSExactInternal e; - }; + EGSRelaxInternal r; + EGSExactInternal e; TPContext *tp_internal; }; @@ -327,7 +326,7 @@ static int residual_add_task(void *arg, unsigned int job_idx, unsigned int threa static int solve_relax_step(EGSContext *ctx, int export_res) { EGSInternal *priv = ctx->priv; - EGSRelaxContext *r = ctx->solver_data; + EGSRelaxContext *r = ctx->relax; int u_next_valid = priv->u_next_valid; @@ -361,6 +360,69 @@ static int solve_relax_step(EGSContext *ctx, int export_res) return 0; } +static void exact_arrays_free(EGSContext *ctx) +{ + EGSInternal *priv = ctx->priv; + EGSExactInternal *e = &priv->e; + + free(ctx->priv->e.scratch_lines); + free(ctx->priv->e.mat); + free(ctx->priv->e.mat_f); + free(ctx->priv->e.rhs); + free(ctx->priv->e.rhs_mat); + free(ctx->priv->e.rhs_factor); + free(ctx->priv->e.x); + free(ctx->priv->e.ipiv); + + mg2di_bicgstab_context_free(&e->bicgstab); + + e->arrays_alloced = 0; +} + +static int exact_arrays_alloc(EGSContext *ctx) +{ + EGSInternal *priv = ctx->priv; + EGSExactInternal *e = &priv->e; + int ret; + + if (e->arrays_alloced) + return 0; + + e->N = ctx->domain_size[0] * ctx->domain_size[1]; + e->N_ghosts = (ctx->domain_size[0] + 2 * FD_STENCIL_MAX) * + (ctx->domain_size[1] + 2 * FD_STENCIL_MAX); + + e->mat = calloc(SQR(e->N), sizeof(*e->mat)); + e->mat_f = calloc(SQR(e->N), sizeof(*e->mat_f)); + e->rhs = calloc(e->N, sizeof(*e->rhs)); + e->rhs_mat = calloc(e->N, sizeof(*e->rhs_mat)); + e->rhs_factor = calloc(e->N, sizeof(*e->rhs_factor)); + e->x = calloc(e->N, sizeof(*e->x)); + e->ipiv = calloc(e->N, sizeof(*e->ipiv)); + if (!e->mat || !e->mat_f || !e->rhs || !e->rhs_mat || !e->rhs_factor || !e->x || !e->ipiv) { + ret = -ENOMEM; + goto fail; + } + + ret = mg2di_bicgstab_context_alloc(&e->bicgstab, e->N, 64); + if (ret < 0) + goto fail; + + for (int i = 0; i < e->N; i++) + e->mat[i * e->N + i] = 1.0; + + ret = mg2di_bicgstab_init(e->bicgstab, e->mat, e->x); + if (ret < 0) + goto fail; + + e->arrays_alloced = 1; + + return 0; +fail: + exact_arrays_free(ctx); + return ret; +} + static void fill_mat_s1(double *mat_row, ptrdiff_t mat_stride, ptrdiff_t row_stride, NDArray **diff_coeffs, ptrdiff_t idx_src) { @@ -610,10 +672,29 @@ static int mat_fill_row_task(void *arg, unsigned int job_idx, unsigned int threa static int solve_exact(EGSContext *ctx) { EGSInternal *priv = ctx->priv; - EGSExactContext *ec = ctx->solver_data; + EGSExactContext *ec = ctx->exact; EGSExactInternal *e = &priv->e; + unsigned int nb_threads = tp_get_nb_threads(ctx->tp); int ret; + if (!e->arrays_alloced) { + ret = exact_arrays_alloc(ctx); + if (ret < 0) + return ret; + } + + if (e->scratch_lines_allocated < nb_threads) { + free(e->scratch_lines); + e->scratch_lines = NULL; + e->scratch_lines_allocated = 0; + + e->scratch_lines = calloc(e->N_ghosts * nb_threads, + sizeof(*e->scratch_lines)); + if (!e->scratch_lines) + return -ENOMEM; + e->scratch_lines_allocated = nb_threads; + } + mg2di_timer_start(&ec->timer_mat_construct); if (!e->mat_filled) { @@ -673,22 +754,23 @@ static int solve_exact(EGSContext *ctx) mg2di_timer_stop(&ec->timer_export); + memset(priv->reflect_disable, 0, sizeof(priv->reflect_disable)); boundaries_apply(ctx, ctx->u, 0); residual_calc(ctx, 1); return 0; } -int mg2di_egs_solve(EGSContext *ctx, int export_res) +int mg2di_egs_solve(EGSContext *ctx, enum EGSType solve_type, int export_res) { int ret; mg2di_timer_start(&ctx->timer_solve); - switch (ctx->solver_type) { - case EGS_SOLVER_RELAXATION: ret = solve_relax_step(ctx, export_res); break; - case EGS_SOLVER_EXACT: ret = solve_exact(ctx); break; - default: ret = -EINVAL; + switch (solve_type) { + case EGS_SOLVE_RELAXATION: ret = solve_relax_step(ctx, export_res); break; + case EGS_SOLVE_EXACT: ret = solve_exact(ctx); break; + default: ret = -EINVAL; } mg2di_timer_stop(&ctx->timer_solve); @@ -714,6 +796,7 @@ static int init_diff_coeffs_task(void *arg, unsigned int job_idx, unsigned int t int mg2di_egs_init(EGSContext *ctx, int flags) { EGSInternal *priv = ctx->priv; + EGSRelaxContext *r = ctx->relax; int ret = 0; mg2di_timer_start(&ctx->timer_solve); @@ -725,23 +808,20 @@ int mg2di_egs_init(EGSContext *ctx, int flags) goto finish; } - if (ctx->solver_type == EGS_SOLVER_RELAXATION) { - EGSRelaxContext *r = ctx->solver_data; - if (r->relax_factor == 0.0) - priv->r.relax_factor = relax_factors[ctx->fd_stencil - 1]; - else - priv->r.relax_factor = r->relax_factor; - priv->r.relax_factor *= ctx->step[0] * ctx->step[0]; + if (r->relax_factor == 0.0) + priv->r.relax_factor = relax_factors[ctx->fd_stencil - 1]; + else + priv->r.relax_factor = r->relax_factor; + priv->r.relax_factor *= ctx->step[0] * ctx->step[0]; - if (r->relax_multiplier > 0.0) - priv->r.relax_factor *= r->relax_multiplier; + if (r->relax_multiplier > 0.0) + priv->r.relax_factor *= r->relax_multiplier; - priv->r.line_add = line_madd_c; + priv->r.line_add = line_madd_c; #if HAVE_EXTERNAL_ASM - if (ctx->cpuflags & MG2DI_CPU_FLAG_FMA3) - priv->r.line_add = mg2di_line_madd_fma3; + if (ctx->cpuflags & MG2DI_CPU_FLAG_FMA3) + priv->r.line_add = mg2di_line_madd_fma3; #endif - } if (!ctx->tp) { ret = tp_init(&priv->tp_internal, 1); @@ -773,36 +853,17 @@ int mg2di_egs_init(EGSContext *ctx, int flags) tp_execute(ctx->tp, ctx->domain_size[0], init_diff_coeffs_task, &arg); } - if (ctx->solver_type == EGS_SOLVER_EXACT) { - EGSExactInternal *e = &priv->e; - unsigned int nb_threads = tp_get_nb_threads(ctx->tp); - - if (!(flags & EGS_INIT_FLAG_SAME_DIFF_COEFFS)) - e->mat_filled = 0; - - switch (ctx->fd_stencil) { - case 1: e->fill_mat = fill_mat_s1; break; - case 2: e->fill_mat = fill_mat_s2; break; - default: - mg2di_log(&ctx->logger, 0, "Invalid finite difference stencil: %zd\n", - ctx->fd_stencil); - ret = -EINVAL; - goto finish; - } - - if (e->scratch_lines_allocated < nb_threads) { - free(e->scratch_lines); - e->scratch_lines = NULL; - e->scratch_lines_allocated = 0; + if (!(flags & EGS_INIT_FLAG_SAME_DIFF_COEFFS)) + priv->e.mat_filled = 0; - e->scratch_lines = calloc(e->N_ghosts * nb_threads, - sizeof(*e->scratch_lines)); - if (!e->scratch_lines) { - ret = -ENOMEM; - goto finish; - } - e->scratch_lines_allocated = nb_threads; - } + switch (ctx->fd_stencil) { + case 1: priv->e.fill_mat = fill_mat_s1; break; + case 2: priv->e.fill_mat = fill_mat_s2; break; + default: + mg2di_log(&ctx->logger, 0, "Invalid finite difference stencil: %zd\n", + ctx->fd_stencil); + ret = -EINVAL; + goto finish; } priv->residual_calc_size[0] = ctx->domain_size[0]; @@ -844,15 +905,12 @@ finish: if (ret >= 0) { memset(priv->reflect_disable, 0, sizeof(priv->reflect_disable)); - boundaries_apply(ctx, ctx->u, 1); - if (ctx->solver_type == EGS_SOLVER_RELAXATION) { - boundaries_apply(ctx, priv->u_next, 1); + boundaries_apply(ctx, ctx->u, 1); + boundaries_apply(ctx, priv->u_next, 1); - residual_calc(ctx, 0); - boundaries_apply(ctx, priv->u_next, 0); - priv->u_next_valid = 1; - } else - residual_calc(ctx, 1); + residual_calc(ctx, 0); + boundaries_apply(ctx, priv->u_next, 0); + priv->u_next_valid = 1; } mg2di_timer_stop(&ctx->timer_solve); @@ -947,38 +1005,10 @@ static int arrays_alloc(EGSContext *ctx, const size_t domain_size[2]) return -ENOMEM; } - if (ctx->solver_type == EGS_SOLVER_EXACT) { - EGSExactInternal *e = &priv->e; - - e->N = domain_size[0] * domain_size[1]; - e->N_ghosts = (domain_size[0] + 2 * FD_STENCIL_MAX) * (domain_size[1] + 2 * FD_STENCIL_MAX); - - e->mat = calloc(SQR(e->N), sizeof(*e->mat)); - e->mat_f = calloc(SQR(e->N), sizeof(*e->mat_f)); - e->rhs = calloc(e->N, sizeof(*e->rhs)); - e->rhs_mat = calloc(e->N, sizeof(*e->rhs_mat)); - e->rhs_factor = calloc(e->N, sizeof(*e->rhs_factor)); - e->x = calloc(e->N, sizeof(*e->x)); - e->ipiv = calloc(e->N, sizeof(*e->ipiv)); - if (!e->mat || !e->mat_f || !e->rhs || !e->rhs_mat || !e->rhs_factor || !e->x || !e->ipiv) - return -ENOMEM; - - ret = mg2di_bicgstab_context_alloc(&e->bicgstab, e->N, 64); - if (ret < 0) - return ret; - - for (int i = 0; i < e->N; i++) - e->mat[i * e->N + i] = 1.0; - - ret = mg2di_bicgstab_init(e->bicgstab, e->mat, e->x); - if (ret < 0) - return ret; - } - return 0; } -EGSContext *mg2di_egs_alloc(enum EGSType type, size_t domain_size[2]) +EGSContext *mg2di_egs_alloc(size_t domain_size[2]) { EGSContext *ctx; int ret; @@ -991,32 +1021,19 @@ EGSContext *mg2di_egs_alloc(enum EGSType type, size_t domain_size[2]) if (!ctx->priv) goto fail; - switch (type) { - case EGS_SOLVER_RELAXATION: { - EGSRelaxContext *r; - r = calloc(1, sizeof(EGSRelaxContext)); - if (!r) - goto fail; - ctx->solver_data = r; - mg2di_timer_init(&r->timer_correct); - break; - } - case EGS_SOLVER_EXACT: { - EGSExactContext *e; - e = calloc(1, sizeof(EGSExactContext)); - if (!e) - goto fail; - ctx->solver_data = e; - - mg2di_timer_init(&e->timer_mat_construct); - mg2di_timer_init(&e->timer_bicgstab); - mg2di_timer_init(&e->timer_lu_solve); - mg2di_timer_init(&e->timer_export); - break; - } - default: goto fail; - } - ctx->solver_type = type; + ctx->relax = calloc(1, sizeof(*ctx->relax)); + if (!ctx->relax) + goto fail; + mg2di_timer_init(&ctx->relax->timer_correct); + + ctx->exact = calloc(1, sizeof(*ctx->exact)); + if (!ctx->exact) + goto fail; + + mg2di_timer_init(&ctx->exact->timer_mat_construct); + mg2di_timer_init(&ctx->exact->timer_bicgstab); + mg2di_timer_init(&ctx->exact->timer_lu_solve); + mg2di_timer_init(&ctx->exact->timer_export); if (!domain_size[0] || !domain_size[1] || domain_size[0] > SIZE_MAX / domain_size[1]) @@ -1057,22 +1074,10 @@ void mg2di_egs_free(EGSContext **pctx) mg2di_residual_calc_free(&ctx->priv->rescalc); - free(ctx->solver_data); - - if (ctx->solver_type == EGS_SOLVER_EXACT) { - EGSExactInternal *e = &ctx->priv->e; + free(ctx->relax); + free(ctx->exact); - free(e->scratch_lines); - free(e->mat); - free(e->mat_f); - free(e->rhs); - free(e->rhs_mat); - free(e->rhs_factor); - free(e->x); - free(e->ipiv); - - mg2di_bicgstab_context_free(&e->bicgstab); - } + exact_arrays_free(ctx); mg2di_ndarray_free(&ctx->u); mg2di_ndarray_free(&ctx->priv->u_next); -- cgit v1.2.3