diff options
Diffstat (limited to 'series_expansion.py')
-rw-r--r-- | series_expansion.py | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/series_expansion.py b/series_expansion.py new file mode 100644 index 0000000..9cd7671 --- /dev/null +++ b/series_expansion.py @@ -0,0 +1,65 @@ +import numpy as np + +class SeriesExpansion(object): + """ + An N-dimensional function approximated as a series expansion in some chosen + basis + """ + + _dim = None + _coeffs = None + _basis = None + + def __init__(self, coeffs, basis): + self._dim = coeffs.ndim + self._coeffs = coeffs + + # for 1D, allow passing just the basis object itself + # without being wrapped in an iterable + if self._dim == 1: + try: + basis[0] + except TypeError: + basis = [basis] + + if len(basis) != coeffs.ndim: + raise ValueError('Mismatching number of coefficient and basis functions') + + self._basis = basis + + @property + def coeffs(self): + return self._coeffs + + @property + def basis(self): + return self._basis + + def eval(self, coords, diff_order = None): + # for 1D, allow passing just the plain array of coords + # without being wrapped in an iterable + if self._dim == 1: + try: + coords[0][0] + except TypeError: + coords = [coords] + + if diff_order is None: + diff_order = [0] * len(coords) + + shape = [len(c) for c in coords] + ret = np.zeros(shape) + + basis_vals = [] + for i, (b, c, d) in enumerate(zip(self._basis, coords, diff_order)): + val = [] + for idx in xrange(self._coeffs.shape[i]): + val.append(b.eval(idx, c, d)) + basis_vals.append(val) + + if self._dim == 1: + for c, val in zip(self._coeffs, basis_vals[0]): + ret += val * c + return ret + else: + raise NotImplementedError('Unsupported number of dimensions') |