aboutsummaryrefslogtreecommitdiff
path: root/ell_grid_solve.c
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2019-03-18 10:33:46 +0100
committerAnton Khirnov <anton@khirnov.net>2019-03-18 10:33:46 +0100
commit457d8274fde5d20dc5694576b99c5541e55b4f40 (patch)
treeca1df761a0340bcc1bd5324895d8067d3b8528e6 /ell_grid_solve.c
parent53e7613a1111702bb62708d1e5aff8b18fa9c9cb (diff)
ell_grid_solve: use BiCGSTAB to speed up exact solves
Diffstat (limited to 'ell_grid_solve.c')
-rw-r--r--ell_grid_solve.c63
1 files changed, 55 insertions, 8 deletions
diff --git a/ell_grid_solve.c b/ell_grid_solve.c
index cbb7d68..550f305 100644
--- a/ell_grid_solve.c
+++ b/ell_grid_solve.c
@@ -27,6 +27,7 @@
#include <lapacke.h>
+#include "bicgstab.h"
#include "common.h"
#include "ell_grid_solve.h"
#include "log.h"
@@ -51,9 +52,13 @@ typedef struct EGSExactInternal {
size_t N_ghosts;
double *mat;
+ double *mat_f;
double *rhs;
+ double *x;
double *scratch_line;
int *ipiv;
+
+ BiCGStabContext *bicgstab;
} EGSExactInternal;
struct EGSInternal {
@@ -511,23 +516,48 @@ static int solve_exact(EGSContext *ctx)
}
}
+ for (int i = 0; i < e->N; i++)
+ for (int j = i + 1; j < e->N; j++) {
+ double tmp = e->mat[j * e->N + i];
+ e->mat[j * e->N + i] = e->mat[i * e->N + j];
+ e->mat[i * e->N + j] = tmp;
+ }
ec->time_mat_construct += gettime() - start;
ec->count_mat_construct++;
start = gettime();
- ret = LAPACKE_dgesv(LAPACK_ROW_MAJOR, e->N, 1, e->mat, e->N, e->ipiv, e->rhs, 1);
- if (ret != 0) {
- mg2di_log(&ctx->logger, MG2D_LOG_ERROR,
- "Error solving the linear system: %d\n", ret);
- return -EDOM;
+
+ ret = mg2di_bicgstab_solve(e->bicgstab, e->mat, e->rhs, e->x);
+ if (ret < 0) {
+ char equed = 'N';
+ double cond, ferr, berr, rpivot;
+
+ ret = LAPACKE_dgesvx(LAPACK_COL_MAJOR, 'N', 'N', e->N, 1,
+ e->mat, e->N, e->mat_f, e->N, e->ipiv, &equed, NULL, NULL,
+ e->rhs, e->N, e->x, e->N, &cond, &ferr, &berr, &rpivot);
+ if (ret == 0)
+ ret = LAPACKE_dgetri(LAPACK_COL_MAJOR, e->N, e->mat_f, e->N, e->ipiv);
+ if (ret != 0) {
+ mg2di_log(&ctx->logger, MG2D_LOG_ERROR,
+ "Error solving the linear system: %d\n", ret);
+ return -EDOM;
+ }
+ mg2di_log(&ctx->logger, MG2D_LOG_DEBUG,
+ "LU factorization solution to a %zdx%zd matrix: "
+ "condition number %16.16g; forward error %16.16g backward error %16.16g\n",
+ e->N, e->N, cond, ferr, berr);
+
+ ret = mg2di_bicgstab_init(e->bicgstab, e->mat_f, e->x);
+ if (ret < 0)
+ return ret;
}
+ for (size_t idx1 = 0; idx1 < ctx->domain_size[1]; idx1++)
+ memcpy(ctx->u + idx1 * ctx->u_stride, e->x + idx1 * ctx->domain_size[0], ctx->domain_size[0] * sizeof(*e->x));
+
ec->time_lin_solve += gettime() - start;
ec->count_lin_solve++;
- for (size_t idx1 = 0; idx1 < ctx->domain_size[1]; idx1++)
- memcpy(ctx->u + idx1 * ctx->u_stride, e->rhs + idx1 * ctx->domain_size[0], ctx->domain_size[0] * sizeof(*e->rhs));
-
boundaries_apply(ctx);
residual_calc(ctx);
@@ -692,10 +722,23 @@ static int arrays_alloc(EGSContext *ctx, const size_t domain_size[2])
e->scratch_line = calloc(e->N_ghosts, sizeof(*e->scratch_line));
e->mat = calloc(SQR(e->N), sizeof(*e->mat));
+ e->mat_f = calloc(SQR(e->N), sizeof(*e->mat_f));
e->rhs = calloc(e->N, sizeof(*e->rhs));
+ e->x = calloc(e->N, sizeof(*e->x));
e->ipiv = calloc(e->N, sizeof(*e->ipiv));
if (!e->scratch_line || !e->mat || !e->rhs || !e->ipiv)
return -ENOMEM;
+
+ ret = mg2di_bicgstab_context_alloc(&e->bicgstab, e->N, 64);
+ if (ret < 0)
+ return ret;
+
+ for (int i = 0; i < e->N; i++)
+ e->mat[i * e->N + i] = 1.0;
+
+ ret = mg2di_bicgstab_init(e->bicgstab, e->mat, e->x);
+ if (ret < 0)
+ return ret;
}
return 0;
@@ -766,8 +809,12 @@ void mg2di_egs_free(EGSContext **pctx)
free(e->scratch_line);
free(e->mat);
+ free(e->mat_f);
free(e->rhs);
+ free(e->x);
free(e->ipiv);
+
+ mg2di_bicgstab_context_free(&e->bicgstab);
}
free(ctx->priv->u_base);