aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2020-07-11 08:12:48 +0200
committerAnton Khirnov <anton@khirnov.net>2020-07-13 10:58:17 +0200
commitc578b25b4b45570ed8d5729613c0c064b72ae06e (patch)
tree10b0b120b81257a4480e5764cc572a356c42ab5c
parentb487acb4843a8413ddcecce398a44d0c7c050c61 (diff)
Add a module for adaptive step control.
-rw-r--r--common.h1
-rw-r--r--meson.build3
-rw-r--r--step_control.c449
-rw-r--r--step_control.h35
4 files changed, 488 insertions, 0 deletions
diff --git a/common.h b/common.h
index 9c36a33..d6f55f0 100644
--- a/common.h
+++ b/common.h
@@ -22,6 +22,7 @@
#define SGN(x) ((x) >= 0.0 ? 1.0 : -1.0)
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) > (y) ? (y) : (x))
+#define CLIP(val, min, max) (MIN(MAX(val, min), max))
#define ARRAY_ELEMS(arr) (sizeof(arr) / sizeof(*arr))
#include <stdio.h>
diff --git a/meson.build b/meson.build
index 3134837..7521cbd 100644
--- a/meson.build
+++ b/meson.build
@@ -2,6 +2,8 @@ project('libmg2d', 'c',
default_options : ['c_std=c11'])
add_project_arguments('-D_XOPEN_SOURCE=700', language : 'c')
+# for random_r
+add_project_arguments('-D_DEFAULT_SOURCE=1', language : 'c')
lib_src = [
'bicgstab.c',
@@ -12,6 +14,7 @@ lib_src = [
'log.c',
'mg2d.c',
'residual_calc.c',
+ 'step_control.c',
'timer.c',
'transfer.c',
]
diff --git a/step_control.c b/step_control.c
new file mode 100644
index 0000000..5dba741
--- /dev/null
+++ b/step_control.c
@@ -0,0 +1,449 @@
+/*
+ * Copyright 2020 Anton Khirnov <anton@khirnov.net>
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include <errno.h>
+#include <float.h>
+#include <math.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <time.h>
+
+#include "common.h"
+#include "log.h"
+#include "mg2d_constants.h"
+#include "step_control.h"
+
+#define SC_HIST_SIZE 16
+
+struct StepControl {
+ /* estimated value for a good step */
+ double hint;
+ /* minimum step size, fail if this value is reached */
+ double step_min;
+
+ uint64_t step_counter;
+
+ double step[SC_HIST_SIZE];
+ double fact[SC_HIST_SIZE];
+
+ double conv_epsilon;
+
+ int hist_len;
+ int idx_bracket;
+ int idx_diverge;
+
+ struct random_data rd;
+ char rd_buf[16];
+
+ MG2DLogger logger;
+};
+
+static void sci_find_bracket(StepControl *sc)
+{
+ int idx_max = -1;
+ double val_max = 0.0;
+
+ for (int i = 0; i < sc->hist_len; i++) {
+ if (sc->fact[i] > val_max) {
+ val_max = sc->fact[i];
+ idx_max = i;
+ }
+ }
+
+ if (val_max > 1.0 && idx_max > 0 && idx_max < sc->hist_len - 1) {
+ mg2di_assert(sc->fact[idx_max - 1] < val_max &&
+ sc->fact[idx_max + 1] < val_max);
+ sc->idx_bracket = idx_max - 1;
+ }
+}
+
+static void sci_recheck_bracket(StepControl *sc)
+{
+ if (sc->idx_bracket >= 0 &&
+ (sc->fact[sc->idx_bracket + 1] <= sc->fact[sc->idx_bracket] ||
+ sc->fact[sc->idx_bracket + 1] <= sc->fact[sc->idx_bracket + 2]))
+ sc->idx_bracket = -1;
+ sci_find_bracket(sc);
+}
+
+static void sci_hist_clear(StepControl *sc)
+{
+ for (int i = 0; i < SC_HIST_SIZE; i++) {
+ sc->step[i] = DBL_MAX;
+ sc->fact[i] = DBL_MAX;
+ }
+
+ sc->hist_len = 0;
+ sc->idx_bracket = -1;
+ sc->idx_diverge = -1;
+}
+
+static void sci_hist_append(StepControl *sc, double step, double fact)
+{
+ mg2di_assert(sc->hist_len < SC_HIST_SIZE);
+ sc->step[sc->hist_len] = step;
+ sc->fact[sc->hist_len] = fact;
+ sc->hist_len++;
+}
+
+static void sci_hist_insert(StepControl *sc, double step, double fact, int idx)
+{
+ mg2di_assert(idx < SC_HIST_SIZE);
+ mg2di_assert(sc->hist_len < SC_HIST_SIZE);
+
+ memmove(sc->step + idx + 1, sc->step + idx,
+ sizeof(*sc->step) * MAX(sc->hist_len - idx, 0));
+ memmove(sc->fact + idx + 1, sc->fact + idx,
+ sizeof(*sc->fact) * MAX(sc->hist_len - idx, 0));
+
+ sc->step[idx] = step;
+ sc->fact[idx] = fact;
+
+ if (sc->idx_diverge >= idx)
+ sc->idx_diverge++;
+ if (sc->idx_bracket >= 0) {
+ if (sc->idx_bracket >= idx)
+ sc->idx_bracket++;
+ else if (sc->idx_bracket >= idx - 2)
+ sc->idx_bracket = -1;
+ }
+
+ sc->hist_len++;
+ if (sc->idx_bracket == -1)
+ sci_find_bracket(sc);
+}
+
+static void sci_hist_drop(StepControl *sc, int idx_start, int idx_end)
+{
+ idx_start = CLIP(idx_start, 0, SC_HIST_SIZE);
+ idx_end = CLIP(idx_end, idx_start, SC_HIST_SIZE);
+
+ memmove(sc->step + idx_start, sc->step + idx_end,
+ sizeof(*sc->step) * MAX(sc->hist_len - idx_end, 0));
+ memmove(sc->fact + idx_start, sc->fact + idx_end,
+ sizeof(*sc->fact) * MAX(sc->hist_len - idx_end, 0));
+
+ if (sc->idx_diverge >= idx_start) {
+ if (sc->idx_diverge >= idx_end)
+ sc->idx_diverge -= idx_end - idx_start;
+ else
+ sc->idx_diverge = -1;
+ }
+ if (sc->idx_bracket >= 0) {
+ if (sc->idx_bracket >= idx_end)
+ sc->idx_bracket -= idx_end - idx_start;
+ else if (sc->idx_bracket >= idx_start - 2)
+ sc->idx_bracket = -1;
+ }
+ sc->hist_len -= idx_end - idx_start;
+ if (sc->idx_bracket == -1)
+ sci_find_bracket(sc);
+}
+
+int sc_alloc(StepControl **psc, MG2DLogger logger)
+{
+ StepControl *sc;
+ struct timespec tv;
+
+ sc = calloc(1, sizeof(*sc));
+ if (!sc)
+ return -ENOMEM;
+
+ sci_hist_clear(sc);
+ sc->hint = DBL_MAX;
+ sc->step_min = DBL_MAX;
+ sc->step_counter = 0;
+
+ clock_gettime(CLOCK_REALTIME, &tv);
+ initstate_r(tv.tv_nsec, sc->rd_buf, sizeof(sc->rd_buf), &sc->rd);
+
+ sc->logger = logger;
+
+ *psc = sc;
+
+ return 0;
+}
+
+void sc_free(StepControl **psc)
+{
+ StepControl *sc = *psc;
+
+ if (!sc)
+ return;
+
+ free(sc);
+ *psc = NULL;
+}
+
+void sc_init(StepControl *sc, double hint, double step_min)
+{
+ sc->hint = hint > 0.0 ? hint : 0.25;
+ sc->conv_epsilon = sc->hint * 1e-3;
+ sc->step_min = step_min >= 0.0 ? step_min : sc->conv_epsilon;
+ sc->step_counter = 0;
+
+ mg2di_log(&sc->logger, MG2D_LOG_DEBUG, "stepcontrol init: hint %g step_min %g\n", hint, step_min);
+}
+
+double sc_step_get(StepControl *sc)
+{
+#define RETURN_STEP(sc, x, desc) \
+do { \
+ mg2di_log(&(sc)->logger, MG2D_LOG_DEBUG, "step get: %g (%s)\n", x, desc); \
+ return x; \
+} while (0)
+
+ sc->step_counter++;
+
+ // no history present, try the hint
+ if (!sc->hist_len)
+ RETURN_STEP(sc, sc->hint, "no history");
+
+ // don't have the bracket yet
+ if (sc->idx_bracket < 0) {
+ const double high_step = sc->step[sc->hist_len - 1];
+ const double high_fact = sc->fact[sc->hist_len - 1];
+
+ // don't know where divergence happens
+ // -> try a higher step
+ if (high_fact > 1.0)
+ RETURN_STEP(sc, high_step * 1.2, "no bracket, step high");
+
+ // convergence factor should decrease towards smaller timesteps
+ // so just try lower ones until we get a bracket
+ RETURN_STEP(sc, MAX(sc->step_min, sc->step[0] * 0.8), "no bracket, step low");
+ }
+
+ // got a bracket, golden-section-search it until convergence
+ mg2di_assert(sc->hist_len - sc->idx_bracket >= 3);
+ mg2di_assert(sc->fact[sc->idx_bracket] < sc->fact[sc->idx_bracket + 1]);
+ mg2di_assert(sc->fact[sc->idx_bracket + 2] < sc->fact[sc->idx_bracket + 1]);
+ {
+ const double dist0 = sc->step[sc->idx_bracket + 1] - sc->step[sc->idx_bracket];
+ const double dist1 = sc->step[sc->idx_bracket + 2] - sc->step[sc->idx_bracket + 1];
+
+ // try a random step upwards once in a while to escape local minima
+ if (!(sc->step_counter & ((1 << 6) - 1))) {
+ // scale such that half the steps fall within 1.25 * upper bracket
+ // with logarithmic growth
+ const double scale = log(1.5) / (0.25 * sc->step[sc->idx_bracket + 2]);
+ double offset;
+ int32_t rval;
+
+ random_r(&sc->rd, &rval);
+ mg2di_assert(rval >= 0);
+
+ offset = log(1.0 + (double)rval/RAND_MAX) / scale;
+
+ RETURN_STEP(sc, sc->step[sc->idx_bracket + 2] + offset, "random step");
+ }
+
+ // converged
+ if (dist0 < sc->conv_epsilon && dist1 < sc->conv_epsilon)
+ RETURN_STEP(sc, sc->step[sc->idx_bracket + 1], "converged");
+
+ if (dist0 > dist1)
+ RETURN_STEP(sc, sc->step[sc->idx_bracket] + 0.61803 * dist0, "bisect low");
+ else
+ RETURN_STEP(sc, sc->step[sc->idx_bracket + 2] - 0.61803 * dist1, "bisect high");
+ }
+}
+
+static void sci_log(StepControl *sc)
+{
+ char buf[1024], *p = buf;
+ for (int i = 0; i < sc->hist_len; i++) {
+ if (sc->idx_bracket == i)
+ p += snprintf(p, sizeof(buf) - (p - buf), "[");
+ if (sc->idx_diverge == i)
+ p += snprintf(p, sizeof(buf) - (p - buf), "|");
+ p += snprintf(p, sizeof(buf) - (p - buf), "%10.8g", sc->step[i]);
+ if (sc->idx_bracket >= 0 && sc->idx_bracket == i - 2)
+ p += snprintf(p, sizeof(buf) - (p - buf), "]");
+ p += snprintf(p, sizeof(buf) - (p - buf), "\t");
+ }
+ if (sc->hist_len)
+ mg2di_log(&sc->logger, MG2D_LOG_DEBUG, "%s\n", buf);
+
+ p = buf;
+ for (int i = 0; i < sc->hist_len; i++) {
+ if (sc->idx_bracket == i)
+ p += snprintf(p, sizeof(buf) - (p - buf), "[");
+ if (sc->idx_diverge == i)
+ p += snprintf(p, sizeof(buf) - (p - buf), "|");
+ p += snprintf(p, sizeof(buf) - (p - buf), "%10.8g", sc->fact[i]);
+ if (sc->idx_bracket >= 0 && sc->idx_bracket == i - 2)
+ p += snprintf(p, sizeof(buf) - (p - buf), "]");
+ p += snprintf(p, sizeof(buf) - (p - buf), "\t");
+ }
+ if (sc->hist_len)
+ mg2di_log(&sc->logger, MG2D_LOG_DEBUG, "%s\n", buf);
+}
+
+int sc_step_confirm(StepControl *sc, const double step,
+ const double norm_old, const double norm_new)
+{
+ const double conv_fact = norm_old / norm_new;
+
+ // signal failure if we reached minimum step
+ if (step <= sc->step_min) {
+ mg2di_log(&sc->logger, MG2D_LOG_ERROR, "Minimum step reached\n");
+ return -EINVAL;
+ }
+
+ // divergence
+ if (conv_fact <= 1.0) {
+ // if the step is not larger then the largest known diverging step,
+ // add it into history
+ if (sc->idx_diverge < 0 ||
+ sc->step[sc->idx_diverge] > step) {
+ const int idx_converge = sc->idx_diverge >= 0 ?
+ sc->idx_diverge - 1 : sc->hist_len - 1;
+
+ // largest known coverging step doesn't converge anymore,
+ // clear history
+ if (idx_converge >= 0 && sc->step[idx_converge] >= step)
+ sci_hist_clear(sc);
+ // drop larger diverging steps
+ if (sc->idx_diverge >= 0)
+ sci_hist_drop(sc, sc->idx_diverge, sc->hist_len);
+ sci_hist_append(sc, step, conv_fact);
+ sc->idx_diverge = sc->hist_len - 1;
+ sci_find_bracket(sc);
+ }
+
+ mg2di_log(&sc->logger, MG2D_LOG_DEBUG, "step reject: %g %g\n", step, conv_fact);
+ sci_log(sc);
+
+ return 1;
+ }
+
+ // convergence
+#define FINISH(sc, reason) \
+do { \
+ mg2di_log(&sc->logger, MG2D_LOG_DEBUG, "step confirm: %s %g %g\n", \
+ reason, step, conv_fact); \
+ goto finish; \
+} while (0);
+
+ // history empty, add first element
+ if (!sc->hist_len) {
+ sci_hist_append(sc, step, conv_fact);
+ FINISH(sc, "history empty");
+ }
+
+ // previously diverging step is now converging
+ // drop all diverging steps from history
+ if (sc->idx_diverge >= 0 && step >= sc->step[sc->idx_diverge]) {
+ sci_hist_drop(sc, sc->idx_diverge, sc->hist_len);
+ sci_hist_append(sc, step, conv_fact);
+ sci_recheck_bracket(sc);
+ FINISH(sc, "diverge->converge");
+ }
+
+ // no bracket
+ if (sc->idx_bracket < 0) {
+ int idx_floor = -1;
+
+ for (int i = 0; i < sc->hist_len; i++) {
+ if (sc->step[i] < step)
+ idx_floor = i;
+ else
+ break;
+ }
+
+ // insert the new step where it belongs and check if we now have a
+ // bracket
+ sci_hist_insert(sc, step, conv_fact, idx_floor + 1);
+ sci_find_bracket(sc);
+
+ // garbage-collect history
+ if (sc->idx_bracket >= 0) {
+ // bracket found: drop all converging steps beside the bracket
+ sci_hist_drop(sc, 0, sc->idx_bracket);
+ sci_hist_drop(sc, sc->idx_bracket + 3,
+ sc->idx_diverge >= 0 ? sc->idx_diverge : sc->hist_len);
+ } else {
+ // bracket still not found -> keep at most:
+ // - lowest converging step
+ // - highest converging step
+ sci_hist_drop(sc, 1, sc->idx_diverge >= 0 ? sc->idx_diverge - 1 : sc->hist_len - 1);
+ }
+
+ FINISH(sc, "no bracket");
+ }
+
+ // have bracket
+
+ // got a step outside of the bracket
+ if (step <= sc->step[sc->idx_bracket] ||
+ step >= sc->step[sc->idx_bracket + 2]) {
+ if (conv_fact >= sc->fact[sc->idx_bracket + 1]) {
+ sci_hist_clear(sc);
+ sci_hist_append(sc, step, conv_fact);
+ }
+ FINISH(sc, "outside bracket");
+ }
+
+ // step within epsilon of the center
+ // update central value and check if that breaks the bracket
+ if (fabs(step - sc->step[sc->idx_bracket + 1]) <= sc->conv_epsilon) {
+ sc->fact[sc->idx_bracket + 1] = conv_fact;
+ sci_recheck_bracket(sc);
+ FINISH(sc, "step center");
+ }
+
+ if (step < sc->step[sc->idx_bracket + 1]) {
+ // step inside lower interval
+ if (conv_fact > sc->fact[sc->idx_bracket + 1]) {
+ // replace upper bound
+ sc->step[sc->idx_bracket + 2] = sc->step[sc->idx_bracket + 1];
+ sc->fact[sc->idx_bracket + 2] = sc->fact[sc->idx_bracket + 1];
+ if (sc->idx_diverge == sc->idx_bracket + 2)
+ sc->idx_diverge = -1;
+
+ sc->step[sc->idx_bracket + 1] = step;
+ sc->fact[sc->idx_bracket + 1] = conv_fact;
+ } else {
+ // replace lower bound
+ sc->step[sc->idx_bracket] = step;
+ sc->fact[sc->idx_bracket] = conv_fact;
+ }
+ } else {
+ // step inside upper interval
+ if (conv_fact > sc->fact[sc->idx_bracket + 1]) {
+ // replace lower bound
+ sc->step[sc->idx_bracket] = sc->step[sc->idx_bracket + 1];
+ sc->fact[sc->idx_bracket] = sc->fact[sc->idx_bracket + 1];
+ sc->step[sc->idx_bracket + 1] = step;
+ sc->fact[sc->idx_bracket + 1] = conv_fact;
+ } else {
+ sc->step[sc->idx_bracket + 2] = step;
+ sc->fact[sc->idx_bracket + 2] = conv_fact;
+ if (sc->idx_diverge == sc->idx_bracket + 2)
+ sc->idx_diverge = -1;
+ }
+ }
+ FINISH(sc, "bisect");
+
+finish:
+ mg2di_log(&sc->logger, MG2D_LOG_DEBUG, "converge: %g %g\n", step, conv_fact);
+ sci_log(sc);
+ return 0;
+}
diff --git a/step_control.h b/step_control.h
new file mode 100644
index 0000000..98c4744
--- /dev/null
+++ b/step_control.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright 2020 Anton Khirnov <anton@khirnov.net>
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef MG2D_STEP_CONTROL_H
+#define MG2D_STEP_CONTROL_H
+
+#include "log.h"
+
+typedef struct StepControl StepControl;
+
+int sc_alloc(StepControl **sc, MG2DLogger logger);
+void sc_free(StepControl **psc);
+
+void sc_init(StepControl *sc, double hint, double step_min);
+
+double sc_step_get(StepControl *sc);
+
+int sc_step_confirm(StepControl *sc, const double step,
+ const double norm_old, const double norm_new);
+
+#endif // MG2D_STEP_CONTROL_H