From 457d8274fde5d20dc5694576b99c5541e55b4f40 Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Mon, 18 Mar 2019 10:33:46 +0100 Subject: ell_grid_solve: use BiCGSTAB to speed up exact solves --- bicgstab.c | 182 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 bicgstab.c (limited to 'bicgstab.c') diff --git a/bicgstab.c b/bicgstab.c new file mode 100644 index 0000000..66d3b0b --- /dev/null +++ b/bicgstab.c @@ -0,0 +1,182 @@ +/* + * BiCGStab iterative linear system solver + * Copyright (C) 2016 Anton Khirnov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include +#include +#include +#include + +#include "bicgstab.h" + +#define BICGSTAB_TOL (1e-15) + +struct BiCGStabContext { + size_t N; + unsigned int maxiter; + + double *x; + double *p, *v, *y, *z, *t; + double *res, *res0; + double *k; +}; + +// based on the wikipedia article +// and http://www.netlib.org/templates/matlab/bicgstab.m +static int solve_sw(BiCGStabContext *ctx, + const double *mat, const double *rhs, double *x) +{ + const int N = ctx->N; + const double rhs_norm = cblas_dnrm2(N, rhs, 1); + + double rho, rho_prev = 1.0; + double omega = 1.0; + double alpha = 1.0; + + double err; + int i; + + double *k = ctx->k; + double *p = ctx->p, *v = ctx->v, *y = ctx->y, *z = ctx->z, *t = ctx->t; + double *res = ctx->res, *res0 = ctx->res0; + + // initialize the residual + memcpy(res, rhs, N * sizeof(*res)); + cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, -1.0, + mat, N, ctx->x, 1, 1.0, res, 1); + + memcpy(res0, res, N * sizeof(*res0)); + memcpy(p, res, N * sizeof(*p)); + + for (i = 0; i < ctx->maxiter; i++) { + rho = cblas_ddot(N, res, 1, res0, 1); + + if (i) { + double beta = (rho / rho_prev) * (alpha / omega); + + cblas_daxpy(N, -omega, v, 1, p, 1); + cblas_dscal(N, beta, p, 1); + cblas_daxpy(N, 1, res, 1, p, 1); + } + + cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, 1.0, + k, N, p, 1, 0.0, y, 1); + + cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, 1.0, + mat, N, y, 1, 0.0, v, 1); + + alpha = rho / cblas_ddot(N, res0, 1, v, 1); + + cblas_daxpy(N, -alpha, v, 1, res, 1); + + cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, 1.0, + k, N, res, 1, 0.0, z, 1); + cblas_dgemv(CblasColMajor, CblasNoTrans, N, N, 1.0, + mat, N, z, 1, 0.0, t, 1); + + omega = cblas_ddot(N, t, 1, res, 1) / cblas_ddot(N, t, 1, t, 1); + + cblas_daxpy(N, alpha, y, 1, ctx->x, 1); + cblas_daxpy(N, omega, z, 1, ctx->x, 1); + + cblas_daxpy(N, -omega, t, 1, res, 1); + + err = cblas_dnrm2(N, res, 1) / rhs_norm; + if (err < BICGSTAB_TOL) + break; + + rho_prev = rho; + } + if (i == ctx->maxiter) + return -1; + + memcpy(x, ctx->x, sizeof(*x) * ctx->N); + + return i; +} + +int mg2di_bicgstab_solve(BiCGStabContext *ctx, const double *mat, const double *rhs, double *x) +{ + int ret; + + ret = solve_sw(ctx, mat, rhs, x); + if (ret < 0) + return ret; + + return ret; +} + +int mg2di_bicgstab_init(BiCGStabContext *ctx, const double *k, const double *x0) +{ + memcpy(ctx->x, x0, ctx->N * sizeof(*x0)); + memcpy(ctx->k, k, ctx->N * ctx->N * sizeof(*k)); + + return 0; +} + +int mg2di_bicgstab_context_alloc(BiCGStabContext **pctx, size_t N, unsigned int maxiter) +{ + BiCGStabContext *ctx; + int ret = 0; + + ctx = calloc(1, sizeof(*ctx)); + if (!ctx) + return -ENOMEM; + + ctx->N = N; + ctx->maxiter = maxiter; + + ret |= posix_memalign((void**)&ctx->x, 32, sizeof(double) * N); + ret |= posix_memalign((void**)&ctx->p, 32, sizeof(double) * N); + ret |= posix_memalign((void**)&ctx->v, 32, sizeof(double) * N); + ret |= posix_memalign((void**)&ctx->y, 32, sizeof(double) * N); + ret |= posix_memalign((void**)&ctx->z, 32, sizeof(double) * N); + ret |= posix_memalign((void**)&ctx->t, 32, sizeof(double) * N); + ret |= posix_memalign((void**)&ctx->res, 32, sizeof(double) * N); + ret |= posix_memalign((void**)&ctx->res0, 32, sizeof(double) * N); + ret |= posix_memalign((void**)&ctx->k, 32, sizeof(double) * N * N); + +fail: + if (ret) { + mg2di_bicgstab_context_free(&ctx); + return -ENOMEM; + } + + *pctx = ctx; + return 0; +} + +void mg2di_bicgstab_context_free(BiCGStabContext **pctx) +{ + BiCGStabContext *ctx = *pctx; + + if (!ctx) + return; + + free(ctx->x); + free(ctx->p); + free(ctx->v); + free(ctx->y); + free(ctx->z); + free(ctx->t); + free(ctx->res); + free(ctx->res0); + free(ctx->k); + + free(ctx); + *pctx = NULL; +} -- cgit v1.2.3