aboutsummaryrefslogtreecommitdiff
path: root/brill_data.py
blob: 460b0d29ff539a2b3dd9432ed20ebcbc8b6bdc1d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#
# Copyright 2014 Anton Khirnov <anton@khirnov.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#


import brill_data
import ctypes
import numpy as np


class BrillData(object):

    coeffs = None

    _libbd = None
    _bdctx = None

    class _BDContext(ctypes.Structure):
     _fields_ = [("priv",                   ctypes.c_void_p),
                 ("q_func_type",            ctypes.c_int),
                 ("amplitude",              ctypes.c_double),
                 ("eppley_n",               ctypes.c_uint),
                 ("nb_coeffs_rho",          ctypes.c_uint),
                 ("nb_coeffs_z",            ctypes.c_uint),
                 ("overdet_rho",            ctypes.c_int),
                 ("overdet_z",              ctypes.c_int),
                 ("colloc_grid_offset_rho", ctypes.c_uint),
                 ("colloc_grid_offset_z",   ctypes.c_uint),
                 ("basis_scale_factor_rho", ctypes.c_double),
                 ("basis_scale_factor_z",   ctypes.c_double),
                 ("psi_minus1_coeffs",      ctypes.POINTER(ctypes.c_double)),
                 ("stride",                 ctypes.c_long)]

    def __init__(self, **kwargs):
        self._libbd = ctypes.CDLL('libbrilldata.so')
        self._bdctx = ctypes.cast(self._libbd.bd_context_alloc(), ctypes.POINTER(self._BDContext))
        for arg, value in kwargs.iteritems():
            self._bdctx.contents.__setattr__(arg, value)

        ret = self._libbd.bd_solve(self._bdctx)
        if ret < 0:
            raise RuntimeError('Error solving the equation')

        self.coeffs = np.copy(np.ctypeslib.as_array(self._bdctx.contents.psi_minus1_coeffs, (self._bdctx.contents.nb_coeffs_rho, self._bdctx.contents.nb_coeffs_z)))

    def __del__(self):
        if self._bdctx:
            addr_bdctx = ctypes.c_void_p(ctypes.addressof(self._bdctx))
            self._libbd.bd_context_free(addr_bdctx)
            self._bdctx = None

    def eval_psi(self, rho, z, diff_order = None):
        if rho.ndim != 1 or z.ndim != 1:
            raise TypeError('rho and z must be 1-dimensional NumPy arrays')

        if diff_order is None:
            diff_order = [0, 0]

        c_diff_order = (ctypes.c_uint * 2)()
        c_diff_order[0] = diff_order[0]
        c_diff_order[1] = diff_order[1]

        psi = np.empty((rho.shape[0], z.shape[0]))

        c_psi = np.ctypeslib.as_ctypes(psi)
        c_rho = np.ctypeslib.as_ctypes(rho)
        c_z   = np.ctypeslib.as_ctypes(z)

        ret = self._libbd.bd_eval_psi(self._bdctx, c_rho, len(c_rho),
                                      c_z, len(c_z), c_diff_order, c_psi)
        if ret < 0:
            raise RuntimeError('Error evaluating psi')

        return psi

    def eval_metric(self, rho, z, component, diff_order = None):
        if rho.ndim != 1 or z.ndim != 1:
            raise TypeError('rho and z must be 1-dimensional NumPy arrays')

        if diff_order is None:
            diff_order = [0, 0]

        c_diff_order = (ctypes.c_uint * 2)()
        c_diff_order[0] = diff_order[0]
        c_diff_order[1] = diff_order[1]

        c_component = (ctypes.c_uint * 2)()
        c_component[0] = component[0]
        c_component[1] = component[1]

        metric = np.empty((rho.shape[0], z.shape[0]))

        c_metric = np.ctypeslib.as_ctypes(metric)
        c_rho    = np.ctypeslib.as_ctypes(rho)
        c_z      = np.ctypeslib.as_ctypes(z)

        ret = self._libbd.bd_eval_metric(self._bdctx, c_rho, len(c_rho),
                                         c_z, len(c_z), c_component, c_diff_order, c_metric)
        if ret < 0:
            raise RuntimeError('Error evaluating the metric')

        return metric