From bd178d67da6a8c30b3ccbd020be1b00f42eceb53 Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Fri, 22 Mar 2019 19:02:28 +0100 Subject: mg2d: use ndarray for internal arrays --- mg2d.c | 61 +++++++++++++++++++++++++++++++------------------------------ 1 file changed, 31 insertions(+), 30 deletions(-) (limited to 'mg2d.c') diff --git a/mg2d.c b/mg2d.c index a9b98a3..531b061 100644 --- a/mg2d.c +++ b/mg2d.c @@ -61,6 +61,10 @@ struct MG2DInternal { MG2DLevel *root; + NDArray *u; + NDArray *rhs; + NDArray *diff_coeffs[MG2D_DIFF_COEFF_NB]; + int cpuflags; int64_t time_solve; @@ -509,14 +513,6 @@ finish: return 0; } -static void array_copy(double *dst, ptrdiff_t dst_stride, - const double *src, ptrdiff_t src_stride, - size_t linesize, size_t nb_lines) -{ - for (size_t i = 0; i < nb_lines; i++) - memcpy(dst + i * dst_stride, src + i * src_stride, linesize * sizeof(*dst)); -} - static void bnd_zero(MG2DBoundary *bdst, size_t nb_rows, size_t domain_size) { for (size_t i = 0; i < nb_rows; i++) { @@ -581,13 +577,10 @@ static int mg_levels_init(MG2DContext *ctx) cur = priv->root; prev = NULL; - array_copy(cur->solver->u->data, cur->solver->u->stride[0], ctx->u, ctx->u_stride, ctx->domain_size, ctx->domain_size); - array_copy(cur->solver->rhs->data, cur->solver->rhs->stride[0], ctx->rhs, ctx->rhs_stride, ctx->domain_size, ctx->domain_size); - for (int i = 0; i < ARRAY_ELEMS(ctx->diff_coeffs); i++) { - array_copy(cur->solver->diff_coeffs[i]->data, cur->solver->diff_coeffs[i]->stride[0], - ctx->diff_coeffs[i], ctx->diff_coeffs_stride, - ctx->domain_size, ctx->domain_size); - } + mg2di_ndarray_copy(cur->solver->u, priv->u); + mg2di_ndarray_copy(cur->solver->rhs, priv->rhs); + for (int i = 0; i < ARRAY_ELEMS(ctx->diff_coeffs); i++) + mg2di_ndarray_copy(cur->solver->diff_coeffs[i], priv->diff_coeffs[i]); while (cur) { if (!prev) { @@ -710,7 +703,7 @@ int mg2d_solve(MG2DContext *ctx) mg2di_log(&priv->logger, MG2D_LOG_INFO, "converged on iteration %d, residual %g\n", i, res_cur); - array_copy(ctx->u, ctx->u_stride, s_root->u->data, s_root->u->stride[0], ctx->domain_size, ctx->domain_size); + mg2di_ndarray_copy(priv->u, s_root->u); priv->time_solve += gettime() - time_start; priv->count_solve++; @@ -839,6 +832,7 @@ MG2DContext *mg2d_solver_alloc(size_t domain_size) { MG2DContext *ctx; MG2DInternal *priv; + int ret; if (domain_size < 3 || SIZE_MAX / domain_size < domain_size) return NULL; @@ -864,22 +858,28 @@ MG2DContext *mg2d_solver_alloc(size_t domain_size) } } - ctx->u = calloc(SQR(domain_size), sizeof(*ctx->u)); - if (!ctx->u) + ret = mg2di_ndarray_alloc(&priv->u, 2, (size_t [2]){ domain_size, domain_size }, + NDARRAY_ALLOC_ZERO); + if (ret < 0) goto fail; - ctx->u_stride = domain_size; + ctx->u = priv->u->data; + ctx->u_stride = priv->u->stride[0]; - ctx->rhs = calloc(SQR(domain_size), sizeof(*ctx->rhs)); - if (!ctx->rhs) + ret = mg2di_ndarray_alloc(&priv->rhs, 2, (size_t [2]){ domain_size, domain_size }, + NDARRAY_ALLOC_ZERO); + if (ret < 0) goto fail; - ctx->rhs_stride = domain_size; + ctx->rhs = priv->rhs->data; + ctx->rhs_stride = priv->rhs->stride[0]; for (int i = 0; i < ARRAY_ELEMS(ctx->diff_coeffs); i++) { - ctx->diff_coeffs[i] = calloc(SQR(domain_size), sizeof(*ctx->diff_coeffs[i])); - if (!ctx->diff_coeffs[i]) + ret = mg2di_ndarray_alloc(&priv->diff_coeffs[i], 2, (size_t [2]){ domain_size, domain_size }, + NDARRAY_ALLOC_ZERO); + if (ret < 0) goto fail; + ctx->diff_coeffs[i] = priv->diff_coeffs[i]->data; } - ctx->diff_coeffs_stride = domain_size; + ctx->diff_coeffs_stride = priv->diff_coeffs[0]->stride[0]; ctx->domain_size = domain_size; @@ -917,12 +917,13 @@ void mg2d_solver_free(MG2DContext **pctx) mg2di_bc_free(&ctx->boundaries[i]); tp_free(&ctx->priv->tp); - free(ctx->priv); - free(ctx->u); - free(ctx->rhs); - for (int i = 0; i < ARRAY_ELEMS(ctx->diff_coeffs); i++) - free(ctx->diff_coeffs[i]); + mg2di_ndarray_free(&ctx->priv->u); + mg2di_ndarray_free(&ctx->priv->rhs); + for (int i = 0; i < ARRAY_ELEMS(ctx->priv->diff_coeffs); i++) + mg2di_ndarray_free(&ctx->priv->diff_coeffs[i]); + + free(ctx->priv); free(ctx); *pctx = NULL; -- cgit v1.2.3