aboutsummaryrefslogtreecommitdiff
path: root/ell_grid_solve.c
diff options
context:
space:
mode:
Diffstat (limited to 'ell_grid_solve.c')
-rw-r--r--ell_grid_solve.c253
1 files changed, 224 insertions, 29 deletions
diff --git a/ell_grid_solve.c b/ell_grid_solve.c
index 7dcd41e..8c37799 100644
--- a/ell_grid_solve.c
+++ b/ell_grid_solve.c
@@ -26,9 +26,11 @@
#include <threadpool.h>
#include <lapacke.h>
+#include <mpi.h>
#include "bicgstab.h"
#include "common.h"
+#include "components.h"
#include "cpu.h"
#include "ell_grid_solve.h"
#include "log.h"
@@ -76,7 +78,10 @@ struct EGSInternal {
NDArray *rhs_base;
NDArray *residual_base;
+ NDArray *u_base;
+
NDArray *u_next_base;
+ NDArray *u_next_exterior;
NDArray *u_next;
int u_next_valid;
@@ -98,6 +103,17 @@ struct EGSInternal {
EGSExactInternal e;
TPContext *tp_internal;
+
+ MPI_Comm comm;
+ DomainGeometry *dg;
+ unsigned int local_component;
+
+ int *sync_sendcounts;
+ int *sync_senddispl;
+ MPI_Datatype *sync_sendtypes;
+ int *sync_recvcounts;
+ int *sync_recvdispl;
+ MPI_Datatype *sync_recvtypes;
};
static const double fd_denoms[][MG2D_DIFF_COEFF_NB] = {
@@ -119,6 +135,19 @@ static const double fd_denoms[][MG2D_DIFF_COEFF_NB] = {
},
};
+static void boundaries_sync(EGSContext *ctx, NDArray *a_dst)
+{
+ EGSInternal *priv = ctx->priv;
+
+ mg2di_timer_start(&ctx->timer_mpi_sync);
+
+ MPI_Alltoallw(a_dst->data, priv->sync_sendcounts, priv->sync_senddispl, priv->sync_sendtypes,
+ a_dst->data, priv->sync_recvcounts, priv->sync_recvdispl, priv->sync_recvtypes,
+ priv->comm);
+
+ mg2di_timer_stop(&ctx->timer_mpi_sync);
+}
+
static void residual_calc(EGSContext *ctx, int export_res)
{
EGSInternal *priv = ctx->priv;
@@ -230,7 +259,8 @@ static void boundaries_apply(EGSContext *ctx, NDArray *a_dst, int init)
double *dst = a_dst->data + mg2d_bnd_is_upper(i) * ((size_offset - 1) * stride_offset);
const ptrdiff_t dst_strides[] = { stride_boundary, mg2d_bnd_out_dir(i) * stride_offset };
- if (bnd->type != bnd_type_order[order_idx])
+ if (bnd->type != bnd_type_order[order_idx] ||
+ !priv->dg->components[priv->local_component].bnd_is_outer[i])
continue;
switch (bnd->type) {
@@ -291,6 +321,10 @@ static void boundaries_apply(EGSContext *ctx, NDArray *a_dst, int init)
}
}
mg2di_timer_stop(&ctx->timer_bnd_corners);
+
+ if (priv->dg->nb_components > 1)
+ boundaries_sync(ctx, a_dst);
+
mg2di_timer_stop(&ctx->timer_bnd);
}
@@ -331,12 +365,15 @@ static int solve_relax_step(EGSContext *ctx, int export_res)
int u_next_valid = priv->u_next_valid;
if (u_next_valid) {
- NDArray *tmp = ctx->u;
- NDArray *tmp_base = ctx->u_base;
- ctx->u = priv->u_next;
- ctx->u_base = priv->u_next_base;
- priv->u_next = tmp;
- priv->u_next_base = tmp_base;
+ NDArray *tmp = ctx->u;
+ NDArray *tmp_ext = ctx->u_exterior;
+ NDArray *tmp_base = priv->u_base;
+ ctx->u = priv->u_next;
+ ctx->u_exterior = priv->u_next_exterior;
+ priv->u_base = priv->u_next_base;
+ priv->u_next = tmp;
+ priv->u_next_exterior = tmp_ext;
+ priv->u_next_base = tmp_base;
priv->u_next_valid = 0;
}
@@ -352,6 +389,8 @@ static int solve_relax_step(EGSContext *ctx, int export_res)
}
residual_calc(ctx, 1);
+ if (priv->dg->nb_components > 1)
+ boundaries_sync(ctx, ctx->residual);
} else {
mg2di_assert(u_next_valid);
residual_calc(ctx, 0);
@@ -799,6 +838,7 @@ int mg2di_egs_init(EGSContext *ctx, int flags)
{
EGSInternal *priv = ctx->priv;
EGSRelaxContext *r = ctx->relax;
+ DomainComponent *dc = &priv->dg->components[priv->local_component];
int ret = 0;
mg2di_timer_start(&ctx->timer_solve);
@@ -837,22 +877,22 @@ int mg2di_egs_init(EGSContext *ctx, int flags)
arg.dc = MG2D_DIFF_COEFF_00;
arg.fact = 1.0 / fd_denoms[ctx->fd_stencil - 1][MG2D_DIFF_COEFF_00];
- tp_execute(ctx->tp, ctx->domain_size[0], init_diff_coeffs_task, &arg);
+ tp_execute(ctx->tp, ctx->domain_size[1], init_diff_coeffs_task, &arg);
arg.dc = MG2D_DIFF_COEFF_10;
arg.fact = 1.0 / (fd_denoms[ctx->fd_stencil - 1][MG2D_DIFF_COEFF_10] * ctx->step[0]);
- tp_execute(ctx->tp, ctx->domain_size[0], init_diff_coeffs_task, &arg);
+ tp_execute(ctx->tp, ctx->domain_size[1], init_diff_coeffs_task, &arg);
arg.dc = MG2D_DIFF_COEFF_01;
arg.fact = 1.0 / (fd_denoms[ctx->fd_stencil - 1][MG2D_DIFF_COEFF_01] * ctx->step[1]);
- tp_execute(ctx->tp, ctx->domain_size[0], init_diff_coeffs_task, &arg);
+ tp_execute(ctx->tp, ctx->domain_size[1], init_diff_coeffs_task, &arg);
arg.dc = MG2D_DIFF_COEFF_20;
arg.fact = 1.0 / (fd_denoms[ctx->fd_stencil - 1][MG2D_DIFF_COEFF_20] * SQR(ctx->step[0]));
- tp_execute(ctx->tp, ctx->domain_size[0], init_diff_coeffs_task, &arg);
+ tp_execute(ctx->tp, ctx->domain_size[1], init_diff_coeffs_task, &arg);
arg.dc = MG2D_DIFF_COEFF_02;
arg.fact = 1.0 / (fd_denoms[ctx->fd_stencil - 1][MG2D_DIFF_COEFF_02] * SQR(ctx->step[1]));
- tp_execute(ctx->tp, ctx->domain_size[0], init_diff_coeffs_task, &arg);
+ tp_execute(ctx->tp, ctx->domain_size[1], init_diff_coeffs_task, &arg);
arg.dc = MG2D_DIFF_COEFF_11;
arg.fact = 1.0 / (fd_denoms[ctx->fd_stencil - 1][MG2D_DIFF_COEFF_11] * ctx->step[0] * ctx->step[1]);
- tp_execute(ctx->tp, ctx->domain_size[0], init_diff_coeffs_task, &arg);
+ tp_execute(ctx->tp, ctx->domain_size[1], init_diff_coeffs_task, &arg);
}
if (!(flags & EGS_INIT_FLAG_SAME_DIFF_COEFFS))
@@ -870,13 +910,18 @@ int mg2di_egs_init(EGSContext *ctx, int flags)
priv->residual_calc_size[0] = ctx->domain_size[0];
priv->residual_calc_size[1] = ctx->domain_size[1];
- priv->residual_calc_offset[0] = ctx->boundaries[MG2D_BOUNDARY_1L]->type == MG2D_BC_TYPE_FIXVAL;
- priv->residual_calc_offset[1] = ctx->boundaries[MG2D_BOUNDARY_0L]->type == MG2D_BC_TYPE_FIXVAL;
+ priv->residual_calc_offset[0] = ctx->boundaries[MG2D_BOUNDARY_1L]->type == MG2D_BC_TYPE_FIXVAL &&
+ dc->bnd_is_outer[MG2D_BOUNDARY_1L];
+ priv->residual_calc_offset[1] = ctx->boundaries[MG2D_BOUNDARY_0L]->type == MG2D_BC_TYPE_FIXVAL &&
+ dc->bnd_is_outer[MG2D_BOUNDARY_0L];
for (int bnd_idx = 0; bnd_idx < ARRAY_ELEMS(ctx->boundaries); bnd_idx++) {
MG2DBoundary *bnd = ctx->boundaries[bnd_idx];
const int ci = mg2d_bnd_coord_idx(bnd_idx);
+ if (!dc->bnd_is_outer[bnd_idx])
+ continue;
+
if (bnd->type == MG2D_BC_TYPE_FIXVAL) {
double maxval = 0.0;
@@ -923,6 +968,7 @@ finish:
static int arrays_alloc(EGSContext *ctx, const size_t domain_size[2])
{
EGSInternal *priv = ctx->priv;
+ const DomainComponent *dc = &priv->dg->components[priv->local_component];
const size_t ghosts = FD_STENCIL_MAX;
const size_t size_padded[2] = {
@@ -935,15 +981,23 @@ static int arrays_alloc(EGSContext *ctx, const size_t domain_size[2])
};
const Slice slice[2] = { SLICE(ghosts, -ghosts, 1),
SLICE(ghosts, -ghosts, 1) };
+ Slice slice_exterior[2] = {
+ SLICE(dc->bnd_is_outer[MG2D_BOUNDARY_1L] ? 0 : ghosts,
+ dc->bnd_is_outer[MG2D_BOUNDARY_1U] ? 0 : -ghosts, 1),
+ SLICE(dc->bnd_is_outer[MG2D_BOUNDARY_0L] ? 0 : ghosts,
+ dc->bnd_is_outer[MG2D_BOUNDARY_0U] ? 0 : -ghosts, 1),
+ };
int ret;
- ret = mg2di_ndarray_alloc(&ctx->u_base, 2, size_padded, NDARRAY_ALLOC_ZERO);
+ ret = mg2di_ndarray_alloc(&priv->u_base, 2, size_padded, NDARRAY_ALLOC_ZERO);
if (ret < 0)
return ret;
- ret = mg2di_ndarray_slice(&ctx->u, ctx->u_base,
- (Slice [2]){ SLICE(ghosts, size_padded[0] - ghosts, 1),
- SLICE(ghosts, -ghosts, 1) });
+ ret = mg2di_ndarray_slice(&ctx->u, priv->u_base, slice);
+ if (ret < 0)
+ return ret;
+
+ ret = mg2di_ndarray_slice(&ctx->u_exterior, priv->u_base, slice_exterior);
if (ret < 0)
return ret;
@@ -951,9 +1005,11 @@ static int arrays_alloc(EGSContext *ctx, const size_t domain_size[2])
if (ret < 0)
return ret;
- ret = mg2di_ndarray_slice(&priv->u_next, priv->u_next_base,
- (Slice [2]){ SLICE(ghosts, size_padded[0] - ghosts, 1),
- SLICE(ghosts, -ghosts, 1) });
+ ret = mg2di_ndarray_slice(&priv->u_next, priv->u_next_base, slice);
+ if (ret < 0)
+ return ret;
+
+ ret = mg2di_ndarray_slice(&priv->u_next_exterior, priv->u_next_base, slice_exterior);
if (ret < 0)
return ret;
@@ -1010,7 +1066,7 @@ static int arrays_alloc(EGSContext *ctx, const size_t domain_size[2])
return 0;
}
-EGSContext *mg2di_egs_alloc(size_t domain_size[2])
+static EGSContext *egs_alloc(const DomainGeometry *dg, unsigned int local_component)
{
EGSContext *ctx;
int ret;
@@ -1032,21 +1088,28 @@ EGSContext *mg2di_egs_alloc(size_t domain_size[2])
if (!ctx->exact)
goto fail;
+ ret = mg2di_dg_copy(&ctx->priv->dg, dg);
+ if (ret < 0)
+ goto fail;
+
+ ctx->priv->local_component = local_component;
+ ctx->priv->comm = MPI_COMM_NULL;
+
mg2di_timer_init(&ctx->exact->timer_mat_construct);
mg2di_timer_init(&ctx->exact->timer_bicgstab);
mg2di_timer_init(&ctx->exact->timer_lu_solve);
mg2di_timer_init(&ctx->exact->timer_export);
- if (!domain_size[0] || !domain_size[1] ||
- domain_size[0] > SIZE_MAX / domain_size[1])
+ if (!dg->domain_size[0] || !dg->domain_size[1] ||
+ dg->domain_size[0] > SIZE_MAX / dg->domain_size[1])
goto fail;
- ret = arrays_alloc(ctx, domain_size);
+ ret = arrays_alloc(ctx, dg->components[local_component].interior.size);
if (ret < 0)
goto fail;
- ctx->domain_size[0] = domain_size[0];
- ctx->domain_size[1] = domain_size[1];
+ ctx->domain_size[0] = dg->components[local_component].interior.size[0];
+ ctx->domain_size[1] = dg->components[local_component].interior.size[1];
ctx->priv->rescalc = mg2di_residual_calc_alloc();
if (!ctx->priv->rescalc)
@@ -1060,6 +1123,7 @@ EGSContext *mg2di_egs_alloc(size_t domain_size[2])
mg2di_timer_init(&ctx->timer_res_calc);
mg2di_timer_init(&ctx->timer_init);
mg2di_timer_init(&ctx->timer_solve);
+ mg2di_timer_init(&ctx->timer_mpi_sync);
return ctx;
fail:
@@ -1067,6 +1131,113 @@ fail:
return NULL;
}
+EGSContext *mg2di_egs_alloc(size_t domain_size[2])
+{
+ EGSContext *ctx;
+ DomainGeometry *dg;
+
+ dg = mg2di_dg_alloc(1);
+ if (!dg)
+ return NULL;
+
+ dg->domain_size[0] = domain_size[0];
+ dg->domain_size[1] = domain_size[1];
+
+ dg->components[0].interior.start[0] = 0;
+ dg->components[0].interior.start[1] = 0;
+ dg->components[0].interior.size[0] = domain_size[0];
+ dg->components[0].interior.size[1] = domain_size[1];
+
+ dg->components[0].exterior.start[0] = -FD_STENCIL_MAX;
+ dg->components[0].exterior.start[1] = -FD_STENCIL_MAX;
+ dg->components[0].exterior.size[0] = domain_size[0] + 2 * FD_STENCIL_MAX;
+ dg->components[0].exterior.size[1] = domain_size[1] + 2 * FD_STENCIL_MAX;
+
+ for (int i = 0; i < ARRAY_ELEMS(dg->components[0].bnd_is_outer); i++)
+ dg->components[0].bnd_is_outer[i] = 1;
+
+ ctx = egs_alloc(dg, 0);
+ mg2di_dg_free(&dg);
+ if (!ctx)
+ return NULL;
+
+ return ctx;
+}
+
+EGSContext *mg2di_egs_alloc_mpi(MPI_Comm comm, const DomainGeometry *dg)
+{
+ EGSContext *ctx = NULL;
+ Rect *overlaps_recv = NULL, *overlaps_send = NULL;
+ ptrdiff_t *lo;
+ int local_component;
+ int ret;
+
+ MPI_Comm_rank(comm, &local_component);
+
+ overlaps_recv = calloc(dg->nb_components, sizeof(*overlaps_recv));
+ overlaps_send = calloc(dg->nb_components, sizeof(*overlaps_send));
+ if (!overlaps_recv || !overlaps_send)
+ goto fail;
+
+ ret = mg2di_dg_edge_overlaps(overlaps_recv, overlaps_send,
+ dg, local_component, FD_STENCIL_MAX);
+ if (ret < 0)
+ goto fail;
+
+ ctx = egs_alloc(dg, local_component);
+
+ ctx->priv->comm = comm;
+
+ ctx->priv->sync_sendtypes = calloc(dg->nb_components, sizeof(*ctx->priv->sync_sendtypes));
+ ctx->priv->sync_senddispl = calloc(dg->nb_components, sizeof(*ctx->priv->sync_senddispl));
+ ctx->priv->sync_sendcounts = calloc(dg->nb_components, sizeof(*ctx->priv->sync_sendcounts));
+ ctx->priv->sync_recvtypes = calloc(dg->nb_components, sizeof(*ctx->priv->sync_recvtypes));
+ ctx->priv->sync_recvdispl = calloc(dg->nb_components, sizeof(*ctx->priv->sync_recvdispl));
+ ctx->priv->sync_recvcounts = calloc(dg->nb_components, sizeof(*ctx->priv->sync_recvcounts));
+ if (!ctx->priv->sync_sendtypes || !ctx->priv->sync_senddispl || !ctx->priv->sync_sendcounts ||
+ !ctx->priv->sync_recvtypes || !ctx->priv->sync_recvdispl || !ctx->priv->sync_recvcounts)
+ goto fail;
+
+ lo = dg->components[local_component].interior.start;
+
+ /* construct the send/receive parameters */
+ for (unsigned int i = 0; i < dg->nb_components; i++) {
+ if (i == local_component) {
+ MPI_Type_dup(MPI_INT, &ctx->priv->sync_sendtypes[i]);
+ MPI_Type_dup(MPI_INT, &ctx->priv->sync_recvtypes[i]);
+ ctx->priv->sync_sendcounts[i] = 0;
+ ctx->priv->sync_recvcounts[i] = 0;
+ ctx->priv->sync_senddispl[i] = 0;
+ ctx->priv->sync_recvdispl[i] = 0;
+
+ continue;
+ }
+
+ /* receive */
+ MPI_Type_vector(overlaps_recv[i].size[1], overlaps_recv[i].size[0],
+ ctx->u->stride[0], MPI_DOUBLE, &ctx->priv->sync_recvtypes[i]);
+ MPI_Type_commit(&ctx->priv->sync_recvtypes[i]);
+ ctx->priv->sync_recvcounts[i] = 1;
+ ctx->priv->sync_recvdispl[i] = ((overlaps_recv[i].start[1] - lo[1]) * ctx->u->stride[0] +
+ (overlaps_recv[i].start[0] - lo[0])) * sizeof(*ctx->u->data);
+
+ /* send */
+ MPI_Type_vector(overlaps_send[i].size[1], overlaps_send[i].size[0],
+ ctx->u->stride[0], MPI_DOUBLE, &ctx->priv->sync_sendtypes[i]);
+ MPI_Type_commit(&ctx->priv->sync_sendtypes[i]);
+ ctx->priv->sync_sendcounts[i] = 1;
+ ctx->priv->sync_senddispl[i] = ((overlaps_send[i].start[1] - lo[1]) * ctx->u->stride[0] +
+ (overlaps_send[i].start[0] - lo[0])) * sizeof(*ctx->u->data);
+ }
+
+ return ctx;
+fail:
+ free(overlaps_recv);
+ free(overlaps_send);
+ mg2di_egs_free(&ctx);
+ return NULL;
+}
+
void mg2di_egs_free(EGSContext **pctx)
{
EGSContext *ctx = *pctx;
@@ -1082,8 +1253,10 @@ void mg2di_egs_free(EGSContext **pctx)
exact_arrays_free(ctx);
mg2di_ndarray_free(&ctx->u);
- mg2di_ndarray_free(&ctx->u_base);
+ mg2di_ndarray_free(&ctx->u_exterior);
+ mg2di_ndarray_free(&ctx->priv->u_base);
mg2di_ndarray_free(&ctx->priv->u_next);
+ mg2di_ndarray_free(&ctx->priv->u_next_exterior);
mg2di_ndarray_free(&ctx->priv->u_next_base);
mg2di_ndarray_free(&ctx->rhs);
@@ -1105,6 +1278,28 @@ void mg2di_egs_free(EGSContext **pctx)
for (int i = 0; i < ARRAY_ELEMS(ctx->boundaries); i++)
mg2di_bc_free(&ctx->boundaries[i]);
+ if (ctx->priv->sync_sendtypes) {
+ for (int i = 0; i < ctx->priv->dg->nb_components; i++) {
+ if (ctx->priv->sync_sendtypes[i])
+ MPI_Type_free(&ctx->priv->sync_sendtypes[i]);
+ }
+ }
+ if (ctx->priv->sync_recvtypes) {
+ for (int i = 0; i < ctx->priv->dg->nb_components; i++) {
+ if (ctx->priv->sync_recvtypes[i])
+ MPI_Type_free(&ctx->priv->sync_recvtypes[i]);
+ }
+ }
+ free(ctx->priv->sync_sendtypes);
+ free(ctx->priv->sync_recvtypes);
+
+ free(ctx->priv->sync_sendcounts);
+ free(ctx->priv->sync_senddispl);
+ free(ctx->priv->sync_recvcounts);
+ free(ctx->priv->sync_recvdispl);
+
+ mg2di_dg_free(&ctx->priv->dg);
+
tp_free(&ctx->priv->tp_internal);
free(ctx->priv);