summaryrefslogtreecommitdiff
path: root/series_expansion.py
diff options
context:
space:
mode:
Diffstat (limited to 'series_expansion.py')
-rw-r--r--series_expansion.py65
1 files changed, 65 insertions, 0 deletions
diff --git a/series_expansion.py b/series_expansion.py
new file mode 100644
index 0000000..9cd7671
--- /dev/null
+++ b/series_expansion.py
@@ -0,0 +1,65 @@
+import numpy as np
+
+class SeriesExpansion(object):
+ """
+ An N-dimensional function approximated as a series expansion in some chosen
+ basis
+ """
+
+ _dim = None
+ _coeffs = None
+ _basis = None
+
+ def __init__(self, coeffs, basis):
+ self._dim = coeffs.ndim
+ self._coeffs = coeffs
+
+ # for 1D, allow passing just the basis object itself
+ # without being wrapped in an iterable
+ if self._dim == 1:
+ try:
+ basis[0]
+ except TypeError:
+ basis = [basis]
+
+ if len(basis) != coeffs.ndim:
+ raise ValueError('Mismatching number of coefficient and basis functions')
+
+ self._basis = basis
+
+ @property
+ def coeffs(self):
+ return self._coeffs
+
+ @property
+ def basis(self):
+ return self._basis
+
+ def eval(self, coords, diff_order = None):
+ # for 1D, allow passing just the plain array of coords
+ # without being wrapped in an iterable
+ if self._dim == 1:
+ try:
+ coords[0][0]
+ except TypeError:
+ coords = [coords]
+
+ if diff_order is None:
+ diff_order = [0] * len(coords)
+
+ shape = [len(c) for c in coords]
+ ret = np.zeros(shape)
+
+ basis_vals = []
+ for i, (b, c, d) in enumerate(zip(self._basis, coords, diff_order)):
+ val = []
+ for idx in xrange(self._coeffs.shape[i]):
+ val.append(b.eval(idx, c, d))
+ basis_vals.append(val)
+
+ if self._dim == 1:
+ for c, val in zip(self._coeffs, basis_vals[0]):
+ ret += val * c
+ return ret
+ else:
+ raise NotImplementedError('Unsupported number of dimensions')