summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2019-03-18 10:33:46 +0100
committerAnton Khirnov <anton@khirnov.net>2019-03-18 10:33:46 +0100
commit457d8274fde5d20dc5694576b99c5541e55b4f40 (patch)
treeca1df761a0340bcc1bd5324895d8067d3b8528e6
parent53e7613a1111702bb62708d1e5aff8b18fa9c9cb (diff)
ell_grid_solve: use BiCGSTAB to speed up exact solves
-rw-r--r--bicgstab.c182
-rw-r--r--bicgstab.h47
-rw-r--r--ell_grid_solve.c63
-rw-r--r--meson.build4
4 files changed, 287 insertions, 9 deletions
diff --git a/bicgstab.c b/bicgstab.c
new file mode 100644
index 0000000..66d3b0b
--- /dev/null
+++ b/bicgstab.c
@@ -0,0 +1,182 @@
+/*
+ * BiCGStab iterative linear system solver
+ * Copyright (C) 2016 Anton Khirnov <anton@khirnov.net>
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include <cblas.h>
+#include <errno.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include "bicgstab.h"
+
+#define BICGSTAB_TOL (1e-15)
+
+struct BiCGStabContext {
+ size_t N;
+ unsigned int maxiter;
+
+ double *x;
+ double *p, *v, *y, *z, *t;
+ double *res, *res0;
+ double *k;
+};
+
+// based on the wikipedia article
+// and http://www.netlib.org/templates/matlab/bicgstab.m
+static int solve_sw(BiCGStabContext *ctx,
+ const double *mat, const double *rhs, double *x)
+{
+ const int N = ctx->N;
+ const double rhs_norm = cblas_dnrm2(N, rhs, 1);
+
+ double rho, rho_prev = 1.0;
+ double omega = 1.0;
+ double alpha = 1.0;
+
+ double err;
+ int i;
+
+ double *k = ctx->k;
+ double *p = ctx->p, *v = ctx->v, *y = ctx->y, *z = ctx->z, *t = ctx->t;
+ double *res = ctx->res, *res0 = ctx->res0;
+
+ // initialize the residual
+ memcpy(res, rhs, N * sizeof(*res));
+ cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, -1.0,
+ mat, N, ctx->x, 1, 1.0, res, 1);
+
+ memcpy(res0, res, N * sizeof(*res0));
+ memcpy(p, res, N * sizeof(*p));
+
+ for (i = 0; i < ctx->maxiter; i++) {
+ rho = cblas_ddot(N, res, 1, res0, 1);
+
+ if (i) {
+ double beta = (rho / rho_prev) * (alpha / omega);
+
+ cblas_daxpy(N, -omega, v, 1, p, 1);
+ cblas_dscal(N, beta, p, 1);
+ cblas_daxpy(N, 1, res, 1, p, 1);
+ }
+
+ cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, 1.0,
+ k, N, p, 1, 0.0, y, 1);
+
+ cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, 1.0,
+ mat, N, y, 1, 0.0, v, 1);
+
+ alpha = rho / cblas_ddot(N, res0, 1, v, 1);
+
+ cblas_daxpy(N, -alpha, v, 1, res, 1);
+
+ cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, 1.0,
+ k, N, res, 1, 0.0, z, 1);
+ cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, 1.0,
+ mat, N, z, 1, 0.0, t, 1);
+
+ omega = cblas_ddot(N, t, 1, res, 1) / cblas_ddot(N, t, 1, t, 1);
+
+ cblas_daxpy(N, alpha, y, 1, ctx->x, 1);
+ cblas_daxpy(N, omega, z, 1, ctx->x, 1);
+
+ cblas_daxpy(N, -omega, t, 1, res, 1);
+
+ err = cblas_dnrm2(N, res, 1) / rhs_norm;
+ if (err < BICGSTAB_TOL)
+ break;
+
+ rho_prev = rho;
+ }
+ if (i == ctx->maxiter)
+ return -1;
+
+ memcpy(x, ctx->x, sizeof(*x) * ctx->N);
+
+ return i;
+}
+
+int mg2di_bicgstab_solve(BiCGStabContext *ctx, const double *mat, const double *rhs, double *x)
+{
+ int ret;
+
+ ret = solve_sw(ctx, mat, rhs, x);
+ if (ret < 0)
+ return ret;
+
+ return ret;
+}
+
+int mg2di_bicgstab_init(BiCGStabContext *ctx, const double *k, const double *x0)
+{
+ memcpy(ctx->x, x0, ctx->N * sizeof(*x0));
+ memcpy(ctx->k, k, ctx->N * ctx->N * sizeof(*k));
+
+ return 0;
+}
+
+int mg2di_bicgstab_context_alloc(BiCGStabContext **pctx, size_t N, unsigned int maxiter)
+{
+ BiCGStabContext *ctx;
+ int ret = 0;
+
+ ctx = calloc(1, sizeof(*ctx));
+ if (!ctx)
+ return -ENOMEM;
+
+ ctx->N = N;
+ ctx->maxiter = maxiter;
+
+ ret |= posix_memalign((void**)&ctx->x, 32, sizeof(double) * N);
+ ret |= posix_memalign((void**)&ctx->p, 32, sizeof(double) * N);
+ ret |= posix_memalign((void**)&ctx->v, 32, sizeof(double) * N);
+ ret |= posix_memalign((void**)&ctx->y, 32, sizeof(double) * N);
+ ret |= posix_memalign((void**)&ctx->z, 32, sizeof(double) * N);
+ ret |= posix_memalign((void**)&ctx->t, 32, sizeof(double) * N);
+ ret |= posix_memalign((void**)&ctx->res, 32, sizeof(double) * N);
+ ret |= posix_memalign((void**)&ctx->res0, 32, sizeof(double) * N);
+ ret |= posix_memalign((void**)&ctx->k, 32, sizeof(double) * N * N);
+
+fail:
+ if (ret) {
+ mg2di_bicgstab_context_free(&ctx);
+ return -ENOMEM;
+ }
+
+ *pctx = ctx;
+ return 0;
+}
+
+void mg2di_bicgstab_context_free(BiCGStabContext **pctx)
+{
+ BiCGStabContext *ctx = *pctx;
+
+ if (!ctx)
+ return;
+
+ free(ctx->x);
+ free(ctx->p);
+ free(ctx->v);
+ free(ctx->y);
+ free(ctx->z);
+ free(ctx->t);
+ free(ctx->res);
+ free(ctx->res0);
+ free(ctx->k);
+
+ free(ctx);
+ *pctx = NULL;
+}
diff --git a/bicgstab.h b/bicgstab.h
new file mode 100644
index 0000000..f25cb28
--- /dev/null
+++ b/bicgstab.h
@@ -0,0 +1,47 @@
+/*
+ * BiCGStab iterative linear system solver
+ * Copyright (C) 2016 Anton Khirnov <anton@khirnov.net>
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef MG2D_BICGSTAB_H
+#define MG2D_BICGSTAB_H
+
+typedef struct BiCGStabContext BiCGStabContext;
+
+/**
+ * Allocate and initialize the solver for the NxN system.
+ */
+int mg2di_bicgstab_context_alloc(BiCGStabContext **ctx, size_t N, unsigned int maxiter);
+
+/**
+ * Free the solver and all its internal state.
+ */
+void mg2di_bicgstab_context_free(BiCGStabContext **ctx);
+
+/**
+ * Initialise the solver with the given preconditioner matrix. This function
+ * may be any number of times on a given solver context.
+ */
+int mg2di_bicgstab_init(BiCGStabContext *ctx, const double *k, const double *x0);
+
+/**
+ * Solve the linear system
+ * mat ยท x = rhs
+ * The result is written into x.
+ */
+int mg2di_bicgstab_solve(BiCGStabContext *ctx, const double *mat, const double *rhs, double *x);
+
+#endif /* MG2D_BICGSTAB_H */
diff --git a/ell_grid_solve.c b/ell_grid_solve.c
index cbb7d68..550f305 100644
--- a/ell_grid_solve.c
+++ b/ell_grid_solve.c
@@ -27,6 +27,7 @@
#include <lapacke.h>
+#include "bicgstab.h"
#include "common.h"
#include "ell_grid_solve.h"
#include "log.h"
@@ -51,9 +52,13 @@ typedef struct EGSExactInternal {
size_t N_ghosts;
double *mat;
+ double *mat_f;
double *rhs;
+ double *x;
double *scratch_line;
int *ipiv;
+
+ BiCGStabContext *bicgstab;
} EGSExactInternal;
struct EGSInternal {
@@ -511,23 +516,48 @@ static int solve_exact(EGSContext *ctx)
}
}
+ for (int i = 0; i < e->N; i++)
+ for (int j = i + 1; j < e->N; j++) {
+ double tmp = e->mat[j * e->N + i];
+ e->mat[j * e->N + i] = e->mat[i * e->N + j];
+ e->mat[i * e->N + j] = tmp;
+ }
ec->time_mat_construct += gettime() - start;
ec->count_mat_construct++;
start = gettime();
- ret = LAPACKE_dgesv(LAPACK_ROW_MAJOR, e->N, 1, e->mat, e->N, e->ipiv, e->rhs, 1);
- if (ret != 0) {
- mg2di_log(&ctx->logger, MG2D_LOG_ERROR,
- "Error solving the linear system: %d\n", ret);
- return -EDOM;
+
+ ret = mg2di_bicgstab_solve(e->bicgstab, e->mat, e->rhs, e->x);
+ if (ret < 0) {
+ char equed = 'N';
+ double cond, ferr, berr, rpivot;
+
+ ret = LAPACKE_dgesvx(LAPACK_COL_MAJOR, 'N', 'N', e->N, 1,
+ e->mat, e->N, e->mat_f, e->N, e->ipiv, &equed, NULL, NULL,
+ e->rhs, e->N, e->x, e->N, &cond, &ferr, &berr, &rpivot);
+ if (ret == 0)
+ ret = LAPACKE_dgetri(LAPACK_COL_MAJOR, e->N, e->mat_f, e->N, e->ipiv);
+ if (ret != 0) {
+ mg2di_log(&ctx->logger, MG2D_LOG_ERROR,
+ "Error solving the linear system: %d\n", ret);
+ return -EDOM;
+ }
+ mg2di_log(&ctx->logger, MG2D_LOG_DEBUG,
+ "LU factorization solution to a %zdx%zd matrix: "
+ "condition number %16.16g; forward error %16.16g backward error %16.16g\n",
+ e->N, e->N, cond, ferr, berr);
+
+ ret = mg2di_bicgstab_init(e->bicgstab, e->mat_f, e->x);
+ if (ret < 0)
+ return ret;
}
+ for (size_t idx1 = 0; idx1 < ctx->domain_size[1]; idx1++)
+ memcpy(ctx->u + idx1 * ctx->u_stride, e->x + idx1 * ctx->domain_size[0], ctx->domain_size[0] * sizeof(*e->x));
+
ec->time_lin_solve += gettime() - start;
ec->count_lin_solve++;
- for (size_t idx1 = 0; idx1 < ctx->domain_size[1]; idx1++)
- memcpy(ctx->u + idx1 * ctx->u_stride, e->rhs + idx1 * ctx->domain_size[0], ctx->domain_size[0] * sizeof(*e->rhs));
-
boundaries_apply(ctx);
residual_calc(ctx);
@@ -692,10 +722,23 @@ static int arrays_alloc(EGSContext *ctx, const size_t domain_size[2])
e->scratch_line = calloc(e->N_ghosts, sizeof(*e->scratch_line));
e->mat = calloc(SQR(e->N), sizeof(*e->mat));
+ e->mat_f = calloc(SQR(e->N), sizeof(*e->mat_f));
e->rhs = calloc(e->N, sizeof(*e->rhs));
+ e->x = calloc(e->N, sizeof(*e->x));
e->ipiv = calloc(e->N, sizeof(*e->ipiv));
if (!e->scratch_line || !e->mat || !e->rhs || !e->ipiv)
return -ENOMEM;
+
+ ret = mg2di_bicgstab_context_alloc(&e->bicgstab, e->N, 64);
+ if (ret < 0)
+ return ret;
+
+ for (int i = 0; i < e->N; i++)
+ e->mat[i * e->N + i] = 1.0;
+
+ ret = mg2di_bicgstab_init(e->bicgstab, e->mat, e->x);
+ if (ret < 0)
+ return ret;
}
return 0;
@@ -766,8 +809,12 @@ void mg2di_egs_free(EGSContext **pctx)
free(e->scratch_line);
free(e->mat);
+ free(e->mat_f);
free(e->rhs);
+ free(e->x);
free(e->ipiv);
+
+ mg2di_bicgstab_context_free(&e->bicgstab);
}
free(ctx->priv->u_base);
diff --git a/meson.build b/meson.build
index 48d80ef..73dfee9 100644
--- a/meson.build
+++ b/meson.build
@@ -4,6 +4,7 @@ project('libmg2d', 'c',
add_project_arguments('-D_XOPEN_SOURCE=700', language : 'c')
lib_src = [
+ 'bicgstab.c',
'boundary.c',
'cpu.c',
'ell_grid_solve.c',
@@ -18,11 +19,12 @@ ver_flag = '-Wl,--version-script,@0@/@1@'.format(meson.current_source_dir(), ve
cc = meson.get_compiler('c')
libm = cc.find_library('m', required : false)
+libcblas = cc.find_library('blas')
liblapacke = cc.find_library('lapacke')
dep_tp = declare_dependency(link_args : '-lthreadpool')
-deps = [dep_tp, libm, liblapacke]
+deps = [dep_tp, libm, libcblas, liblapacke]
cdata = configuration_data()
cdata.set10('ARCH_X86', false)