aboutsummaryrefslogtreecommitdiff
path: root/Carpet/CarpetLib/src/dist.cc
diff options
context:
space:
mode:
authorErik Schnetter <schnetter@cct.lsu.edu>2011-02-05 19:19:25 -0500
committerBarry Wardell <barry.wardell@gmail.com>2011-12-14 18:25:57 +0000
commitf7600fa527bbfa205bc885861c7b2d78e3075077 (patch)
treee2350e9d5d0b8d0b0419f4e71b0e5b97ea780165 /Carpet/CarpetLib/src/dist.cc
parentb1993d72f3c4e11ac8c59d998dffbfcf06dbb350 (diff)
CarpetLib: Recalculate total number of threads after changing it
Diffstat (limited to 'Carpet/CarpetLib/src/dist.cc')
-rw-r--r--Carpet/CarpetLib/src/dist.cc29
1 files changed, 25 insertions, 4 deletions
diff --git a/Carpet/CarpetLib/src/dist.cc b/Carpet/CarpetLib/src/dist.cc
index eb9adff32..42e8af322 100644
--- a/Carpet/CarpetLib/src/dist.cc
+++ b/Carpet/CarpetLib/src/dist.cc
@@ -31,6 +31,7 @@ namespace dist {
MPI_Datatype mpi_complex16 = MPI_DATATYPE_NULL;
MPI_Datatype mpi_complex32 = MPI_DATATYPE_NULL;
+ int num_threads_ = -1;
int total_num_threads_ = -1;
void init (int& argc, char**& argv) {
@@ -265,7 +266,11 @@ namespace dist {
if (num_threads > 0) {
// Set number of threads which should be used
// TODO: do this at startup, not in this routine
+ CCTK_VInfo (CCTK_THORNSTRING,
+ "Setting number of OpenMP threads per process to %d",
+ num_threads);
omp_set_num_threads (num_threads);
+ collect_total_num_threads ();
}
#else
if (num_threads > 0 and num_threads != 1) {
@@ -278,12 +283,28 @@ namespace dist {
// Global number of threads
void collect_total_num_threads ()
{
- int const mynthreads = num_threads();
- // cerr << "QQQ: collect_total_num_threads[1]" << endl;
+#ifdef _OPENMP
+# pragma omp parallel
+ {
+# pragma omp single
+ {
+ num_threads_ = omp_get_num_threads();
+ }
+ }
+ int const max_threads = omp_get_max_threads();
+ if (max_threads != num_threads_) {
+ CCTK_VWarn (CCTK_WARN_ALERT, __LINE__, __FILE__, CCTK_THORNSTRING,
+ "Unexpected OpenMP setup: omp_get_max_threads=%d, omp_get_num_threads=%d",
+ max_threads, num_threads_);
+ }
+#else
+ num_threads_ = 1;
+#endif
+ assert (num_threads_ >= 1);
+
MPI_Allreduce
- (const_cast <int *> (& mynthreads), & total_num_threads_, 1, MPI_INT,
+ (const_cast <int *> (& num_threads_), & total_num_threads_, 1, MPI_INT,
MPI_SUM, comm());
- // cerr << "QQQ: collect_total_num_threads[2]" << endl;
assert (total_num_threads_ >= size());
}