aboutsummaryrefslogtreecommitdiff
path: root/ell_grid_solve.h
blob: 1e00393be5fb8ae36ef20bfa1aa0ffdfe198a165 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
/*
 * Solver for a 2nd order 2D linear PDE using finite differencing on a grid.
 * Copyright 2018 Anton Khirnov <anton@khirnov.net>
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#ifndef MG2D_ELL_GRID_SOLVE_H
#define MG2D_ELL_GRID_SOLVE_H

/**
 * The problem being solved is a linear partial differential
 * equation
 *
 *   ∑ C_{ab} ∂_a ∂_b u  + ∑ C_{a} ∂_a u + C u = rhs
 *  a,b                     a
 *
 *  where
 *      * a and b identify spatial directions and run from 0 to 1
 *      * u = u(x_a) is the unknown function
 *      * C_{ab}, C_{a} and C are the coefficients in front of
 *        the corresponding derivative of unknown function
 *      * rhs is the right-hand side of the equation
 * C_{*} and rhs are all (known) functions of space and define the
 * equation to be solved.
 */

#include <stddef.h>
#include <stdint.h>
#include <threadpool.h>

#include <mpi.h>
#include <ndarray.h>

#include "components.h"
#include "log.h"
#include "mg2d_boundary.h"
#include "mg2d_constants.h"
#include "timer.h"

enum EGSType {
    /**
     * Solve the equation using relaxation.
     *
     * solver_data is EGSRelaxContext
     * mg2di_egs_solve() does a single relaxation step
     */
    EGS_SOLVE_RELAXATION,
    /**
     * Solve the equation exactly by contructing a linear system and solving it with LAPACK.
     *
     * solver_data is EGSExactContext
     * mg2di_egs_solve() solves the discretized system exactly (up to roundoff error)
     */
    EGS_SOLVE_EXACT,
};

typedef struct EGSInternal EGSInternal;

typedef struct EGSRelaxContext {
    /**
     * The "time step" for relaxation is calculated as
     * Δt = r_m * r_f * step[0] * step[1]
     *
     * Where
     *  - r_m is relax_multiplier if specified, 1.0 otherwise
     *  - r_f is relax_factor if specified, a default fd_stencil-dependent value
     *    otherwise
     */
    /**
     * The time stepping factor in relaxation.
     */
    double relax_factor;

    /**
     * Multiplier for the time stepping factor.
     */
    double relax_multiplier;

    Timer timer_correct;
} EGSRelaxContext;

typedef struct EGSExactContext {
    Timer   timer_mat_construct;
    Timer   timer_bicgstab;
    int64_t bicgstab_iterations;
    Timer   timer_lu_solve;
    Timer   timer_export;
} EGSExactContext;

typedef struct EGSContext {
    EGSRelaxContext *relax;
    EGSExactContext *exact;

    /**
     * Solver private data, not to be accessed in any way by the caller.
     */
    EGSInternal *priv;

    /**
     * The logging context, to be filled by the caller before mg2di_egs_init().
     */
    MG2DLogger logger;

    /**
     * The thread pool used for execution. May be set by the caller before
     * mg2di_egs_init().
     */
    TPContext *tp;

    /**
     * Flags indicating supported CPU features.
     */
    int cpuflags;

    /**
     * Size of the solver grid, set by mg2di_egs_alloc[_mpi](). For
     * multi-component runs, this contains the size of this component only.
     * Read-only for the caller.
     */
    size_t domain_size[2];

    /**
     * Distance between the neighbouring grid points.
     * Must be set by the caller before mg2di_egs_init().
     */
    double step[2];

    /**
     * Order of the finite difference operators used for approximating derivatives.
     * Must be set by the caller before mg2di_egs_init().
     */
    size_t fd_stencil;

    /**
     * Boundary specification, indexed by MG2DBoundaryLoc.
     * To be filled by the caller before mg2di_egs_init().
     *
     * For multi-component runs, only the outer (not inter-component) boundary
     * specifications are accessed by the solver.
     */
    MG2DBoundary *boundaries[4];

    /**
     * Values of the unknown function.
     *
     * Allocated by the solver in mg2di_egs_alloc(), owned by the solver.
     * Must be filled by the caller before mg2di_egs_init() to set the
     * initial guess.
     * Afterwards updated in mg2di_egs_step().
     */
    NDArray *u;

    /**
     * u including the outer boundary ghost zones.
     */
    NDArray *u_exterior;

    /**
     * Values of the right-hand side.
     *
     * Allocated by the solver in mg2di_egs_alloc(), owned by the solver.
     * Must be filled by the caller before mg2di_egs_init().
     */
    NDArray *rhs;

    /**
     * Values of the residual.
     *
     * Allocated by the solver in mg2di_egs_alloc(), owned by the solver.
     * Read-only for the caller. Initialized after mg2di_egs_init(),
     * afterwards updated in mg2di_egs_step().
     */
    NDArray *residual;

    /**
     * Maximum of the absolute value of residual.
     */
    double residual_max;

    /**
     * Coefficients C_{*} that define the differential equation.
     *
     * Allocated by the solver in mg2di_egs_alloc(), owned by the solver.
     * Must be filled by the caller before mg2di_egs_init().
     */
    NDArray *diff_coeffs[MG2D_DIFF_COEFF_NB];

    /* timings */
    Timer   timer_bnd;
    Timer   timer_bnd_fixval;
    Timer   timer_bnd_falloff;
    Timer   timer_bnd_reflect;
    Timer   timer_bnd_corners;
    Timer   timer_res_calc;
    Timer   timer_init;
    Timer   timer_solve;
    Timer   timer_mpi_sync;
} EGSContext;

#define EGS_INIT_FLAG_SAME_DIFF_COEFFS (1 << 0)

/**
 * Allocate the solver for the given domain size.
 *
 * @param domain_size number of grid points in each direction.
 *
 * @return The solver context on success, NULL on failure.
 */
EGSContext *mg2di_egs_alloc(const size_t domain_size[2]);

/**
 * Allocate a solver component in a multi-component MPI-based solve.
 *
 * @param comm The MPI communicator used to communicate with the other
 *             components
 * @dg The geometry of the full computational domain. This component is indexed
 *     by its MPI rank in this geometry.
 *
 * @return The solver context on success, NULL on failure.
 */
EGSContext *mg2di_egs_alloc_mpi(MPI_Comm comm, const DomainGeometry *dg);
/**
 * Initialize the solver for use, after all the required fields are filled by
 * the caller.
 *
 * This function may be called multiple times to re-initialize the solver after
 * certain parameters (e.g. the right-hand side or the equation coefficients)
 * change.
 *
 * @return 0 on success, a negative error code on failure.
 */
int              mg2di_egs_init(EGSContext *ctx, int flags);
/**
 * Free the solver and write NULL to the provided pointer.
 */
void             mg2di_egs_free(EGSContext **ctx);

/**
 * Solve the equation defined by the data filled in the context. Precise
 * semantics depends on the solver type.
 *
 * @return 0 on success, a negative error code on failure.
 */
int mg2di_egs_solve(EGSContext *ctx, enum EGSType solve_type, int export_res);

#endif /* MG2D_ELL_GRID_SOLVE_H */