aboutsummaryrefslogtreecommitdiff
path: root/teukolsky_data.py
diff options
context:
space:
mode:
Diffstat (limited to 'teukolsky_data.py')
-rw-r--r--teukolsky_data.py128
1 files changed, 128 insertions, 0 deletions
diff --git a/teukolsky_data.py b/teukolsky_data.py
new file mode 100644
index 0000000..5bc9364
--- /dev/null
+++ b/teukolsky_data.py
@@ -0,0 +1,128 @@
+#
+# Copyright 2014 Anton Khirnov <anton@khirnov.net>
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+#
+
+
+import ctypes
+import numpy as np
+
+class TeukolskyData(object):
+
+ coeffs = None
+
+ _libtd = None
+ _tdctx = None
+
+ class _TDContext(ctypes.Structure):
+ _fields_ = [("priv", ctypes.c_void_p),
+ ("log_callback", ctypes.c_void_p),
+ ("opaque", ctypes.c_void_p),
+ ("amplitude", ctypes.c_double),
+ ("nb_coeffs", ctypes.c_uint * 2),
+ ("basis_scale_factor", ctypes.c_double * 2),
+ ("max_iter", ctypes.c_uint),
+ ("atol", ctypes.c_double),
+ ("nb_threads", ctypes.c_uint),
+ ("coeffs", ctypes.POINTER(ctypes.c_double) * 3),
+ ]
+
+ def __init__(self, **kwargs):
+ self._libtd = ctypes.CDLL('libteukolskydata.so')
+ tdctx_alloc = self._libtd.td_context_alloc
+ tdctx_alloc.restype = ctypes.POINTER(self._TDContext)
+ self._tdctx = self._libtd.td_context_alloc()
+
+ coeffs_init = ctypes.c_void_p()
+
+ for arg, value in kwargs.iteritems():
+ if arg == 'coeffs_init':
+ coeffs_init = (ctypes.POINTER(ctypes.c_double) * 3)()
+ coeffs_init[0] = ctypes.cast(np.ctypeslib.as_ctypes(value[0]), ctypes.POINTER(ctypes.c_double))
+ coeffs_init[1] = ctypes.cast(np.ctypeslib.as_ctypes(value[1]), ctypes.POINTER(ctypes.c_double))
+ coeffs_init[2] = ctypes.cast(np.ctypeslib.as_ctypes(value[2]), ctypes.POINTER(ctypes.c_double))
+ continue
+
+ try:
+ self._tdctx.contents.__setattr__(arg, value)
+ except TypeError as e:
+ # try assigning items of an iterable
+ try:
+ for i, it in enumerate(value):
+ self._tdctx.contents.__getattribute__(arg)[i] = it
+ except:
+ raise e
+
+ ret = self._libtd.td_solve(self._tdctx, coeffs_init)
+ if ret < 0:
+ raise RuntimeError('Error solving the equation')
+
+ self.coeffs = [None] * 3
+ for i in xrange(3):
+ self.coeffs[i] = np.copy(np.ctypeslib.as_array(self._tdctx.contents.coeffs[i], (self._tdctx.contents.nb_coeffs[1], self._tdctx.contents.nb_coeffs[0])))
+
+ def __del__(self):
+ if self._tdctx:
+ addr_tdctx = ctypes.c_void_p(ctypes.addressof(self._tdctx))
+ self._libtd.td_context_free(addr_tdctx)
+ self._tdctx = None
+
+ def _eval_var(self, eval_func, r, theta, diff_order = None):
+ if diff_order is None:
+ diff_order = [0, 0]
+
+ c_diff_order = (ctypes.c_uint * 2)()
+ c_diff_order[0] = diff_order[0]
+ c_diff_order[1] = diff_order[1]
+
+ if r.ndim == 2:
+ if r.shape != theta.shape:
+ raise TypeError('r and theta must be identically-shaped 2-dimensional arrays')
+ R, Theta = r.view(), theta.view()
+ elif r.ndim == 1:
+ if theta.ndim != 1:
+ raise TypeError('r and theta must both be 1-dimensional NumPy arrays')
+ R, Theta = np.meshgrid(r, theta)
+ else:
+ raise TypeError('invalid r/theta parameters')
+
+ out = np.empty(R.shape[0] * R.shape[1])
+
+ R.shape = out.shape
+ Theta.shape = out.shape
+
+ c_out = np.ctypeslib.as_ctypes(out)
+ c_r = np.ctypeslib.as_ctypes(R)
+ c_theta = np.ctypeslib.as_ctypes(Theta)
+
+ ret = eval_func(self._tdctx, out.shape[0], c_r, c_theta,
+ c_diff_order, c_out, ctypes.c_long(r.shape[0]))
+ if ret < 0:
+ raise RuntimeError('Error evaluating the variable')
+
+ out.shape = (theta.shape[0], r.shape[0])
+ return out
+
+ def eval_psi(self, r, theta, diff_order = None):
+ return self._eval_var(self._libtd.td_eval_psi, r, theta, diff_order)
+ def eval_krr(self, r, theta, diff_order = None):
+ return self._eval_var(self._libtd.td_eval_krr, r, theta, diff_order)
+ def eval_kpp(self, r, theta, diff_order = None):
+ return self._eval_var(self._libtd.td_eval_kpp, r, theta, diff_order)
+
+
+ @property
+ def amplitude(self):
+ return self._tdctx.contents.__getattribute__('amplitude')