aboutsummaryrefslogtreecommitdiff
path: root/mg2d.c
diff options
context:
space:
mode:
Diffstat (limited to 'mg2d.c')
-rw-r--r--mg2d.c132
1 files changed, 64 insertions, 68 deletions
diff --git a/mg2d.c b/mg2d.c
index ed56eb2..5f2ce22 100644
--- a/mg2d.c
+++ b/mg2d.c
@@ -85,8 +85,8 @@ static void log_callback(MG2DLogger *log, int level, const char *fmt, va_list vl
ctx->log_callback(ctx, level, fmt, vl);
}
-static void log_relax_step(MG2DContext *ctx, MG2DLevel *level, const char *step_desc,
- double res_old, double res_new)
+static void log_egs_step(MG2DContext *ctx, MG2DLevel *level, const char *step_desc,
+ double res_old, double res_new)
{
char prefix[32] = { 0 };
@@ -109,8 +109,8 @@ static int coarse_correct_task(void *arg, unsigned int job_idx, unsigned int thr
return 0;
}
-static int mg_relax_step(MG2DContext *ctx, MG2DLevel *level, const char *step_desc,
- int export_res)
+static int mg_solve(MG2DContext *ctx, MG2DLevel *level, enum EGSType solve_type,
+ const char *step_desc, int export_res)
{
double res_old;
int ret;
@@ -118,21 +118,23 @@ static int mg_relax_step(MG2DContext *ctx, MG2DLevel *level, const char *step_de
res_old = level->solver->residual_max;
mg2di_timer_start(&level->timer_solve);
- ret = mg2di_egs_solve(level->solver, export_res);
+ ret = mg2di_egs_solve(level->solver, solve_type, export_res);
mg2di_timer_stop(&level->timer_solve);
if (ret < 0)
return ret;
- log_relax_step(ctx, level, step_desc, res_old, level->solver->residual_max);
+ log_egs_step(ctx, level, step_desc, res_old, level->solver->residual_max);
return 0;
}
-static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level)
+static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level, int allow_exact)
{
+ enum EGSType solve_type = allow_exact && level->solver->domain_size[0] <= ctx->max_exact_size ?
+ EGS_SOLVE_EXACT : EGS_SOLVE_RELAXATION;
double res_old, res_new;
- int ret;
+ int ret = 0;
/* on the refined levels, use zero as the initial guess for the
* solution (correction for the upper level) */
@@ -152,9 +154,9 @@ static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level)
res_old = level->solver->residual_max;
- /* handle coarsest grid */
- if (!level->child) {
- ret = mg_relax_step(ctx, level, "coarse-step", 1);
+ /* handle exact solve */
+ if (solve_type == EGS_SOLVE_EXACT) {
+ ret = mg_solve(ctx, level, solve_type, "coarse-step", 0);
if (ret < 0)
return ret;
level->count_cycles++;
@@ -167,52 +169,55 @@ static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level)
/* pre-restrict relaxation */
for (int j = 0; j < ctx->nb_relax_pre; j++) {
- ret = mg_relax_step(ctx, level, "pre-relax", j == ctx->nb_relax_pre - 1);
+ ret = mg_solve(ctx, level, solve_type,
+ "pre-relax", j == ctx->nb_relax_pre - 1 && level->child);
if (ret < 0)
return ret;
}
- /* restrict the residual as to the coarser-level rhs */
- mg2di_timer_start(&level->timer_restrict);
- ret = mg2di_gt_transfer(level->transfer_restrict, level->child->solver->rhs,
- level->solver->residual);
- mg2di_timer_stop(&level->timer_restrict);
- if (ret < 0)
- return ret;
+ if (level->child) {
+ /* restrict the residual as to the coarser-level rhs */
+ mg2di_timer_start(&level->timer_restrict);
+ ret = mg2di_gt_transfer(level->transfer_restrict, level->child->solver->rhs,
+ level->solver->residual);
+ mg2di_timer_stop(&level->timer_restrict);
+ if (ret < 0)
+ return ret;
- /* solve on the coarser level */
- ret = mg_solve_subgrid(ctx, level->child);
- if (ret < 0)
- return ret;
+ /* solve on the coarser level */
+ ret = mg_solve_subgrid(ctx, level->child, 1);
+ if (ret < 0)
+ return ret;
- /* prolongate the coarser-level correction */
- mg2di_timer_start(&level->timer_prolong);
- ret = mg2di_gt_transfer(level->transfer_prolong, level->prolong_tmp,
- level->child->solver->u);
- mg2di_timer_stop(&level->timer_prolong);
- if (ret < 0)
- return ret;
+ /* prolongate the coarser-level correction */
+ mg2di_timer_start(&level->timer_prolong);
+ ret = mg2di_gt_transfer(level->transfer_prolong, level->prolong_tmp,
+ level->child->solver->u);
+ mg2di_timer_stop(&level->timer_prolong);
+ if (ret < 0)
+ return ret;
- /* apply the correction */
- mg2di_timer_start(&level->timer_correct);
- tp_execute(ctx->priv->tp, level->solver->domain_size[1], coarse_correct_task, level);
- mg2di_timer_stop(&level->timer_correct);
+ /* apply the correction */
+ mg2di_timer_start(&level->timer_correct);
+ tp_execute(ctx->priv->tp, level->solver->domain_size[1], coarse_correct_task, level);
+ mg2di_timer_stop(&level->timer_correct);
- /* re-init the current-level solver (re-calc the residual) */
- res_prev = level->solver->residual_max;
- mg2di_timer_start(&level->timer_reinit);
- ret = mg2di_egs_init(level->solver, level->egs_init_flags);
- mg2di_timer_stop(&level->timer_reinit);
- if (ret < 0)
- return ret;
+ /* re-init the current-level solver (re-calc the residual) */
+ res_prev = level->solver->residual_max;
+ mg2di_timer_start(&level->timer_reinit);
+ ret = mg2di_egs_init(level->solver, level->egs_init_flags);
+ mg2di_timer_stop(&level->timer_reinit);
+ if (ret < 0)
+ return ret;
- level->egs_init_flags |= EGS_INIT_FLAG_SAME_DIFF_COEFFS;
+ level->egs_init_flags |= EGS_INIT_FLAG_SAME_DIFF_COEFFS;
- log_relax_step(ctx, level, "correct", res_prev, level->solver->residual_max);
+ log_egs_step(ctx, level, "correct", res_prev, level->solver->residual_max);
+ }
/* post-correct relaxation */
for (int j = 0; j < ctx->nb_relax_post; j++) {
- ret = mg_relax_step(ctx, level, "post-relax", 0);
+ ret = mg_solve(ctx, level, solve_type, "post-relax", 0);
if (ret < 0)
return ret;
}
@@ -375,13 +380,10 @@ static int mg_levels_init(MG2DContext *ctx)
cur->solver->tp = priv->tp;
cur->solver->fd_stencil = ctx->fd_stencil;
- if (cur->solver->solver_type == EGS_SOLVER_RELAXATION) {
- EGSRelaxContext *r = cur->solver->solver_data;
- r->relax_factor = ctx->cfl_factor;
- r->relax_multiplier = 1.0 / (diff2_max + cur->solver->step[0] * cur->solver->step[1] *
- diff0_max / 8.0);
- }
+ cur->solver->relax->relax_factor = ctx->cfl_factor;
+ cur->solver->relax->relax_multiplier = 1.0 / (diff2_max + cur->solver->step[0] * cur->solver->step[1] *
+ diff0_max / 8.0);
prev = cur;
cur = cur->child;
@@ -521,7 +523,7 @@ int mg2d_solve(MG2DContext *ctx)
res_prev = res_orig;
for (int i = 0; i < ctx->maxiter; i++) {
- ret = mg_solve_subgrid(ctx, root);
+ ret = mg_solve_subgrid(ctx, root, 1);
if (ret < 0)
goto fail;
@@ -575,7 +577,7 @@ static void mg_level_free(MG2DLevel **plevel)
*plevel = NULL;
}
-static MG2DLevel *mg_level_alloc(enum EGSType type, const size_t domain_size)
+static MG2DLevel *mg_level_alloc(const size_t domain_size)
{
MG2DLevel *level;
int ret;
@@ -592,8 +594,7 @@ static MG2DLevel *mg_level_alloc(enum EGSType type, const size_t domain_size)
if (ret < 0)
goto fail;
- level->solver = mg2di_egs_alloc(type,
- (size_t [2]){domain_size, domain_size});
+ level->solver = mg2di_egs_alloc((size_t [2]){domain_size, domain_size});
if (!level->solver)
goto fail;
@@ -630,7 +631,6 @@ static int mg_levels_alloc(MG2DContext *ctx, size_t domain_size)
next_size = domain_size;
for (int depth = 0; depth < ctx->max_levels; depth++) {
- enum EGSType cur_type;
size_t cur_size = next_size;
// choose the domain size for the next child
@@ -643,15 +643,11 @@ static int mg_levels_alloc(MG2DContext *ctx, size_t domain_size)
} else
next_size = (cur_size >> 1) + 1;
- cur_type = (cur_size <= ctx->max_exact_size) ? EGS_SOLVER_EXACT : EGS_SOLVER_RELAXATION;
-
- *dst = mg_level_alloc(cur_type, cur_size);
+ *dst = mg_level_alloc(cur_size);
if (!*dst)
return -ENOMEM;
(*dst)->depth = depth;
- if (cur_type == EGS_SOLVER_EXACT)
- break;
if (next_size <= 4)
break;
@@ -794,16 +790,16 @@ void mg2d_print_stats(MG2DContext *ctx, const char *prefix)
char buf[1024], *p;
int ret;
- EGSRelaxContext *r = NULL;
- EGSExactContext *e = NULL;
+ EGSRelaxContext *r = level->solver->relax;
+ EGSExactContext *e = level->solver->exact;
int64_t level_total = level->timer_solve.time_nsec + level->timer_prolong.time_nsec + level->timer_restrict.time_nsec +
level->timer_correct.time_nsec + level->timer_reinit.time_nsec;
- if (level->solver->solver_type == EGS_SOLVER_RELAXATION)
- r = level->solver->solver_data;
- else if (level->solver->solver_type == EGS_SOLVER_EXACT)
- e = level->solver->solver_data;
+ if (!level->count_cycles) {
+ level = level->child;
+ continue;
+ }
levels_total += level_total;
@@ -845,13 +841,13 @@ void mg2d_print_stats(MG2DContext *ctx, const char *prefix)
if (ret > 0)
p += ret;
- if (r) {
+ if (r->timer_correct.nb_runs) {
ret = snprintf(p, sizeof(buf) - (p - buf),
" %2.2f%% correct",
r->timer_correct.time_nsec * 100.0 / level->solver->timer_solve.time_nsec);
if (ret > 0)
p += ret;
- } else if (e) {
+ } else if (e->timer_mat_construct.nb_runs) {
ret = snprintf(p, sizeof(buf) - (p - buf),
" %2.2f%% const %2.2f%% bicgstab (%ld; %g it/slv) %2.2f%% lu (%ld) %2.2f%% export",
e->timer_mat_construct.time_nsec * 100.0 / level->solver->timer_solve.time_nsec,