From b38434e2c0e43ba0eff35a561d5642e7959a6162 Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Wed, 10 Apr 2019 14:43:16 +0200 Subject: transfer: implement transfer_add --- mg2d.c | 50 +++------------- transfer.c | 138 +++++++++++++++++++++++++++++++++++++++------ transfer.h | 2 +- transfer_interp.asm | 24 +++++++- transfer_interp_template.c | 22 ++++++-- 5 files changed, 169 insertions(+), 67 deletions(-) diff --git a/mg2d.c b/mg2d.c index 11a8158..05bc10f 100644 --- a/mg2d.c +++ b/mg2d.c @@ -45,9 +45,6 @@ typedef struct MG2DLevel { GridTransferContext *transfer_restrict; GridTransferContext *transfer_prolong; - NDArray *prolong_tmp_base; - NDArray *prolong_tmp; - struct MG2DLevel *child; /* timings */ @@ -55,7 +52,6 @@ typedef struct MG2DLevel { int64_t time_relax; int64_t time_prolong; int64_t time_restrict; - int64_t time_correct; int64_t time_reinit; } MG2DLevel; @@ -98,18 +94,6 @@ static void log_relax_step(MG2DContext *ctx, MG2DLevel *level, const char *step_ prefix, level->depth, step_desc, res_old, res_new, res_old / res_new); } -static int coarse_correct_task(void *arg, unsigned int job_idx, unsigned int thread_idx) -{ - MG2DLevel *level = arg; - - for (size_t idx0 = 0; idx0 < level->solver->domain_size[0]; idx0++) { - const ptrdiff_t idx_dst = job_idx * level->solver->u->stride[0] + idx0; - const ptrdiff_t idx_src = job_idx * level->prolong_tmp->stride[0] + idx0; - level->solver->u->data[idx_dst] -= level->prolong_tmp->data[idx_src]; - } - return 0; -} - static int mg_relax_step(MG2DContext *ctx, MG2DLevel *level, const char *step_desc) { double res_old; @@ -176,7 +160,7 @@ static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level) /* restrict the residual as to the coarser-level rhs */ start = gettime(); ret = mg2di_gt_transfer(level->transfer_restrict, level->child->solver->rhs, - level->solver->residual); + level->solver->residual, 0); if (ret < 0) return ret; level->time_restrict += gettime() - start; @@ -188,19 +172,12 @@ static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level) /* prolongate the coarser-level correction */ start = gettime(); - ret = mg2di_gt_transfer(level->transfer_prolong, level->prolong_tmp, - level->child->solver->u); + ret = mg2di_gt_transfer(level->transfer_prolong, level->solver->u, + level->child->solver->u, -1.0); if (ret < 0) return ret; level->time_prolong += gettime() - start; - /* apply the correction */ - start = gettime(); - - tp_execute(ctx->priv->tp, level->solver->domain_size[1], coarse_correct_task, level); - - level->time_correct += gettime() - start; - /* re-init the current-level solver (re-calc the residual) */ res_prev = level->solver->residual_max; start = gettime(); @@ -280,7 +257,7 @@ static int restrict_diff_coeffs(MG2DContext *ctx, enum GridTransferOperator op, goto finish; for (int i = 0; i < MG2D_DIFF_COEFF_NB; i++) { - ret = mg2di_gt_transfer(tc, dst->diff_coeffs[i], src->diff_coeffs[i]); + ret = mg2di_gt_transfer(tc, dst->diff_coeffs[i], src->diff_coeffs[i], 0); if (ret < 0) goto finish; } @@ -566,8 +543,6 @@ static void mg_level_free(MG2DLevel **plevel) if (!level) return; - mg2di_ndarray_free(&level->prolong_tmp); - mg2di_ndarray_free(&level->prolong_tmp_base); mg2di_egs_free(&level->solver); mg2di_gt_free(&level->transfer_restrict); @@ -586,14 +561,6 @@ static MG2DLevel *mg_level_alloc(enum EGSType type, const size_t domain_size) if (!level) return NULL; - ret = mg2di_ndarray_alloc(&level->prolong_tmp_base, 2, (size_t [2]){domain_size + 1, domain_size + 1}, 0); - if (ret < 0) - goto fail; - - ret = mg2di_ndarray_slice(&level->prolong_tmp, level->prolong_tmp_base, (Slice [2]){ SLICE(0, -1, 1), SLICE(0, -1, 1) }); - if (ret < 0) - goto fail; - level->solver = mg2di_egs_alloc(type, (size_t [2]){domain_size, domain_size}); if (!level->solver) @@ -791,7 +758,7 @@ void mg2d_print_stats(MG2DContext *ctx, const char *prefix) EGSExactContext *e = NULL; int64_t level_total = level->time_relax + level->time_prolong + level->time_restrict + - level->time_correct + level->time_reinit; + level->time_reinit; if (level->solver->solver_type == EGS_SOLVER_RELAXATION) r = level->solver->solver_data; @@ -799,7 +766,7 @@ void mg2d_print_stats(MG2DContext *ctx, const char *prefix) e = level->solver->solver_data; level_total = level->time_relax + level->time_prolong + level->time_restrict + - level->time_correct + level->time_reinit; + level->time_reinit; levels_total += level_total; @@ -814,11 +781,10 @@ void mg2d_print_stats(MG2DContext *ctx, const char *prefix) if (level->child) { ret = snprintf(p, sizeof(buf) - (p - buf), - "||%2.2f%% relax %2.2f%% prolong %2.2f%% restrict %2.2f%% correct %2.2f%% reinit", + "||%2.2f%% relax %2.2f%% prolong %2.2f%% restrict %2.2f%% reinit", level->time_relax * 100.0 / level_total, level->time_prolong * 100.0 / level_total, level->time_restrict * 100.0 / level_total, - level->time_correct * 100.0 / level_total, level->time_reinit * 100.0 / level_total); if (ret > 0) p += ret; @@ -926,7 +892,7 @@ int mg2d_init_guess(MG2DContext *ctx, const double *src, if (ret < 0) return ret; - ret = mg2di_gt_transfer(priv->transfer_init, priv->u ? priv->u : priv->root->solver->u, a_src); + ret = mg2di_gt_transfer(priv->transfer_init, priv->u ? priv->u : priv->root->solver->u, a_src, 0); mg2di_ndarray_free(&a_src); return ret; diff --git a/transfer.c b/transfer.c index db3cecb..7a7f1c9 100644 --- a/transfer.c +++ b/transfer.c @@ -35,7 +35,7 @@ typedef struct GridTransfer { size_t priv_data_size; int (*init)(GridTransferContext *ctx); void (*free)(GridTransferContext *ctx); - int (*transfer)(GridTransferContext *ctx, NDArray *dst, const NDArray *src); + int (*transfer)(GridTransferContext *ctx, NDArray *dst, const NDArray *src, double fact); } GridTransfer; typedef struct GridTransferLagrange { @@ -48,9 +48,17 @@ typedef struct GridTransferLagrange { const double *src, ptrdiff_t src_stride, const ptrdiff_t *idx_x, const double *fact_x, const double *fact_y, ptrdiff_t dst_stride0, ptrdiff_t src_stride0); + void (*transfer_add_cont) (double *dst, ptrdiff_t dst_len, + const double *src, ptrdiff_t src_stride, + const ptrdiff_t *idx_x, const double *fact_x, const double *fact_y, double fact); + void (*transfer_add_generic)(double *dst, ptrdiff_t dst_len, + const double *src, ptrdiff_t src_stride, + const ptrdiff_t *idx_x, const double *fact_x, const double *fact_y, double fact, + ptrdiff_t dst_stride0, ptrdiff_t src_stride0); const NDArray *src; NDArray *dst; + double fact_res; ptrdiff_t *idx[2]; double *fact[2]; @@ -65,9 +73,14 @@ void mg2di_transfer_interp_line_cont_6_fma3(double *dst, ptrdiff_t dst_len, const double *src, ptrdiff_t src_stride, const ptrdiff_t *idx_x, const double *fact_x, const double *fact_y); +void mg2di_transfer_interp_line_add_cont_6_fma3(double *dst, ptrdiff_t dst_len, + const double *src, ptrdiff_t src_stride, + const ptrdiff_t *idx_x, + const double *fact_x, const double *fact_y, double fact); #endif #define STENCIL 2 +#define ADD 0 #define CONTIGUOUS 0 #include "transfer_interp_template.c" #undef CONTIGUOUS @@ -75,9 +88,21 @@ void mg2di_transfer_interp_line_cont_6_fma3(double *dst, ptrdiff_t dst_len, #define CONTIGUOUS 1 #include "transfer_interp_template.c" #undef CONTIGUOUS +#undef ADD + +#define ADD 1 +#define CONTIGUOUS 0 +#include "transfer_interp_template.c" +#undef CONTIGUOUS + +#define CONTIGUOUS 1 +#include "transfer_interp_template.c" +#undef CONTIGUOUS +#undef ADD #undef STENCIL #define STENCIL 4 +#define ADD 0 #define CONTIGUOUS 0 #include "transfer_interp_template.c" #undef CONTIGUOUS @@ -85,9 +110,21 @@ void mg2di_transfer_interp_line_cont_6_fma3(double *dst, ptrdiff_t dst_len, #define CONTIGUOUS 1 #include "transfer_interp_template.c" #undef CONTIGUOUS +#undef ADD + +#define ADD 1 +#define CONTIGUOUS 0 +#include "transfer_interp_template.c" +#undef CONTIGUOUS + +#define CONTIGUOUS 1 +#include "transfer_interp_template.c" +#undef CONTIGUOUS +#undef ADD #undef STENCIL #define STENCIL 6 +#define ADD 0 #define CONTIGUOUS 0 #include "transfer_interp_template.c" #undef CONTIGUOUS @@ -95,9 +132,21 @@ void mg2di_transfer_interp_line_cont_6_fma3(double *dst, ptrdiff_t dst_len, #define CONTIGUOUS 1 #include "transfer_interp_template.c" #undef CONTIGUOUS +#undef ADD + +#define ADD 1 +#define CONTIGUOUS 0 +#include "transfer_interp_template.c" +#undef CONTIGUOUS + +#define CONTIGUOUS 1 +#include "transfer_interp_template.c" +#undef CONTIGUOUS +#undef ADD #undef STENCIL #define STENCIL 8 +#define ADD 0 #define CONTIGUOUS 0 #include "transfer_interp_template.c" #undef CONTIGUOUS @@ -105,6 +154,17 @@ void mg2di_transfer_interp_line_cont_6_fma3(double *dst, ptrdiff_t dst_len, #define CONTIGUOUS 1 #include "transfer_interp_template.c" #undef CONTIGUOUS +#undef ADD + +#define ADD 1 +#define CONTIGUOUS 0 +#include "transfer_interp_template.c" +#undef CONTIGUOUS + +#define CONTIGUOUS 1 +#include "transfer_interp_template.c" +#undef CONTIGUOUS +#undef ADD #undef STENCIL // generate the interpolation source indices and weights @@ -156,12 +216,16 @@ static int transfer_lagrange_init(GridTransferContext *ctx) switch (ctx->op) { case GRID_TRANSFER_LAGRANGE_1: priv->stencil = 2; - priv->transfer_cont = interp_transfer_line_cont_2; - priv->transfer_generic = interp_transfer_line_generic_2; + priv->transfer_cont = interp_transfer_line_store_cont_2; + priv->transfer_generic = interp_transfer_line_store_generic_2; + priv->transfer_add_cont = interp_transfer_line_add_cont_2; + priv->transfer_add_generic = interp_transfer_line_add_generic_2; break; case GRID_TRANSFER_LAGRANGE_3: - priv->transfer_cont = interp_transfer_line_cont_4; - priv->transfer_generic = interp_transfer_line_generic_4; + priv->transfer_cont = interp_transfer_line_store_cont_4; + priv->transfer_generic = interp_transfer_line_store_generic_4; + priv->transfer_add_cont = interp_transfer_line_add_cont_4; + priv->transfer_add_generic = interp_transfer_line_add_generic_4; priv->stencil = 4; #if HAVE_EXTERNAL_ASM @@ -171,18 +235,23 @@ static int transfer_lagrange_init(GridTransferContext *ctx) #endif break; case GRID_TRANSFER_LAGRANGE_5: - priv->transfer_cont = interp_transfer_line_cont_6; - priv->transfer_generic = interp_transfer_line_generic_6; + priv->transfer_cont = interp_transfer_line_store_cont_6; + priv->transfer_generic = interp_transfer_line_store_generic_6; + priv->transfer_add_cont = interp_transfer_line_add_cont_6; + priv->transfer_add_generic = interp_transfer_line_add_generic_6; priv->stencil = 6; #if HAVE_EXTERNAL_ASM if (ctx->cpuflags & MG2DI_CPU_FLAG_FMA3) { - priv->transfer_cont = mg2di_transfer_interp_line_cont_6_fma3; + priv->transfer_cont = mg2di_transfer_interp_line_cont_6_fma3; + priv->transfer_add_cont = mg2di_transfer_interp_line_add_cont_6_fma3; } #endif break; case GRID_TRANSFER_LAGRANGE_7: - priv->transfer_cont = interp_transfer_line_cont_8; - priv->transfer_generic = interp_transfer_line_generic_8; + priv->transfer_cont = interp_transfer_line_store_cont_8; + priv->transfer_generic = interp_transfer_line_store_generic_8; + priv->transfer_add_cont = interp_transfer_line_add_cont_8; + priv->transfer_add_generic = interp_transfer_line_add_generic_8; priv->stencil = 8; break; default: @@ -238,15 +307,41 @@ static int transfer_interp_task(void *arg, unsigned int job_idx, unsigned int th return 0; } +static int transfer_interp_task_add(void *arg, unsigned int job_idx, unsigned int thread_idx) +{ + GridTransferLagrange *priv = arg; + const NDArray *src = priv->src; + NDArray *dst = priv->dst; + + const ptrdiff_t idx_y = priv->idx[0][job_idx]; + const double *fact_y = priv->fact[0] + priv->stencil * job_idx; + + if (dst->stride[1] == 1 && src->stride[1] == 1) { + priv->transfer_add_cont(dst->data + job_idx * dst->stride[0], dst->shape[1], + src->data + idx_y * src->stride[0], src->stride[0], + priv->idx[1], priv->fact[1], fact_y, priv->fact_res); + } else { + priv->transfer_add_generic(dst->data + job_idx * dst->stride[0], dst->shape[1], + src->data + idx_y * src->stride[0], src->stride[0], + priv->idx[1], priv->fact[1], fact_y, priv->fact_res, dst->stride[1], src->stride[1]); + } + + return 0; +} + static int transfer_lagrange_transfer(GridTransferContext *ctx, - NDArray *dst, const NDArray *src) + NDArray *dst, const NDArray *src, double fact) { GridTransferLagrange *priv = ctx->priv; priv->src = src; priv->dst = dst; + priv->fact_res = fact; - tp_execute(ctx->tp, ctx->dst.size[0], transfer_interp_task, priv); + if (fact == 0.0) + tp_execute(ctx->tp, ctx->dst.size[0], transfer_interp_task, priv); + else + tp_execute(ctx->tp, ctx->dst.size[0], transfer_interp_task_add, priv); priv->src = NULL; priv->dst = NULL; @@ -295,10 +390,13 @@ static int fw_1_transfer_task(void *arg, unsigned int job_idx, unsigned int thre return 0; } static int transfer_fw_1_transfer(GridTransferContext *ctx, - NDArray *dst, const NDArray *src) + NDArray *dst, const NDArray *src, double fact) { FWThreadData td = { dst, src }; + if (fact != 0.0) + return -EINVAL; + tp_execute(ctx->tp, ctx->dst.size[0], fw_1_transfer_task, &td); return 0; @@ -337,10 +435,13 @@ static int fw_3_transfer_task(void *arg, unsigned int job_idx, unsigned int thre return 0; } static int transfer_fw_3_transfer(GridTransferContext *ctx, - NDArray *dst, const NDArray *src) + NDArray *dst, const NDArray *src, double fact) { FWThreadData td = { dst, src }; + if (fact != 0.0) + return -EINVAL; + tp_execute(ctx->tp, ctx->dst.size[0], fw_3_transfer_task, &td); return 0; @@ -379,10 +480,13 @@ static int fw_5_transfer_task(void *arg, unsigned int job_idx, unsigned int thre return 0; } static int transfer_fw_5_transfer(GridTransferContext *ctx, - NDArray *dst, const NDArray *src) + NDArray *dst, const NDArray *src, double fact) { FWThreadData td = { dst, src }; + if (fact != 0.0) + return -EINVAL; + tp_execute(ctx->tp, ctx->dst.size[0], fw_5_transfer_task, &td); return 0; @@ -404,7 +508,7 @@ static const GridTransfer *transfer_ops[] = { }; int mg2di_gt_transfer(GridTransferContext *ctx, - NDArray *dst, const NDArray *src) + NDArray *dst, const NDArray *src, double fact) { const GridTransfer *t = transfer_ops[ctx->op]; @@ -413,7 +517,7 @@ int mg2di_gt_transfer(GridTransferContext *ctx, src->shape[0] != ctx->src.size[0] || src->shape[1] != ctx->src.size[1]) return -EINVAL; - return t->transfer(ctx, dst, src); + return t->transfer(ctx, dst, src, fact); } GridTransferContext *mg2di_gt_alloc(enum GridTransferOperator op) diff --git a/transfer.h b/transfer.h index b2133f4..249272e 100644 --- a/transfer.h +++ b/transfer.h @@ -57,6 +57,6 @@ int mg2di_gt_init(GridTransferContext *ctx); void mg2di_gt_free(GridTransferContext **ctx); int mg2di_gt_transfer(GridTransferContext *ctx, - NDArray *dst, const NDArray *src); + NDArray *dst, const NDArray *src, double fact); #endif // MG2D_TRANSFER_H diff --git a/transfer_interp.asm b/transfer_interp.asm index 1b1fe7d..982b8aa 100644 --- a/transfer_interp.asm +++ b/transfer_interp.asm @@ -73,9 +73,11 @@ cglobal transfer_interp_line_cont_4, 7, 8, 6, dst, dst_len, src, src_stride, idx RET -INIT_YMM fma3 -cglobal transfer_interp_line_cont_6, 7, 9, 11, dst, dst_len, src, src_stride, idx_x, fact_x, fact_y,\ - idx_x_val, offset6 +%macro INTERP6 1 +%if %1 + mova m12, m0 +%endif + shl src_strideq, 3 shl dst_lenq, 3 @@ -145,9 +147,25 @@ cglobal transfer_interp_line_cont_6, 7, 9, 11, dst, dst_len, src, src_stride, id haddpd xm8, xm8 addpd m8, m9 +%if %1 + movq xm13, [dstq + offsetq] + vfmadd213pd xm8, xm12, xm13 + movq [dstq + offsetq], xm8 +%else movq [dstq + offsetq], xm8 +%endif add offsetq, 8 add offset6q, 8 * 6 js .loop RET +%endmacro + +INIT_YMM fma3 +cglobal transfer_interp_line_cont_6, 7, 9, 11, dst, dst_len, src, src_stride, idx_x, fact_x, fact_y,\ + idx_x_val, offset6 +INTERP6 0 + +cglobal transfer_interp_line_add_cont_6, 7, 9, 13, dst, dst_len, src, src_stride, idx_x, fact_x, fact_y,\ + idx_x_val, offset6 +INTERP6 1 diff --git a/transfer_interp_template.c b/transfer_interp_template.c index 65ae98b..f97f816 100644 --- a/transfer_interp_template.c +++ b/transfer_interp_template.c @@ -22,14 +22,23 @@ # define CONT generic #endif -#define JOIN3(a, b, c) a ## _ ## b ## _ ## c -#define FUNC2(name, cont, stencil) JOIN3(name, cont, stencil) -#define FUNC(name) FUNC2(name, CONT, STENCIL) +#if ADD +# define ACT add +#else +# define ACT store +#endif + +#define JOIN4(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d +#define FUNC2(name, act, cont, stencil) JOIN4(name, act, cont, stencil) +#define FUNC(name) FUNC2(name, ACT, CONT, STENCIL) static void FUNC(interp_transfer_line)(double *dst, ptrdiff_t dst_len, const double *src, ptrdiff_t src_stride, const ptrdiff_t *idx_x, const double *fact_x, const double *fact_y +#if ADD + , double fact +#endif #if !CONTIGUOUS , ptrdiff_t dst_stride0, ptrdiff_t src_stride0 # define SSTRIDE1 src_stride0 @@ -52,13 +61,18 @@ FUNC(interp_transfer_line)(double *dst, ptrdiff_t dst_len, val += tmp * fact_y[idx0]; } +#if ADD + dst[x * DSTRIDE1] += fact * val; +#else dst[x * DSTRIDE1] = val; +#endif } } #undef SSTRIDE1 #undef DSTRIDE1 #undef CONT -#undef JOIN3 +#undef ACT +#undef JOIN4 #undef FUNC2 #undef FUNC -- cgit v1.2.3