From 1145d43ffef5568e04fb79a4f73ecd69f2f09ce2 Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Sun, 5 May 2019 15:45:10 +0200 Subject: egs: allow the same context to be used for both relaxation and exact solves --- mg2d.c | 132 ++++++++++++++++++++++++++++++++--------------------------------- 1 file changed, 64 insertions(+), 68 deletions(-) (limited to 'mg2d.c') diff --git a/mg2d.c b/mg2d.c index ed56eb2..5f2ce22 100644 --- a/mg2d.c +++ b/mg2d.c @@ -85,8 +85,8 @@ static void log_callback(MG2DLogger *log, int level, const char *fmt, va_list vl ctx->log_callback(ctx, level, fmt, vl); } -static void log_relax_step(MG2DContext *ctx, MG2DLevel *level, const char *step_desc, - double res_old, double res_new) +static void log_egs_step(MG2DContext *ctx, MG2DLevel *level, const char *step_desc, + double res_old, double res_new) { char prefix[32] = { 0 }; @@ -109,8 +109,8 @@ static int coarse_correct_task(void *arg, unsigned int job_idx, unsigned int thr return 0; } -static int mg_relax_step(MG2DContext *ctx, MG2DLevel *level, const char *step_desc, - int export_res) +static int mg_solve(MG2DContext *ctx, MG2DLevel *level, enum EGSType solve_type, + const char *step_desc, int export_res) { double res_old; int ret; @@ -118,21 +118,23 @@ static int mg_relax_step(MG2DContext *ctx, MG2DLevel *level, const char *step_de res_old = level->solver->residual_max; mg2di_timer_start(&level->timer_solve); - ret = mg2di_egs_solve(level->solver, export_res); + ret = mg2di_egs_solve(level->solver, solve_type, export_res); mg2di_timer_stop(&level->timer_solve); if (ret < 0) return ret; - log_relax_step(ctx, level, step_desc, res_old, level->solver->residual_max); + log_egs_step(ctx, level, step_desc, res_old, level->solver->residual_max); return 0; } -static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level) +static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level, int allow_exact) { + enum EGSType solve_type = allow_exact && level->solver->domain_size[0] <= ctx->max_exact_size ? + EGS_SOLVE_EXACT : EGS_SOLVE_RELAXATION; double res_old, res_new; - int ret; + int ret = 0; /* on the refined levels, use zero as the initial guess for the * solution (correction for the upper level) */ @@ -152,9 +154,9 @@ static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level) res_old = level->solver->residual_max; - /* handle coarsest grid */ - if (!level->child) { - ret = mg_relax_step(ctx, level, "coarse-step", 1); + /* handle exact solve */ + if (solve_type == EGS_SOLVE_EXACT) { + ret = mg_solve(ctx, level, solve_type, "coarse-step", 0); if (ret < 0) return ret; level->count_cycles++; @@ -167,52 +169,55 @@ static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level) /* pre-restrict relaxation */ for (int j = 0; j < ctx->nb_relax_pre; j++) { - ret = mg_relax_step(ctx, level, "pre-relax", j == ctx->nb_relax_pre - 1); + ret = mg_solve(ctx, level, solve_type, + "pre-relax", j == ctx->nb_relax_pre - 1 && level->child); if (ret < 0) return ret; } - /* restrict the residual as to the coarser-level rhs */ - mg2di_timer_start(&level->timer_restrict); - ret = mg2di_gt_transfer(level->transfer_restrict, level->child->solver->rhs, - level->solver->residual); - mg2di_timer_stop(&level->timer_restrict); - if (ret < 0) - return ret; + if (level->child) { + /* restrict the residual as to the coarser-level rhs */ + mg2di_timer_start(&level->timer_restrict); + ret = mg2di_gt_transfer(level->transfer_restrict, level->child->solver->rhs, + level->solver->residual); + mg2di_timer_stop(&level->timer_restrict); + if (ret < 0) + return ret; - /* solve on the coarser level */ - ret = mg_solve_subgrid(ctx, level->child); - if (ret < 0) - return ret; + /* solve on the coarser level */ + ret = mg_solve_subgrid(ctx, level->child, 1); + if (ret < 0) + return ret; - /* prolongate the coarser-level correction */ - mg2di_timer_start(&level->timer_prolong); - ret = mg2di_gt_transfer(level->transfer_prolong, level->prolong_tmp, - level->child->solver->u); - mg2di_timer_stop(&level->timer_prolong); - if (ret < 0) - return ret; + /* prolongate the coarser-level correction */ + mg2di_timer_start(&level->timer_prolong); + ret = mg2di_gt_transfer(level->transfer_prolong, level->prolong_tmp, + level->child->solver->u); + mg2di_timer_stop(&level->timer_prolong); + if (ret < 0) + return ret; - /* apply the correction */ - mg2di_timer_start(&level->timer_correct); - tp_execute(ctx->priv->tp, level->solver->domain_size[1], coarse_correct_task, level); - mg2di_timer_stop(&level->timer_correct); + /* apply the correction */ + mg2di_timer_start(&level->timer_correct); + tp_execute(ctx->priv->tp, level->solver->domain_size[1], coarse_correct_task, level); + mg2di_timer_stop(&level->timer_correct); - /* re-init the current-level solver (re-calc the residual) */ - res_prev = level->solver->residual_max; - mg2di_timer_start(&level->timer_reinit); - ret = mg2di_egs_init(level->solver, level->egs_init_flags); - mg2di_timer_stop(&level->timer_reinit); - if (ret < 0) - return ret; + /* re-init the current-level solver (re-calc the residual) */ + res_prev = level->solver->residual_max; + mg2di_timer_start(&level->timer_reinit); + ret = mg2di_egs_init(level->solver, level->egs_init_flags); + mg2di_timer_stop(&level->timer_reinit); + if (ret < 0) + return ret; - level->egs_init_flags |= EGS_INIT_FLAG_SAME_DIFF_COEFFS; + level->egs_init_flags |= EGS_INIT_FLAG_SAME_DIFF_COEFFS; - log_relax_step(ctx, level, "correct", res_prev, level->solver->residual_max); + log_egs_step(ctx, level, "correct", res_prev, level->solver->residual_max); + } /* post-correct relaxation */ for (int j = 0; j < ctx->nb_relax_post; j++) { - ret = mg_relax_step(ctx, level, "post-relax", 0); + ret = mg_solve(ctx, level, solve_type, "post-relax", 0); if (ret < 0) return ret; } @@ -375,13 +380,10 @@ static int mg_levels_init(MG2DContext *ctx) cur->solver->tp = priv->tp; cur->solver->fd_stencil = ctx->fd_stencil; - if (cur->solver->solver_type == EGS_SOLVER_RELAXATION) { - EGSRelaxContext *r = cur->solver->solver_data; - r->relax_factor = ctx->cfl_factor; - r->relax_multiplier = 1.0 / (diff2_max + cur->solver->step[0] * cur->solver->step[1] * - diff0_max / 8.0); - } + cur->solver->relax->relax_factor = ctx->cfl_factor; + cur->solver->relax->relax_multiplier = 1.0 / (diff2_max + cur->solver->step[0] * cur->solver->step[1] * + diff0_max / 8.0); prev = cur; cur = cur->child; @@ -521,7 +523,7 @@ int mg2d_solve(MG2DContext *ctx) res_prev = res_orig; for (int i = 0; i < ctx->maxiter; i++) { - ret = mg_solve_subgrid(ctx, root); + ret = mg_solve_subgrid(ctx, root, 1); if (ret < 0) goto fail; @@ -575,7 +577,7 @@ static void mg_level_free(MG2DLevel **plevel) *plevel = NULL; } -static MG2DLevel *mg_level_alloc(enum EGSType type, const size_t domain_size) +static MG2DLevel *mg_level_alloc(const size_t domain_size) { MG2DLevel *level; int ret; @@ -592,8 +594,7 @@ static MG2DLevel *mg_level_alloc(enum EGSType type, const size_t domain_size) if (ret < 0) goto fail; - level->solver = mg2di_egs_alloc(type, - (size_t [2]){domain_size, domain_size}); + level->solver = mg2di_egs_alloc((size_t [2]){domain_size, domain_size}); if (!level->solver) goto fail; @@ -630,7 +631,6 @@ static int mg_levels_alloc(MG2DContext *ctx, size_t domain_size) next_size = domain_size; for (int depth = 0; depth < ctx->max_levels; depth++) { - enum EGSType cur_type; size_t cur_size = next_size; // choose the domain size for the next child @@ -643,15 +643,11 @@ static int mg_levels_alloc(MG2DContext *ctx, size_t domain_size) } else next_size = (cur_size >> 1) + 1; - cur_type = (cur_size <= ctx->max_exact_size) ? EGS_SOLVER_EXACT : EGS_SOLVER_RELAXATION; - - *dst = mg_level_alloc(cur_type, cur_size); + *dst = mg_level_alloc(cur_size); if (!*dst) return -ENOMEM; (*dst)->depth = depth; - if (cur_type == EGS_SOLVER_EXACT) - break; if (next_size <= 4) break; @@ -794,16 +790,16 @@ void mg2d_print_stats(MG2DContext *ctx, const char *prefix) char buf[1024], *p; int ret; - EGSRelaxContext *r = NULL; - EGSExactContext *e = NULL; + EGSRelaxContext *r = level->solver->relax; + EGSExactContext *e = level->solver->exact; int64_t level_total = level->timer_solve.time_nsec + level->timer_prolong.time_nsec + level->timer_restrict.time_nsec + level->timer_correct.time_nsec + level->timer_reinit.time_nsec; - if (level->solver->solver_type == EGS_SOLVER_RELAXATION) - r = level->solver->solver_data; - else if (level->solver->solver_type == EGS_SOLVER_EXACT) - e = level->solver->solver_data; + if (!level->count_cycles) { + level = level->child; + continue; + } levels_total += level_total; @@ -845,13 +841,13 @@ void mg2d_print_stats(MG2DContext *ctx, const char *prefix) if (ret > 0) p += ret; - if (r) { + if (r->timer_correct.nb_runs) { ret = snprintf(p, sizeof(buf) - (p - buf), " %2.2f%% correct", r->timer_correct.time_nsec * 100.0 / level->solver->timer_solve.time_nsec); if (ret > 0) p += ret; - } else if (e) { + } else if (e->timer_mat_construct.nb_runs) { ret = snprintf(p, sizeof(buf) - (p - buf), " %2.2f%% const %2.2f%% bicgstab (%ld; %g it/slv) %2.2f%% lu (%ld) %2.2f%% export", e->timer_mat_construct.time_nsec * 100.0 / level->solver->timer_solve.time_nsec, -- cgit v1.2.3