From d519d82a3e4b32944b77b1ae26cfefa45ec29d71 Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Sat, 26 Jan 2019 10:19:30 +0100 Subject: ell_relax -> ell_grid_solve Generalize the API to allow for multiple solver types. This is done in preparation for the exact linear system inversion solver. --- ell_grid_solve.c | 647 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 647 insertions(+) create mode 100644 ell_grid_solve.c (limited to 'ell_grid_solve.c') diff --git a/ell_grid_solve.c b/ell_grid_solve.c new file mode 100644 index 0000000..191fc8f --- /dev/null +++ b/ell_grid_solve.c @@ -0,0 +1,647 @@ +/* + * Copyright 2018 Anton Khirnov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "config.h" + +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "cpu.h" +#include "ell_grid_solve.h" +#include "log.h" +#include "mg2d_boundary.h" +#include "mg2d_boundary_internal.h" +#include "mg2d_constants.h" + +static const double relax_factors[FD_STENCIL_MAX] = { + [0] = 1.0 / 5, + [1] = 1.0 / 7, +}; + +typedef struct EGSRelaxInternal { + double relax_factor; +} EGSRelaxInternal; + +struct EGSInternal { + ptrdiff_t stride; + double *u_base; + double *rhs_base; + double *residual_base; + double *diff_coeffs_base[MG2D_DIFF_COEFF_NB]; + + double *residual_max; + size_t residual_max_size; + + void (*residual_calc_line)(size_t linesize, double *dst, double *dst_max, + ptrdiff_t stride, const double *u, const double *rhs, + const double * const diff_coeffs[MG2D_DIFF_COEFF_NB], + const double *fd_factors); + size_t calc_blocksize; + ptrdiff_t residual_calc_offset; + size_t residual_calc_size[2]; + double fd_factors[MG2D_DIFF_COEFF_NB]; + + union { + EGSRelaxInternal r; + }; + + TPContext *tp_internal; +}; + +static const struct { + unsigned int stride_idx; + unsigned int is_upper; +} boundary_def[] = { + [MG2D_BOUNDARY_0L] = { + .stride_idx = 0, + .is_upper = 0 + }, + [MG2D_BOUNDARY_0U] = { + .stride_idx = 0, + .is_upper = 1, + }, + [MG2D_BOUNDARY_1L] = { + .stride_idx = 1, + .is_upper = 0, + }, + [MG2D_BOUNDARY_1U] = { + .stride_idx = 1, + .is_upper = 1, + }, +}; + +static const double fd_denoms[][MG2D_DIFF_COEFF_NB] = { + { + [MG2D_DIFF_COEFF_00] = 1.0, + [MG2D_DIFF_COEFF_10] = 2.0, + [MG2D_DIFF_COEFF_01] = 2.0, + [MG2D_DIFF_COEFF_20] = 1.0, + [MG2D_DIFF_COEFF_02] = 1.0, + [MG2D_DIFF_COEFF_11] = 4.0, + }, + { + [MG2D_DIFF_COEFF_00] = 1.0, + [MG2D_DIFF_COEFF_10] = 12.0, + [MG2D_DIFF_COEFF_01] = 12.0, + [MG2D_DIFF_COEFF_20] = 12.0, + [MG2D_DIFF_COEFF_02] = 12.0, + [MG2D_DIFF_COEFF_11] = 144.0, + }, +}; + +#if HAVE_EXTERNAL_ASM +void mg2di_residual_calc_line_s1_fma3(size_t linesize, double *dst, double *dst_max, + ptrdiff_t stride, const double *u, const double *rhs, + const double * const diff_coeffs[MG2D_DIFF_COEFF_NB], + const double *fd_factors); +void mg2di_residual_calc_line_s2_fma3(size_t linesize, double *dst, double *dst_max, + ptrdiff_t stride, const double *u, const double *rhs, + const double * const diff_coeffs[MG2D_DIFF_COEFF_NB], + const double *fd_factors); +#endif + +static void +derivatives_calc_s1(double *dst, const double *u, const double *fd_factors, ptrdiff_t stride) +{ + dst[MG2D_DIFF_COEFF_00] = u[0]; + dst[MG2D_DIFF_COEFF_10] = (u[1] - u[-1]) * fd_factors[MG2D_DIFF_COEFF_10]; + dst[MG2D_DIFF_COEFF_01] = (u[stride] - u[-stride]) * fd_factors[MG2D_DIFF_COEFF_01]; + + dst[MG2D_DIFF_COEFF_20] = (u[1] - 2.0 * u[0] + u[-1]) * fd_factors[MG2D_DIFF_COEFF_20]; + dst[MG2D_DIFF_COEFF_02] = (u[stride] - 2.0 * u[0] + u[-stride]) * fd_factors[MG2D_DIFF_COEFF_02]; + + dst[MG2D_DIFF_COEFF_11] = (u[1 + stride] - u[stride - 1] - u[-stride + 1] + u[-stride - 1]) * fd_factors[MG2D_DIFF_COEFF_11]; +} + +static void +derivatives_calc_s2(double *dst, const double *u, const double *fd_factors, ptrdiff_t stride) +{ + const double val = u[0]; + + const double valxp1 = u[ 1]; + const double valxp2 = u[ 2]; + const double valxm1 = u[-1]; + const double valxm2 = u[-2]; + const double valyp1 = u[ 1 * stride]; + const double valyp2 = u[ 2 * stride]; + const double valym1 = u[-1 * stride]; + const double valym2 = u[-2 * stride]; + + const double valxp1yp1 = u[ 1 + 1 * stride]; + const double valxp1yp2 = u[ 1 + 2 * stride]; + const double valxp1ym1 = u[ 1 - 1 * stride]; + const double valxp1ym2 = u[ 1 - 2 * stride]; + + const double valxp2yp1 = u[ 2 + 1 * stride]; + const double valxp2yp2 = u[ 2 + 2 * stride]; + const double valxp2ym1 = u[ 2 - 1 * stride]; + const double valxp2ym2 = u[ 2 - 2 * stride]; + + const double valxm1yp1 = u[-1 + 1 * stride]; + const double valxm1yp2 = u[-1 + 2 * stride]; + const double valxm1ym1 = u[-1 - 1 * stride]; + const double valxm1ym2 = u[-1 - 2 * stride]; + + const double valxm2yp1 = u[-2 + 1 * stride]; + const double valxm2yp2 = u[-2 + 2 * stride]; + const double valxm2ym1 = u[-2 - 1 * stride]; + const double valxm2ym2 = u[-2 - 2 * stride]; + + dst[MG2D_DIFF_COEFF_00] = val; + dst[MG2D_DIFF_COEFF_10] = (-1.0 * valxp2 + 8.0 * valxp1 - 8.0 * valxm1 + 1.0 * valxm2) * fd_factors[MG2D_DIFF_COEFF_10]; + dst[MG2D_DIFF_COEFF_01] = (-1.0 * valyp2 + 8.0 * valyp1 - 8.0 * valym1 + 1.0 * valym2) * fd_factors[MG2D_DIFF_COEFF_01]; + + dst[MG2D_DIFF_COEFF_20] = (-1.0 * valxp2 + 16.0 * valxp1 - 30.0 * val + 16.0 * valxm1 - 1.0 * valxm2) * fd_factors[MG2D_DIFF_COEFF_20]; + dst[MG2D_DIFF_COEFF_02] = (-1.0 * valyp2 + 16.0 * valyp1 - 30.0 * val + 16.0 * valym1 - 1.0 * valym2) * fd_factors[MG2D_DIFF_COEFF_02]; + + dst[MG2D_DIFF_COEFF_11] = ( 1.0 * valxp2yp2 - 8.0 * valxp2yp1 + 8.0 * valxp2ym1 - 1.0 * valxp2ym2 + -8.0 * valxp1yp2 + 64.0 * valxp1yp1 - 64.0 * valxp1ym1 + 8.0 * valxp1ym2 + +8.0 * valxm1yp2 - 64.0 * valxm1yp1 + 64.0 * valxm1ym1 - 8.0 * valxm1ym2 + -1.0 * valxm2yp2 + 8.0 * valxm2yp1 - 8.0 * valxm2ym1 + 1.0 * valxm2ym2) * fd_factors[MG2D_DIFF_COEFF_11]; +} + +static void residual_calc_line_s1_c(size_t linesize, double *dst, double *dst_max, + ptrdiff_t stride, const double *u, const double *rhs, + const double * const diff_coeffs[MG2D_DIFF_COEFF_NB], + const double *fd_factors) +{ + double res_max = 0.0, res_abs; + for (size_t i = 0; i < linesize; i++) { + double u_vals[MG2D_DIFF_COEFF_NB]; + double res; + + derivatives_calc_s1(u_vals, u + i, fd_factors, stride); + + res = -rhs[i]; + for (int j = 0; j < ARRAY_ELEMS(u_vals); j++) + res += u_vals[j] * diff_coeffs[j][i]; + dst[i] = res; + + res_abs = fabs(res); + res_max = MAX(res_max, res_abs); + } + + *dst_max = MAX(*dst_max, res_max); +} + +static void residual_calc_line_s2_c(size_t linesize, double *dst, double *dst_max, + ptrdiff_t stride, const double *u, const double *rhs, + const double * const diff_coeffs[MG2D_DIFF_COEFF_NB], + const double *fd_factors) +{ + double res_max = 0.0, res_abs; + for (size_t i = 0; i < linesize; i++) { + double u_vals[MG2D_DIFF_COEFF_NB]; + double res; + + derivatives_calc_s2(u_vals, u + i, fd_factors, stride); + + res = -rhs[i]; + for (int j = 0; j < ARRAY_ELEMS(u_vals); j++) + res += u_vals[j] * diff_coeffs[j][i]; + dst[i] = res; + + res_abs = fabs(res); + res_max = MAX(res_max, res_abs); + } + + *dst_max = MAX(*dst_max, res_max); +} + +static void residual_calc_task(void *arg, unsigned int job_idx, unsigned int thread_idx) +{ + EGSContext *ctx = arg; + EGSInternal *priv = ctx->priv; + const ptrdiff_t offset = priv->residual_calc_offset + job_idx * priv->stride; + const double *diff_coeffs[MG2D_DIFF_COEFF_NB]; + + for (int i = 0; i < ARRAY_ELEMS(diff_coeffs); i++) + diff_coeffs[i] = ctx->diff_coeffs[i] + offset; + + priv->residual_calc_line(priv->residual_calc_size[0], ctx->residual + offset, + priv->residual_max + thread_idx * priv->calc_blocksize, + priv->stride, ctx->u + offset, ctx->rhs + offset, + diff_coeffs, priv->fd_factors); +} + +static void residual_calc(EGSContext *ctx) +{ + EGSInternal *priv = ctx->priv; + double res_max = 0.0; + int64_t start; + + memset(priv->residual_max, 0, sizeof(*priv->residual_max) * priv->residual_max_size); + + start = gettime(); + + tp_execute(ctx->tp, priv->residual_calc_size[1], residual_calc_task, ctx); + + for (size_t i = 0; i < priv->residual_max_size; i++) + res_max = MAX(res_max, priv->residual_max[i]); + ctx->residual_max = res_max; + + ctx->time_res_calc += gettime() - start; + ctx->count_res++; +} + +static void boundaries_apply_fixval(double *dst, const ptrdiff_t dst_stride[2], + const double *src, ptrdiff_t src_stride, + size_t boundary_size) +{ + for (int j = 0; j < FD_STENCIL_MAX; j++) { + for (ptrdiff_t i = -j; i < (ptrdiff_t)boundary_size + j; i++) + dst[i * dst_stride[0]] = src[i]; + dst += dst_stride[1]; + src += src_stride; + } +} + +static void boundaries_apply_fixdiff(double *dst, const ptrdiff_t dst_stride[2], + const double *src, ptrdiff_t src_stride, + size_t boundary_size) +{ + for (size_t i = 0; i < boundary_size; i++) { + for (int j = 1; j <= FD_STENCIL_MAX; j++) + dst[dst_stride[1] * j] = dst[-dst_stride[1] * j]; + + dst += dst_stride[0]; + } +} + +static void boundaries_apply(EGSContext *ctx) +{ + EGSInternal *priv = ctx->priv; + const ptrdiff_t strides[2] = { 1, priv->stride }; + int64_t start; + + start = gettime(); + for (int i = 0; i < ARRAY_ELEMS(ctx->boundaries); i++) { + const int si = boundary_def[i].stride_idx; + const ptrdiff_t stride_boundary = strides[si]; + const ptrdiff_t stride_offset = strides[!si]; + const size_t size_boundary = ctx->domain_size[si]; + const size_t size_offset = ctx->domain_size[!si]; + + double *dst = ctx->u + boundary_def[i].is_upper * ((size_offset - 1) * stride_offset); + const ptrdiff_t dst_strides[] = { stride_boundary, + (boundary_def[i].is_upper ? 1 : -1) * stride_offset }; + + switch (ctx->boundaries[i]->type) { + case MG2D_BC_TYPE_FIXVAL: + boundaries_apply_fixval(dst, dst_strides, ctx->boundaries[i]->val, + ctx->boundaries[i]->val_stride, size_boundary); + break; + case MG2D_BC_TYPE_FIXDIFF: + boundaries_apply_fixdiff(dst, dst_strides, ctx->boundaries[i]->val, + ctx->boundaries[i]->val_stride, size_boundary); + break; + } + } + + /* fill in the corner ghosts */ + if (ctx->boundaries[MG2D_BOUNDARY_0L]->type == MG2D_BC_TYPE_FIXDIFF || + ctx->boundaries[MG2D_BOUNDARY_1L]->type == MG2D_BC_TYPE_FIXDIFF) { + double *dst = ctx->u; + int fact_x = -1, fact_z = -1; + + if (ctx->boundaries[MG2D_BOUNDARY_0L]->type == MG2D_BC_TYPE_FIXDIFF) + fact_z *= -1; + else + fact_x *= -1; + + for (int i = 1; i <= FD_STENCIL_MAX; i++) + for (int j = 1; j <= FD_STENCIL_MAX; j++) { + const ptrdiff_t idx_dst = -j * strides[1] - i; + const ptrdiff_t idx_src = fact_z * j * strides[1] + fact_x * i; + dst[idx_dst] = dst[idx_src]; + } + } + if (ctx->boundaries[MG2D_BOUNDARY_0L]->type == MG2D_BC_TYPE_FIXDIFF || + ctx->boundaries[MG2D_BOUNDARY_1U]->type == MG2D_BC_TYPE_FIXDIFF) { + double *dst = ctx->u + ctx->domain_size[0] - 1; + int fact_x = 1, fact_z = -1; + + if (ctx->boundaries[MG2D_BOUNDARY_0L]->type == MG2D_BC_TYPE_FIXDIFF) + fact_z *= -1; + else + fact_x *= -1; + + for (int i = 1; i <= FD_STENCIL_MAX; i++) + for (int j = 1; j <= FD_STENCIL_MAX; j++) { + const ptrdiff_t idx_dst = -j * strides[1] + i; + const ptrdiff_t idx_src = fact_z * j * strides[1] + fact_x * i; + dst[idx_dst] = dst[idx_src]; + } + } + if (ctx->boundaries[MG2D_BOUNDARY_1L]->type == MG2D_BC_TYPE_FIXDIFF || + ctx->boundaries[MG2D_BOUNDARY_0U]->type == MG2D_BC_TYPE_FIXDIFF) { + double *dst = ctx->u + strides[1] * (ctx->domain_size[1] - 1); + int fact_x = -1, fact_z = 1; + + if (ctx->boundaries[MG2D_BOUNDARY_0U]->type == MG2D_BC_TYPE_FIXDIFF) + fact_z *= -1; + else + fact_x *= -1; + + for (int i = 1; i <= FD_STENCIL_MAX; i++) + for (int j = 1; j <= FD_STENCIL_MAX; j++) { + const ptrdiff_t idx_dst = j * strides[1] - i; + const ptrdiff_t idx_src = fact_z * j * strides[1] + fact_x * i; + dst[idx_dst] = dst[idx_src]; + } + } + if (ctx->boundaries[MG2D_BOUNDARY_0U]->type == MG2D_BC_TYPE_FIXDIFF || + ctx->boundaries[MG2D_BOUNDARY_1U]->type == MG2D_BC_TYPE_FIXDIFF) { + double *dst = ctx->u + strides[1] * (ctx->domain_size[1] - 1) + ctx->domain_size[0] - 1; + int fact_x = 1, fact_z = 1; + + if (ctx->boundaries[MG2D_BOUNDARY_0U]->type == MG2D_BC_TYPE_FIXDIFF) + fact_z *= -1; + else + fact_x *= -1; + + for (int i = 1; i <= FD_STENCIL_MAX; i++) + for (int j = 1; j <= FD_STENCIL_MAX; j++) { + const ptrdiff_t idx_dst = j * strides[1] + i; + const ptrdiff_t idx_src = fact_z * j * strides[1] + fact_x * i; + dst[idx_dst] = dst[idx_src]; + } + } + ctx->time_boundaries += gettime() - start; + ctx->count_boundaries++; +} + +static void residual_add_task(void *arg, unsigned int job_idx, unsigned int thread_idx) +{ + EGSContext *ctx = arg; + EGSInternal *priv = ctx->priv; + ptrdiff_t offset = job_idx * priv->stride; + + for (int idx0 = 0; idx0 < ctx->domain_size[0]; idx0++) { + ptrdiff_t idx = job_idx * ctx->u_stride + idx0; + + ctx->u[idx] += priv->r.relax_factor * ctx->residual[idx]; + } +} + +static int solve_relax_step(EGSContext *ctx) +{ + EGSRelaxContext *r = ctx->solver_data; + EGSInternal *priv = ctx->priv; + int64_t start; + + start = gettime(); + + tp_execute(ctx->tp, ctx->domain_size[1], residual_add_task, ctx); + + r->time_correct += gettime() - start; + r->count_correct++; + + boundaries_apply(ctx); + residual_calc(ctx); + + return 0; +} +int mg2di_egs_solve(EGSContext *ctx) +{ + switch (ctx->solver_type) { + case EGS_SOLVER_RELAXATION: return solve_relax_step(ctx); + } + + return -EINVAL; +} + +int mg2di_egs_init(EGSContext *ctx) +{ + EGSInternal *priv = ctx->priv; + double *tmp; + int ret; + + priv->calc_blocksize = 1; + switch (ctx->fd_stencil) { + case 1: + priv->residual_calc_line = residual_calc_line_s1_c; +#if HAVE_EXTERNAL_ASM + if (ctx->cpuflags & MG2DI_CPU_FLAG_FMA3) { + priv->residual_calc_line = mg2di_residual_calc_line_s1_fma3; + priv->calc_blocksize = 4; + } +#endif + break; + case 2: + priv->residual_calc_line = residual_calc_line_s2_c; +#if HAVE_EXTERNAL_ASM + if (ctx->cpuflags & MG2DI_CPU_FLAG_FMA3) { + priv->residual_calc_line = mg2di_residual_calc_line_s2_fma3; + priv->calc_blocksize = 4; + } +#endif + break; + default: + mg2di_log(&ctx->logger, 0, "Invalid finite difference stencil: %zd\n", + ctx->fd_stencil); + return -EINVAL; + } + + if (ctx->step[0] <= DBL_EPSILON || ctx->step[1] <= DBL_EPSILON) { + mg2di_log(&ctx->logger, 0, "Spatial step size too small\n"); + return -EINVAL; + } + + if (ctx->solver_type == EGS_SOLVER_RELAXATION) { + EGSRelaxContext *r = ctx->solver_data; + if (r->relax_factor == 0.0) + priv->r.relax_factor = relax_factors[ctx->fd_stencil - 1]; + else + priv->r.relax_factor = r->relax_factor; + priv->r.relax_factor *= ctx->step[0] * ctx->step[0]; + } + + priv->fd_factors[MG2D_DIFF_COEFF_00] = 1.0 / fd_denoms[ctx->fd_stencil - 1][MG2D_DIFF_COEFF_00]; + priv->fd_factors[MG2D_DIFF_COEFF_10] = 1.0 / (fd_denoms[ctx->fd_stencil - 1][MG2D_DIFF_COEFF_10] * ctx->step[0]); + priv->fd_factors[MG2D_DIFF_COEFF_01] = 1.0 / (fd_denoms[ctx->fd_stencil - 1][MG2D_DIFF_COEFF_01] * ctx->step[1]); + priv->fd_factors[MG2D_DIFF_COEFF_20] = 1.0 / (fd_denoms[ctx->fd_stencil - 1][MG2D_DIFF_COEFF_20] * SQR(ctx->step[0])); + priv->fd_factors[MG2D_DIFF_COEFF_02] = 1.0 / (fd_denoms[ctx->fd_stencil - 1][MG2D_DIFF_COEFF_02] * SQR(ctx->step[1])); + priv->fd_factors[MG2D_DIFF_COEFF_11] = 1.0 / (fd_denoms[ctx->fd_stencil - 1][MG2D_DIFF_COEFF_11] * ctx->step[0] * ctx->step[1]); + + if (!ctx->tp) { + ret = tp_init(&priv->tp_internal, 1); + if (ret < 0) + return ret; + ctx->tp = priv->tp_internal; + } + + priv->residual_calc_offset = 0; + priv->residual_calc_size[0] = ctx->domain_size[0]; + priv->residual_calc_size[1] = ctx->domain_size[1]; + if (ctx->boundaries[MG2D_BOUNDARY_0L]->type == MG2D_BC_TYPE_FIXVAL) + priv->residual_calc_offset += ctx->residual_stride; + if (ctx->boundaries[MG2D_BOUNDARY_1L]->type == MG2D_BC_TYPE_FIXVAL) + priv->residual_calc_offset++; + + for (int i = 0; i < ARRAY_ELEMS(ctx->boundaries); i++) { + MG2DBoundary *bnd = ctx->boundaries[i]; + if (bnd->type == MG2D_BC_TYPE_FIXDIFF) { + for (int k = 0; k < ctx->domain_size[boundary_def[i].stride_idx]; k++) + if (bnd->val[k] != 0.0) { + mg2di_log(&ctx->logger, 0, "Only zero boundary derivative supported\n"); + return -ENOSYS; + } + } else if (bnd->type == MG2D_BC_TYPE_FIXVAL) { + priv->residual_calc_size[!boundary_def[i].stride_idx]--; + } + } + + priv->residual_max_size = tp_get_nb_threads(ctx->tp) * priv->calc_blocksize; + tmp = realloc(priv->residual_max, + sizeof(*priv->residual_max) * priv->residual_max_size); + if (!tmp) { + priv->residual_max_size = 0; + return -ENOMEM; + } + priv->residual_max = tmp; + + boundaries_apply(ctx); + residual_calc(ctx); + + return 0; +} + +static int arrays_alloc(EGSContext *ctx, const size_t domain_size[2]) +{ + EGSInternal *priv = ctx->priv; + + const size_t ghosts = FD_STENCIL_MAX; + const size_t size_padded[2] = { + domain_size[0] + 2 * ghosts, + domain_size[1] + 2 * ghosts, + }; + const size_t stride = size_padded[0]; + const size_t start_offset = ghosts * stride + ghosts; + const size_t arr_size = size_padded[0] * size_padded[1]; + int ret; + + ret = posix_memalign((void**)&priv->u_base, 32, sizeof(*priv->u_base) * arr_size); + if (ret != 0) + return -ret; + ctx->u = priv->u_base + start_offset; + ctx->u_stride = stride; + + ret = posix_memalign((void**)&priv->rhs_base, 32, sizeof(*priv->rhs_base) * arr_size); + if (ret != 0) + return -ret; + ctx->rhs = priv->rhs_base + start_offset; + ctx->rhs_stride = stride; + + ret = posix_memalign((void**)&priv->residual_base, 32, sizeof(*priv->residual_base) * arr_size); + if (ret != 0) + return -ret; + memset(priv->residual_base, 0, sizeof(*priv->residual_base) * arr_size); + ctx->residual = priv->residual_base + start_offset; + ctx->residual_stride = stride; + + for (int i = 0; i < ARRAY_ELEMS(ctx->diff_coeffs); i++) { + ret = posix_memalign((void**)&priv->diff_coeffs_base[i], 32, + sizeof(*priv->diff_coeffs_base[i]) * arr_size); + if (ret != 0) + return -ret; + ctx->diff_coeffs[i] = priv->diff_coeffs_base[i] + start_offset; + } + ctx->diff_coeffs_stride = stride; + + priv->stride = stride; + + for (int i = 0; i < ARRAY_ELEMS(ctx->boundaries); i++) { + ctx->boundaries[i] = mg2di_bc_alloc(domain_size[boundary_def[i].stride_idx]); + if (!ctx->boundaries[i]) + return -ENOMEM; + } + + return 0; +} + +EGSContext *mg2di_egs_alloc(enum EGSType type, size_t domain_size[2]) +{ + EGSContext *ctx; + EGSInternal *priv; + int ret; + + ctx = calloc(1, sizeof(*ctx)); + if (!ctx) + return NULL; + + ctx->priv = calloc(1, sizeof(*ctx->priv)); + if (!ctx->priv) + goto fail; + priv = ctx->priv; + + switch (type) { + case EGS_SOLVER_RELAXATION: + ctx->solver_data = calloc(1, sizeof(EGSRelaxContext)); + if (!ctx->solver_data) + goto fail; + break; + default: goto fail; + } + ctx->solver_type = type; + + if (!domain_size[0] || !domain_size[1] || + domain_size[0] > SIZE_MAX / domain_size[1]) + goto fail; + + ret = arrays_alloc(ctx, domain_size); + if (ret < 0) + goto fail; + + ctx->domain_size[0] = domain_size[0]; + ctx->domain_size[1] = domain_size[1]; + + return ctx; +fail: + mg2di_egs_free(&ctx); + return NULL; +} + +void mg2di_egs_free(EGSContext **pctx) +{ + EGSContext *ctx = *pctx; + + if (!ctx) + return; + + free(ctx->solver_data); + + free(ctx->priv->u_base); + free(ctx->priv->rhs_base); + free(ctx->priv->residual_base); + free(ctx->priv->residual_max); + for (int i = 0; i < ARRAY_ELEMS(ctx->priv->diff_coeffs_base); i++) + free(ctx->priv->diff_coeffs_base[i]); + for (int i = 0; i < ARRAY_ELEMS(ctx->boundaries); i++) + mg2di_bc_free(&ctx->boundaries[i]); + + tp_free(&ctx->priv->tp_internal); + free(ctx->priv); + + free(ctx); + *pctx = NULL; +} -- cgit v1.2.3