summaryrefslogtreecommitdiff
path: root/transfer.c
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2019-04-10 14:43:16 +0200
committerAnton Khirnov <anton@khirnov.net>2019-04-10 14:43:16 +0200
commitb38434e2c0e43ba0eff35a561d5642e7959a6162 (patch)
treeb2fb13ded4825a9f05b3f34efaea85f1db5b1a97 /transfer.c
parent552bcf4c906522c3ef7695654052f61e12260049 (diff)
transfer: implement transfer_addtransfer_add
Diffstat (limited to 'transfer.c')
-rw-r--r--transfer.c138
1 files changed, 121 insertions, 17 deletions
diff --git a/transfer.c b/transfer.c
index db3cecb..7a7f1c9 100644
--- a/transfer.c
+++ b/transfer.c
@@ -35,7 +35,7 @@ typedef struct GridTransfer {
size_t priv_data_size;
int (*init)(GridTransferContext *ctx);
void (*free)(GridTransferContext *ctx);
- int (*transfer)(GridTransferContext *ctx, NDArray *dst, const NDArray *src);
+ int (*transfer)(GridTransferContext *ctx, NDArray *dst, const NDArray *src, double fact);
} GridTransfer;
typedef struct GridTransferLagrange {
@@ -48,9 +48,17 @@ typedef struct GridTransferLagrange {
const double *src, ptrdiff_t src_stride,
const ptrdiff_t *idx_x, const double *fact_x, const double *fact_y,
ptrdiff_t dst_stride0, ptrdiff_t src_stride0);
+ void (*transfer_add_cont) (double *dst, ptrdiff_t dst_len,
+ const double *src, ptrdiff_t src_stride,
+ const ptrdiff_t *idx_x, const double *fact_x, const double *fact_y, double fact);
+ void (*transfer_add_generic)(double *dst, ptrdiff_t dst_len,
+ const double *src, ptrdiff_t src_stride,
+ const ptrdiff_t *idx_x, const double *fact_x, const double *fact_y, double fact,
+ ptrdiff_t dst_stride0, ptrdiff_t src_stride0);
const NDArray *src;
NDArray *dst;
+ double fact_res;
ptrdiff_t *idx[2];
double *fact[2];
@@ -65,9 +73,14 @@ void mg2di_transfer_interp_line_cont_6_fma3(double *dst, ptrdiff_t dst_len,
const double *src, ptrdiff_t src_stride,
const ptrdiff_t *idx_x,
const double *fact_x, const double *fact_y);
+void mg2di_transfer_interp_line_add_cont_6_fma3(double *dst, ptrdiff_t dst_len,
+ const double *src, ptrdiff_t src_stride,
+ const ptrdiff_t *idx_x,
+ const double *fact_x, const double *fact_y, double fact);
#endif
#define STENCIL 2
+#define ADD 0
#define CONTIGUOUS 0
#include "transfer_interp_template.c"
#undef CONTIGUOUS
@@ -75,9 +88,21 @@ void mg2di_transfer_interp_line_cont_6_fma3(double *dst, ptrdiff_t dst_len,
#define CONTIGUOUS 1
#include "transfer_interp_template.c"
#undef CONTIGUOUS
+#undef ADD
+
+#define ADD 1
+#define CONTIGUOUS 0
+#include "transfer_interp_template.c"
+#undef CONTIGUOUS
+
+#define CONTIGUOUS 1
+#include "transfer_interp_template.c"
+#undef CONTIGUOUS
+#undef ADD
#undef STENCIL
#define STENCIL 4
+#define ADD 0
#define CONTIGUOUS 0
#include "transfer_interp_template.c"
#undef CONTIGUOUS
@@ -85,9 +110,21 @@ void mg2di_transfer_interp_line_cont_6_fma3(double *dst, ptrdiff_t dst_len,
#define CONTIGUOUS 1
#include "transfer_interp_template.c"
#undef CONTIGUOUS
+#undef ADD
+
+#define ADD 1
+#define CONTIGUOUS 0
+#include "transfer_interp_template.c"
+#undef CONTIGUOUS
+
+#define CONTIGUOUS 1
+#include "transfer_interp_template.c"
+#undef CONTIGUOUS
+#undef ADD
#undef STENCIL
#define STENCIL 6
+#define ADD 0
#define CONTIGUOUS 0
#include "transfer_interp_template.c"
#undef CONTIGUOUS
@@ -95,9 +132,21 @@ void mg2di_transfer_interp_line_cont_6_fma3(double *dst, ptrdiff_t dst_len,
#define CONTIGUOUS 1
#include "transfer_interp_template.c"
#undef CONTIGUOUS
+#undef ADD
+
+#define ADD 1
+#define CONTIGUOUS 0
+#include "transfer_interp_template.c"
+#undef CONTIGUOUS
+
+#define CONTIGUOUS 1
+#include "transfer_interp_template.c"
+#undef CONTIGUOUS
+#undef ADD
#undef STENCIL
#define STENCIL 8
+#define ADD 0
#define CONTIGUOUS 0
#include "transfer_interp_template.c"
#undef CONTIGUOUS
@@ -105,6 +154,17 @@ void mg2di_transfer_interp_line_cont_6_fma3(double *dst, ptrdiff_t dst_len,
#define CONTIGUOUS 1
#include "transfer_interp_template.c"
#undef CONTIGUOUS
+#undef ADD
+
+#define ADD 1
+#define CONTIGUOUS 0
+#include "transfer_interp_template.c"
+#undef CONTIGUOUS
+
+#define CONTIGUOUS 1
+#include "transfer_interp_template.c"
+#undef CONTIGUOUS
+#undef ADD
#undef STENCIL
// generate the interpolation source indices and weights
@@ -156,12 +216,16 @@ static int transfer_lagrange_init(GridTransferContext *ctx)
switch (ctx->op) {
case GRID_TRANSFER_LAGRANGE_1:
priv->stencil = 2;
- priv->transfer_cont = interp_transfer_line_cont_2;
- priv->transfer_generic = interp_transfer_line_generic_2;
+ priv->transfer_cont = interp_transfer_line_store_cont_2;
+ priv->transfer_generic = interp_transfer_line_store_generic_2;
+ priv->transfer_add_cont = interp_transfer_line_add_cont_2;
+ priv->transfer_add_generic = interp_transfer_line_add_generic_2;
break;
case GRID_TRANSFER_LAGRANGE_3:
- priv->transfer_cont = interp_transfer_line_cont_4;
- priv->transfer_generic = interp_transfer_line_generic_4;
+ priv->transfer_cont = interp_transfer_line_store_cont_4;
+ priv->transfer_generic = interp_transfer_line_store_generic_4;
+ priv->transfer_add_cont = interp_transfer_line_add_cont_4;
+ priv->transfer_add_generic = interp_transfer_line_add_generic_4;
priv->stencil = 4;
#if HAVE_EXTERNAL_ASM
@@ -171,18 +235,23 @@ static int transfer_lagrange_init(GridTransferContext *ctx)
#endif
break;
case GRID_TRANSFER_LAGRANGE_5:
- priv->transfer_cont = interp_transfer_line_cont_6;
- priv->transfer_generic = interp_transfer_line_generic_6;
+ priv->transfer_cont = interp_transfer_line_store_cont_6;
+ priv->transfer_generic = interp_transfer_line_store_generic_6;
+ priv->transfer_add_cont = interp_transfer_line_add_cont_6;
+ priv->transfer_add_generic = interp_transfer_line_add_generic_6;
priv->stencil = 6;
#if HAVE_EXTERNAL_ASM
if (ctx->cpuflags & MG2DI_CPU_FLAG_FMA3) {
- priv->transfer_cont = mg2di_transfer_interp_line_cont_6_fma3;
+ priv->transfer_cont = mg2di_transfer_interp_line_cont_6_fma3;
+ priv->transfer_add_cont = mg2di_transfer_interp_line_add_cont_6_fma3;
}
#endif
break;
case GRID_TRANSFER_LAGRANGE_7:
- priv->transfer_cont = interp_transfer_line_cont_8;
- priv->transfer_generic = interp_transfer_line_generic_8;
+ priv->transfer_cont = interp_transfer_line_store_cont_8;
+ priv->transfer_generic = interp_transfer_line_store_generic_8;
+ priv->transfer_add_cont = interp_transfer_line_add_cont_8;
+ priv->transfer_add_generic = interp_transfer_line_add_generic_8;
priv->stencil = 8;
break;
default:
@@ -238,15 +307,41 @@ static int transfer_interp_task(void *arg, unsigned int job_idx, unsigned int th
return 0;
}
+static int transfer_interp_task_add(void *arg, unsigned int job_idx, unsigned int thread_idx)
+{
+ GridTransferLagrange *priv = arg;
+ const NDArray *src = priv->src;
+ NDArray *dst = priv->dst;
+
+ const ptrdiff_t idx_y = priv->idx[0][job_idx];
+ const double *fact_y = priv->fact[0] + priv->stencil * job_idx;
+
+ if (dst->stride[1] == 1 && src->stride[1] == 1) {
+ priv->transfer_add_cont(dst->data + job_idx * dst->stride[0], dst->shape[1],
+ src->data + idx_y * src->stride[0], src->stride[0],
+ priv->idx[1], priv->fact[1], fact_y, priv->fact_res);
+ } else {
+ priv->transfer_add_generic(dst->data + job_idx * dst->stride[0], dst->shape[1],
+ src->data + idx_y * src->stride[0], src->stride[0],
+ priv->idx[1], priv->fact[1], fact_y, priv->fact_res, dst->stride[1], src->stride[1]);
+ }
+
+ return 0;
+}
+
static int transfer_lagrange_transfer(GridTransferContext *ctx,
- NDArray *dst, const NDArray *src)
+ NDArray *dst, const NDArray *src, double fact)
{
GridTransferLagrange *priv = ctx->priv;
priv->src = src;
priv->dst = dst;
+ priv->fact_res = fact;
- tp_execute(ctx->tp, ctx->dst.size[0], transfer_interp_task, priv);
+ if (fact == 0.0)
+ tp_execute(ctx->tp, ctx->dst.size[0], transfer_interp_task, priv);
+ else
+ tp_execute(ctx->tp, ctx->dst.size[0], transfer_interp_task_add, priv);
priv->src = NULL;
priv->dst = NULL;
@@ -295,10 +390,13 @@ static int fw_1_transfer_task(void *arg, unsigned int job_idx, unsigned int thre
return 0;
}
static int transfer_fw_1_transfer(GridTransferContext *ctx,
- NDArray *dst, const NDArray *src)
+ NDArray *dst, const NDArray *src, double fact)
{
FWThreadData td = { dst, src };
+ if (fact != 0.0)
+ return -EINVAL;
+
tp_execute(ctx->tp, ctx->dst.size[0], fw_1_transfer_task, &td);
return 0;
@@ -337,10 +435,13 @@ static int fw_3_transfer_task(void *arg, unsigned int job_idx, unsigned int thre
return 0;
}
static int transfer_fw_3_transfer(GridTransferContext *ctx,
- NDArray *dst, const NDArray *src)
+ NDArray *dst, const NDArray *src, double fact)
{
FWThreadData td = { dst, src };
+ if (fact != 0.0)
+ return -EINVAL;
+
tp_execute(ctx->tp, ctx->dst.size[0], fw_3_transfer_task, &td);
return 0;
@@ -379,10 +480,13 @@ static int fw_5_transfer_task(void *arg, unsigned int job_idx, unsigned int thre
return 0;
}
static int transfer_fw_5_transfer(GridTransferContext *ctx,
- NDArray *dst, const NDArray *src)
+ NDArray *dst, const NDArray *src, double fact)
{
FWThreadData td = { dst, src };
+ if (fact != 0.0)
+ return -EINVAL;
+
tp_execute(ctx->tp, ctx->dst.size[0], fw_5_transfer_task, &td);
return 0;
@@ -404,7 +508,7 @@ static const GridTransfer *transfer_ops[] = {
};
int mg2di_gt_transfer(GridTransferContext *ctx,
- NDArray *dst, const NDArray *src)
+ NDArray *dst, const NDArray *src, double fact)
{
const GridTransfer *t = transfer_ops[ctx->op];
@@ -413,7 +517,7 @@ int mg2di_gt_transfer(GridTransferContext *ctx,
src->shape[0] != ctx->src.size[0] || src->shape[1] != ctx->src.size[1])
return -EINVAL;
- return t->transfer(ctx, dst, src);
+ return t->transfer(ctx, dst, src, fact);
}
GridTransferContext *mg2di_gt_alloc(enum GridTransferOperator op)