summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2019-04-10 14:43:16 +0200
committerAnton Khirnov <anton@khirnov.net>2019-04-10 14:43:16 +0200
commitb38434e2c0e43ba0eff35a561d5642e7959a6162 (patch)
treeb2fb13ded4825a9f05b3f34efaea85f1db5b1a97
parent552bcf4c906522c3ef7695654052f61e12260049 (diff)
transfer: implement transfer_addtransfer_add
-rw-r--r--mg2d.c50
-rw-r--r--transfer.c138
-rw-r--r--transfer.h2
-rw-r--r--transfer_interp.asm24
-rw-r--r--transfer_interp_template.c22
5 files changed, 169 insertions, 67 deletions
diff --git a/mg2d.c b/mg2d.c
index 11a8158..05bc10f 100644
--- a/mg2d.c
+++ b/mg2d.c
@@ -45,9 +45,6 @@ typedef struct MG2DLevel {
GridTransferContext *transfer_restrict;
GridTransferContext *transfer_prolong;
- NDArray *prolong_tmp_base;
- NDArray *prolong_tmp;
-
struct MG2DLevel *child;
/* timings */
@@ -55,7 +52,6 @@ typedef struct MG2DLevel {
int64_t time_relax;
int64_t time_prolong;
int64_t time_restrict;
- int64_t time_correct;
int64_t time_reinit;
} MG2DLevel;
@@ -98,18 +94,6 @@ static void log_relax_step(MG2DContext *ctx, MG2DLevel *level, const char *step_
prefix, level->depth, step_desc, res_old, res_new, res_old / res_new);
}
-static int coarse_correct_task(void *arg, unsigned int job_idx, unsigned int thread_idx)
-{
- MG2DLevel *level = arg;
-
- for (size_t idx0 = 0; idx0 < level->solver->domain_size[0]; idx0++) {
- const ptrdiff_t idx_dst = job_idx * level->solver->u->stride[0] + idx0;
- const ptrdiff_t idx_src = job_idx * level->prolong_tmp->stride[0] + idx0;
- level->solver->u->data[idx_dst] -= level->prolong_tmp->data[idx_src];
- }
- return 0;
-}
-
static int mg_relax_step(MG2DContext *ctx, MG2DLevel *level, const char *step_desc)
{
double res_old;
@@ -176,7 +160,7 @@ static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level)
/* restrict the residual as to the coarser-level rhs */
start = gettime();
ret = mg2di_gt_transfer(level->transfer_restrict, level->child->solver->rhs,
- level->solver->residual);
+ level->solver->residual, 0);
if (ret < 0)
return ret;
level->time_restrict += gettime() - start;
@@ -188,19 +172,12 @@ static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level)
/* prolongate the coarser-level correction */
start = gettime();
- ret = mg2di_gt_transfer(level->transfer_prolong, level->prolong_tmp,
- level->child->solver->u);
+ ret = mg2di_gt_transfer(level->transfer_prolong, level->solver->u,
+ level->child->solver->u, -1.0);
if (ret < 0)
return ret;
level->time_prolong += gettime() - start;
- /* apply the correction */
- start = gettime();
-
- tp_execute(ctx->priv->tp, level->solver->domain_size[1], coarse_correct_task, level);
-
- level->time_correct += gettime() - start;
-
/* re-init the current-level solver (re-calc the residual) */
res_prev = level->solver->residual_max;
start = gettime();
@@ -280,7 +257,7 @@ static int restrict_diff_coeffs(MG2DContext *ctx, enum GridTransferOperator op,
goto finish;
for (int i = 0; i < MG2D_DIFF_COEFF_NB; i++) {
- ret = mg2di_gt_transfer(tc, dst->diff_coeffs[i], src->diff_coeffs[i]);
+ ret = mg2di_gt_transfer(tc, dst->diff_coeffs[i], src->diff_coeffs[i], 0);
if (ret < 0)
goto finish;
}
@@ -566,8 +543,6 @@ static void mg_level_free(MG2DLevel **plevel)
if (!level)
return;
- mg2di_ndarray_free(&level->prolong_tmp);
- mg2di_ndarray_free(&level->prolong_tmp_base);
mg2di_egs_free(&level->solver);
mg2di_gt_free(&level->transfer_restrict);
@@ -586,14 +561,6 @@ static MG2DLevel *mg_level_alloc(enum EGSType type, const size_t domain_size)
if (!level)
return NULL;
- ret = mg2di_ndarray_alloc(&level->prolong_tmp_base, 2, (size_t [2]){domain_size + 1, domain_size + 1}, 0);
- if (ret < 0)
- goto fail;
-
- ret = mg2di_ndarray_slice(&level->prolong_tmp, level->prolong_tmp_base, (Slice [2]){ SLICE(0, -1, 1), SLICE(0, -1, 1) });
- if (ret < 0)
- goto fail;
-
level->solver = mg2di_egs_alloc(type,
(size_t [2]){domain_size, domain_size});
if (!level->solver)
@@ -791,7 +758,7 @@ void mg2d_print_stats(MG2DContext *ctx, const char *prefix)
EGSExactContext *e = NULL;
int64_t level_total = level->time_relax + level->time_prolong + level->time_restrict +
- level->time_correct + level->time_reinit;
+ level->time_reinit;
if (level->solver->solver_type == EGS_SOLVER_RELAXATION)
r = level->solver->solver_data;
@@ -799,7 +766,7 @@ void mg2d_print_stats(MG2DContext *ctx, const char *prefix)
e = level->solver->solver_data;
level_total = level->time_relax + level->time_prolong + level->time_restrict +
- level->time_correct + level->time_reinit;
+ level->time_reinit;
levels_total += level_total;
@@ -814,11 +781,10 @@ void mg2d_print_stats(MG2DContext *ctx, const char *prefix)
if (level->child) {
ret = snprintf(p, sizeof(buf) - (p - buf),
- "||%2.2f%% relax %2.2f%% prolong %2.2f%% restrict %2.2f%% correct %2.2f%% reinit",
+ "||%2.2f%% relax %2.2f%% prolong %2.2f%% restrict %2.2f%% reinit",
level->time_relax * 100.0 / level_total,
level->time_prolong * 100.0 / level_total,
level->time_restrict * 100.0 / level_total,
- level->time_correct * 100.0 / level_total,
level->time_reinit * 100.0 / level_total);
if (ret > 0)
p += ret;
@@ -926,7 +892,7 @@ int mg2d_init_guess(MG2DContext *ctx, const double *src,
if (ret < 0)
return ret;
- ret = mg2di_gt_transfer(priv->transfer_init, priv->u ? priv->u : priv->root->solver->u, a_src);
+ ret = mg2di_gt_transfer(priv->transfer_init, priv->u ? priv->u : priv->root->solver->u, a_src, 0);
mg2di_ndarray_free(&a_src);
return ret;
diff --git a/transfer.c b/transfer.c
index db3cecb..7a7f1c9 100644
--- a/transfer.c
+++ b/transfer.c
@@ -35,7 +35,7 @@ typedef struct GridTransfer {
size_t priv_data_size;
int (*init)(GridTransferContext *ctx);
void (*free)(GridTransferContext *ctx);
- int (*transfer)(GridTransferContext *ctx, NDArray *dst, const NDArray *src);
+ int (*transfer)(GridTransferContext *ctx, NDArray *dst, const NDArray *src, double fact);
} GridTransfer;
typedef struct GridTransferLagrange {
@@ -48,9 +48,17 @@ typedef struct GridTransferLagrange {
const double *src, ptrdiff_t src_stride,
const ptrdiff_t *idx_x, const double *fact_x, const double *fact_y,
ptrdiff_t dst_stride0, ptrdiff_t src_stride0);
+ void (*transfer_add_cont) (double *dst, ptrdiff_t dst_len,
+ const double *src, ptrdiff_t src_stride,
+ const ptrdiff_t *idx_x, const double *fact_x, const double *fact_y, double fact);
+ void (*transfer_add_generic)(double *dst, ptrdiff_t dst_len,
+ const double *src, ptrdiff_t src_stride,
+ const ptrdiff_t *idx_x, const double *fact_x, const double *fact_y, double fact,
+ ptrdiff_t dst_stride0, ptrdiff_t src_stride0);
const NDArray *src;
NDArray *dst;
+ double fact_res;
ptrdiff_t *idx[2];
double *fact[2];
@@ -65,9 +73,14 @@ void mg2di_transfer_interp_line_cont_6_fma3(double *dst, ptrdiff_t dst_len,
const double *src, ptrdiff_t src_stride,
const ptrdiff_t *idx_x,
const double *fact_x, const double *fact_y);
+void mg2di_transfer_interp_line_add_cont_6_fma3(double *dst, ptrdiff_t dst_len,
+ const double *src, ptrdiff_t src_stride,
+ const ptrdiff_t *idx_x,
+ const double *fact_x, const double *fact_y, double fact);
#endif
#define STENCIL 2
+#define ADD 0
#define CONTIGUOUS 0
#include "transfer_interp_template.c"
#undef CONTIGUOUS
@@ -75,9 +88,21 @@ void mg2di_transfer_interp_line_cont_6_fma3(double *dst, ptrdiff_t dst_len,
#define CONTIGUOUS 1
#include "transfer_interp_template.c"
#undef CONTIGUOUS
+#undef ADD
+
+#define ADD 1
+#define CONTIGUOUS 0
+#include "transfer_interp_template.c"
+#undef CONTIGUOUS
+
+#define CONTIGUOUS 1
+#include "transfer_interp_template.c"
+#undef CONTIGUOUS
+#undef ADD
#undef STENCIL
#define STENCIL 4
+#define ADD 0
#define CONTIGUOUS 0
#include "transfer_interp_template.c"
#undef CONTIGUOUS
@@ -85,9 +110,21 @@ void mg2di_transfer_interp_line_cont_6_fma3(double *dst, ptrdiff_t dst_len,
#define CONTIGUOUS 1
#include "transfer_interp_template.c"
#undef CONTIGUOUS
+#undef ADD
+
+#define ADD 1
+#define CONTIGUOUS 0
+#include "transfer_interp_template.c"
+#undef CONTIGUOUS
+
+#define CONTIGUOUS 1
+#include "transfer_interp_template.c"
+#undef CONTIGUOUS
+#undef ADD
#undef STENCIL
#define STENCIL 6
+#define ADD 0
#define CONTIGUOUS 0
#include "transfer_interp_template.c"
#undef CONTIGUOUS
@@ -95,9 +132,21 @@ void mg2di_transfer_interp_line_cont_6_fma3(double *dst, ptrdiff_t dst_len,
#define CONTIGUOUS 1
#include "transfer_interp_template.c"
#undef CONTIGUOUS
+#undef ADD
+
+#define ADD 1
+#define CONTIGUOUS 0
+#include "transfer_interp_template.c"
+#undef CONTIGUOUS
+
+#define CONTIGUOUS 1
+#include "transfer_interp_template.c"
+#undef CONTIGUOUS
+#undef ADD
#undef STENCIL
#define STENCIL 8
+#define ADD 0
#define CONTIGUOUS 0
#include "transfer_interp_template.c"
#undef CONTIGUOUS
@@ -105,6 +154,17 @@ void mg2di_transfer_interp_line_cont_6_fma3(double *dst, ptrdiff_t dst_len,
#define CONTIGUOUS 1
#include "transfer_interp_template.c"
#undef CONTIGUOUS
+#undef ADD
+
+#define ADD 1
+#define CONTIGUOUS 0
+#include "transfer_interp_template.c"
+#undef CONTIGUOUS
+
+#define CONTIGUOUS 1
+#include "transfer_interp_template.c"
+#undef CONTIGUOUS
+#undef ADD
#undef STENCIL
// generate the interpolation source indices and weights
@@ -156,12 +216,16 @@ static int transfer_lagrange_init(GridTransferContext *ctx)
switch (ctx->op) {
case GRID_TRANSFER_LAGRANGE_1:
priv->stencil = 2;
- priv->transfer_cont = interp_transfer_line_cont_2;
- priv->transfer_generic = interp_transfer_line_generic_2;
+ priv->transfer_cont = interp_transfer_line_store_cont_2;
+ priv->transfer_generic = interp_transfer_line_store_generic_2;
+ priv->transfer_add_cont = interp_transfer_line_add_cont_2;
+ priv->transfer_add_generic = interp_transfer_line_add_generic_2;
break;
case GRID_TRANSFER_LAGRANGE_3:
- priv->transfer_cont = interp_transfer_line_cont_4;
- priv->transfer_generic = interp_transfer_line_generic_4;
+ priv->transfer_cont = interp_transfer_line_store_cont_4;
+ priv->transfer_generic = interp_transfer_line_store_generic_4;
+ priv->transfer_add_cont = interp_transfer_line_add_cont_4;
+ priv->transfer_add_generic = interp_transfer_line_add_generic_4;
priv->stencil = 4;
#if HAVE_EXTERNAL_ASM
@@ -171,18 +235,23 @@ static int transfer_lagrange_init(GridTransferContext *ctx)
#endif
break;
case GRID_TRANSFER_LAGRANGE_5:
- priv->transfer_cont = interp_transfer_line_cont_6;
- priv->transfer_generic = interp_transfer_line_generic_6;
+ priv->transfer_cont = interp_transfer_line_store_cont_6;
+ priv->transfer_generic = interp_transfer_line_store_generic_6;
+ priv->transfer_add_cont = interp_transfer_line_add_cont_6;
+ priv->transfer_add_generic = interp_transfer_line_add_generic_6;
priv->stencil = 6;
#if HAVE_EXTERNAL_ASM
if (ctx->cpuflags & MG2DI_CPU_FLAG_FMA3) {
- priv->transfer_cont = mg2di_transfer_interp_line_cont_6_fma3;
+ priv->transfer_cont = mg2di_transfer_interp_line_cont_6_fma3;
+ priv->transfer_add_cont = mg2di_transfer_interp_line_add_cont_6_fma3;
}
#endif
break;
case GRID_TRANSFER_LAGRANGE_7:
- priv->transfer_cont = interp_transfer_line_cont_8;
- priv->transfer_generic = interp_transfer_line_generic_8;
+ priv->transfer_cont = interp_transfer_line_store_cont_8;
+ priv->transfer_generic = interp_transfer_line_store_generic_8;
+ priv->transfer_add_cont = interp_transfer_line_add_cont_8;
+ priv->transfer_add_generic = interp_transfer_line_add_generic_8;
priv->stencil = 8;
break;
default:
@@ -238,15 +307,41 @@ static int transfer_interp_task(void *arg, unsigned int job_idx, unsigned int th
return 0;
}
+static int transfer_interp_task_add(void *arg, unsigned int job_idx, unsigned int thread_idx)
+{
+ GridTransferLagrange *priv = arg;
+ const NDArray *src = priv->src;
+ NDArray *dst = priv->dst;
+
+ const ptrdiff_t idx_y = priv->idx[0][job_idx];
+ const double *fact_y = priv->fact[0] + priv->stencil * job_idx;
+
+ if (dst->stride[1] == 1 && src->stride[1] == 1) {
+ priv->transfer_add_cont(dst->data + job_idx * dst->stride[0], dst->shape[1],
+ src->data + idx_y * src->stride[0], src->stride[0],
+ priv->idx[1], priv->fact[1], fact_y, priv->fact_res);
+ } else {
+ priv->transfer_add_generic(dst->data + job_idx * dst->stride[0], dst->shape[1],
+ src->data + idx_y * src->stride[0], src->stride[0],
+ priv->idx[1], priv->fact[1], fact_y, priv->fact_res, dst->stride[1], src->stride[1]);
+ }
+
+ return 0;
+}
+
static int transfer_lagrange_transfer(GridTransferContext *ctx,
- NDArray *dst, const NDArray *src)
+ NDArray *dst, const NDArray *src, double fact)
{
GridTransferLagrange *priv = ctx->priv;
priv->src = src;
priv->dst = dst;
+ priv->fact_res = fact;
- tp_execute(ctx->tp, ctx->dst.size[0], transfer_interp_task, priv);
+ if (fact == 0.0)
+ tp_execute(ctx->tp, ctx->dst.size[0], transfer_interp_task, priv);
+ else
+ tp_execute(ctx->tp, ctx->dst.size[0], transfer_interp_task_add, priv);
priv->src = NULL;
priv->dst = NULL;
@@ -295,10 +390,13 @@ static int fw_1_transfer_task(void *arg, unsigned int job_idx, unsigned int thre
return 0;
}
static int transfer_fw_1_transfer(GridTransferContext *ctx,
- NDArray *dst, const NDArray *src)
+ NDArray *dst, const NDArray *src, double fact)
{
FWThreadData td = { dst, src };
+ if (fact != 0.0)
+ return -EINVAL;
+
tp_execute(ctx->tp, ctx->dst.size[0], fw_1_transfer_task, &td);
return 0;
@@ -337,10 +435,13 @@ static int fw_3_transfer_task(void *arg, unsigned int job_idx, unsigned int thre
return 0;
}
static int transfer_fw_3_transfer(GridTransferContext *ctx,
- NDArray *dst, const NDArray *src)
+ NDArray *dst, const NDArray *src, double fact)
{
FWThreadData td = { dst, src };
+ if (fact != 0.0)
+ return -EINVAL;
+
tp_execute(ctx->tp, ctx->dst.size[0], fw_3_transfer_task, &td);
return 0;
@@ -379,10 +480,13 @@ static int fw_5_transfer_task(void *arg, unsigned int job_idx, unsigned int thre
return 0;
}
static int transfer_fw_5_transfer(GridTransferContext *ctx,
- NDArray *dst, const NDArray *src)
+ NDArray *dst, const NDArray *src, double fact)
{
FWThreadData td = { dst, src };
+ if (fact != 0.0)
+ return -EINVAL;
+
tp_execute(ctx->tp, ctx->dst.size[0], fw_5_transfer_task, &td);
return 0;
@@ -404,7 +508,7 @@ static const GridTransfer *transfer_ops[] = {
};
int mg2di_gt_transfer(GridTransferContext *ctx,
- NDArray *dst, const NDArray *src)
+ NDArray *dst, const NDArray *src, double fact)
{
const GridTransfer *t = transfer_ops[ctx->op];
@@ -413,7 +517,7 @@ int mg2di_gt_transfer(GridTransferContext *ctx,
src->shape[0] != ctx->src.size[0] || src->shape[1] != ctx->src.size[1])
return -EINVAL;
- return t->transfer(ctx, dst, src);
+ return t->transfer(ctx, dst, src, fact);
}
GridTransferContext *mg2di_gt_alloc(enum GridTransferOperator op)
diff --git a/transfer.h b/transfer.h
index b2133f4..249272e 100644
--- a/transfer.h
+++ b/transfer.h
@@ -57,6 +57,6 @@ int mg2di_gt_init(GridTransferContext *ctx);
void mg2di_gt_free(GridTransferContext **ctx);
int mg2di_gt_transfer(GridTransferContext *ctx,
- NDArray *dst, const NDArray *src);
+ NDArray *dst, const NDArray *src, double fact);
#endif // MG2D_TRANSFER_H
diff --git a/transfer_interp.asm b/transfer_interp.asm
index 1b1fe7d..982b8aa 100644
--- a/transfer_interp.asm
+++ b/transfer_interp.asm
@@ -73,9 +73,11 @@ cglobal transfer_interp_line_cont_4, 7, 8, 6, dst, dst_len, src, src_stride, idx
RET
-INIT_YMM fma3
-cglobal transfer_interp_line_cont_6, 7, 9, 11, dst, dst_len, src, src_stride, idx_x, fact_x, fact_y,\
- idx_x_val, offset6
+%macro INTERP6 1
+%if %1
+ mova m12, m0
+%endif
+
shl src_strideq, 3
shl dst_lenq, 3
@@ -145,9 +147,25 @@ cglobal transfer_interp_line_cont_6, 7, 9, 11, dst, dst_len, src, src_stride, id
haddpd xm8, xm8
addpd m8, m9
+%if %1
+ movq xm13, [dstq + offsetq]
+ vfmadd213pd xm8, xm12, xm13
+ movq [dstq + offsetq], xm8
+%else
movq [dstq + offsetq], xm8
+%endif
add offsetq, 8
add offset6q, 8 * 6
js .loop
RET
+%endmacro
+
+INIT_YMM fma3
+cglobal transfer_interp_line_cont_6, 7, 9, 11, dst, dst_len, src, src_stride, idx_x, fact_x, fact_y,\
+ idx_x_val, offset6
+INTERP6 0
+
+cglobal transfer_interp_line_add_cont_6, 7, 9, 13, dst, dst_len, src, src_stride, idx_x, fact_x, fact_y,\
+ idx_x_val, offset6
+INTERP6 1
diff --git a/transfer_interp_template.c b/transfer_interp_template.c
index 65ae98b..f97f816 100644
--- a/transfer_interp_template.c
+++ b/transfer_interp_template.c
@@ -22,14 +22,23 @@
# define CONT generic
#endif
-#define JOIN3(a, b, c) a ## _ ## b ## _ ## c
-#define FUNC2(name, cont, stencil) JOIN3(name, cont, stencil)
-#define FUNC(name) FUNC2(name, CONT, STENCIL)
+#if ADD
+# define ACT add
+#else
+# define ACT store
+#endif
+
+#define JOIN4(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d
+#define FUNC2(name, act, cont, stencil) JOIN4(name, act, cont, stencil)
+#define FUNC(name) FUNC2(name, ACT, CONT, STENCIL)
static void
FUNC(interp_transfer_line)(double *dst, ptrdiff_t dst_len,
const double *src, ptrdiff_t src_stride,
const ptrdiff_t *idx_x, const double *fact_x, const double *fact_y
+#if ADD
+ , double fact
+#endif
#if !CONTIGUOUS
, ptrdiff_t dst_stride0, ptrdiff_t src_stride0
# define SSTRIDE1 src_stride0
@@ -52,13 +61,18 @@ FUNC(interp_transfer_line)(double *dst, ptrdiff_t dst_len,
val += tmp * fact_y[idx0];
}
+#if ADD
+ dst[x * DSTRIDE1] += fact * val;
+#else
dst[x * DSTRIDE1] = val;
+#endif
}
}
#undef SSTRIDE1
#undef DSTRIDE1
#undef CONT
-#undef JOIN3
+#undef ACT
+#undef JOIN4
#undef FUNC2
#undef FUNC