aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2019-01-30 11:36:34 +0100
committerAnton Khirnov <anton@khirnov.net>2019-01-30 11:36:34 +0100
commitb584bfe20168ac6208154b1eef395b3805b35e77 (patch)
tree1882c7708e474adfc864bc3c1a1bd90a4b83d7fd
parent783d260e0d47d6adb4388fea9ed8e35122d4f6c2 (diff)
ell_grid_solve: split residual computation into its own file
-rw-r--r--ell_grid_solve.c223
-rw-r--r--meson.build1
-rw-r--r--residual_calc.c292
-rw-r--r--residual_calc.h77
4 files changed, 401 insertions, 192 deletions
diff --git a/ell_grid_solve.c b/ell_grid_solve.c
index 20de6bc..1893214 100644
--- a/ell_grid_solve.c
+++ b/ell_grid_solve.c
@@ -28,12 +28,12 @@
#include <lapacke.h>
#include "common.h"
-#include "cpu.h"
#include "ell_grid_solve.h"
#include "log.h"
#include "mg2d_boundary.h"
#include "mg2d_boundary_internal.h"
#include "mg2d_constants.h"
+#include "residual_calc.h"
static const double relax_factors[FD_STENCIL_MAX] = {
[0] = 1.0 / 5,
@@ -63,18 +63,12 @@ struct EGSInternal {
double *residual_base;
double *diff_coeffs_base[MG2D_DIFF_COEFF_NB];
- double *residual_max;
- size_t residual_max_size;
-
- void (*residual_calc_line)(size_t linesize, double *dst, double *dst_max,
- ptrdiff_t stride, const double *u, const double *rhs,
- const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
- const double *fd_factors);
- size_t calc_blocksize;
ptrdiff_t residual_calc_offset;
size_t residual_calc_size[2];
double fd_factors[MG2D_DIFF_COEFF_NB];
+ ResidualCalcContext *rescalc;
+
union {
EGSRelaxInternal r;
EGSExactInternal e;
@@ -124,158 +118,21 @@ static const double fd_denoms[][MG2D_DIFF_COEFF_NB] = {
},
};
-#if HAVE_EXTERNAL_ASM
-void mg2di_residual_calc_line_s1_fma3(size_t linesize, double *dst, double *dst_max,
- ptrdiff_t stride, const double *u, const double *rhs,
- const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
- const double *fd_factors);
-void mg2di_residual_calc_line_s2_fma3(size_t linesize, double *dst, double *dst_max,
- ptrdiff_t stride, const double *u, const double *rhs,
- const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
- const double *fd_factors);
-#endif
-
-static void
-derivatives_calc_s1(double *dst, const double *u, const double *fd_factors, ptrdiff_t stride)
-{
- dst[MG2D_DIFF_COEFF_00] = u[0];
- dst[MG2D_DIFF_COEFF_10] = (u[1] - u[-1]) * fd_factors[MG2D_DIFF_COEFF_10];
- dst[MG2D_DIFF_COEFF_01] = (u[stride] - u[-stride]) * fd_factors[MG2D_DIFF_COEFF_01];
-
- dst[MG2D_DIFF_COEFF_20] = (u[1] - 2.0 * u[0] + u[-1]) * fd_factors[MG2D_DIFF_COEFF_20];
- dst[MG2D_DIFF_COEFF_02] = (u[stride] - 2.0 * u[0] + u[-stride]) * fd_factors[MG2D_DIFF_COEFF_02];
-
- dst[MG2D_DIFF_COEFF_11] = (u[1 + stride] - u[stride - 1] - u[-stride + 1] + u[-stride - 1]) * fd_factors[MG2D_DIFF_COEFF_11];
-}
-
-static void
-derivatives_calc_s2(double *dst, const double *u, const double *fd_factors, ptrdiff_t stride)
-{
- const double val = u[0];
-
- const double valxp1 = u[ 1];
- const double valxp2 = u[ 2];
- const double valxm1 = u[-1];
- const double valxm2 = u[-2];
- const double valyp1 = u[ 1 * stride];
- const double valyp2 = u[ 2 * stride];
- const double valym1 = u[-1 * stride];
- const double valym2 = u[-2 * stride];
-
- const double valxp1yp1 = u[ 1 + 1 * stride];
- const double valxp1yp2 = u[ 1 + 2 * stride];
- const double valxp1ym1 = u[ 1 - 1 * stride];
- const double valxp1ym2 = u[ 1 - 2 * stride];
-
- const double valxp2yp1 = u[ 2 + 1 * stride];
- const double valxp2yp2 = u[ 2 + 2 * stride];
- const double valxp2ym1 = u[ 2 - 1 * stride];
- const double valxp2ym2 = u[ 2 - 2 * stride];
-
- const double valxm1yp1 = u[-1 + 1 * stride];
- const double valxm1yp2 = u[-1 + 2 * stride];
- const double valxm1ym1 = u[-1 - 1 * stride];
- const double valxm1ym2 = u[-1 - 2 * stride];
-
- const double valxm2yp1 = u[-2 + 1 * stride];
- const double valxm2yp2 = u[-2 + 2 * stride];
- const double valxm2ym1 = u[-2 - 1 * stride];
- const double valxm2ym2 = u[-2 - 2 * stride];
-
- dst[MG2D_DIFF_COEFF_00] = val;
- dst[MG2D_DIFF_COEFF_10] = (-1.0 * valxp2 + 8.0 * valxp1 - 8.0 * valxm1 + 1.0 * valxm2) * fd_factors[MG2D_DIFF_COEFF_10];
- dst[MG2D_DIFF_COEFF_01] = (-1.0 * valyp2 + 8.0 * valyp1 - 8.0 * valym1 + 1.0 * valym2) * fd_factors[MG2D_DIFF_COEFF_01];
-
- dst[MG2D_DIFF_COEFF_20] = (-1.0 * valxp2 + 16.0 * valxp1 - 30.0 * val + 16.0 * valxm1 - 1.0 * valxm2) * fd_factors[MG2D_DIFF_COEFF_20];
- dst[MG2D_DIFF_COEFF_02] = (-1.0 * valyp2 + 16.0 * valyp1 - 30.0 * val + 16.0 * valym1 - 1.0 * valym2) * fd_factors[MG2D_DIFF_COEFF_02];
-
- dst[MG2D_DIFF_COEFF_11] = ( 1.0 * valxp2yp2 - 8.0 * valxp2yp1 + 8.0 * valxp2ym1 - 1.0 * valxp2ym2
- -8.0 * valxp1yp2 + 64.0 * valxp1yp1 - 64.0 * valxp1ym1 + 8.0 * valxp1ym2
- +8.0 * valxm1yp2 - 64.0 * valxm1yp1 + 64.0 * valxm1ym1 - 8.0 * valxm1ym2
- -1.0 * valxm2yp2 + 8.0 * valxm2yp1 - 8.0 * valxm2ym1 + 1.0 * valxm2ym2) * fd_factors[MG2D_DIFF_COEFF_11];
-}
-
-static void residual_calc_line_s1_c(size_t linesize, double *dst, double *dst_max,
- ptrdiff_t stride, const double *u, const double *rhs,
- const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
- const double *fd_factors)
-{
- double res_max = 0.0, res_abs;
- for (size_t i = 0; i < linesize; i++) {
- double u_vals[MG2D_DIFF_COEFF_NB];
- double res;
-
- derivatives_calc_s1(u_vals, u + i, fd_factors, stride);
-
- res = -rhs[i];
- for (int j = 0; j < ARRAY_ELEMS(u_vals); j++)
- res += u_vals[j] * diff_coeffs[j][i];
- dst[i] = res;
-
- res_abs = fabs(res);
- res_max = MAX(res_max, res_abs);
- }
-
- *dst_max = MAX(*dst_max, res_max);
-}
-
-static void residual_calc_line_s2_c(size_t linesize, double *dst, double *dst_max,
- ptrdiff_t stride, const double *u, const double *rhs,
- const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
- const double *fd_factors)
-{
- double res_max = 0.0, res_abs;
- for (size_t i = 0; i < linesize; i++) {
- double u_vals[MG2D_DIFF_COEFF_NB];
- double res;
-
- derivatives_calc_s2(u_vals, u + i, fd_factors, stride);
-
- res = -rhs[i];
- for (int j = 0; j < ARRAY_ELEMS(u_vals); j++)
- res += u_vals[j] * diff_coeffs[j][i];
- dst[i] = res;
-
- res_abs = fabs(res);
- res_max = MAX(res_max, res_abs);
- }
-
- *dst_max = MAX(*dst_max, res_max);
-}
-
-static int residual_calc_task(void *arg, unsigned int job_idx, unsigned int thread_idx)
-{
- EGSContext *ctx = arg;
- EGSInternal *priv = ctx->priv;
- const ptrdiff_t offset = priv->residual_calc_offset + job_idx * priv->stride;
- const double *diff_coeffs[MG2D_DIFF_COEFF_NB];
-
- for (int i = 0; i < ARRAY_ELEMS(diff_coeffs); i++)
- diff_coeffs[i] = ctx->diff_coeffs[i] + offset;
-
- priv->residual_calc_line(priv->residual_calc_size[0], ctx->residual + offset,
- priv->residual_max + thread_idx * priv->calc_blocksize,
- priv->stride, ctx->u + offset, ctx->rhs + offset,
- diff_coeffs, priv->fd_factors);
-
- return 0;
-}
-
static void residual_calc(EGSContext *ctx)
{
EGSInternal *priv = ctx->priv;
- double res_max = 0.0;
+ const double *diff_coeffs[MG2D_DIFF_COEFF_NB];
int64_t start;
- memset(priv->residual_max, 0, sizeof(*priv->residual_max) * priv->residual_max_size);
-
start = gettime();
- tp_execute(ctx->tp, priv->residual_calc_size[1], residual_calc_task, ctx);
+ for (int i = 0; i < ARRAY_ELEMS(diff_coeffs); i++)
+ diff_coeffs[i] = ctx->diff_coeffs[i] + priv->residual_calc_offset;
- for (size_t i = 0; i < priv->residual_max_size; i++)
- res_max = MAX(res_max, priv->residual_max[i]);
- ctx->residual_max = res_max;
+ mg2di_residual_calc(priv->rescalc, priv->residual_calc_size, priv->stride,
+ &ctx->residual_max, ctx->residual + priv->residual_calc_offset,
+ ctx->u + priv->residual_calc_offset, ctx->rhs + priv->residual_calc_offset,
+ diff_coeffs, priv->fd_factors);
ctx->time_res_calc += gettime() - start;
ctx->count_res++;
@@ -704,39 +561,17 @@ int mg2di_egs_solve(EGSContext *ctx)
int mg2di_egs_init(EGSContext *ctx)
{
EGSInternal *priv = ctx->priv;
- double *tmp;
int ret;
- priv->calc_blocksize = 1;
- switch (ctx->fd_stencil) {
- case 1:
- if (ctx->solver_type == EGS_SOLVER_EXACT)
- priv->e.fill_mat = fill_mat_s1;
-
- priv->residual_calc_line = residual_calc_line_s1_c;
-#if HAVE_EXTERNAL_ASM
- if (ctx->cpuflags & MG2DI_CPU_FLAG_FMA3) {
- priv->residual_calc_line = mg2di_residual_calc_line_s1_fma3;
- priv->calc_blocksize = 4;
- }
-#endif
- break;
- case 2:
- if (ctx->solver_type == EGS_SOLVER_EXACT)
- priv->e.fill_mat = fill_mat_s2;
-
- priv->residual_calc_line = residual_calc_line_s2_c;
-#if HAVE_EXTERNAL_ASM
- if (ctx->cpuflags & MG2DI_CPU_FLAG_FMA3) {
- priv->residual_calc_line = mg2di_residual_calc_line_s2_fma3;
- priv->calc_blocksize = 4;
+ if (ctx->solver_type == EGS_SOLVER_EXACT) {
+ switch (ctx->fd_stencil) {
+ case 1: priv->e.fill_mat = fill_mat_s1; break;
+ case 2: priv->e.fill_mat = fill_mat_s2; break;
+ default:
+ mg2di_log(&ctx->logger, 0, "Invalid finite difference stencil: %zd\n",
+ ctx->fd_stencil);
+ return -EINVAL;
}
-#endif
- break;
- default:
- mg2di_log(&ctx->logger, 0, "Invalid finite difference stencil: %zd\n",
- ctx->fd_stencil);
- return -EINVAL;
}
if (ctx->step[0] <= DBL_EPSILON || ctx->step[1] <= DBL_EPSILON) {
@@ -791,14 +626,13 @@ int mg2di_egs_init(EGSContext *ctx)
}
}
- priv->residual_max_size = tp_get_nb_threads(ctx->tp) * priv->calc_blocksize;
- tmp = realloc(priv->residual_max,
- sizeof(*priv->residual_max) * priv->residual_max_size);
- if (!tmp) {
- priv->residual_max_size = 0;
- return -ENOMEM;
- }
- priv->residual_max = tmp;
+ priv->rescalc->tp = ctx->tp;
+ priv->rescalc->fd_stencil = ctx->fd_stencil;
+ priv->rescalc->cpuflags = ctx->cpuflags;
+
+ ret = mg2di_residual_calc_init(priv->rescalc);
+ if (ret < 0)
+ return ret;
boundaries_apply(ctx);
residual_calc(ctx);
@@ -909,6 +743,10 @@ EGSContext *mg2di_egs_alloc(enum EGSType type, size_t domain_size[2])
ctx->domain_size[0] = domain_size[0];
ctx->domain_size[1] = domain_size[1];
+ ctx->priv->rescalc = mg2di_residual_calc_alloc();
+ if (!ctx->priv->rescalc)
+ goto fail;
+
return ctx;
fail:
mg2di_egs_free(&ctx);
@@ -922,6 +760,8 @@ void mg2di_egs_free(EGSContext **pctx)
if (!ctx)
return;
+ mg2di_residual_calc_free(&ctx->priv->rescalc);
+
free(ctx->solver_data);
if (ctx->solver_type == EGS_SOLVER_EXACT) {
@@ -936,7 +776,6 @@ void mg2di_egs_free(EGSContext **pctx)
free(ctx->priv->u_base);
free(ctx->priv->rhs_base);
free(ctx->priv->residual_base);
- free(ctx->priv->residual_max);
for (int i = 0; i < ARRAY_ELEMS(ctx->priv->diff_coeffs_base); i++)
free(ctx->priv->diff_coeffs_base[i]);
for (int i = 0; i < ARRAY_ELEMS(ctx->boundaries); i++)
diff --git a/meson.build b/meson.build
index c672c28..fb67c79 100644
--- a/meson.build
+++ b/meson.build
@@ -9,6 +9,7 @@ lib_src = [
'ell_grid_solve.c',
'log.c',
'mg2d.c',
+ 'residual_calc.c',
]
lib_obj = []
diff --git a/residual_calc.c b/residual_calc.c
new file mode 100644
index 0000000..2fc1a66
--- /dev/null
+++ b/residual_calc.c
@@ -0,0 +1,292 @@
+/*
+ * Copyright 2019 Anton Khirnov <anton@khirnov.net>
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "config.h"
+
+#include <errno.h>
+#include <math.h>
+#include <stddef.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include <threadpool.h>
+
+#include "common.h"
+#include "cpu.h"
+#include "mg2d_constants.h"
+#include "residual_calc.h"
+
+typedef struct ResidualCalcTask {
+ size_t line_size;
+ ptrdiff_t stride;
+
+ double *dst;
+ const double *u;
+ const double *rhs;
+ const double * const *diff_coeffs;
+ const double *fd_factors;
+} ResidualCalcTask;
+
+struct ResidualCalcInternal {
+ double *residual_max;
+ size_t residual_max_size;
+
+ void (*residual_calc_line)(size_t linesize, double *dst, double *dst_max,
+ ptrdiff_t stride, const double *u, const double *rhs,
+ const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
+ const double *fd_factors);
+ size_t calc_blocksize;
+
+ ResidualCalcTask task;
+};
+
+#if HAVE_EXTERNAL_ASM
+void mg2di_residual_calc_line_s1_fma3(size_t linesize, double *dst, double *dst_max,
+ ptrdiff_t stride, const double *u, const double *rhs,
+ const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
+ const double *fd_factors);
+void mg2di_residual_calc_line_s2_fma3(size_t linesize, double *dst, double *dst_max,
+ ptrdiff_t stride, const double *u, const double *rhs,
+ const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
+ const double *fd_factors);
+#endif
+
+static void
+derivatives_calc_s1(double *dst, const double *u, const double *fd_factors, ptrdiff_t stride)
+{
+ dst[MG2D_DIFF_COEFF_00] = u[0];
+ dst[MG2D_DIFF_COEFF_10] = (u[1] - u[-1]) * fd_factors[MG2D_DIFF_COEFF_10];
+ dst[MG2D_DIFF_COEFF_01] = (u[stride] - u[-stride]) * fd_factors[MG2D_DIFF_COEFF_01];
+
+ dst[MG2D_DIFF_COEFF_20] = (u[1] - 2.0 * u[0] + u[-1]) * fd_factors[MG2D_DIFF_COEFF_20];
+ dst[MG2D_DIFF_COEFF_02] = (u[stride] - 2.0 * u[0] + u[-stride]) * fd_factors[MG2D_DIFF_COEFF_02];
+
+ dst[MG2D_DIFF_COEFF_11] = (u[1 + stride] - u[stride - 1] - u[-stride + 1] + u[-stride - 1]) * fd_factors[MG2D_DIFF_COEFF_11];
+}
+
+static void
+derivatives_calc_s2(double *dst, const double *u, const double *fd_factors, ptrdiff_t stride)
+{
+ const double val = u[0];
+
+ const double valxp1 = u[ 1];
+ const double valxp2 = u[ 2];
+ const double valxm1 = u[-1];
+ const double valxm2 = u[-2];
+ const double valyp1 = u[ 1 * stride];
+ const double valyp2 = u[ 2 * stride];
+ const double valym1 = u[-1 * stride];
+ const double valym2 = u[-2 * stride];
+
+ const double valxp1yp1 = u[ 1 + 1 * stride];
+ const double valxp1yp2 = u[ 1 + 2 * stride];
+ const double valxp1ym1 = u[ 1 - 1 * stride];
+ const double valxp1ym2 = u[ 1 - 2 * stride];
+
+ const double valxp2yp1 = u[ 2 + 1 * stride];
+ const double valxp2yp2 = u[ 2 + 2 * stride];
+ const double valxp2ym1 = u[ 2 - 1 * stride];
+ const double valxp2ym2 = u[ 2 - 2 * stride];
+
+ const double valxm1yp1 = u[-1 + 1 * stride];
+ const double valxm1yp2 = u[-1 + 2 * stride];
+ const double valxm1ym1 = u[-1 - 1 * stride];
+ const double valxm1ym2 = u[-1 - 2 * stride];
+
+ const double valxm2yp1 = u[-2 + 1 * stride];
+ const double valxm2yp2 = u[-2 + 2 * stride];
+ const double valxm2ym1 = u[-2 - 1 * stride];
+ const double valxm2ym2 = u[-2 - 2 * stride];
+
+ dst[MG2D_DIFF_COEFF_00] = val;
+ dst[MG2D_DIFF_COEFF_10] = (-1.0 * valxp2 + 8.0 * valxp1 - 8.0 * valxm1 + 1.0 * valxm2) * fd_factors[MG2D_DIFF_COEFF_10];
+ dst[MG2D_DIFF_COEFF_01] = (-1.0 * valyp2 + 8.0 * valyp1 - 8.0 * valym1 + 1.0 * valym2) * fd_factors[MG2D_DIFF_COEFF_01];
+
+ dst[MG2D_DIFF_COEFF_20] = (-1.0 * valxp2 + 16.0 * valxp1 - 30.0 * val + 16.0 * valxm1 - 1.0 * valxm2) * fd_factors[MG2D_DIFF_COEFF_20];
+ dst[MG2D_DIFF_COEFF_02] = (-1.0 * valyp2 + 16.0 * valyp1 - 30.0 * val + 16.0 * valym1 - 1.0 * valym2) * fd_factors[MG2D_DIFF_COEFF_02];
+
+ dst[MG2D_DIFF_COEFF_11] = ( 1.0 * valxp2yp2 - 8.0 * valxp2yp1 + 8.0 * valxp2ym1 - 1.0 * valxp2ym2
+ -8.0 * valxp1yp2 + 64.0 * valxp1yp1 - 64.0 * valxp1ym1 + 8.0 * valxp1ym2
+ +8.0 * valxm1yp2 - 64.0 * valxm1yp1 + 64.0 * valxm1ym1 - 8.0 * valxm1ym2
+ -1.0 * valxm2yp2 + 8.0 * valxm2yp1 - 8.0 * valxm2ym1 + 1.0 * valxm2ym2) * fd_factors[MG2D_DIFF_COEFF_11];
+}
+
+static void residual_calc_line_s1_c(size_t linesize, double *dst, double *dst_max,
+ ptrdiff_t stride, const double *u, const double *rhs,
+ const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
+ const double *fd_factors)
+{
+ double res_max = 0.0, res_abs;
+ for (size_t i = 0; i < linesize; i++) {
+ double u_vals[MG2D_DIFF_COEFF_NB];
+ double res;
+
+ derivatives_calc_s1(u_vals, u + i, fd_factors, stride);
+
+ res = -rhs[i];
+ for (int j = 0; j < ARRAY_ELEMS(u_vals); j++)
+ res += u_vals[j] * diff_coeffs[j][i];
+ dst[i] = res;
+
+ res_abs = fabs(res);
+ res_max = MAX(res_max, res_abs);
+ }
+
+ *dst_max = MAX(*dst_max, res_max);
+}
+
+static void residual_calc_line_s2_c(size_t linesize, double *dst, double *dst_max,
+ ptrdiff_t stride, const double *u, const double *rhs,
+ const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
+ const double *fd_factors)
+{
+ double res_max = 0.0, res_abs;
+ for (size_t i = 0; i < linesize; i++) {
+ double u_vals[MG2D_DIFF_COEFF_NB];
+ double res;
+
+ derivatives_calc_s2(u_vals, u + i, fd_factors, stride);
+
+ res = -rhs[i];
+ for (int j = 0; j < ARRAY_ELEMS(u_vals); j++)
+ res += u_vals[j] * diff_coeffs[j][i];
+ dst[i] = res;
+
+ res_abs = fabs(res);
+ res_max = MAX(res_max, res_abs);
+ }
+
+ *dst_max = MAX(*dst_max, res_max);
+}
+
+static int residual_calc_task(void *arg, unsigned int job_idx, unsigned int thread_idx)
+{
+ ResidualCalcInternal *priv = arg;
+ ResidualCalcTask *task = &priv->task;
+
+ const ptrdiff_t offset = job_idx * task->stride;
+ const double *diff_coeffs[MG2D_DIFF_COEFF_NB];
+
+ for (int i = 0; i < ARRAY_ELEMS(diff_coeffs); i++)
+ diff_coeffs[i] = task->diff_coeffs[i] + offset;
+
+ priv->residual_calc_line(task->line_size, task->dst + offset,
+ priv->residual_max + thread_idx * priv->calc_blocksize,
+ task->stride, task->u + offset, task->rhs + offset,
+ diff_coeffs, task->fd_factors);
+
+ return 0;
+}
+
+int mg2di_residual_calc(ResidualCalcContext *ctx, size_t size[2], ptrdiff_t stride,
+ double *residual_max,
+ double *dst, const double *u, const double *rhs,
+ const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
+ const double *fd_factors)
+{
+ ResidualCalcInternal *priv = ctx->priv;
+ ResidualCalcTask *task = &priv->task;
+ double res_max = 0.0;
+
+ memset(priv->residual_max, 0, sizeof(*priv->residual_max) * priv->residual_max_size);
+
+ task->line_size = size[0];
+ task->stride = stride;
+ task->dst = dst;
+ task->u = u;
+ task->rhs = rhs;
+ task->diff_coeffs = diff_coeffs;
+ task->fd_factors = fd_factors;
+
+ tp_execute(ctx->tp, size[1], residual_calc_task, priv);
+
+ for (size_t i = 0; i < priv->residual_max_size; i++)
+ res_max = MAX(res_max, priv->residual_max[i]);
+ *residual_max = res_max;
+
+ return 0;
+}
+
+int mg2di_residual_calc_init(ResidualCalcContext *ctx)
+{
+ ResidualCalcInternal *priv = ctx->priv;
+ double *tmp;
+
+ priv->calc_blocksize = 1;
+ switch (ctx->fd_stencil) {
+ case 1:
+ priv->residual_calc_line = residual_calc_line_s1_c;
+#if HAVE_EXTERNAL_ASM
+ if (ctx->cpuflags & MG2DI_CPU_FLAG_FMA3) {
+ priv->residual_calc_line = mg2di_residual_calc_line_s1_fma3;
+ priv->calc_blocksize = 4;
+ }
+#endif
+ break;
+ case 2:
+ priv->residual_calc_line = residual_calc_line_s2_c;
+#if HAVE_EXTERNAL_ASM
+ if (ctx->cpuflags & MG2DI_CPU_FLAG_FMA3) {
+ priv->residual_calc_line = mg2di_residual_calc_line_s2_fma3;
+ priv->calc_blocksize = 4;
+ }
+#endif
+ break;
+ }
+
+ priv->residual_max_size = tp_get_nb_threads(ctx->tp) * priv->calc_blocksize;
+ tmp = realloc(priv->residual_max,
+ sizeof(*priv->residual_max) * priv->residual_max_size);
+ if (!tmp) {
+ priv->residual_max_size = 0;
+ return -ENOMEM;
+ }
+ priv->residual_max = tmp;
+
+ return 0;
+}
+
+ResidualCalcContext *mg2di_residual_calc_alloc(void)
+{
+ ResidualCalcContext *ctx;
+
+ ctx = calloc(1, sizeof(*ctx));
+ if (!ctx)
+ return NULL;
+
+ ctx->priv = calloc(1, sizeof(*ctx->priv));
+ if (!ctx->priv) {
+ free(ctx);
+ return NULL;
+ }
+
+ return ctx;
+}
+
+void mg2di_residual_calc_free(ResidualCalcContext **pctx)
+{
+ ResidualCalcContext *ctx = *pctx;
+
+ if (!ctx)
+ return;
+
+ free(ctx->priv->residual_max);
+ free(ctx->priv);
+
+ free(ctx);
+ *pctx = NULL;
+}
diff --git a/residual_calc.h b/residual_calc.h
new file mode 100644
index 0000000..af1bf39
--- /dev/null
+++ b/residual_calc.h
@@ -0,0 +1,77 @@
+/*
+ * Copyright 2019 Anton Khirnov <anton@khirnov.net>
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef MG2D_RESIDUAL_CALC_H
+#define MG2D_RESIDUAL_CALC_H
+
+#include <stddef.h>
+
+#include <threadpool.h>
+
+typedef struct ResidualCalcInternal ResidualCalcInternal;
+
+/**
+ * Context for computing the residual, allocated by mg2di_residual_calc_alloc(),
+ * freed by mg2di_residual_calc_free()
+ */
+typedef struct ResidualCalcContext {
+ ResidualCalcInternal *priv;
+
+ /**
+ * The thread pool, must be set by the caller before
+ * mg2di_residual_calc_init().
+ */
+ TPContext *tp;
+
+ /**
+ * FD stencil size, must be set by the caller before
+ * mg2di_residual_calc_init().
+ */
+ size_t fd_stencil;
+
+ /**
+ * Flags indicating supported CPU features, must be set by the caller
+ * before mg2di_residual_calc_init().
+ */
+ int cpuflags;
+} ResidualCalcContext;
+
+/**
+ * Allocate and retur a new ResidualCalcContext.
+ */
+ResidualCalcContext *mg2di_residual_calc_alloc(void);
+/**
+ * Free a ResidualCalcContext and write NULL into the supplied pointer.
+ */
+void mg2di_residual_calc_free(ResidualCalcContext **ctx);
+
+/**
+ * Reinitialize the context for updated parameters. Must be called at least once
+ * before mg2di_residual_calc().
+ */
+int mg2di_residual_calc_init(ResidualCalcContext *ctx);
+
+/**
+ * Calculate the residual.
+ */
+int mg2di_residual_calc(ResidualCalcContext *ctx, size_t size[2], ptrdiff_t stride,
+ double *res_max,
+ double *dst, const double *u, const double *rhs,
+ const double * const diff_coeffs[MG2D_DIFF_COEFF_NB],
+ const double *fd_factors);
+
+#endif // MG2D_RESIDUAL_CALC_H