aboutsummaryrefslogtreecommitdiff
path: root/ell_grid_solve.c
diff options
context:
space:
mode:
Diffstat (limited to 'ell_grid_solve.c')
-rw-r--r--ell_grid_solve.c267
1 files changed, 136 insertions, 131 deletions
diff --git a/ell_grid_solve.c b/ell_grid_solve.c
index cefb231..ac5aa28 100644
--- a/ell_grid_solve.c
+++ b/ell_grid_solve.c
@@ -55,6 +55,7 @@ typedef struct EGSExactInternal {
size_t N;
size_t N_ghosts;
+ int arrays_alloced;
double *mat;
double *mat_f;
double *rhs;
@@ -93,10 +94,8 @@ struct EGSInternal {
ResidualCalcContext *rescalc;
- union {
- EGSRelaxInternal r;
- EGSExactInternal e;
- };
+ EGSRelaxInternal r;
+ EGSExactInternal e;
TPContext *tp_internal;
};
@@ -327,7 +326,7 @@ static int residual_add_task(void *arg, unsigned int job_idx, unsigned int threa
static int solve_relax_step(EGSContext *ctx, int export_res)
{
EGSInternal *priv = ctx->priv;
- EGSRelaxContext *r = ctx->solver_data;
+ EGSRelaxContext *r = ctx->relax;
int u_next_valid = priv->u_next_valid;
@@ -361,6 +360,69 @@ static int solve_relax_step(EGSContext *ctx, int export_res)
return 0;
}
+static void exact_arrays_free(EGSContext *ctx)
+{
+ EGSInternal *priv = ctx->priv;
+ EGSExactInternal *e = &priv->e;
+
+ free(ctx->priv->e.scratch_lines);
+ free(ctx->priv->e.mat);
+ free(ctx->priv->e.mat_f);
+ free(ctx->priv->e.rhs);
+ free(ctx->priv->e.rhs_mat);
+ free(ctx->priv->e.rhs_factor);
+ free(ctx->priv->e.x);
+ free(ctx->priv->e.ipiv);
+
+ mg2di_bicgstab_context_free(&e->bicgstab);
+
+ e->arrays_alloced = 0;
+}
+
+static int exact_arrays_alloc(EGSContext *ctx)
+{
+ EGSInternal *priv = ctx->priv;
+ EGSExactInternal *e = &priv->e;
+ int ret;
+
+ if (e->arrays_alloced)
+ return 0;
+
+ e->N = ctx->domain_size[0] * ctx->domain_size[1];
+ e->N_ghosts = (ctx->domain_size[0] + 2 * FD_STENCIL_MAX) *
+ (ctx->domain_size[1] + 2 * FD_STENCIL_MAX);
+
+ e->mat = calloc(SQR(e->N), sizeof(*e->mat));
+ e->mat_f = calloc(SQR(e->N), sizeof(*e->mat_f));
+ e->rhs = calloc(e->N, sizeof(*e->rhs));
+ e->rhs_mat = calloc(e->N, sizeof(*e->rhs_mat));
+ e->rhs_factor = calloc(e->N, sizeof(*e->rhs_factor));
+ e->x = calloc(e->N, sizeof(*e->x));
+ e->ipiv = calloc(e->N, sizeof(*e->ipiv));
+ if (!e->mat || !e->mat_f || !e->rhs || !e->rhs_mat || !e->rhs_factor || !e->x || !e->ipiv) {
+ ret = -ENOMEM;
+ goto fail;
+ }
+
+ ret = mg2di_bicgstab_context_alloc(&e->bicgstab, e->N, 64);
+ if (ret < 0)
+ goto fail;
+
+ for (int i = 0; i < e->N; i++)
+ e->mat[i * e->N + i] = 1.0;
+
+ ret = mg2di_bicgstab_init(e->bicgstab, e->mat, e->x);
+ if (ret < 0)
+ goto fail;
+
+ e->arrays_alloced = 1;
+
+ return 0;
+fail:
+ exact_arrays_free(ctx);
+ return ret;
+}
+
static void fill_mat_s1(double *mat_row, ptrdiff_t mat_stride, ptrdiff_t row_stride,
NDArray **diff_coeffs, ptrdiff_t idx_src)
{
@@ -610,10 +672,29 @@ static int mat_fill_row_task(void *arg, unsigned int job_idx, unsigned int threa
static int solve_exact(EGSContext *ctx)
{
EGSInternal *priv = ctx->priv;
- EGSExactContext *ec = ctx->solver_data;
+ EGSExactContext *ec = ctx->exact;
EGSExactInternal *e = &priv->e;
+ unsigned int nb_threads = tp_get_nb_threads(ctx->tp);
int ret;
+ if (!e->arrays_alloced) {
+ ret = exact_arrays_alloc(ctx);
+ if (ret < 0)
+ return ret;
+ }
+
+ if (e->scratch_lines_allocated < nb_threads) {
+ free(e->scratch_lines);
+ e->scratch_lines = NULL;
+ e->scratch_lines_allocated = 0;
+
+ e->scratch_lines = calloc(e->N_ghosts * nb_threads,
+ sizeof(*e->scratch_lines));
+ if (!e->scratch_lines)
+ return -ENOMEM;
+ e->scratch_lines_allocated = nb_threads;
+ }
+
mg2di_timer_start(&ec->timer_mat_construct);
if (!e->mat_filled) {
@@ -673,22 +754,23 @@ static int solve_exact(EGSContext *ctx)
mg2di_timer_stop(&ec->timer_export);
+ memset(priv->reflect_disable, 0, sizeof(priv->reflect_disable));
boundaries_apply(ctx, ctx->u, 0);
residual_calc(ctx, 1);
return 0;
}
-int mg2di_egs_solve(EGSContext *ctx, int export_res)
+int mg2di_egs_solve(EGSContext *ctx, enum EGSType solve_type, int export_res)
{
int ret;
mg2di_timer_start(&ctx->timer_solve);
- switch (ctx->solver_type) {
- case EGS_SOLVER_RELAXATION: ret = solve_relax_step(ctx, export_res); break;
- case EGS_SOLVER_EXACT: ret = solve_exact(ctx); break;
- default: ret = -EINVAL;
+ switch (solve_type) {
+ case EGS_SOLVE_RELAXATION: ret = solve_relax_step(ctx, export_res); break;
+ case EGS_SOLVE_EXACT: ret = solve_exact(ctx); break;
+ default: ret = -EINVAL;
}
mg2di_timer_stop(&ctx->timer_solve);
@@ -714,6 +796,7 @@ static int init_diff_coeffs_task(void *arg, unsigned int job_idx, unsigned int t
int mg2di_egs_init(EGSContext *ctx, int flags)
{
EGSInternal *priv = ctx->priv;
+ EGSRelaxContext *r = ctx->relax;
int ret = 0;
mg2di_timer_start(&ctx->timer_solve);
@@ -725,23 +808,20 @@ int mg2di_egs_init(EGSContext *ctx, int flags)
goto finish;
}
- if (ctx->solver_type == EGS_SOLVER_RELAXATION) {
- EGSRelaxContext *r = ctx->solver_data;
- if (r->relax_factor == 0.0)
- priv->r.relax_factor = relax_factors[ctx->fd_stencil - 1];
- else
- priv->r.relax_factor = r->relax_factor;
- priv->r.relax_factor *= ctx->step[0] * ctx->step[0];
+ if (r->relax_factor == 0.0)
+ priv->r.relax_factor = relax_factors[ctx->fd_stencil - 1];
+ else
+ priv->r.relax_factor = r->relax_factor;
+ priv->r.relax_factor *= ctx->step[0] * ctx->step[0];
- if (r->relax_multiplier > 0.0)
- priv->r.relax_factor *= r->relax_multiplier;
+ if (r->relax_multiplier > 0.0)
+ priv->r.relax_factor *= r->relax_multiplier;
- priv->r.line_add = line_madd_c;
+ priv->r.line_add = line_madd_c;
#if HAVE_EXTERNAL_ASM
- if (ctx->cpuflags & MG2DI_CPU_FLAG_FMA3)
- priv->r.line_add = mg2di_line_madd_fma3;
+ if (ctx->cpuflags & MG2DI_CPU_FLAG_FMA3)
+ priv->r.line_add = mg2di_line_madd_fma3;
#endif
- }
if (!ctx->tp) {
ret = tp_init(&priv->tp_internal, 1);
@@ -773,36 +853,17 @@ int mg2di_egs_init(EGSContext *ctx, int flags)
tp_execute(ctx->tp, ctx->domain_size[0], init_diff_coeffs_task, &arg);
}
- if (ctx->solver_type == EGS_SOLVER_EXACT) {
- EGSExactInternal *e = &priv->e;
- unsigned int nb_threads = tp_get_nb_threads(ctx->tp);
-
- if (!(flags & EGS_INIT_FLAG_SAME_DIFF_COEFFS))
- e->mat_filled = 0;
-
- switch (ctx->fd_stencil) {
- case 1: e->fill_mat = fill_mat_s1; break;
- case 2: e->fill_mat = fill_mat_s2; break;
- default:
- mg2di_log(&ctx->logger, 0, "Invalid finite difference stencil: %zd\n",
- ctx->fd_stencil);
- ret = -EINVAL;
- goto finish;
- }
-
- if (e->scratch_lines_allocated < nb_threads) {
- free(e->scratch_lines);
- e->scratch_lines = NULL;
- e->scratch_lines_allocated = 0;
+ if (!(flags & EGS_INIT_FLAG_SAME_DIFF_COEFFS))
+ priv->e.mat_filled = 0;
- e->scratch_lines = calloc(e->N_ghosts * nb_threads,
- sizeof(*e->scratch_lines));
- if (!e->scratch_lines) {
- ret = -ENOMEM;
- goto finish;
- }
- e->scratch_lines_allocated = nb_threads;
- }
+ switch (ctx->fd_stencil) {
+ case 1: priv->e.fill_mat = fill_mat_s1; break;
+ case 2: priv->e.fill_mat = fill_mat_s2; break;
+ default:
+ mg2di_log(&ctx->logger, 0, "Invalid finite difference stencil: %zd\n",
+ ctx->fd_stencil);
+ ret = -EINVAL;
+ goto finish;
}
priv->residual_calc_size[0] = ctx->domain_size[0];
@@ -844,15 +905,12 @@ finish:
if (ret >= 0) {
memset(priv->reflect_disable, 0, sizeof(priv->reflect_disable));
- boundaries_apply(ctx, ctx->u, 1);
- if (ctx->solver_type == EGS_SOLVER_RELAXATION) {
- boundaries_apply(ctx, priv->u_next, 1);
+ boundaries_apply(ctx, ctx->u, 1);
+ boundaries_apply(ctx, priv->u_next, 1);
- residual_calc(ctx, 0);
- boundaries_apply(ctx, priv->u_next, 0);
- priv->u_next_valid = 1;
- } else
- residual_calc(ctx, 1);
+ residual_calc(ctx, 0);
+ boundaries_apply(ctx, priv->u_next, 0);
+ priv->u_next_valid = 1;
}
mg2di_timer_stop(&ctx->timer_solve);
@@ -947,38 +1005,10 @@ static int arrays_alloc(EGSContext *ctx, const size_t domain_size[2])
return -ENOMEM;
}
- if (ctx->solver_type == EGS_SOLVER_EXACT) {
- EGSExactInternal *e = &priv->e;
-
- e->N = domain_size[0] * domain_size[1];
- e->N_ghosts = (domain_size[0] + 2 * FD_STENCIL_MAX) * (domain_size[1] + 2 * FD_STENCIL_MAX);
-
- e->mat = calloc(SQR(e->N), sizeof(*e->mat));
- e->mat_f = calloc(SQR(e->N), sizeof(*e->mat_f));
- e->rhs = calloc(e->N, sizeof(*e->rhs));
- e->rhs_mat = calloc(e->N, sizeof(*e->rhs_mat));
- e->rhs_factor = calloc(e->N, sizeof(*e->rhs_factor));
- e->x = calloc(e->N, sizeof(*e->x));
- e->ipiv = calloc(e->N, sizeof(*e->ipiv));
- if (!e->mat || !e->mat_f || !e->rhs || !e->rhs_mat || !e->rhs_factor || !e->x || !e->ipiv)
- return -ENOMEM;
-
- ret = mg2di_bicgstab_context_alloc(&e->bicgstab, e->N, 64);
- if (ret < 0)
- return ret;
-
- for (int i = 0; i < e->N; i++)
- e->mat[i * e->N + i] = 1.0;
-
- ret = mg2di_bicgstab_init(e->bicgstab, e->mat, e->x);
- if (ret < 0)
- return ret;
- }
-
return 0;
}
-EGSContext *mg2di_egs_alloc(enum EGSType type, size_t domain_size[2])
+EGSContext *mg2di_egs_alloc(size_t domain_size[2])
{
EGSContext *ctx;
int ret;
@@ -991,32 +1021,19 @@ EGSContext *mg2di_egs_alloc(enum EGSType type, size_t domain_size[2])
if (!ctx->priv)
goto fail;
- switch (type) {
- case EGS_SOLVER_RELAXATION: {
- EGSRelaxContext *r;
- r = calloc(1, sizeof(EGSRelaxContext));
- if (!r)
- goto fail;
- ctx->solver_data = r;
- mg2di_timer_init(&r->timer_correct);
- break;
- }
- case EGS_SOLVER_EXACT: {
- EGSExactContext *e;
- e = calloc(1, sizeof(EGSExactContext));
- if (!e)
- goto fail;
- ctx->solver_data = e;
-
- mg2di_timer_init(&e->timer_mat_construct);
- mg2di_timer_init(&e->timer_bicgstab);
- mg2di_timer_init(&e->timer_lu_solve);
- mg2di_timer_init(&e->timer_export);
- break;
- }
- default: goto fail;
- }
- ctx->solver_type = type;
+ ctx->relax = calloc(1, sizeof(*ctx->relax));
+ if (!ctx->relax)
+ goto fail;
+ mg2di_timer_init(&ctx->relax->timer_correct);
+
+ ctx->exact = calloc(1, sizeof(*ctx->exact));
+ if (!ctx->exact)
+ goto fail;
+
+ mg2di_timer_init(&ctx->exact->timer_mat_construct);
+ mg2di_timer_init(&ctx->exact->timer_bicgstab);
+ mg2di_timer_init(&ctx->exact->timer_lu_solve);
+ mg2di_timer_init(&ctx->exact->timer_export);
if (!domain_size[0] || !domain_size[1] ||
domain_size[0] > SIZE_MAX / domain_size[1])
@@ -1057,22 +1074,10 @@ void mg2di_egs_free(EGSContext **pctx)
mg2di_residual_calc_free(&ctx->priv->rescalc);
- free(ctx->solver_data);
-
- if (ctx->solver_type == EGS_SOLVER_EXACT) {
- EGSExactInternal *e = &ctx->priv->e;
+ free(ctx->relax);
+ free(ctx->exact);
- free(e->scratch_lines);
- free(e->mat);
- free(e->mat_f);
- free(e->rhs);
- free(e->rhs_mat);
- free(e->rhs_factor);
- free(e->x);
- free(e->ipiv);
-
- mg2di_bicgstab_context_free(&e->bicgstab);
- }
+ exact_arrays_free(ctx);
mg2di_ndarray_free(&ctx->u);
mg2di_ndarray_free(&ctx->priv->u_next);