From 457d8274fde5d20dc5694576b99c5541e55b4f40 Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Mon, 18 Mar 2019 10:33:46 +0100 Subject: ell_grid_solve: use BiCGSTAB to speed up exact solves --- ell_grid_solve.c | 63 +++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 55 insertions(+), 8 deletions(-) (limited to 'ell_grid_solve.c') diff --git a/ell_grid_solve.c b/ell_grid_solve.c index cbb7d68..550f305 100644 --- a/ell_grid_solve.c +++ b/ell_grid_solve.c @@ -27,6 +27,7 @@ #include +#include "bicgstab.h" #include "common.h" #include "ell_grid_solve.h" #include "log.h" @@ -51,9 +52,13 @@ typedef struct EGSExactInternal { size_t N_ghosts; double *mat; + double *mat_f; double *rhs; + double *x; double *scratch_line; int *ipiv; + + BiCGStabContext *bicgstab; } EGSExactInternal; struct EGSInternal { @@ -511,23 +516,48 @@ static int solve_exact(EGSContext *ctx) } } + for (int i = 0; i < e->N; i++) + for (int j = i + 1; j < e->N; j++) { + double tmp = e->mat[j * e->N + i]; + e->mat[j * e->N + i] = e->mat[i * e->N + j]; + e->mat[i * e->N + j] = tmp; + } ec->time_mat_construct += gettime() - start; ec->count_mat_construct++; start = gettime(); - ret = LAPACKE_dgesv(LAPACK_ROW_MAJOR, e->N, 1, e->mat, e->N, e->ipiv, e->rhs, 1); - if (ret != 0) { - mg2di_log(&ctx->logger, MG2D_LOG_ERROR, - "Error solving the linear system: %d\n", ret); - return -EDOM; + + ret = mg2di_bicgstab_solve(e->bicgstab, e->mat, e->rhs, e->x); + if (ret < 0) { + char equed = 'N'; + double cond, ferr, berr, rpivot; + + ret = LAPACKE_dgesvx(LAPACK_COL_MAJOR, 'N', 'N', e->N, 1, + e->mat, e->N, e->mat_f, e->N, e->ipiv, &equed, NULL, NULL, + e->rhs, e->N, e->x, e->N, &cond, &ferr, &berr, &rpivot); + if (ret == 0) + ret = LAPACKE_dgetri(LAPACK_COL_MAJOR, e->N, e->mat_f, e->N, e->ipiv); + if (ret != 0) { + mg2di_log(&ctx->logger, MG2D_LOG_ERROR, + "Error solving the linear system: %d\n", ret); + return -EDOM; + } + mg2di_log(&ctx->logger, MG2D_LOG_DEBUG, + "LU factorization solution to a %zdx%zd matrix: " + "condition number %16.16g; forward error %16.16g backward error %16.16g\n", + e->N, e->N, cond, ferr, berr); + + ret = mg2di_bicgstab_init(e->bicgstab, e->mat_f, e->x); + if (ret < 0) + return ret; } + for (size_t idx1 = 0; idx1 < ctx->domain_size[1]; idx1++) + memcpy(ctx->u + idx1 * ctx->u_stride, e->x + idx1 * ctx->domain_size[0], ctx->domain_size[0] * sizeof(*e->x)); + ec->time_lin_solve += gettime() - start; ec->count_lin_solve++; - for (size_t idx1 = 0; idx1 < ctx->domain_size[1]; idx1++) - memcpy(ctx->u + idx1 * ctx->u_stride, e->rhs + idx1 * ctx->domain_size[0], ctx->domain_size[0] * sizeof(*e->rhs)); - boundaries_apply(ctx); residual_calc(ctx); @@ -692,10 +722,23 @@ static int arrays_alloc(EGSContext *ctx, const size_t domain_size[2]) e->scratch_line = calloc(e->N_ghosts, sizeof(*e->scratch_line)); e->mat = calloc(SQR(e->N), sizeof(*e->mat)); + e->mat_f = calloc(SQR(e->N), sizeof(*e->mat_f)); e->rhs = calloc(e->N, sizeof(*e->rhs)); + e->x = calloc(e->N, sizeof(*e->x)); e->ipiv = calloc(e->N, sizeof(*e->ipiv)); if (!e->scratch_line || !e->mat || !e->rhs || !e->ipiv) return -ENOMEM; + + ret = mg2di_bicgstab_context_alloc(&e->bicgstab, e->N, 64); + if (ret < 0) + return ret; + + for (int i = 0; i < e->N; i++) + e->mat[i * e->N + i] = 1.0; + + ret = mg2di_bicgstab_init(e->bicgstab, e->mat, e->x); + if (ret < 0) + return ret; } return 0; @@ -766,8 +809,12 @@ void mg2di_egs_free(EGSContext **pctx) free(e->scratch_line); free(e->mat); + free(e->mat_f); free(e->rhs); + free(e->x); free(e->ipiv); + + mg2di_bicgstab_context_free(&e->bicgstab); } free(ctx->priv->u_base); -- cgit v1.2.3