From 7dcabe1f8d406b62df4141883e503cabb39b8d45 Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Wed, 30 Jan 2019 12:51:18 +0100 Subject: mg2d: print stats properly for exact solves --- ell_grid_solve.c | 33 ++++++++++++++++++++++++--- ell_grid_solve.h | 10 ++++++++- mg2d.c | 68 +++++++++++++++++++++++++++++++++++++++++--------------- 3 files changed, 89 insertions(+), 22 deletions(-) diff --git a/ell_grid_solve.c b/ell_grid_solve.c index 1893214..13d1854 100644 --- a/ell_grid_solve.c +++ b/ell_grid_solve.c @@ -387,9 +387,13 @@ static void fill_mat_s2(double *mat_row, double **diff_coeffs, double *fd_factor static int solve_exact(EGSContext *ctx) { EGSInternal *priv = ctx->priv; + EGSExactContext *ec = ctx->solver_data; EGSExactInternal *e = &priv->e; + int64_t start; int ret; + start = gettime(); + memset(e->mat, 0, SQR(e->N) * sizeof(*e->mat)); for (ptrdiff_t idx1 = 0; idx1 < ctx->domain_size[1]; idx1++) { @@ -532,6 +536,10 @@ static int solve_exact(EGSContext *ctx) } } + ec->time_mat_construct += gettime() - start; + ec->count_mat_construct++; + + start = gettime(); ret = LAPACKE_dgesv(LAPACK_ROW_MAJOR, e->N, 1, e->mat, e->N, e->ipiv, e->rhs, 1); if (ret != 0) { mg2di_log(&ctx->logger, MG2D_LOG_ERROR, @@ -539,6 +547,9 @@ static int solve_exact(EGSContext *ctx) return -EDOM; } + ec->time_lin_solve += gettime() - start; + ec->count_lin_solve++; + for (size_t idx1 = 0; idx1 < ctx->domain_size[1]; idx1++) memcpy(ctx->u + idx1 * ctx->u_stride, e->rhs + idx1 * ctx->domain_size[0], ctx->domain_size[0] * sizeof(*e->rhs)); @@ -550,19 +561,30 @@ static int solve_exact(EGSContext *ctx) int mg2di_egs_solve(EGSContext *ctx) { + int64_t start; + int ret; + + start = gettime(); + switch (ctx->solver_type) { - case EGS_SOLVER_RELAXATION: return solve_relax_step(ctx); - case EGS_SOLVER_EXACT: return solve_exact(ctx); + case EGS_SOLVER_RELAXATION: ret = solve_relax_step(ctx); break; + case EGS_SOLVER_EXACT: ret = solve_exact(ctx); break; + default: ret = -EINVAL; } - return -EINVAL; + ctx->time_total += gettime() - start; + + return ret; } int mg2di_egs_init(EGSContext *ctx) { EGSInternal *priv = ctx->priv; + int64_t start; int ret; + start = gettime(); + if (ctx->solver_type == EGS_SOLVER_EXACT) { switch (ctx->fd_stencil) { case 1: priv->e.fill_mat = fill_mat_s1; break; @@ -637,6 +659,8 @@ int mg2di_egs_init(EGSContext *ctx) boundaries_apply(ctx); residual_calc(ctx); + ctx->time_total += gettime() - start; + return 0; } @@ -727,6 +751,9 @@ EGSContext *mg2di_egs_alloc(enum EGSType type, size_t domain_size[2]) goto fail; break; case EGS_SOLVER_EXACT: + ctx->solver_data = calloc(1, sizeof(EGSExactContext)); + if (!ctx->solver_data) + goto fail; break; default: goto fail; } diff --git a/ell_grid_solve.h b/ell_grid_solve.h index 4ee6cb4..5455c13 100644 --- a/ell_grid_solve.h +++ b/ell_grid_solve.h @@ -55,7 +55,7 @@ enum EGSType { /** * Solve the equation exactly by contructing a linear system and solving it with LAPACK. * - * solver_data is NULL + * solver_data is EGSExactContext * mg2di_egs_solve() solves the discretized system exactly (up to roundoff error) */ EGS_SOLVER_EXACT, @@ -87,6 +87,13 @@ typedef struct EGSRelaxContext { int64_t time_correct; } EGSRelaxContext; +typedef struct EGSExactContext { + int64_t count_mat_construct; + int64_t time_mat_construct; + int64_t count_lin_solve; + int64_t time_lin_solve; +} EGSExactContext; + typedef struct EGSContext { enum EGSType solver_type; @@ -202,6 +209,7 @@ typedef struct EGSContext { int64_t count_boundaries; int64_t time_res_calc; int64_t count_res; + int64_t time_total; } EGSContext; /** diff --git a/mg2d.c b/mg2d.c index cbb2797..453a9c4 100644 --- a/mg2d.c +++ b/mg2d.c @@ -764,37 +764,69 @@ void mg2d_print_stats(MG2DContext *ctx, const char *prefix) prefix, priv->count_solve, priv->time_solve / 1e6, priv->time_solve / 1e3 / priv->count_solve); while (level) { + char buf[1024], *p; + int ret; + EGSRelaxContext *r = NULL; + EGSExactContext *e = NULL; + int64_t level_total = level->time_relax + level->time_prolong + level->time_restrict + level->time_correct + level->time_reinit; - int64_t relax_total = (r ? r->time_correct : 0) + level->solver->time_res_calc + level->solver->time_boundaries; if (level->solver->solver_type == EGS_SOLVER_RELAXATION) r = level->solver->solver_data; + else if (level->solver->solver_type == EGS_SOLVER_EXACT) + e = level->solver->solver_data; level_total = level->time_relax + level->time_prolong + level->time_restrict + level->time_correct + level->time_reinit; - relax_total = (r ? r->time_correct : 0) + level->solver->time_res_calc + level->solver->time_boundaries; levels_total += level_total; - mg2di_log(&priv->logger, MG2D_LOG_VERBOSE, - "%s%2.2f%% level %d: %ld cycles %g s total time %g ms avg per call || " - "%2.2f%% relax %2.2f%% prolong %2.2f%% restrict %2.2f%% correct %2.2f%% reinit || " - "%2.2f%% residual %2.2f%% correct %2.2f%% boundaries ||" - "\n", - prefix, level_total * 100.0 / priv->time_solve, level->depth, level->count_cycles, - level_total / 1e6, level_total / 1e3 / level->count_cycles, - level->time_relax * 100.0 / level_total, - level->time_prolong * 100.0 / level_total, - level->time_restrict * 100.0 / level_total, - level->time_correct * 100.0 / level_total, - level->time_reinit * 100.0 / level_total, - level->solver->time_res_calc * 100.0 / relax_total, - (r ? r->time_correct : 0) * 100.0 / relax_total, - level->solver->time_boundaries * 100.0 / relax_total - ); + p = buf; + + ret = snprintf(p, sizeof(buf) - (p - buf), + "%2.2f%% level %d: %ld cycles %g s total time %g ms avg per call", + level_total * 100.0 / priv->time_solve, level->depth, level->count_cycles, + level_total / 1e6, level_total / 1e3 / level->count_cycles); + if (ret > 0) + p += ret; + + if (level->child) { + ret = snprintf(p, sizeof(buf) - (p - buf), + "||%2.2f%% relax %2.2f%% prolong %2.2f%% restrict %2.2f%% correct %2.2f%% reinit", + level->time_relax * 100.0 / level_total, + level->time_prolong * 100.0 / level_total, + level->time_restrict * 100.0 / level_total, + level->time_correct * 100.0 / level_total, + level->time_reinit * 100.0 / level_total); + if (ret > 0) + p += ret; + } + + ret = snprintf(p, sizeof(buf) - (p - buf), + "||%2.2f%% residual %2.2f%% boundaries", + level->solver->time_res_calc * 100.0 / level->solver->time_total, + level->solver->time_boundaries * 100.0 / level->solver->time_total); + if (ret > 0) + p += ret; + + if (r) { + ret = snprintf(p, sizeof(buf) - (p - buf), + " %2.2f%% correct", + r->time_correct * 100.0 / level->solver->time_total); + if (ret > 0) + p += ret; + } else if (e) { + ret = snprintf(p, sizeof(buf) - (p - buf), + " %2.2f%% matrix construct %2.2f%% linear solve", + e->time_mat_construct * 100.0 / level->solver->time_total, + e->time_lin_solve * 100.0 / level->solver->time_total); + if (ret > 0) + p += ret; + } + mg2di_log(&priv->logger, MG2D_LOG_VERBOSE, "%s%s\n", prefix, buf); level = level->child; } -- cgit v1.2.3