aboutsummaryrefslogtreecommitdiff
path: root/bicgstab.c
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2019-03-18 10:33:46 +0100
committerAnton Khirnov <anton@khirnov.net>2019-03-18 10:33:46 +0100
commit457d8274fde5d20dc5694576b99c5541e55b4f40 (patch)
treeca1df761a0340bcc1bd5324895d8067d3b8528e6 /bicgstab.c
parent53e7613a1111702bb62708d1e5aff8b18fa9c9cb (diff)
ell_grid_solve: use BiCGSTAB to speed up exact solves
Diffstat (limited to 'bicgstab.c')
-rw-r--r--bicgstab.c182
1 files changed, 182 insertions, 0 deletions
diff --git a/bicgstab.c b/bicgstab.c
new file mode 100644
index 0000000..66d3b0b
--- /dev/null
+++ b/bicgstab.c
@@ -0,0 +1,182 @@
+/*
+ * BiCGStab iterative linear system solver
+ * Copyright (C) 2016 Anton Khirnov <anton@khirnov.net>
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include <cblas.h>
+#include <errno.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include "bicgstab.h"
+
+#define BICGSTAB_TOL (1e-15)
+
+struct BiCGStabContext {
+ size_t N;
+ unsigned int maxiter;
+
+ double *x;
+ double *p, *v, *y, *z, *t;
+ double *res, *res0;
+ double *k;
+};
+
+// based on the wikipedia article
+// and http://www.netlib.org/templates/matlab/bicgstab.m
+static int solve_sw(BiCGStabContext *ctx,
+ const double *mat, const double *rhs, double *x)
+{
+ const int N = ctx->N;
+ const double rhs_norm = cblas_dnrm2(N, rhs, 1);
+
+ double rho, rho_prev = 1.0;
+ double omega = 1.0;
+ double alpha = 1.0;
+
+ double err;
+ int i;
+
+ double *k = ctx->k;
+ double *p = ctx->p, *v = ctx->v, *y = ctx->y, *z = ctx->z, *t = ctx->t;
+ double *res = ctx->res, *res0 = ctx->res0;
+
+ // initialize the residual
+ memcpy(res, rhs, N * sizeof(*res));
+ cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, -1.0,
+ mat, N, ctx->x, 1, 1.0, res, 1);
+
+ memcpy(res0, res, N * sizeof(*res0));
+ memcpy(p, res, N * sizeof(*p));
+
+ for (i = 0; i < ctx->maxiter; i++) {
+ rho = cblas_ddot(N, res, 1, res0, 1);
+
+ if (i) {
+ double beta = (rho / rho_prev) * (alpha / omega);
+
+ cblas_daxpy(N, -omega, v, 1, p, 1);
+ cblas_dscal(N, beta, p, 1);
+ cblas_daxpy(N, 1, res, 1, p, 1);
+ }
+
+ cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, 1.0,
+ k, N, p, 1, 0.0, y, 1);
+
+ cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, 1.0,
+ mat, N, y, 1, 0.0, v, 1);
+
+ alpha = rho / cblas_ddot(N, res0, 1, v, 1);
+
+ cblas_daxpy(N, -alpha, v, 1, res, 1);
+
+ cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, 1.0,
+ k, N, res, 1, 0.0, z, 1);
+ cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, 1.0,
+ mat, N, z, 1, 0.0, t, 1);
+
+ omega = cblas_ddot(N, t, 1, res, 1) / cblas_ddot(N, t, 1, t, 1);
+
+ cblas_daxpy(N, alpha, y, 1, ctx->x, 1);
+ cblas_daxpy(N, omega, z, 1, ctx->x, 1);
+
+ cblas_daxpy(N, -omega, t, 1, res, 1);
+
+ err = cblas_dnrm2(N, res, 1) / rhs_norm;
+ if (err < BICGSTAB_TOL)
+ break;
+
+ rho_prev = rho;
+ }
+ if (i == ctx->maxiter)
+ return -1;
+
+ memcpy(x, ctx->x, sizeof(*x) * ctx->N);
+
+ return i;
+}
+
+int mg2di_bicgstab_solve(BiCGStabContext *ctx, const double *mat, const double *rhs, double *x)
+{
+ int ret;
+
+ ret = solve_sw(ctx, mat, rhs, x);
+ if (ret < 0)
+ return ret;
+
+ return ret;
+}
+
+int mg2di_bicgstab_init(BiCGStabContext *ctx, const double *k, const double *x0)
+{
+ memcpy(ctx->x, x0, ctx->N * sizeof(*x0));
+ memcpy(ctx->k, k, ctx->N * ctx->N * sizeof(*k));
+
+ return 0;
+}
+
+int mg2di_bicgstab_context_alloc(BiCGStabContext **pctx, size_t N, unsigned int maxiter)
+{
+ BiCGStabContext *ctx;
+ int ret = 0;
+
+ ctx = calloc(1, sizeof(*ctx));
+ if (!ctx)
+ return -ENOMEM;
+
+ ctx->N = N;
+ ctx->maxiter = maxiter;
+
+ ret |= posix_memalign((void**)&ctx->x, 32, sizeof(double) * N);
+ ret |= posix_memalign((void**)&ctx->p, 32, sizeof(double) * N);
+ ret |= posix_memalign((void**)&ctx->v, 32, sizeof(double) * N);
+ ret |= posix_memalign((void**)&ctx->y, 32, sizeof(double) * N);
+ ret |= posix_memalign((void**)&ctx->z, 32, sizeof(double) * N);
+ ret |= posix_memalign((void**)&ctx->t, 32, sizeof(double) * N);
+ ret |= posix_memalign((void**)&ctx->res, 32, sizeof(double) * N);
+ ret |= posix_memalign((void**)&ctx->res0, 32, sizeof(double) * N);
+ ret |= posix_memalign((void**)&ctx->k, 32, sizeof(double) * N * N);
+
+fail:
+ if (ret) {
+ mg2di_bicgstab_context_free(&ctx);
+ return -ENOMEM;
+ }
+
+ *pctx = ctx;
+ return 0;
+}
+
+void mg2di_bicgstab_context_free(BiCGStabContext **pctx)
+{
+ BiCGStabContext *ctx = *pctx;
+
+ if (!ctx)
+ return;
+
+ free(ctx->x);
+ free(ctx->p);
+ free(ctx->v);
+ free(ctx->y);
+ free(ctx->z);
+ free(ctx->t);
+ free(ctx->res);
+ free(ctx->res0);
+ free(ctx->k);
+
+ free(ctx);
+ *pctx = NULL;
+}