aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2019-04-23 15:36:34 +0200
committerAnton Khirnov <anton@khirnov.net>2019-04-24 14:44:36 +0200
commit86ad823b9ade211bfa9361b61571933aff1c9d24 (patch)
tree7a44acde9b11a9958b953f50927441318c2beb5e
parent580740356c44658620bff6f9ddd8a006f04c31fc (diff)
egs: merge residual calc and correct when possible
Also, merge the reflect boundary condition into residual calc+add. Improves performance due to better locality.
-rw-r--r--ell_grid_solve.c115
-rw-r--r--ell_grid_solve.h2
-rw-r--r--mg2d.c11
-rw-r--r--relax_test.c2
-rw-r--r--residual_calc.asm27
-rw-r--r--residual_calc.c124
-rw-r--r--residual_calc.h3
7 files changed, 233 insertions, 51 deletions
diff --git a/ell_grid_solve.c b/ell_grid_solve.c
index 4cb87ed..cefb231 100644
--- a/ell_grid_solve.c
+++ b/ell_grid_solve.c
@@ -76,6 +76,9 @@ struct EGSInternal {
NDArray *rhs_base;
NDArray *residual_base;
+ NDArray *u_next;
+ int u_next_valid;
+
/* all the diff coeffs concatenated */
NDArray *diff_coeffs_public_base;
@@ -86,6 +89,7 @@ struct EGSInternal {
size_t residual_calc_size[2];
int bnd_zero[4];
+ int reflect_disable[4];
ResidualCalcContext *rescalc;
@@ -116,7 +120,7 @@ static const double fd_denoms[][MG2D_DIFF_COEFF_NB] = {
},
};
-static void residual_calc(EGSContext *ctx)
+static void residual_calc(EGSContext *ctx, int export_res)
{
EGSInternal *priv = ctx->priv;
const double *diff_coeffs[MG2D_DIFF_COEFF_NB];
@@ -127,15 +131,26 @@ static void residual_calc(EGSContext *ctx)
diff_coeffs[i] = NDPTR2D(priv->diff_coeffs[i], priv->residual_calc_offset[1], priv->residual_calc_offset[0]);
mg2di_residual_calc(priv->rescalc, priv->residual_calc_size, &ctx->residual_max,
- NDPTR2D(ctx->residual, priv->residual_calc_offset[1], priv->residual_calc_offset[0]),
- ctx->residual->stride[0],
+ NDPTR2D(export_res ? ctx->residual : priv->u_next, priv->residual_calc_offset[1], priv->residual_calc_offset[0]),
+ export_res ? ctx->residual->stride[0] : priv->u_next->stride[0],
NDPTR2D(ctx->u, priv->residual_calc_offset[1], priv->residual_calc_offset[0]),
ctx->u->stride[0],
NDPTR2D(ctx->rhs, priv->residual_calc_offset[1], priv->residual_calc_offset[0]),
ctx->rhs->stride[0],
- diff_coeffs, priv->diff_coeffs[0]->stride[0]);
+ diff_coeffs, priv->diff_coeffs[0]->stride[0],
+ export_res ? 0.0 : 1.0, export_res ? 1.0 : priv->r.relax_factor,
+ export_res ? 0 :
+ ((ctx->boundaries[MG2D_BOUNDARY_0L]->type == MG2D_BC_TYPE_REFLECT) << 0) |
+ ((ctx->boundaries[MG2D_BOUNDARY_0U]->type == MG2D_BC_TYPE_REFLECT) << 1) |
+ ((ctx->boundaries[MG2D_BOUNDARY_1L]->type == MG2D_BC_TYPE_REFLECT) << 2));
mg2di_timer_stop(&ctx->timer_res_calc);
+
+ if (!export_res) {
+ priv->reflect_disable[MG2D_BOUNDARY_0L] = 1;
+ priv->reflect_disable[MG2D_BOUNDARY_0U] = 1;
+ priv->reflect_disable[MG2D_BOUNDARY_1L] = 1;
+ }
}
static void boundaries_apply_fixval(double *dst, const ptrdiff_t dst_stride[2],
@@ -195,11 +210,11 @@ static void boundaries_apply_falloff(double *dst, const ptrdiff_t dst_stride[2],
}
}
-static void boundaries_apply(EGSContext *ctx, int init)
+static void boundaries_apply(EGSContext *ctx, NDArray *a_dst, int init)
{
static const enum MG2DBCType bnd_type_order[] = { MG2D_BC_TYPE_FIXVAL, MG2D_BC_TYPE_REFLECT, MG2D_BC_TYPE_FALLOFF };
EGSInternal *priv = ctx->priv;
- const ptrdiff_t strides[2] = { 1, ctx->u->stride[0] };
+ const ptrdiff_t strides[2] = { 1, a_dst->stride[0] };
mg2di_timer_start(&ctx->timer_bnd);
@@ -213,7 +228,7 @@ static void boundaries_apply(EGSContext *ctx, int init)
const size_t size_boundary = ctx->domain_size[!ci];
const size_t size_offset = ctx->domain_size[ci];
- double *dst = ctx->u->data + mg2d_bnd_is_upper(i) * ((size_offset - 1) * stride_offset);
+ double *dst = a_dst->data + mg2d_bnd_is_upper(i) * ((size_offset - 1) * stride_offset);
const ptrdiff_t dst_strides[] = { stride_boundary, mg2d_bnd_out_dir(i) * stride_offset };
if (bnd->type != bnd_type_order[order_idx])
@@ -229,9 +244,11 @@ static void boundaries_apply(EGSContext *ctx, int init)
}
break;
case MG2D_BC_TYPE_REFLECT:
- mg2di_timer_start(&ctx->timer_bnd_reflect);
- boundaries_apply_reflect(dst, dst_strides, size_boundary);
- mg2di_timer_stop(&ctx->timer_bnd_reflect);
+ if (!priv->reflect_disable[i]) {
+ mg2di_timer_start(&ctx->timer_bnd_reflect);
+ boundaries_apply_reflect(dst, dst_strides, size_boundary);
+ mg2di_timer_stop(&ctx->timer_bnd_reflect);
+ }
break;
case MG2D_BC_TYPE_FALLOFF:
mg2di_timer_start(&ctx->timer_bnd_falloff);
@@ -255,8 +272,8 @@ static void boundaries_apply(EGSContext *ctx, int init)
const int dir_x = mg2d_bnd_out_dir(loc_x);
int fact_x = dir_x, fact_y = dir_y;
- double *dst = ctx->u->data
- + mg2d_bnd_is_upper(loc_y) * ((ctx->domain_size[1] - 1) * ctx->u->stride[0])
+ double *dst = a_dst->data
+ + mg2d_bnd_is_upper(loc_y) * ((ctx->domain_size[1] - 1) * a_dst->stride[0])
+ mg2d_bnd_is_upper(loc_x) * (ctx->domain_size[0] - 1);
if (bnd_y->type == MG2D_BC_TYPE_REFLECT)
@@ -307,17 +324,39 @@ static int residual_add_task(void *arg, unsigned int job_idx, unsigned int threa
return 0;
}
-static int solve_relax_step(EGSContext *ctx)
+static int solve_relax_step(EGSContext *ctx, int export_res)
{
+ EGSInternal *priv = ctx->priv;
EGSRelaxContext *r = ctx->solver_data;
+ int u_next_valid = priv->u_next_valid;
- mg2di_timer_start(&r->timer_correct);
- tp_execute(ctx->tp, ctx->domain_size[1], residual_add_task, ctx);
- mg2di_timer_stop(&r->timer_correct);
+ if (u_next_valid) {
+ NDArray *tmp = ctx->u;
+ ctx->u = priv->u_next;
+ priv->u_next = tmp;
+ priv->u_next_valid = 0;
+ }
+
+ if (export_res) {
+ if (u_next_valid)
+ residual_calc(ctx, 1);
+ else {
+ mg2di_timer_start(&r->timer_correct);
+ tp_execute(ctx->tp, ctx->domain_size[1], residual_add_task, ctx);
+ mg2di_timer_stop(&r->timer_correct);
- boundaries_apply(ctx, 0);
- residual_calc(ctx);
+ memset(priv->reflect_disable, 0, sizeof(priv->reflect_disable));
+
+ boundaries_apply(ctx, ctx->u, 0);
+ residual_calc(ctx, 1);
+ }
+ } else {
+ mg2di_assert(u_next_valid);
+ residual_calc(ctx, 0);
+ boundaries_apply(ctx, priv->u_next, 0);
+ priv->u_next_valid = 1;
+ }
return 0;
}
@@ -634,21 +673,21 @@ static int solve_exact(EGSContext *ctx)
mg2di_timer_stop(&ec->timer_export);
- boundaries_apply(ctx, 0);
- residual_calc(ctx);
+ boundaries_apply(ctx, ctx->u, 0);
+ residual_calc(ctx, 1);
return 0;
}
-int mg2di_egs_solve(EGSContext *ctx)
+int mg2di_egs_solve(EGSContext *ctx, int export_res)
{
int ret;
mg2di_timer_start(&ctx->timer_solve);
switch (ctx->solver_type) {
- case EGS_SOLVER_RELAXATION: ret = solve_relax_step(ctx); break;
- case EGS_SOLVER_EXACT: ret = solve_exact(ctx); break;
+ case EGS_SOLVER_RELAXATION: ret = solve_relax_step(ctx, export_res); break;
+ case EGS_SOLVER_EXACT: ret = solve_exact(ctx); break;
default: ret = -EINVAL;
}
@@ -803,8 +842,17 @@ finish:
mg2di_timer_stop(&ctx->timer_init);
if (ret >= 0) {
- boundaries_apply(ctx, 1);
- residual_calc(ctx);
+ memset(priv->reflect_disable, 0, sizeof(priv->reflect_disable));
+
+ boundaries_apply(ctx, ctx->u, 1);
+ if (ctx->solver_type == EGS_SOLVER_RELAXATION) {
+ boundaries_apply(ctx, priv->u_next, 1);
+
+ residual_calc(ctx, 0);
+ boundaries_apply(ctx, priv->u_next, 0);
+ priv->u_next_valid = 1;
+ } else
+ residual_calc(ctx, 1);
}
mg2di_timer_stop(&ctx->timer_solve);
@@ -825,15 +873,27 @@ static int arrays_alloc(EGSContext *ctx, const size_t domain_size[2])
size_padded[0] * MG2D_DIFF_COEFF_NB,
size_padded[1],
};
+ const size_t size_u[2] = {
+ size_padded[0] * 2,
+ size_padded[1],
+ };
const Slice slice[2] = { SLICE(ghosts, -ghosts, 1),
SLICE(ghosts, -ghosts, 1) };
int ret;
- ret = mg2di_ndarray_alloc(&priv->u_base, 2, size_padded, NDARRAY_ALLOC_ZERO);
+ ret = mg2di_ndarray_alloc(&priv->u_base, 2, size_u, NDARRAY_ALLOC_ZERO);
+ if (ret < 0)
+ return ret;
+
+ ret = mg2di_ndarray_slice(&ctx->u, priv->u_base,
+ (Slice [2]){ SLICE(ghosts, size_padded[0] - ghosts, 1),
+ SLICE(ghosts, -ghosts, 1) });
if (ret < 0)
return ret;
- ret = mg2di_ndarray_slice(&ctx->u, priv->u_base, slice);
+ ret = mg2di_ndarray_slice(&priv->u_next, priv->u_base,
+ (Slice [2]){ SLICE(size_padded[0] + ghosts, size_u[0] - ghosts, 1),
+ SLICE(ghosts, -ghosts, 1) });
if (ret < 0)
return ret;
@@ -1015,6 +1075,7 @@ void mg2di_egs_free(EGSContext **pctx)
}
mg2di_ndarray_free(&ctx->u);
+ mg2di_ndarray_free(&ctx->priv->u_next);
mg2di_ndarray_free(&ctx->priv->u_base);
mg2di_ndarray_free(&ctx->rhs);
diff --git a/ell_grid_solve.h b/ell_grid_solve.h
index 8f984ce..6e59bf0 100644
--- a/ell_grid_solve.h
+++ b/ell_grid_solve.h
@@ -233,6 +233,6 @@ void mg2di_egs_free(EGSContext **ctx);
*
* @return 0 on success, a negative error code on failure.
*/
-int mg2di_egs_solve(EGSContext *ctx);
+int mg2di_egs_solve(EGSContext *ctx, int export_res);
#endif /* MG2D_ELL_GRID_SOLVE_H */
diff --git a/mg2d.c b/mg2d.c
index 6639879..59df5c1 100644
--- a/mg2d.c
+++ b/mg2d.c
@@ -109,7 +109,8 @@ static int coarse_correct_task(void *arg, unsigned int job_idx, unsigned int thr
return 0;
}
-static int mg_relax_step(MG2DContext *ctx, MG2DLevel *level, const char *step_desc)
+static int mg_relax_step(MG2DContext *ctx, MG2DLevel *level, const char *step_desc,
+ int export_res)
{
double res_old;
int ret;
@@ -117,7 +118,7 @@ static int mg_relax_step(MG2DContext *ctx, MG2DLevel *level, const char *step_de
res_old = level->solver->residual_max;
mg2di_timer_start(&level->timer_solve);
- ret = mg2di_egs_solve(level->solver);
+ ret = mg2di_egs_solve(level->solver, export_res);
mg2di_timer_stop(&level->timer_solve);
if (ret < 0)
@@ -153,7 +154,7 @@ static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level)
/* handle coarsest grid */
if (!level->child) {
- ret = mg_relax_step(ctx, level, "coarse-step");
+ ret = mg_relax_step(ctx, level, "coarse-step", 1);
if (ret < 0)
return ret;
level->count_cycles++;
@@ -166,7 +167,7 @@ static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level)
/* pre-restrict relaxation */
for (int j = 0; j < ctx->nb_relax_pre; j++) {
- ret = mg_relax_step(ctx, level, "pre-relax");
+ ret = mg_relax_step(ctx, level, "pre-relax", j == ctx->nb_relax_pre - 1);
if (ret < 0)
return ret;
}
@@ -209,7 +210,7 @@ static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level)
/* post-correct relaxation */
for (int j = 0; j < ctx->nb_relax_post; j++) {
- ret = mg_relax_step(ctx, level, "post-relax");
+ ret = mg_relax_step(ctx, level, "post-relax", 0);
if (ret < 0)
return ret;
}
diff --git a/relax_test.c b/relax_test.c
index 5c07107..50ad04c 100644
--- a/relax_test.c
+++ b/relax_test.c
@@ -144,7 +144,7 @@ int main(int argc, char **argv)
res_old = findmax(ctx->residual->data, ctx->residual->stride[0] * ctx->domain_size[1]);
for (int i = 0; i < maxiter; i++) {
- ret = mg2di_egs_solve(ctx);
+ ret = mg2di_egs_solve(ctx, 0);
if (ret < 0) {
fprintf(stderr, "Error during relaxation\n");
ret = 1;
diff --git a/residual_calc.asm b/residual_calc.asm
index 3a5b800..5eea31c 100644
--- a/residual_calc.asm
+++ b/residual_calc.asm
@@ -39,6 +39,8 @@ SECTION .text
; mm register allocation (both s1 and s2)
; m0: accumulator for the residual
+; m1: dst mult factor
+; m2: res mult factor
; m6-m11: working registers
; m12: max(fabs(residual))
; m13: mask for computing absolute values
@@ -140,9 +142,14 @@ SECTION .text
%endmacro
; %1: stencil
-%macro RESIDUAL_CALC 1
+; %2: 0 - calc; 1 - add
+%macro RESIDUAL_CALC 2
%define stencil %1
+%if %2
+ vpermq m2, m1, 0
+%endif
+ vpermq m1, m0, 0
; compute the mask for absolute value
pcmpeqq m13, m13
@@ -195,6 +202,9 @@ SECTION .text
; plain value
movu m6, [uq + offsetq] ; m6 = u[x]
vfmadd231pd m0, m6, [diff_coeffs00q + offsetq] ; res += u * diff_coeffs00
+%if %2
+ mulpd m3, m6, m2
+%endif
%if stencil == 1
addpd m6, m6 ; m6 = 2 * u[x]
@@ -207,6 +217,10 @@ SECTION .text
RES_ADD_DIFF_MIXED stencil
andpd m6, m0, m13 ; m6 = abs(res)
+ mulpd m0, m1
+%if %2
+ addpd m0, m3
+%endif
; store the result
add offsetq, mmsize
@@ -255,9 +269,16 @@ SECTION .text
INIT_YMM fma3
cglobal residual_calc_line_s1, 7, 14, 14, linesize, dst, res_max, stride, u, rhs, diff_coeffs,\
diff_coeffs00, diff_coeffs01, diff_coeffs10, diff_coeffs11, diff_coeffs02, u_down, u_up
-RESIDUAL_CALC 1
+RESIDUAL_CALC 1, 0
+cglobal residual_add_line_s1, 7, 14, 14, linesize, dst, res_max, stride, u, rhs, diff_coeffs,\
+ diff_coeffs00, diff_coeffs01, diff_coeffs10, diff_coeffs11, diff_coeffs02, u_down, u_up
+RESIDUAL_CALC 1, 1
INIT_YMM fma3
cglobal residual_calc_line_s2, 7, 15, 16, linesize, dst, res_max, stride, u, rhs, diff_coeffs,\
diff_coeffs00, diff_coeffs01, diff_coeffs10, diff_coeffs11, diff_coeffs02, u_down, u_up, u_up2
-RESIDUAL_CALC 2
+RESIDUAL_CALC 2, 0
+
+cglobal residual_add_line_s2, 7, 15, 16, linesize, dst, res_max, stride, u, rhs, diff_coeffs,\
+ diff_coeffs00, diff_coeffs01, diff_coeffs10, diff_coeffs11, diff_coeffs02, u_down, u_up, u_up2
+RESIDUAL_CALC 2, 1
diff --git a/residual_calc.c b/residual_calc.c
index bfbb9bf..3b83c63 100644
--- a/residual_calc.c
+++ b/residual_calc.c
@@ -44,6 +44,11 @@ typedef struct ResidualCalcTask {
const double * const *diff_coeffs;
ptrdiff_t diff_coeffs_stride;
+
+ double u_mult;
+ double res_mult;
+
+ int mirror_line;
} ResidualCalcTask;
struct ResidualCalcInternal {
@@ -52,7 +57,12 @@ struct ResidualCalcInternal {
void (*residual_calc_line)(size_t linesize, double *dst, double *dst_max,
ptrdiff_t stride, const double *u, const double *rhs,
- const double * const diff_coeffs[MG2D_DIFF_COEFF_NB]);
+ const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
+ double res_mult);
+ void (*residual_add_line)(size_t linesize, double *dst, double *dst_max,
+ ptrdiff_t stride, const double *u, const double *rhs,
+ const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
+ double res_mult, double u_mult);
size_t calc_blocksize;
ResidualCalcTask task;
@@ -61,10 +71,20 @@ struct ResidualCalcInternal {
#if HAVE_EXTERNAL_ASM
void mg2di_residual_calc_line_s1_fma3(size_t linesize, double *dst, double *dst_max,
ptrdiff_t stride, const double *u, const double *rhs,
- const double * const diff_coeffs[MG2D_DIFF_COEFF_NB]);
+ const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
+ double res_mult);
+void mg2di_residual_add_line_s1_fma3(size_t linesize, double *dst, double *dst_max,
+ ptrdiff_t stride, const double *u, const double *rhs,
+ const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
+ double res_mult, double u_mult);
void mg2di_residual_calc_line_s2_fma3(size_t linesize, double *dst, double *dst_max,
ptrdiff_t stride, const double *u, const double *rhs,
- const double * const diff_coeffs[MG2D_DIFF_COEFF_NB]);
+ const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
+ double res_mult);
+void mg2di_residual_add_line_s2_fma3(size_t linesize, double *dst, double *dst_max,
+ ptrdiff_t stride, const double *u, const double *rhs,
+ const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
+ double res_mult, double u_mult);
#endif
static void
@@ -129,7 +149,8 @@ derivatives_calc_s2(double *dst, const double *u, ptrdiff_t stride)
static void residual_calc_line_s1_c(size_t linesize, double *dst, double *dst_max,
ptrdiff_t stride, const double *u, const double *rhs,
- const double * const diff_coeffs[MG2D_DIFF_COEFF_NB])
+ const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
+ double res_mult)
{
double res_max = 0.0, res_abs;
for (size_t i = 0; i < linesize; i++) {
@@ -141,7 +162,31 @@ static void residual_calc_line_s1_c(size_t linesize, double *dst, double *dst_ma
res = -rhs[i];
for (int j = 0; j < ARRAY_ELEMS(u_vals); j++)
res += u_vals[j] * diff_coeffs[j][i];
- dst[i] = res;
+ dst[i] = res_mult * res;
+
+ res_abs = fabs(res);
+ res_max = MAX(res_max, res_abs);
+ }
+
+ *dst_max = MAX(*dst_max, res_max);
+}
+
+static void residual_add_line_s1_c(size_t linesize, double *dst, double *dst_max,
+ ptrdiff_t stride, const double *u, const double *rhs,
+ const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
+ double res_mult, double u_mult)
+{
+ double res_max = 0.0, res_abs;
+ for (size_t i = 0; i < linesize; i++) {
+ double u_vals[MG2D_DIFF_COEFF_NB];
+ double res;
+
+ derivatives_calc_s1(u_vals, u + i, stride);
+
+ res = -rhs[i];
+ for (int j = 0; j < ARRAY_ELEMS(u_vals); j++)
+ res += u_vals[j] * diff_coeffs[j][i];
+ dst[i] = u_mult * u[i] + res_mult * res;
res_abs = fabs(res);
res_max = MAX(res_max, res_abs);
@@ -152,7 +197,8 @@ static void residual_calc_line_s1_c(size_t linesize, double *dst, double *dst_ma
static void residual_calc_line_s2_c(size_t linesize, double *dst, double *dst_max,
ptrdiff_t stride, const double *u, const double *rhs,
- const double * const diff_coeffs[MG2D_DIFF_COEFF_NB])
+ const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
+ double res_mult)
{
double res_max = 0.0, res_abs;
for (size_t i = 0; i < linesize; i++) {
@@ -164,7 +210,31 @@ static void residual_calc_line_s2_c(size_t linesize, double *dst, double *dst_ma
res = -rhs[i];
for (int j = 0; j < ARRAY_ELEMS(u_vals); j++)
res += u_vals[j] * diff_coeffs[j][i];
- dst[i] = res;
+ dst[i] = res_mult * res;
+
+ res_abs = fabs(res);
+ res_max = MAX(res_max, res_abs);
+ }
+
+ *dst_max = MAX(*dst_max, res_max);
+}
+
+static void residual_add_line_s2_c(size_t linesize, double *dst, double *dst_max,
+ ptrdiff_t stride, const double *u, const double *rhs,
+ const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
+ double res_mult, double u_mult)
+{
+ double res_max = 0.0, res_abs;
+ for (size_t i = 0; i < linesize; i++) {
+ double u_vals[MG2D_DIFF_COEFF_NB];
+ double res;
+
+ derivatives_calc_s2(u_vals, u + i, stride);
+
+ res = -rhs[i];
+ for (int j = 0; j < ARRAY_ELEMS(u_vals); j++)
+ res += u_vals[j] * diff_coeffs[j][i];
+ dst[i] = u_mult * u[i] + res_mult * res;
res_abs = fabs(res);
res_max = MAX(res_max, res_abs);
@@ -179,15 +249,35 @@ static int residual_calc_task(void *arg, unsigned int job_idx, unsigned int thre
ResidualCalcTask *task = &priv->task;
const double *diff_coeffs[MG2D_DIFF_COEFF_NB];
+ double *dst = task->dst + job_idx * task->dst_stride;
for (int i = 0; i < ARRAY_ELEMS(diff_coeffs); i++)
diff_coeffs[i] = task->diff_coeffs[i] + job_idx * task->diff_coeffs_stride;
- priv->residual_calc_line(task->line_size, task->dst + job_idx * task->dst_stride,
- priv->residual_max + thread_idx * priv->calc_blocksize,
- task->u_stride, task->u + job_idx * task->u_stride,
- task->rhs + job_idx * task->rhs_stride,
- diff_coeffs);
+ if (task->u_mult == 0.0) {
+ priv->residual_calc_line(task->line_size, dst,
+ priv->residual_max + thread_idx * priv->calc_blocksize,
+ task->u_stride, task->u + job_idx * task->u_stride,
+ task->rhs + job_idx * task->rhs_stride,
+ diff_coeffs, task->res_mult);
+ } else {
+ priv->residual_add_line(task->line_size, dst,
+ priv->residual_max + thread_idx * priv->calc_blocksize,
+ task->u_stride, task->u + job_idx * task->u_stride,
+ task->rhs + job_idx * task->rhs_stride,
+ diff_coeffs, task->res_mult, task->u_mult);
+ }
+ if (task->mirror_line & (1 << 0)) {
+ for (int i = 1; i <= FD_STENCIL_MAX; i++)
+ dst[-i] = dst[i];
+ }
+ if (task->mirror_line & (1 << 1)) {
+ for (int i = 1; i <= FD_STENCIL_MAX; i++)
+ dst[task->line_size - 1 + i] = dst[task->line_size - 1 - i];
+ }
+ if ((task->mirror_line & (1 << 2)) && job_idx > 0 && job_idx <= FD_STENCIL_MAX) {
+ memcpy(task->dst - job_idx * task->dst_stride, dst, sizeof(*dst) * task->line_size);
+ }
return 0;
}
@@ -198,7 +288,8 @@ int mg2di_residual_calc(ResidualCalcContext *ctx, size_t size[2],
const double *u, ptrdiff_t u_stride,
const double *rhs, ptrdiff_t rhs_stride,
const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
- ptrdiff_t diff_coeffs_stride)
+ ptrdiff_t diff_coeffs_stride,
+ double u_mult, double res_mult, int mirror_line)
{
ResidualCalcInternal *priv = ctx->priv;
ResidualCalcTask *task = &priv->task;
@@ -215,6 +306,9 @@ int mg2di_residual_calc(ResidualCalcContext *ctx, size_t size[2],
task->rhs_stride = rhs_stride;
task->diff_coeffs = diff_coeffs;
task->diff_coeffs_stride = diff_coeffs_stride;
+ task->u_mult = u_mult;
+ task->res_mult = res_mult;
+ task->mirror_line = mirror_line;
tp_execute(ctx->tp, size[1], residual_calc_task, priv);
@@ -234,18 +328,22 @@ int mg2di_residual_calc_init(ResidualCalcContext *ctx)
switch (ctx->fd_stencil) {
case 1:
priv->residual_calc_line = residual_calc_line_s1_c;
+ priv->residual_add_line = residual_add_line_s1_c;
#if HAVE_EXTERNAL_ASM
if (ctx->cpuflags & MG2DI_CPU_FLAG_FMA3) {
priv->residual_calc_line = mg2di_residual_calc_line_s1_fma3;
+ priv->residual_add_line = mg2di_residual_add_line_s1_fma3;
priv->calc_blocksize = 4;
}
#endif
break;
case 2:
priv->residual_calc_line = residual_calc_line_s2_c;
+ priv->residual_add_line = residual_add_line_s2_c;
#if HAVE_EXTERNAL_ASM
if (ctx->cpuflags & MG2DI_CPU_FLAG_FMA3) {
priv->residual_calc_line = mg2di_residual_calc_line_s2_fma3;
+ priv->residual_add_line = mg2di_residual_add_line_s2_fma3;
priv->calc_blocksize = 4;
}
#endif
diff --git a/residual_calc.h b/residual_calc.h
index 31c1909..8c9b628 100644
--- a/residual_calc.h
+++ b/residual_calc.h
@@ -74,6 +74,7 @@ int mg2di_residual_calc(ResidualCalcContext *ctx, size_t size[2],
const double *u, ptrdiff_t u_stride,
const double *rhs, ptrdiff_t rhs_stride,
const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
- ptrdiff_t diff_coeffs_stride);
+ ptrdiff_t diff_coeffs_stride,
+ double u_mult, double res_mult, int mirror_line);
#endif // MG2D_RESIDUAL_CALC_H