aboutsummaryrefslogtreecommitdiff
path: root/ell_grid_solve.c
diff options
context:
space:
mode:
Diffstat (limited to 'ell_grid_solve.c')
-rw-r--r--ell_grid_solve.c115
1 files changed, 88 insertions, 27 deletions
diff --git a/ell_grid_solve.c b/ell_grid_solve.c
index 4cb87ed..cefb231 100644
--- a/ell_grid_solve.c
+++ b/ell_grid_solve.c
@@ -76,6 +76,9 @@ struct EGSInternal {
NDArray *rhs_base;
NDArray *residual_base;
+ NDArray *u_next;
+ int u_next_valid;
+
/* all the diff coeffs concatenated */
NDArray *diff_coeffs_public_base;
@@ -86,6 +89,7 @@ struct EGSInternal {
size_t residual_calc_size[2];
int bnd_zero[4];
+ int reflect_disable[4];
ResidualCalcContext *rescalc;
@@ -116,7 +120,7 @@ static const double fd_denoms[][MG2D_DIFF_COEFF_NB] = {
},
};
-static void residual_calc(EGSContext *ctx)
+static void residual_calc(EGSContext *ctx, int export_res)
{
EGSInternal *priv = ctx->priv;
const double *diff_coeffs[MG2D_DIFF_COEFF_NB];
@@ -127,15 +131,26 @@ static void residual_calc(EGSContext *ctx)
diff_coeffs[i] = NDPTR2D(priv->diff_coeffs[i], priv->residual_calc_offset[1], priv->residual_calc_offset[0]);
mg2di_residual_calc(priv->rescalc, priv->residual_calc_size, &ctx->residual_max,
- NDPTR2D(ctx->residual, priv->residual_calc_offset[1], priv->residual_calc_offset[0]),
- ctx->residual->stride[0],
+ NDPTR2D(export_res ? ctx->residual : priv->u_next, priv->residual_calc_offset[1], priv->residual_calc_offset[0]),
+ export_res ? ctx->residual->stride[0] : priv->u_next->stride[0],
NDPTR2D(ctx->u, priv->residual_calc_offset[1], priv->residual_calc_offset[0]),
ctx->u->stride[0],
NDPTR2D(ctx->rhs, priv->residual_calc_offset[1], priv->residual_calc_offset[0]),
ctx->rhs->stride[0],
- diff_coeffs, priv->diff_coeffs[0]->stride[0]);
+ diff_coeffs, priv->diff_coeffs[0]->stride[0],
+ export_res ? 0.0 : 1.0, export_res ? 1.0 : priv->r.relax_factor,
+ export_res ? 0 :
+ ((ctx->boundaries[MG2D_BOUNDARY_0L]->type == MG2D_BC_TYPE_REFLECT) << 0) |
+ ((ctx->boundaries[MG2D_BOUNDARY_0U]->type == MG2D_BC_TYPE_REFLECT) << 1) |
+ ((ctx->boundaries[MG2D_BOUNDARY_1L]->type == MG2D_BC_TYPE_REFLECT) << 2));
mg2di_timer_stop(&ctx->timer_res_calc);
+
+ if (!export_res) {
+ priv->reflect_disable[MG2D_BOUNDARY_0L] = 1;
+ priv->reflect_disable[MG2D_BOUNDARY_0U] = 1;
+ priv->reflect_disable[MG2D_BOUNDARY_1L] = 1;
+ }
}
static void boundaries_apply_fixval(double *dst, const ptrdiff_t dst_stride[2],
@@ -195,11 +210,11 @@ static void boundaries_apply_falloff(double *dst, const ptrdiff_t dst_stride[2],
}
}
-static void boundaries_apply(EGSContext *ctx, int init)
+static void boundaries_apply(EGSContext *ctx, NDArray *a_dst, int init)
{
static const enum MG2DBCType bnd_type_order[] = { MG2D_BC_TYPE_FIXVAL, MG2D_BC_TYPE_REFLECT, MG2D_BC_TYPE_FALLOFF };
EGSInternal *priv = ctx->priv;
- const ptrdiff_t strides[2] = { 1, ctx->u->stride[0] };
+ const ptrdiff_t strides[2] = { 1, a_dst->stride[0] };
mg2di_timer_start(&ctx->timer_bnd);
@@ -213,7 +228,7 @@ static void boundaries_apply(EGSContext *ctx, int init)
const size_t size_boundary = ctx->domain_size[!ci];
const size_t size_offset = ctx->domain_size[ci];
- double *dst = ctx->u->data + mg2d_bnd_is_upper(i) * ((size_offset - 1) * stride_offset);
+ double *dst = a_dst->data + mg2d_bnd_is_upper(i) * ((size_offset - 1) * stride_offset);
const ptrdiff_t dst_strides[] = { stride_boundary, mg2d_bnd_out_dir(i) * stride_offset };
if (bnd->type != bnd_type_order[order_idx])
@@ -229,9 +244,11 @@ static void boundaries_apply(EGSContext *ctx, int init)
}
break;
case MG2D_BC_TYPE_REFLECT:
- mg2di_timer_start(&ctx->timer_bnd_reflect);
- boundaries_apply_reflect(dst, dst_strides, size_boundary);
- mg2di_timer_stop(&ctx->timer_bnd_reflect);
+ if (!priv->reflect_disable[i]) {
+ mg2di_timer_start(&ctx->timer_bnd_reflect);
+ boundaries_apply_reflect(dst, dst_strides, size_boundary);
+ mg2di_timer_stop(&ctx->timer_bnd_reflect);
+ }
break;
case MG2D_BC_TYPE_FALLOFF:
mg2di_timer_start(&ctx->timer_bnd_falloff);
@@ -255,8 +272,8 @@ static void boundaries_apply(EGSContext *ctx, int init)
const int dir_x = mg2d_bnd_out_dir(loc_x);
int fact_x = dir_x, fact_y = dir_y;
- double *dst = ctx->u->data
- + mg2d_bnd_is_upper(loc_y) * ((ctx->domain_size[1] - 1) * ctx->u->stride[0])
+ double *dst = a_dst->data
+ + mg2d_bnd_is_upper(loc_y) * ((ctx->domain_size[1] - 1) * a_dst->stride[0])
+ mg2d_bnd_is_upper(loc_x) * (ctx->domain_size[0] - 1);
if (bnd_y->type == MG2D_BC_TYPE_REFLECT)
@@ -307,17 +324,39 @@ static int residual_add_task(void *arg, unsigned int job_idx, unsigned int threa
return 0;
}
-static int solve_relax_step(EGSContext *ctx)
+static int solve_relax_step(EGSContext *ctx, int export_res)
{
+ EGSInternal *priv = ctx->priv;
EGSRelaxContext *r = ctx->solver_data;
+ int u_next_valid = priv->u_next_valid;
- mg2di_timer_start(&r->timer_correct);
- tp_execute(ctx->tp, ctx->domain_size[1], residual_add_task, ctx);
- mg2di_timer_stop(&r->timer_correct);
+ if (u_next_valid) {
+ NDArray *tmp = ctx->u;
+ ctx->u = priv->u_next;
+ priv->u_next = tmp;
+ priv->u_next_valid = 0;
+ }
+
+ if (export_res) {
+ if (u_next_valid)
+ residual_calc(ctx, 1);
+ else {
+ mg2di_timer_start(&r->timer_correct);
+ tp_execute(ctx->tp, ctx->domain_size[1], residual_add_task, ctx);
+ mg2di_timer_stop(&r->timer_correct);
- boundaries_apply(ctx, 0);
- residual_calc(ctx);
+ memset(priv->reflect_disable, 0, sizeof(priv->reflect_disable));
+
+ boundaries_apply(ctx, ctx->u, 0);
+ residual_calc(ctx, 1);
+ }
+ } else {
+ mg2di_assert(u_next_valid);
+ residual_calc(ctx, 0);
+ boundaries_apply(ctx, priv->u_next, 0);
+ priv->u_next_valid = 1;
+ }
return 0;
}
@@ -634,21 +673,21 @@ static int solve_exact(EGSContext *ctx)
mg2di_timer_stop(&ec->timer_export);
- boundaries_apply(ctx, 0);
- residual_calc(ctx);
+ boundaries_apply(ctx, ctx->u, 0);
+ residual_calc(ctx, 1);
return 0;
}
-int mg2di_egs_solve(EGSContext *ctx)
+int mg2di_egs_solve(EGSContext *ctx, int export_res)
{
int ret;
mg2di_timer_start(&ctx->timer_solve);
switch (ctx->solver_type) {
- case EGS_SOLVER_RELAXATION: ret = solve_relax_step(ctx); break;
- case EGS_SOLVER_EXACT: ret = solve_exact(ctx); break;
+ case EGS_SOLVER_RELAXATION: ret = solve_relax_step(ctx, export_res); break;
+ case EGS_SOLVER_EXACT: ret = solve_exact(ctx); break;
default: ret = -EINVAL;
}
@@ -803,8 +842,17 @@ finish:
mg2di_timer_stop(&ctx->timer_init);
if (ret >= 0) {
- boundaries_apply(ctx, 1);
- residual_calc(ctx);
+ memset(priv->reflect_disable, 0, sizeof(priv->reflect_disable));
+
+ boundaries_apply(ctx, ctx->u, 1);
+ if (ctx->solver_type == EGS_SOLVER_RELAXATION) {
+ boundaries_apply(ctx, priv->u_next, 1);
+
+ residual_calc(ctx, 0);
+ boundaries_apply(ctx, priv->u_next, 0);
+ priv->u_next_valid = 1;
+ } else
+ residual_calc(ctx, 1);
}
mg2di_timer_stop(&ctx->timer_solve);
@@ -825,15 +873,27 @@ static int arrays_alloc(EGSContext *ctx, const size_t domain_size[2])
size_padded[0] * MG2D_DIFF_COEFF_NB,
size_padded[1],
};
+ const size_t size_u[2] = {
+ size_padded[0] * 2,
+ size_padded[1],
+ };
const Slice slice[2] = { SLICE(ghosts, -ghosts, 1),
SLICE(ghosts, -ghosts, 1) };
int ret;
- ret = mg2di_ndarray_alloc(&priv->u_base, 2, size_padded, NDARRAY_ALLOC_ZERO);
+ ret = mg2di_ndarray_alloc(&priv->u_base, 2, size_u, NDARRAY_ALLOC_ZERO);
+ if (ret < 0)
+ return ret;
+
+ ret = mg2di_ndarray_slice(&ctx->u, priv->u_base,
+ (Slice [2]){ SLICE(ghosts, size_padded[0] - ghosts, 1),
+ SLICE(ghosts, -ghosts, 1) });
if (ret < 0)
return ret;
- ret = mg2di_ndarray_slice(&ctx->u, priv->u_base, slice);
+ ret = mg2di_ndarray_slice(&priv->u_next, priv->u_base,
+ (Slice [2]){ SLICE(size_padded[0] + ghosts, size_u[0] - ghosts, 1),
+ SLICE(ghosts, -ghosts, 1) });
if (ret < 0)
return ret;
@@ -1015,6 +1075,7 @@ void mg2di_egs_free(EGSContext **pctx)
}
mg2di_ndarray_free(&ctx->u);
+ mg2di_ndarray_free(&ctx->priv->u_next);
mg2di_ndarray_free(&ctx->priv->u_base);
mg2di_ndarray_free(&ctx->rhs);