aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2019-05-05 15:45:10 +0200
committerAnton Khirnov <anton@khirnov.net>2019-05-06 18:23:59 +0200
commit1145d43ffef5568e04fb79a4f73ecd69f2f09ce2 (patch)
treec4c32f122e55497848b2d6fe0adc013fb1fc88af
parente91409830cbb4c06c4035d7084aa20b0d4d259bc (diff)
egs: allow the same context to be used for both relaxation and exact solves
-rw-r--r--ell_grid_solve.c267
-rw-r--r--ell_grid_solve.h17
-rw-r--r--mg2d.c132
-rw-r--r--relax_test.c4
4 files changed, 208 insertions, 212 deletions
diff --git a/ell_grid_solve.c b/ell_grid_solve.c
index cefb231..ac5aa28 100644
--- a/ell_grid_solve.c
+++ b/ell_grid_solve.c
@@ -55,6 +55,7 @@ typedef struct EGSExactInternal {
size_t N;
size_t N_ghosts;
+ int arrays_alloced;
double *mat;
double *mat_f;
double *rhs;
@@ -93,10 +94,8 @@ struct EGSInternal {
ResidualCalcContext *rescalc;
- union {
- EGSRelaxInternal r;
- EGSExactInternal e;
- };
+ EGSRelaxInternal r;
+ EGSExactInternal e;
TPContext *tp_internal;
};
@@ -327,7 +326,7 @@ static int residual_add_task(void *arg, unsigned int job_idx, unsigned int threa
static int solve_relax_step(EGSContext *ctx, int export_res)
{
EGSInternal *priv = ctx->priv;
- EGSRelaxContext *r = ctx->solver_data;
+ EGSRelaxContext *r = ctx->relax;
int u_next_valid = priv->u_next_valid;
@@ -361,6 +360,69 @@ static int solve_relax_step(EGSContext *ctx, int export_res)
return 0;
}
+static void exact_arrays_free(EGSContext *ctx)
+{
+ EGSInternal *priv = ctx->priv;
+ EGSExactInternal *e = &priv->e;
+
+ free(ctx->priv->e.scratch_lines);
+ free(ctx->priv->e.mat);
+ free(ctx->priv->e.mat_f);
+ free(ctx->priv->e.rhs);
+ free(ctx->priv->e.rhs_mat);
+ free(ctx->priv->e.rhs_factor);
+ free(ctx->priv->e.x);
+ free(ctx->priv->e.ipiv);
+
+ mg2di_bicgstab_context_free(&e->bicgstab);
+
+ e->arrays_alloced = 0;
+}
+
+static int exact_arrays_alloc(EGSContext *ctx)
+{
+ EGSInternal *priv = ctx->priv;
+ EGSExactInternal *e = &priv->e;
+ int ret;
+
+ if (e->arrays_alloced)
+ return 0;
+
+ e->N = ctx->domain_size[0] * ctx->domain_size[1];
+ e->N_ghosts = (ctx->domain_size[0] + 2 * FD_STENCIL_MAX) *
+ (ctx->domain_size[1] + 2 * FD_STENCIL_MAX);
+
+ e->mat = calloc(SQR(e->N), sizeof(*e->mat));
+ e->mat_f = calloc(SQR(e->N), sizeof(*e->mat_f));
+ e->rhs = calloc(e->N, sizeof(*e->rhs));
+ e->rhs_mat = calloc(e->N, sizeof(*e->rhs_mat));
+ e->rhs_factor = calloc(e->N, sizeof(*e->rhs_factor));
+ e->x = calloc(e->N, sizeof(*e->x));
+ e->ipiv = calloc(e->N, sizeof(*e->ipiv));
+ if (!e->mat || !e->mat_f || !e->rhs || !e->rhs_mat || !e->rhs_factor || !e->x || !e->ipiv) {
+ ret = -ENOMEM;
+ goto fail;
+ }
+
+ ret = mg2di_bicgstab_context_alloc(&e->bicgstab, e->N, 64);
+ if (ret < 0)
+ goto fail;
+
+ for (int i = 0; i < e->N; i++)
+ e->mat[i * e->N + i] = 1.0;
+
+ ret = mg2di_bicgstab_init(e->bicgstab, e->mat, e->x);
+ if (ret < 0)
+ goto fail;
+
+ e->arrays_alloced = 1;
+
+ return 0;
+fail:
+ exact_arrays_free(ctx);
+ return ret;
+}
+
static void fill_mat_s1(double *mat_row, ptrdiff_t mat_stride, ptrdiff_t row_stride,
NDArray **diff_coeffs, ptrdiff_t idx_src)
{
@@ -610,10 +672,29 @@ static int mat_fill_row_task(void *arg, unsigned int job_idx, unsigned int threa
static int solve_exact(EGSContext *ctx)
{
EGSInternal *priv = ctx->priv;
- EGSExactContext *ec = ctx->solver_data;
+ EGSExactContext *ec = ctx->exact;
EGSExactInternal *e = &priv->e;
+ unsigned int nb_threads = tp_get_nb_threads(ctx->tp);
int ret;
+ if (!e->arrays_alloced) {
+ ret = exact_arrays_alloc(ctx);
+ if (ret < 0)
+ return ret;
+ }
+
+ if (e->scratch_lines_allocated < nb_threads) {
+ free(e->scratch_lines);
+ e->scratch_lines = NULL;
+ e->scratch_lines_allocated = 0;
+
+ e->scratch_lines = calloc(e->N_ghosts * nb_threads,
+ sizeof(*e->scratch_lines));
+ if (!e->scratch_lines)
+ return -ENOMEM;
+ e->scratch_lines_allocated = nb_threads;
+ }
+
mg2di_timer_start(&ec->timer_mat_construct);
if (!e->mat_filled) {
@@ -673,22 +754,23 @@ static int solve_exact(EGSContext *ctx)
mg2di_timer_stop(&ec->timer_export);
+ memset(priv->reflect_disable, 0, sizeof(priv->reflect_disable));
boundaries_apply(ctx, ctx->u, 0);
residual_calc(ctx, 1);
return 0;
}
-int mg2di_egs_solve(EGSContext *ctx, int export_res)
+int mg2di_egs_solve(EGSContext *ctx, enum EGSType solve_type, int export_res)
{
int ret;
mg2di_timer_start(&ctx->timer_solve);
- switch (ctx->solver_type) {
- case EGS_SOLVER_RELAXATION: ret = solve_relax_step(ctx, export_res); break;
- case EGS_SOLVER_EXACT: ret = solve_exact(ctx); break;
- default: ret = -EINVAL;
+ switch (solve_type) {
+ case EGS_SOLVE_RELAXATION: ret = solve_relax_step(ctx, export_res); break;
+ case EGS_SOLVE_EXACT: ret = solve_exact(ctx); break;
+ default: ret = -EINVAL;
}
mg2di_timer_stop(&ctx->timer_solve);
@@ -714,6 +796,7 @@ static int init_diff_coeffs_task(void *arg, unsigned int job_idx, unsigned int t
int mg2di_egs_init(EGSContext *ctx, int flags)
{
EGSInternal *priv = ctx->priv;
+ EGSRelaxContext *r = ctx->relax;
int ret = 0;
mg2di_timer_start(&ctx->timer_solve);
@@ -725,23 +808,20 @@ int mg2di_egs_init(EGSContext *ctx, int flags)
goto finish;
}
- if (ctx->solver_type == EGS_SOLVER_RELAXATION) {
- EGSRelaxContext *r = ctx->solver_data;
- if (r->relax_factor == 0.0)
- priv->r.relax_factor = relax_factors[ctx->fd_stencil - 1];
- else
- priv->r.relax_factor = r->relax_factor;
- priv->r.relax_factor *= ctx->step[0] * ctx->step[0];
+ if (r->relax_factor == 0.0)
+ priv->r.relax_factor = relax_factors[ctx->fd_stencil - 1];
+ else
+ priv->r.relax_factor = r->relax_factor;
+ priv->r.relax_factor *= ctx->step[0] * ctx->step[0];
- if (r->relax_multiplier > 0.0)
- priv->r.relax_factor *= r->relax_multiplier;
+ if (r->relax_multiplier > 0.0)
+ priv->r.relax_factor *= r->relax_multiplier;
- priv->r.line_add = line_madd_c;
+ priv->r.line_add = line_madd_c;
#if HAVE_EXTERNAL_ASM
- if (ctx->cpuflags & MG2DI_CPU_FLAG_FMA3)
- priv->r.line_add = mg2di_line_madd_fma3;
+ if (ctx->cpuflags & MG2DI_CPU_FLAG_FMA3)
+ priv->r.line_add = mg2di_line_madd_fma3;
#endif
- }
if (!ctx->tp) {
ret = tp_init(&priv->tp_internal, 1);
@@ -773,36 +853,17 @@ int mg2di_egs_init(EGSContext *ctx, int flags)
tp_execute(ctx->tp, ctx->domain_size[0], init_diff_coeffs_task, &arg);
}
- if (ctx->solver_type == EGS_SOLVER_EXACT) {
- EGSExactInternal *e = &priv->e;
- unsigned int nb_threads = tp_get_nb_threads(ctx->tp);
-
- if (!(flags & EGS_INIT_FLAG_SAME_DIFF_COEFFS))
- e->mat_filled = 0;
-
- switch (ctx->fd_stencil) {
- case 1: e->fill_mat = fill_mat_s1; break;
- case 2: e->fill_mat = fill_mat_s2; break;
- default:
- mg2di_log(&ctx->logger, 0, "Invalid finite difference stencil: %zd\n",
- ctx->fd_stencil);
- ret = -EINVAL;
- goto finish;
- }
-
- if (e->scratch_lines_allocated < nb_threads) {
- free(e->scratch_lines);
- e->scratch_lines = NULL;
- e->scratch_lines_allocated = 0;
+ if (!(flags & EGS_INIT_FLAG_SAME_DIFF_COEFFS))
+ priv->e.mat_filled = 0;
- e->scratch_lines = calloc(e->N_ghosts * nb_threads,
- sizeof(*e->scratch_lines));
- if (!e->scratch_lines) {
- ret = -ENOMEM;
- goto finish;
- }
- e->scratch_lines_allocated = nb_threads;
- }
+ switch (ctx->fd_stencil) {
+ case 1: priv->e.fill_mat = fill_mat_s1; break;
+ case 2: priv->e.fill_mat = fill_mat_s2; break;
+ default:
+ mg2di_log(&ctx->logger, 0, "Invalid finite difference stencil: %zd\n",
+ ctx->fd_stencil);
+ ret = -EINVAL;
+ goto finish;
}
priv->residual_calc_size[0] = ctx->domain_size[0];
@@ -844,15 +905,12 @@ finish:
if (ret >= 0) {
memset(priv->reflect_disable, 0, sizeof(priv->reflect_disable));
- boundaries_apply(ctx, ctx->u, 1);
- if (ctx->solver_type == EGS_SOLVER_RELAXATION) {
- boundaries_apply(ctx, priv->u_next, 1);
+ boundaries_apply(ctx, ctx->u, 1);
+ boundaries_apply(ctx, priv->u_next, 1);
- residual_calc(ctx, 0);
- boundaries_apply(ctx, priv->u_next, 0);
- priv->u_next_valid = 1;
- } else
- residual_calc(ctx, 1);
+ residual_calc(ctx, 0);
+ boundaries_apply(ctx, priv->u_next, 0);
+ priv->u_next_valid = 1;
}
mg2di_timer_stop(&ctx->timer_solve);
@@ -947,38 +1005,10 @@ static int arrays_alloc(EGSContext *ctx, const size_t domain_size[2])
return -ENOMEM;
}
- if (ctx->solver_type == EGS_SOLVER_EXACT) {
- EGSExactInternal *e = &priv->e;
-
- e->N = domain_size[0] * domain_size[1];
- e->N_ghosts = (domain_size[0] + 2 * FD_STENCIL_MAX) * (domain_size[1] + 2 * FD_STENCIL_MAX);
-
- e->mat = calloc(SQR(e->N), sizeof(*e->mat));
- e->mat_f = calloc(SQR(e->N), sizeof(*e->mat_f));
- e->rhs = calloc(e->N, sizeof(*e->rhs));
- e->rhs_mat = calloc(e->N, sizeof(*e->rhs_mat));
- e->rhs_factor = calloc(e->N, sizeof(*e->rhs_factor));
- e->x = calloc(e->N, sizeof(*e->x));
- e->ipiv = calloc(e->N, sizeof(*e->ipiv));
- if (!e->mat || !e->mat_f || !e->rhs || !e->rhs_mat || !e->rhs_factor || !e->x || !e->ipiv)
- return -ENOMEM;
-
- ret = mg2di_bicgstab_context_alloc(&e->bicgstab, e->N, 64);
- if (ret < 0)
- return ret;
-
- for (int i = 0; i < e->N; i++)
- e->mat[i * e->N + i] = 1.0;
-
- ret = mg2di_bicgstab_init(e->bicgstab, e->mat, e->x);
- if (ret < 0)
- return ret;
- }
-
return 0;
}
-EGSContext *mg2di_egs_alloc(enum EGSType type, size_t domain_size[2])
+EGSContext *mg2di_egs_alloc(size_t domain_size[2])
{
EGSContext *ctx;
int ret;
@@ -991,32 +1021,19 @@ EGSContext *mg2di_egs_alloc(enum EGSType type, size_t domain_size[2])
if (!ctx->priv)
goto fail;
- switch (type) {
- case EGS_SOLVER_RELAXATION: {
- EGSRelaxContext *r;
- r = calloc(1, sizeof(EGSRelaxContext));
- if (!r)
- goto fail;
- ctx->solver_data = r;
- mg2di_timer_init(&r->timer_correct);
- break;
- }
- case EGS_SOLVER_EXACT: {
- EGSExactContext *e;
- e = calloc(1, sizeof(EGSExactContext));
- if (!e)
- goto fail;
- ctx->solver_data = e;
-
- mg2di_timer_init(&e->timer_mat_construct);
- mg2di_timer_init(&e->timer_bicgstab);
- mg2di_timer_init(&e->timer_lu_solve);
- mg2di_timer_init(&e->timer_export);
- break;
- }
- default: goto fail;
- }
- ctx->solver_type = type;
+ ctx->relax = calloc(1, sizeof(*ctx->relax));
+ if (!ctx->relax)
+ goto fail;
+ mg2di_timer_init(&ctx->relax->timer_correct);
+
+ ctx->exact = calloc(1, sizeof(*ctx->exact));
+ if (!ctx->exact)
+ goto fail;
+
+ mg2di_timer_init(&ctx->exact->timer_mat_construct);
+ mg2di_timer_init(&ctx->exact->timer_bicgstab);
+ mg2di_timer_init(&ctx->exact->timer_lu_solve);
+ mg2di_timer_init(&ctx->exact->timer_export);
if (!domain_size[0] || !domain_size[1] ||
domain_size[0] > SIZE_MAX / domain_size[1])
@@ -1057,22 +1074,10 @@ void mg2di_egs_free(EGSContext **pctx)
mg2di_residual_calc_free(&ctx->priv->rescalc);
- free(ctx->solver_data);
-
- if (ctx->solver_type == EGS_SOLVER_EXACT) {
- EGSExactInternal *e = &ctx->priv->e;
+ free(ctx->relax);
+ free(ctx->exact);
- free(e->scratch_lines);
- free(e->mat);
- free(e->mat_f);
- free(e->rhs);
- free(e->rhs_mat);
- free(e->rhs_factor);
- free(e->x);
- free(e->ipiv);
-
- mg2di_bicgstab_context_free(&e->bicgstab);
- }
+ exact_arrays_free(ctx);
mg2di_ndarray_free(&ctx->u);
mg2di_ndarray_free(&ctx->priv->u_next);
diff --git a/ell_grid_solve.h b/ell_grid_solve.h
index 6e59bf0..601eac9 100644
--- a/ell_grid_solve.h
+++ b/ell_grid_solve.h
@@ -53,14 +53,14 @@ enum EGSType {
* solver_data is EGSRelaxContext
* mg2di_egs_solve() does a single relaxation step
*/
- EGS_SOLVER_RELAXATION,
+ EGS_SOLVE_RELAXATION,
/**
* Solve the equation exactly by contructing a linear system and solving it with LAPACK.
*
* solver_data is EGSExactContext
* mg2di_egs_solve() solves the discretized system exactly (up to roundoff error)
*/
- EGS_SOLVER_EXACT,
+ EGS_SOLVE_EXACT,
};
typedef struct EGSInternal EGSInternal;
@@ -97,15 +97,10 @@ typedef struct EGSExactContext {
} EGSExactContext;
typedef struct EGSContext {
- enum EGSType solver_type;
+ EGSRelaxContext *relax;
+ EGSExactContext *exact;
/**
- * Solver type-specific data.
- *
- * Should be cast into the struct specified in documentation for this type.
- */
- void *solver_data;
- /**
* Solver private data, not to be accessed in any way by the caller.
*/
EGSInternal *priv;
@@ -210,7 +205,7 @@ typedef struct EGSContext {
*
* @return The solver context on success, NULL on failure.
*/
-EGSContext *mg2di_egs_alloc(enum EGSType solver_type, size_t domain_size[2]);
+EGSContext *mg2di_egs_alloc(size_t domain_size[2]);
/**
* Initialize the solver for use, after all the required fields are filled by
* the caller.
@@ -233,6 +228,6 @@ void mg2di_egs_free(EGSContext **ctx);
*
* @return 0 on success, a negative error code on failure.
*/
-int mg2di_egs_solve(EGSContext *ctx, int export_res);
+int mg2di_egs_solve(EGSContext *ctx, enum EGSType solve_type, int export_res);
#endif /* MG2D_ELL_GRID_SOLVE_H */
diff --git a/mg2d.c b/mg2d.c
index ed56eb2..5f2ce22 100644
--- a/mg2d.c
+++ b/mg2d.c
@@ -85,8 +85,8 @@ static void log_callback(MG2DLogger *log, int level, const char *fmt, va_list vl
ctx->log_callback(ctx, level, fmt, vl);
}
-static void log_relax_step(MG2DContext *ctx, MG2DLevel *level, const char *step_desc,
- double res_old, double res_new)
+static void log_egs_step(MG2DContext *ctx, MG2DLevel *level, const char *step_desc,
+ double res_old, double res_new)
{
char prefix[32] = { 0 };
@@ -109,8 +109,8 @@ static int coarse_correct_task(void *arg, unsigned int job_idx, unsigned int thr
return 0;
}
-static int mg_relax_step(MG2DContext *ctx, MG2DLevel *level, const char *step_desc,
- int export_res)
+static int mg_solve(MG2DContext *ctx, MG2DLevel *level, enum EGSType solve_type,
+ const char *step_desc, int export_res)
{
double res_old;
int ret;
@@ -118,21 +118,23 @@ static int mg_relax_step(MG2DContext *ctx, MG2DLevel *level, const char *step_de
res_old = level->solver->residual_max;
mg2di_timer_start(&level->timer_solve);
- ret = mg2di_egs_solve(level->solver, export_res);
+ ret = mg2di_egs_solve(level->solver, solve_type, export_res);
mg2di_timer_stop(&level->timer_solve);
if (ret < 0)
return ret;
- log_relax_step(ctx, level, step_desc, res_old, level->solver->residual_max);
+ log_egs_step(ctx, level, step_desc, res_old, level->solver->residual_max);
return 0;
}
-static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level)
+static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level, int allow_exact)
{
+ enum EGSType solve_type = allow_exact && level->solver->domain_size[0] <= ctx->max_exact_size ?
+ EGS_SOLVE_EXACT : EGS_SOLVE_RELAXATION;
double res_old, res_new;
- int ret;
+ int ret = 0;
/* on the refined levels, use zero as the initial guess for the
* solution (correction for the upper level) */
@@ -152,9 +154,9 @@ static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level)
res_old = level->solver->residual_max;
- /* handle coarsest grid */
- if (!level->child) {
- ret = mg_relax_step(ctx, level, "coarse-step", 1);
+ /* handle exact solve */
+ if (solve_type == EGS_SOLVE_EXACT) {
+ ret = mg_solve(ctx, level, solve_type, "coarse-step", 0);
if (ret < 0)
return ret;
level->count_cycles++;
@@ -167,52 +169,55 @@ static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level)
/* pre-restrict relaxation */
for (int j = 0; j < ctx->nb_relax_pre; j++) {
- ret = mg_relax_step(ctx, level, "pre-relax", j == ctx->nb_relax_pre - 1);
+ ret = mg_solve(ctx, level, solve_type,
+ "pre-relax", j == ctx->nb_relax_pre - 1 && level->child);
if (ret < 0)
return ret;
}
- /* restrict the residual as to the coarser-level rhs */
- mg2di_timer_start(&level->timer_restrict);
- ret = mg2di_gt_transfer(level->transfer_restrict, level->child->solver->rhs,
- level->solver->residual);
- mg2di_timer_stop(&level->timer_restrict);
- if (ret < 0)
- return ret;
+ if (level->child) {
+ /* restrict the residual as to the coarser-level rhs */
+ mg2di_timer_start(&level->timer_restrict);
+ ret = mg2di_gt_transfer(level->transfer_restrict, level->child->solver->rhs,
+ level->solver->residual);
+ mg2di_timer_stop(&level->timer_restrict);
+ if (ret < 0)
+ return ret;
- /* solve on the coarser level */
- ret = mg_solve_subgrid(ctx, level->child);
- if (ret < 0)
- return ret;
+ /* solve on the coarser level */
+ ret = mg_solve_subgrid(ctx, level->child, 1);
+ if (ret < 0)
+ return ret;
- /* prolongate the coarser-level correction */
- mg2di_timer_start(&level->timer_prolong);
- ret = mg2di_gt_transfer(level->transfer_prolong, level->prolong_tmp,
- level->child->solver->u);
- mg2di_timer_stop(&level->timer_prolong);
- if (ret < 0)
- return ret;
+ /* prolongate the coarser-level correction */
+ mg2di_timer_start(&level->timer_prolong);
+ ret = mg2di_gt_transfer(level->transfer_prolong, level->prolong_tmp,
+ level->child->solver->u);
+ mg2di_timer_stop(&level->timer_prolong);
+ if (ret < 0)
+ return ret;
- /* apply the correction */
- mg2di_timer_start(&level->timer_correct);
- tp_execute(ctx->priv->tp, level->solver->domain_size[1], coarse_correct_task, level);
- mg2di_timer_stop(&level->timer_correct);
+ /* apply the correction */
+ mg2di_timer_start(&level->timer_correct);
+ tp_execute(ctx->priv->tp, level->solver->domain_size[1], coarse_correct_task, level);
+ mg2di_timer_stop(&level->timer_correct);
- /* re-init the current-level solver (re-calc the residual) */
- res_prev = level->solver->residual_max;
- mg2di_timer_start(&level->timer_reinit);
- ret = mg2di_egs_init(level->solver, level->egs_init_flags);
- mg2di_timer_stop(&level->timer_reinit);
- if (ret < 0)
- return ret;
+ /* re-init the current-level solver (re-calc the residual) */
+ res_prev = level->solver->residual_max;
+ mg2di_timer_start(&level->timer_reinit);
+ ret = mg2di_egs_init(level->solver, level->egs_init_flags);
+ mg2di_timer_stop(&level->timer_reinit);
+ if (ret < 0)
+ return ret;
- level->egs_init_flags |= EGS_INIT_FLAG_SAME_DIFF_COEFFS;
+ level->egs_init_flags |= EGS_INIT_FLAG_SAME_DIFF_COEFFS;
- log_relax_step(ctx, level, "correct", res_prev, level->solver->residual_max);
+ log_egs_step(ctx, level, "correct", res_prev, level->solver->residual_max);
+ }
/* post-correct relaxation */
for (int j = 0; j < ctx->nb_relax_post; j++) {
- ret = mg_relax_step(ctx, level, "post-relax", 0);
+ ret = mg_solve(ctx, level, solve_type, "post-relax", 0);
if (ret < 0)
return ret;
}
@@ -375,13 +380,10 @@ static int mg_levels_init(MG2DContext *ctx)
cur->solver->tp = priv->tp;
cur->solver->fd_stencil = ctx->fd_stencil;
- if (cur->solver->solver_type == EGS_SOLVER_RELAXATION) {
- EGSRelaxContext *r = cur->solver->solver_data;
- r->relax_factor = ctx->cfl_factor;
- r->relax_multiplier = 1.0 / (diff2_max + cur->solver->step[0] * cur->solver->step[1] *
- diff0_max / 8.0);
- }
+ cur->solver->relax->relax_factor = ctx->cfl_factor;
+ cur->solver->relax->relax_multiplier = 1.0 / (diff2_max + cur->solver->step[0] * cur->solver->step[1] *
+ diff0_max / 8.0);
prev = cur;
cur = cur->child;
@@ -521,7 +523,7 @@ int mg2d_solve(MG2DContext *ctx)
res_prev = res_orig;
for (int i = 0; i < ctx->maxiter; i++) {
- ret = mg_solve_subgrid(ctx, root);
+ ret = mg_solve_subgrid(ctx, root, 1);
if (ret < 0)
goto fail;
@@ -575,7 +577,7 @@ static void mg_level_free(MG2DLevel **plevel)
*plevel = NULL;
}
-static MG2DLevel *mg_level_alloc(enum EGSType type, const size_t domain_size)
+static MG2DLevel *mg_level_alloc(const size_t domain_size)
{
MG2DLevel *level;
int ret;
@@ -592,8 +594,7 @@ static MG2DLevel *mg_level_alloc(enum EGSType type, const size_t domain_size)
if (ret < 0)
goto fail;
- level->solver = mg2di_egs_alloc(type,
- (size_t [2]){domain_size, domain_size});
+ level->solver = mg2di_egs_alloc((size_t [2]){domain_size, domain_size});
if (!level->solver)
goto fail;
@@ -630,7 +631,6 @@ static int mg_levels_alloc(MG2DContext *ctx, size_t domain_size)
next_size = domain_size;
for (int depth = 0; depth < ctx->max_levels; depth++) {
- enum EGSType cur_type;
size_t cur_size = next_size;
// choose the domain size for the next child
@@ -643,15 +643,11 @@ static int mg_levels_alloc(MG2DContext *ctx, size_t domain_size)
} else
next_size = (cur_size >> 1) + 1;
- cur_type = (cur_size <= ctx->max_exact_size) ? EGS_SOLVER_EXACT : EGS_SOLVER_RELAXATION;
-
- *dst = mg_level_alloc(cur_type, cur_size);
+ *dst = mg_level_alloc(cur_size);
if (!*dst)
return -ENOMEM;
(*dst)->depth = depth;
- if (cur_type == EGS_SOLVER_EXACT)
- break;
if (next_size <= 4)
break;
@@ -794,16 +790,16 @@ void mg2d_print_stats(MG2DContext *ctx, const char *prefix)
char buf[1024], *p;
int ret;
- EGSRelaxContext *r = NULL;
- EGSExactContext *e = NULL;
+ EGSRelaxContext *r = level->solver->relax;
+ EGSExactContext *e = level->solver->exact;
int64_t level_total = level->timer_solve.time_nsec + level->timer_prolong.time_nsec + level->timer_restrict.time_nsec +
level->timer_correct.time_nsec + level->timer_reinit.time_nsec;
- if (level->solver->solver_type == EGS_SOLVER_RELAXATION)
- r = level->solver->solver_data;
- else if (level->solver->solver_type == EGS_SOLVER_EXACT)
- e = level->solver->solver_data;
+ if (!level->count_cycles) {
+ level = level->child;
+ continue;
+ }
levels_total += level_total;
@@ -845,13 +841,13 @@ void mg2d_print_stats(MG2DContext *ctx, const char *prefix)
if (ret > 0)
p += ret;
- if (r) {
+ if (r->timer_correct.nb_runs) {
ret = snprintf(p, sizeof(buf) - (p - buf),
" %2.2f%% correct",
r->timer_correct.time_nsec * 100.0 / level->solver->timer_solve.time_nsec);
if (ret > 0)
p += ret;
- } else if (e) {
+ } else if (e->timer_mat_construct.nb_runs) {
ret = snprintf(p, sizeof(buf) - (p - buf),
" %2.2f%% const %2.2f%% bicgstab (%ld; %g it/slv) %2.2f%% lu (%ld) %2.2f%% export",
e->timer_mat_construct.time_nsec * 100.0 / level->solver->timer_solve.time_nsec,
diff --git a/relax_test.c b/relax_test.c
index 50ad04c..599fb4d 100644
--- a/relax_test.c
+++ b/relax_test.c
@@ -75,7 +75,7 @@ int main(int argc, char **argv)
N = (1L << log2N) + 1;
maxiter = 1L << log2maxiter;
- ctx = mg2di_egs_alloc(EGS_SOLVER_RELAXATION, (size_t [2]){N, N});
+ ctx = mg2di_egs_alloc((size_t [2]){N, N});
if (!ctx) {
fprintf(stderr, "Error allocating the solver context\n");
return 1;
@@ -144,7 +144,7 @@ int main(int argc, char **argv)
res_old = findmax(ctx->residual->data, ctx->residual->stride[0] * ctx->domain_size[1]);
for (int i = 0; i < maxiter; i++) {
- ret = mg2di_egs_solve(ctx, 0);
+ ret = mg2di_egs_solve(ctx, EGS_SOLVE_RELAXATION, 0);
if (ret < 0) {
fprintf(stderr, "Error during relaxation\n");
ret = 1;