aboutsummaryrefslogtreecommitdiff
path: root/Carpet/LoopControl/src/loopcontrol.c
diff options
context:
space:
mode:
authorErik Schnetter <schnetter@cct.lsu.edu>2009-09-21 11:28:26 -0500
committerBarry Wardell <barry.wardell@gmail.com>2011-12-14 16:45:09 +0000
commit29e373ad99f97175fd6443dd7e9307e10cc125f2 (patch)
tree1b424d49d8e2dd079b78821c5fa842e559c91732 /Carpet/LoopControl/src/loopcontrol.c
parentf5b0376823ed3658847fdf2f3447c87fcdec15ff (diff)
LoopControl: Implement cache-collaborative multi-threading
Ignore-this: 5169757c7749834ae595d4d73b39220 Add a new, additional feature to LoopControl: different threads can work on small regions that are likely to use the same cache entries as other threads, trying to reduce cache pressure. This makes sense mostly when the regions are still expensive although they are small, e.g. for the BSSN RHS.
Diffstat (limited to 'Carpet/LoopControl/src/loopcontrol.c')
-rw-r--r--Carpet/LoopControl/src/loopcontrol.c187
1 files changed, 138 insertions, 49 deletions
diff --git a/Carpet/LoopControl/src/loopcontrol.c b/Carpet/LoopControl/src/loopcontrol.c
index afc19fec3..34bfe4d6b 100644
--- a/Carpet/LoopControl/src/loopcontrol.c
+++ b/Carpet/LoopControl/src/loopcontrol.c
@@ -63,11 +63,11 @@ lc_statmap_t * lc_statmap_list = NULL;
/* Find all possible thread topologies */
/* This finds all possible thread topologies which can be expressed as
- NIxNJxNK. More complex topologies, e.g. based on a recursive
- subdiviston, are not considered (and cannot be expressed in the
- data structures used in LoopControl). I think more complex
- topologies are not necessary, since the number of treads is usually
- quite small and contains many small factors in its prime
+ NIxNJxNK x NIIxNJJxNKK. More complex topologies, e.g. based on a
+ recursive subdiviston, are not considered (and cannot be expressed
+ with the data structures used in LoopControl). I expect that more
+ complex topologies are not necessary, since the number of treads is
+ usually quite small and contains many small factors in its prime
decomposition. */
static
void
@@ -77,17 +77,36 @@ find_thread_topologies (lc_topology_t * restrict const topologies,
int const nthreads)
{
* ntopologies = 0;
+
for (int nk=1; nk<=nthreads; ++nk) {
if (nthreads % nk == 0) {
for (int nj=1; nj<=nthreads/nk; ++nj) {
if (nthreads % (nj*nk) == 0) {
- int const ni = nthreads/(nj*nk);
- if (nthreads == ni*nj*nk) {
- assert (* ntopologies < maxntopologies);
- topologies[* ntopologies].nthreads[0] = ni;
- topologies[* ntopologies].nthreads[1] = nj;
- topologies[* ntopologies].nthreads[2] = nk;
- ++ * ntopologies;
+ for (int ni=1; ni<=nthreads/(nj*nk); ++ni) {
+ if (nthreads % (ni*nj*nk) == 0) {
+
+ int const nithreads = nthreads/(ni*nj*nk);
+ for (int nkk=1; nkk<=nithreads; ++nkk) {
+ if (nithreads % nkk == 0) {
+ for (int njj=1; njj<=nithreads/nkk; ++njj) {
+ if (nithreads % (njj*nkk) == 0) {
+ int const nii = nithreads/(njj*nkk);
+
+ assert (* ntopologies < maxntopologies);
+ topologies[* ntopologies].nthreads[0][0] = ni;
+ topologies[* ntopologies].nthreads[0][1] = nj;
+ topologies[* ntopologies].nthreads[0][2] = nk;
+ topologies[* ntopologies].nthreads[1][0] = nii;
+ topologies[* ntopologies].nthreads[1][1] = njj;
+ topologies[* ntopologies].nthreads[1][2] = nkk;
+ ++ * ntopologies;
+
+ }
+ }
+ }
+ }
+
+ }
}
}
}
@@ -222,20 +241,28 @@ lc_stattime_init (lc_stattime_t * restrict const lt,
lt->inthreads = -1;
lt->jnthreads = -1;
lt->knthreads = -1;
+ lt->inithreads = -1;
+ lt->jnithreads = -1;
+ lt->knithreads = -1;
} else {
assert (state->topology >= 0 && state->topology < ls->ntopologies);
- lt->inthreads = ls->topologies[lt->state.topology].nthreads[0];
- lt->jnthreads = ls->topologies[lt->state.topology].nthreads[1];
- lt->knthreads = ls->topologies[lt->state.topology].nthreads[2];
+ lt->inthreads = ls->topologies[lt->state.topology].nthreads[0][0];
+ lt->jnthreads = ls->topologies[lt->state.topology].nthreads[0][1];
+ lt->knthreads = ls->topologies[lt->state.topology].nthreads[0][2];
+ lt->inithreads = ls->topologies[lt->state.topology].nthreads[1][0];
+ lt->jnithreads = ls->topologies[lt->state.topology].nthreads[1][1];
+ lt->knithreads = ls->topologies[lt->state.topology].nthreads[1][2];
}
if (debug) {
- printf ("Thread topology #%d [%d,%d,%d]\n",
- lt->state.topology, lt->inthreads, lt->jnthreads, lt->knthreads);
+ printf ("Thread topology #%d [%d,%d,%d]x[%d,%d,%d]\n",
+ lt->state.topology,
+ lt->inthreads, lt->jnthreads, lt->knthreads,
+ lt->inithreads, lt->jnithreads, lt->knithreads);
}
/* Assert thread topology consistency */
@@ -243,7 +270,12 @@ lc_stattime_init (lc_stattime_t * restrict const lt,
assert (lt->inthreads >= 1);
assert (lt->jnthreads >= 1);
assert (lt->knthreads >= 1);
- assert (lt->inthreads * lt->jnthreads * lt->knthreads == ls->num_threads);
+ assert (lt->inithreads >= 1);
+ assert (lt->jnithreads >= 1);
+ assert (lt->knithreads >= 1);
+ assert (lt->inthreads * lt->jnthreads * lt->knthreads *
+ lt->inithreads * lt->jnithreads * lt->knithreads ==
+ ls->num_threads);
}
/*** Tilings ****************************************************************/
@@ -355,37 +387,56 @@ lc_statset_init (lc_statset_t * restrict const ls,
assert (ls);
assert (lm);
assert (num_threads >= 1);
+ int total_npoints = 1;
for (int d=0; d<3; ++d) {
assert (npoints[d] >= 0);
+ assert (npoints[d] < 1000000000);
+ assert (npoints[d] < 1000000000 / total_npoints);
+ total_npoints *= npoints[d];
}
/*** Threads ****************************************************************/
- ls->num_threads = num_threads;
+ static int saved_num_threads = -1;
+ static lc_topology_t * restrict saved_topologies;
+ static int saved_ntopologies;
- /* For up to 1024 threads, there are at most 270 possible
- topologies */
- int const maxntopologies = 1000;
- if (debug) {
- printf ("Running on %d threads\n", ls->num_threads);
- }
- ls->topologies = malloc (maxntopologies * sizeof * ls->topologies);
- find_thread_topologies
- (ls->topologies, maxntopologies, & ls->ntopologies, ls->num_threads);
-#if 0
- ls->topologies =
- realloc (ls->topologies, ls->ntopologies * sizeof * ls->topologies);
-#endif
- if (debug) {
- printf ("Found %d possible thread topologies\n", ls->ntopologies);
- for (int n = 0; n < ls->ntopologies; ++n) {
- printf (" %2d: %2d %2d %2d\n",
- n,
- ls->topologies[n].nthreads[0],
- ls->topologies[n].nthreads[1],
- ls->topologies[n].nthreads[2]);
+ if (saved_num_threads == -1) {
+ saved_num_threads = num_threads;
+
+ /* For up to 1024 threads, there are at most 611556 possible
+ topologies */
+ int const maxntopologies = 1000000;
+ if (debug) {
+ printf ("Running on %d threads\n", num_threads);
+ }
+
+ saved_topologies = malloc (maxntopologies * sizeof * saved_topologies);
+ find_thread_topologies
+ (saved_topologies, maxntopologies, & saved_ntopologies,
+ saved_num_threads);
+ saved_topologies =
+ realloc (saved_topologies, saved_ntopologies * sizeof * saved_topologies);
+
+ if (debug) {
+ printf ("Found %d possible thread topologies\n", saved_ntopologies);
+ for (int n = 0; n < saved_ntopologies; ++n) {
+ printf (" %2d: %2d %2d %2d %2d %2d %2d\n",
+ n,
+ saved_topologies[n].nthreads[0][0],
+ saved_topologies[n].nthreads[0][1],
+ saved_topologies[n].nthreads[0][2],
+ saved_topologies[n].nthreads[1][0],
+ saved_topologies[n].nthreads[1][1],
+ saved_topologies[n].nthreads[1][2]);
+ }
}
}
+
+ assert (saved_num_threads == num_threads);
+ ls->num_threads = saved_num_threads;
+ ls->topologies = saved_topologies;
+ ls->ntopologies = saved_ntopologies;
assert (ls->ntopologies > 0);
/*** Tilings ****************************************************************/
@@ -409,7 +460,9 @@ lc_statset_init (lc_statset_t * restrict const ls,
for (int n = 0; n < ls->ntopologies; ++n) {
int tiling;
for (tiling = 1; tiling < ls->ntilings[d]; ++tiling) {
- if (ls->tilings[d][tiling].npoints * ls->topologies[n].nthreads[d] >
+ if (ls->tilings[d][tiling].npoints *
+ ls->topologies[n].nthreads[0][d] *
+ ls->topologies[n].nthreads[1][d] >
ls->npoints[d])
{
break;
@@ -606,12 +659,12 @@ lc_control_init (lc_control_t * restrict const lc,
if (lc_inthreads == -1 || lc_jnthreads == -1 || lc_knthreads == -1) {
CCTK_VWarn (CCTK_WARN_ABORT, __LINE__, __FILE__, CCTK_THORNSTRING,
- "Illegal thread topology [%d,%d,%d] specified",
+ "Illegal thread topology [%d,%d,%d]x[1,1,1] specified",
(int)lc_inthreads, (int)lc_jnthreads, (int)lc_knthreads);
}
if (lc_inthreads * lc_jnthreads * lc_knthreads != ls->num_threads) {
CCTK_VWarn (CCTK_WARN_ABORT, __LINE__, __FILE__, CCTK_THORNSTRING,
- "Specified thread topology [%d,%d,%d] is not compatible with the number of threads %d",
+ "Specified thread topology [%d,%d,%d]x[1,1,1] is not compatible with the number of threads %d",
(int)lc_inthreads, (int)lc_jnthreads, (int)lc_knthreads,
ls->num_threads);
}
@@ -698,13 +751,21 @@ lc_control_init (lc_control_t * restrict const lc,
lt->inthreads = lc_inthreads;
lt->jnthreads = lc_jnthreads;
lt->knthreads = lc_knthreads;
+ lt->inithreads = 1;
+ lt->jnithreads = 1;
+ lt->knithreads = 1;
}
/* Assert thread topology consistency */
assert (lt->inthreads >= 1);
assert (lt->jnthreads >= 1);
assert (lt->knthreads >= 1);
- assert (lt->inthreads * lt->jnthreads * lt->knthreads == ls->num_threads);
+ assert (lt->inithreads >= 1);
+ assert (lt->jnithreads >= 1);
+ assert (lt->knithreads >= 1);
+ assert (lt->inthreads * lt->jnthreads * lt->knthreads *
+ lt->inithreads * lt->jnithreads * lt->knithreads ==
+ ls->num_threads);
/* Tilings */
@@ -738,7 +799,6 @@ lc_control_init (lc_control_t * restrict const lc,
/*** Threads ****************************************************************/
-
/* Thread loop settings */
lc->iiimin = imin;
lc->jjjmin = jmin;
@@ -762,7 +822,7 @@ lc_control_init (lc_control_t * restrict const lc,
int const ci = c % lt->inthreads; c /= lt->inthreads;
int const cj = c % lt->jnthreads; c /= lt->jnthreads;
int const ck = c % lt->knthreads; c /= lt->knthreads;
- assert (c == 0);
+ /* see below where c is continued to be used */
lc->iii = lc->iiimin + ci * lc->iiistep;
lc->jjj = lc->jjjmin + cj * lc->jjjstep;
lc->kkk = lc->kkkmin + ck * lc->kkkstep;
@@ -789,6 +849,29 @@ lc_control_init (lc_control_t * restrict const lc,
+ /*** Inner threads **********************************************************/
+
+ /* Inner loop thread parallelism */
+ /* (No thread parallelism yet) */
+ int const cii = c % lt->inithreads; c /= lt->inithreads;
+ int const cjj = c % lt->jnithreads; c /= lt->jnithreads;
+ int const ckk = c % lt->knithreads; c /= lt->knithreads;
+ assert (c == 0);
+ lc->iiii = cii;
+ lc->jjjj = cjj;
+ lc->kkkk = ckk;
+ lc->iiiistep = (lc->iistep + lt->inithreads - 1) / lt->inithreads;
+ lc->jjjjstep = (lc->jjstep + lt->jnithreads - 1) / lt->jnithreads;
+ lc->kkkkstep = (lc->kkstep + lt->knithreads - 1) / lt->knithreads;
+ lc->iiiimin = lc->iiii * lc->iiiistep;
+ lc->jjjjmin = lc->jjjj * lc->jjjjstep;
+ lc->kkkkmin = lc->kkkk * lc->kkkkstep;
+ lc->iiiimax = lc->iiiimin + lc->iiiistep;
+ lc->jjjjmax = lc->jjjjmin + lc->jjjjstep;
+ lc->kkkkmax = lc->kkkkmin + lc->kkkkstep;
+
+
+
/****************************************************************************/
/* Timer */
@@ -810,6 +893,10 @@ lc_control_finish (lc_control_t * restrict const lc)
ignore_iteration = ignore_initial_overhead && lt->time_count == 0.0;
}
+ /* Add a barrier to catch load imbalances */
+#pragma omp barrier
+ ;
+
/* Timer */
double const time_calc_end = omp_get_wtime();
double const time_calc_begin = lc->time_calc_begin;
@@ -849,7 +936,7 @@ lc_control_finish (lc_control_t * restrict const lc)
lt->last_updated = time_calc_end;
}
-#pragma omp barrier
+/* #pragma omp barrier */
{
DECLARE_CCTK_PARAMETERS;
@@ -910,9 +997,11 @@ lc_printstats (CCTK_ARGUMENTS)
int imin_calc = -1;
int ntimes = 0;
for (lc_stattime_t * lt = ls->stattime_list; lt; lt = lt->next) {
- printf (" stattime #%d topology=%d [%d,%d,%d] tiling=[%d,%d,%d]\n",
+ printf (" stattime #%d topology=%d [%d,%d,%d]x[%d,%d,%d] tiling=[%d,%d,%d]\n",
ntimes,
- lt->state.topology, lt->inthreads, lt->jnthreads, lt->knthreads,
+ lt->state.topology,
+ lt->inthreads, lt->jnthreads, lt->knthreads,
+ lt->inithreads, lt->jnithreads, lt->knithreads,
lt->inpoints, lt->jnpoints, lt->knpoints);
double const count = lt->time_count;
double const setup = lt->time_setup_sum / count;