From 457d8274fde5d20dc5694576b99c5541e55b4f40 Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Mon, 18 Mar 2019 10:33:46 +0100 Subject: ell_grid_solve: use BiCGSTAB to speed up exact solves --- bicgstab.c | 182 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ bicgstab.h | 47 ++++++++++++++ ell_grid_solve.c | 63 ++++++++++++++++--- meson.build | 4 +- 4 files changed, 287 insertions(+), 9 deletions(-) create mode 100644 bicgstab.c create mode 100644 bicgstab.h diff --git a/bicgstab.c b/bicgstab.c new file mode 100644 index 0000000..66d3b0b --- /dev/null +++ b/bicgstab.c @@ -0,0 +1,182 @@ +/* + * BiCGStab iterative linear system solver + * Copyright (C) 2016 Anton Khirnov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include +#include +#include +#include + +#include "bicgstab.h" + +#define BICGSTAB_TOL (1e-15) + +struct BiCGStabContext { + size_t N; + unsigned int maxiter; + + double *x; + double *p, *v, *y, *z, *t; + double *res, *res0; + double *k; +}; + +// based on the wikipedia article +// and http://www.netlib.org/templates/matlab/bicgstab.m +static int solve_sw(BiCGStabContext *ctx, + const double *mat, const double *rhs, double *x) +{ + const int N = ctx->N; + const double rhs_norm = cblas_dnrm2(N, rhs, 1); + + double rho, rho_prev = 1.0; + double omega = 1.0; + double alpha = 1.0; + + double err; + int i; + + double *k = ctx->k; + double *p = ctx->p, *v = ctx->v, *y = ctx->y, *z = ctx->z, *t = ctx->t; + double *res = ctx->res, *res0 = ctx->res0; + + // initialize the residual + memcpy(res, rhs, N * sizeof(*res)); + cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, -1.0, + mat, N, ctx->x, 1, 1.0, res, 1); + + memcpy(res0, res, N * sizeof(*res0)); + memcpy(p, res, N * sizeof(*p)); + + for (i = 0; i < ctx->maxiter; i++) { + rho = cblas_ddot(N, res, 1, res0, 1); + + if (i) { + double beta = (rho / rho_prev) * (alpha / omega); + + cblas_daxpy(N, -omega, v, 1, p, 1); + cblas_dscal(N, beta, p, 1); + cblas_daxpy(N, 1, res, 1, p, 1); + } + + cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, 1.0, + k, N, p, 1, 0.0, y, 1); + + cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, 1.0, + mat, N, y, 1, 0.0, v, 1); + + alpha = rho / cblas_ddot(N, res0, 1, v, 1); + + cblas_daxpy(N, -alpha, v, 1, res, 1); + + cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, 1.0, + k, N, res, 1, 0.0, z, 1); + cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, 1.0, + mat, N, z, 1, 0.0, t, 1); + + omega = cblas_ddot(N, t, 1, res, 1) / cblas_ddot(N, t, 1, t, 1); + + cblas_daxpy(N, alpha, y, 1, ctx->x, 1); + cblas_daxpy(N, omega, z, 1, ctx->x, 1); + + cblas_daxpy(N, -omega, t, 1, res, 1); + + err = cblas_dnrm2(N, res, 1) / rhs_norm; + if (err < BICGSTAB_TOL) + break; + + rho_prev = rho; + } + if (i == ctx->maxiter) + return -1; + + memcpy(x, ctx->x, sizeof(*x) * ctx->N); + + return i; +} + +int mg2di_bicgstab_solve(BiCGStabContext *ctx, const double *mat, const double *rhs, double *x) +{ + int ret; + + ret = solve_sw(ctx, mat, rhs, x); + if (ret < 0) + return ret; + + return ret; +} + +int mg2di_bicgstab_init(BiCGStabContext *ctx, const double *k, const double *x0) +{ + memcpy(ctx->x, x0, ctx->N * sizeof(*x0)); + memcpy(ctx->k, k, ctx->N * ctx->N * sizeof(*k)); + + return 0; +} + +int mg2di_bicgstab_context_alloc(BiCGStabContext **pctx, size_t N, unsigned int maxiter) +{ + BiCGStabContext *ctx; + int ret = 0; + + ctx = calloc(1, sizeof(*ctx)); + if (!ctx) + return -ENOMEM; + + ctx->N = N; + ctx->maxiter = maxiter; + + ret |= posix_memalign((void**)&ctx->x, 32, sizeof(double) * N); + ret |= posix_memalign((void**)&ctx->p, 32, sizeof(double) * N); + ret |= posix_memalign((void**)&ctx->v, 32, sizeof(double) * N); + ret |= posix_memalign((void**)&ctx->y, 32, sizeof(double) * N); + ret |= posix_memalign((void**)&ctx->z, 32, sizeof(double) * N); + ret |= posix_memalign((void**)&ctx->t, 32, sizeof(double) * N); + ret |= posix_memalign((void**)&ctx->res, 32, sizeof(double) * N); + ret |= posix_memalign((void**)&ctx->res0, 32, sizeof(double) * N); + ret |= posix_memalign((void**)&ctx->k, 32, sizeof(double) * N * N); + +fail: + if (ret) { + mg2di_bicgstab_context_free(&ctx); + return -ENOMEM; + } + + *pctx = ctx; + return 0; +} + +void mg2di_bicgstab_context_free(BiCGStabContext **pctx) +{ + BiCGStabContext *ctx = *pctx; + + if (!ctx) + return; + + free(ctx->x); + free(ctx->p); + free(ctx->v); + free(ctx->y); + free(ctx->z); + free(ctx->t); + free(ctx->res); + free(ctx->res0); + free(ctx->k); + + free(ctx); + *pctx = NULL; +} diff --git a/bicgstab.h b/bicgstab.h new file mode 100644 index 0000000..f25cb28 --- /dev/null +++ b/bicgstab.h @@ -0,0 +1,47 @@ +/* + * BiCGStab iterative linear system solver + * Copyright (C) 2016 Anton Khirnov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#ifndef MG2D_BICGSTAB_H +#define MG2D_BICGSTAB_H + +typedef struct BiCGStabContext BiCGStabContext; + +/** + * Allocate and initialize the solver for the NxN system. + */ +int mg2di_bicgstab_context_alloc(BiCGStabContext **ctx, size_t N, unsigned int maxiter); + +/** + * Free the solver and all its internal state. + */ +void mg2di_bicgstab_context_free(BiCGStabContext **ctx); + +/** + * Initialise the solver with the given preconditioner matrix. This function + * may be any number of times on a given solver context. + */ +int mg2di_bicgstab_init(BiCGStabContext *ctx, const double *k, const double *x0); + +/** + * Solve the linear system + * mat ยท x = rhs + * The result is written into x. + */ +int mg2di_bicgstab_solve(BiCGStabContext *ctx, const double *mat, const double *rhs, double *x); + +#endif /* MG2D_BICGSTAB_H */ diff --git a/ell_grid_solve.c b/ell_grid_solve.c index cbb7d68..550f305 100644 --- a/ell_grid_solve.c +++ b/ell_grid_solve.c @@ -27,6 +27,7 @@ #include +#include "bicgstab.h" #include "common.h" #include "ell_grid_solve.h" #include "log.h" @@ -51,9 +52,13 @@ typedef struct EGSExactInternal { size_t N_ghosts; double *mat; + double *mat_f; double *rhs; + double *x; double *scratch_line; int *ipiv; + + BiCGStabContext *bicgstab; } EGSExactInternal; struct EGSInternal { @@ -511,23 +516,48 @@ static int solve_exact(EGSContext *ctx) } } + for (int i = 0; i < e->N; i++) + for (int j = i + 1; j < e->N; j++) { + double tmp = e->mat[j * e->N + i]; + e->mat[j * e->N + i] = e->mat[i * e->N + j]; + e->mat[i * e->N + j] = tmp; + } ec->time_mat_construct += gettime() - start; ec->count_mat_construct++; start = gettime(); - ret = LAPACKE_dgesv(LAPACK_ROW_MAJOR, e->N, 1, e->mat, e->N, e->ipiv, e->rhs, 1); - if (ret != 0) { - mg2di_log(&ctx->logger, MG2D_LOG_ERROR, - "Error solving the linear system: %d\n", ret); - return -EDOM; + + ret = mg2di_bicgstab_solve(e->bicgstab, e->mat, e->rhs, e->x); + if (ret < 0) { + char equed = 'N'; + double cond, ferr, berr, rpivot; + + ret = LAPACKE_dgesvx(LAPACK_COL_MAJOR, 'N', 'N', e->N, 1, + e->mat, e->N, e->mat_f, e->N, e->ipiv, &equed, NULL, NULL, + e->rhs, e->N, e->x, e->N, &cond, &ferr, &berr, &rpivot); + if (ret == 0) + ret = LAPACKE_dgetri(LAPACK_COL_MAJOR, e->N, e->mat_f, e->N, e->ipiv); + if (ret != 0) { + mg2di_log(&ctx->logger, MG2D_LOG_ERROR, + "Error solving the linear system: %d\n", ret); + return -EDOM; + } + mg2di_log(&ctx->logger, MG2D_LOG_DEBUG, + "LU factorization solution to a %zdx%zd matrix: " + "condition number %16.16g; forward error %16.16g backward error %16.16g\n", + e->N, e->N, cond, ferr, berr); + + ret = mg2di_bicgstab_init(e->bicgstab, e->mat_f, e->x); + if (ret < 0) + return ret; } + for (size_t idx1 = 0; idx1 < ctx->domain_size[1]; idx1++) + memcpy(ctx->u + idx1 * ctx->u_stride, e->x + idx1 * ctx->domain_size[0], ctx->domain_size[0] * sizeof(*e->x)); + ec->time_lin_solve += gettime() - start; ec->count_lin_solve++; - for (size_t idx1 = 0; idx1 < ctx->domain_size[1]; idx1++) - memcpy(ctx->u + idx1 * ctx->u_stride, e->rhs + idx1 * ctx->domain_size[0], ctx->domain_size[0] * sizeof(*e->rhs)); - boundaries_apply(ctx); residual_calc(ctx); @@ -692,10 +722,23 @@ static int arrays_alloc(EGSContext *ctx, const size_t domain_size[2]) e->scratch_line = calloc(e->N_ghosts, sizeof(*e->scratch_line)); e->mat = calloc(SQR(e->N), sizeof(*e->mat)); + e->mat_f = calloc(SQR(e->N), sizeof(*e->mat_f)); e->rhs = calloc(e->N, sizeof(*e->rhs)); + e->x = calloc(e->N, sizeof(*e->x)); e->ipiv = calloc(e->N, sizeof(*e->ipiv)); if (!e->scratch_line || !e->mat || !e->rhs || !e->ipiv) return -ENOMEM; + + ret = mg2di_bicgstab_context_alloc(&e->bicgstab, e->N, 64); + if (ret < 0) + return ret; + + for (int i = 0; i < e->N; i++) + e->mat[i * e->N + i] = 1.0; + + ret = mg2di_bicgstab_init(e->bicgstab, e->mat, e->x); + if (ret < 0) + return ret; } return 0; @@ -766,8 +809,12 @@ void mg2di_egs_free(EGSContext **pctx) free(e->scratch_line); free(e->mat); + free(e->mat_f); free(e->rhs); + free(e->x); free(e->ipiv); + + mg2di_bicgstab_context_free(&e->bicgstab); } free(ctx->priv->u_base); diff --git a/meson.build b/meson.build index 48d80ef..73dfee9 100644 --- a/meson.build +++ b/meson.build @@ -4,6 +4,7 @@ project('libmg2d', 'c', add_project_arguments('-D_XOPEN_SOURCE=700', language : 'c') lib_src = [ + 'bicgstab.c', 'boundary.c', 'cpu.c', 'ell_grid_solve.c', @@ -18,11 +19,12 @@ ver_flag = '-Wl,--version-script,@0@/@1@'.format(meson.current_source_dir(), ve cc = meson.get_compiler('c') libm = cc.find_library('m', required : false) +libcblas = cc.find_library('blas') liblapacke = cc.find_library('lapacke') dep_tp = declare_dependency(link_args : '-lthreadpool') -deps = [dep_tp, libm, liblapacke] +deps = [dep_tp, libm, libcblas, liblapacke] cdata = configuration_data() cdata.set10('ARCH_X86', false) -- cgit v1.2.3