summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2018-12-28 10:11:47 +0100
committerAnton Khirnov <anton@khirnov.net>2018-12-28 10:11:47 +0100
commit271c35c5dbc234cc1cadb8ad8658ce085500afda (patch)
tree11f58377d52ac73aa5ede8ca6dc026bb8da99a08
parent47eb01846991ea22e023fdca5e9c0d0068e714ea (diff)
mg2d: factor out restriction/prolongation calls
Also generalize the check for full interpolation vs special-cased restrict/prolong functions.
-rw-r--r--mg2d.c80
1 files changed, 50 insertions, 30 deletions
diff --git a/mg2d.c b/mg2d.c
index e77c3d1..692d4c1 100644
--- a/mg2d.c
+++ b/mg2d.c
@@ -77,11 +77,6 @@ static double findmax(double *arr, size_t len)
return ret;
}
-static int is_power2(int n)
-{
- return !(n & (n - 1));
-}
-
static void log_callback(MG2DLogger *log, int level, const char *fmt, va_list vl)
{
MG2DContext *ctx = log->opaque;
@@ -228,6 +223,34 @@ static void coarse_correct_task(void *arg, unsigned int job_idx, unsigned int th
}
}
+static void restrict_residual(MG2DContext *ctx, MG2DLevel *dst, MG2DLevel *src)
+{
+ EllRelaxContext *s_src = src->solver;
+ EllRelaxContext *s_dst = dst->solver;
+ if (s_src->domain_size[0] == 2 * (s_dst->domain_size[0] - 1) + 1) {
+ mg_restrict_fw(s_dst, s_dst->rhs, s_dst->rhs_stride,
+ s_src, s_src->residual, s_src->residual_stride);
+ } else {
+ mg_interp_bilinear(ctx->priv->tp,
+ s_dst, s_dst->rhs, s_dst->rhs_stride,
+ s_src, s_src->residual, s_src->residual_stride);
+ }
+}
+
+static void prolong_solution(MG2DContext *ctx, MG2DLevel *dst, MG2DLevel *src)
+{
+ EllRelaxContext *s_src = src->solver;
+ EllRelaxContext *s_dst = dst->solver;
+ if (s_dst->domain_size[0] == 2 * (s_src->domain_size[0] - 1) + 1) {
+ mg_prolongate(s_dst, dst->prolong_tmp, dst->prolong_tmp_stride,
+ s_src, s_src->u, s_src->u_stride);
+ } else {
+ mg_interp_bilinear(ctx->priv->tp,
+ s_dst, dst->prolong_tmp, dst->prolong_tmp_stride,
+ s_src, s_src->u, s_src->u_stride);
+ }
+}
+
static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level)
{
int ret;
@@ -274,13 +297,7 @@ static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level)
/* restrict the residual as to the coarser-level rhs */
start = gettime();
- if (!is_power2(level->solver->domain_size[0] - 1)) {
- mg_interp_bilinear(ctx->priv->tp, level->child->solver, level->child->solver->rhs, level->child->solver->rhs_stride,
- level->solver, level->solver->residual, level->solver->residual_stride);
- } else {
- mg_restrict_fw(level->child->solver, level->child->solver->rhs, level->child->solver->rhs_stride,
- level->solver, level->solver->residual, level->solver->residual_stride);
- }
+ restrict_residual(ctx, level->child, level);
level->time_restrict += gettime() - start;
/* solve on the coarser level */
@@ -290,13 +307,7 @@ static int mg_solve_subgrid(MG2DContext *ctx, MG2DLevel *level)
/* prolongate the coarser-level correction */
start = gettime();
- if (!is_power2(level->solver->domain_size[0] - 1)) {
- mg_interp_bilinear(ctx->priv->tp, level->solver, level->prolong_tmp, level->prolong_tmp_stride,
- level->child->solver, level->child->solver->u, level->child->solver->u_stride);
- } else {
- mg_prolongate(level->solver, level->prolong_tmp, level->prolong_tmp_stride,
- level->child->solver, level->child->solver->u, level->child->solver->u_stride);
- }
+ prolong_solution(ctx, level, level->child);
level->time_prolong += gettime() - start;
/* apply the correction */
@@ -346,6 +357,24 @@ static void bnd_copy(EllRelaxBoundary *bdst, const double *src, ptrdiff_t src_st
}
}
+static void restrict_diff_coeffs(MG2DContext *ctx, MG2DLevel *dst, MG2DLevel *src)
+{
+ EllRelaxContext *s_src = src->solver;
+ EllRelaxContext *s_dst = dst->solver;
+ if (s_src->domain_size[0] == 2 * (s_dst->domain_size[0] - 1) + 1) {
+ for (int i = 0; i < ARRAY_ELEMS(s_src->diff_coeffs); i++) {
+ mg_restrict_inject(s_dst, s_dst->diff_coeffs[i], s_dst->diff_coeffs_stride,
+ s_src, s_src->diff_coeffs[i], s_src->diff_coeffs_stride);
+ }
+ } else {
+ for (int i = 0; i < ARRAY_ELEMS(s_src->diff_coeffs); i++) {
+ mg_interp_bilinear(ctx->priv->tp,
+ s_dst, s_dst->diff_coeffs[i], s_dst->diff_coeffs_stride,
+ s_src, s_src->diff_coeffs[i], s_src->diff_coeffs_stride);
+ }
+ }
+}
+
static int mg_levels_init(MG2DContext *ctx)
{
MG2DInternal *priv = ctx->priv;
@@ -363,17 +392,8 @@ static int mg_levels_init(MG2DContext *ctx)
}
/* Set the equation coefficients. */
- if (prev) {
- for (int i = 0; i < ARRAY_ELEMS(prev->solver->diff_coeffs); i++) {
- if (!is_power2(prev->solver->domain_size[0] - 1)) {
- mg_interp_bilinear(priv->tp, cur->solver, cur->solver->diff_coeffs[i], cur->solver->diff_coeffs_stride,
- prev->solver, prev->solver->diff_coeffs[i], prev->solver->diff_coeffs_stride);
- } else {
- mg_restrict_inject(cur->solver, cur->solver->diff_coeffs[i], cur->solver->diff_coeffs_stride,
- prev->solver, prev->solver->diff_coeffs[i], prev->solver->diff_coeffs_stride);
- }
- }
- }
+ if (prev)
+ restrict_diff_coeffs(ctx, cur, prev);
for (int i = 0; i < ARRAY_ELEMS(cur->solver->boundaries); i++) {
MG2DBoundary *bsrc = ctx->boundaries[i];