From f0460fe03fff8c5c567756149e39ff25a7663b1c Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Thu, 1 Oct 2015 12:10:55 +0200 Subject: Prepare for radial basis. --- brill_data.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) (limited to 'brill_data.py') diff --git a/brill_data.py b/brill_data.py index 398a8d8..9bfe25d 100644 --- a/brill_data.py +++ b/brill_data.py @@ -39,17 +39,13 @@ class BrillData(object): class _BDContext(ctypes.Structure): _fields_ = [("priv", ctypes.c_void_p), + ("log_callback", ctypes.c_void_p), + ("opaque", ctypes.c_void_p), ("q_func_type", ctypes.c_int), ("amplitude", ctypes.c_double), ("eppley_n", ctypes.c_uint), - ("nb_coeffs_rho", ctypes.c_uint), - ("nb_coeffs_z", ctypes.c_uint), - ("overdet_rho", ctypes.c_int), - ("overdet_z", ctypes.c_int), - ("basis_scale_factor_rho", ctypes.c_double), - ("basis_scale_factor_z", ctypes.c_double), - ("log_callback", ctypes.c_void_p), - ("opaque", ctypes.c_void_p), + ("nb_coeffs", ctypes.c_uint * 2), + ("basis_scale_factor", ctypes.c_double * 2), ("psi_minus1_coeffs", ctypes.POINTER(ctypes.c_double)), ("stride", ctypes.c_long)] @@ -57,13 +53,21 @@ class BrillData(object): self._libbd = ctypes.CDLL('libbrilldata.so') self._bdctx = ctypes.cast(self._libbd.bd_context_alloc(), ctypes.POINTER(self._BDContext)) for arg, value in kwargs.iteritems(): - self._bdctx.contents.__setattr__(arg, value) + try: + self._bdctx.contents.__setattr__(arg, value) + except TypeError as e: + # try assigning items of an iterable + try: + for i, it in enumerate(value): + self._bdctx.contents.__getattribute__(arg)[i] = it + except: + raise e ret = self._libbd.bd_solve(self._bdctx) if ret < 0: raise RuntimeError('Error solving the equation') - self.coeffs = np.copy(np.ctypeslib.as_array(self._bdctx.contents.psi_minus1_coeffs, (self._bdctx.contents.nb_coeffs_z, self._bdctx.contents.nb_coeffs_rho))) + self.coeffs = np.copy(np.ctypeslib.as_array(self._bdctx.contents.psi_minus1_coeffs, (self._bdctx.contents.nb_coeffs[1], self._bdctx.contents.nb_coeffs[0]))) def __del__(self): if self._bdctx: -- cgit v1.2.3