summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2014-10-21 16:41:42 +0200
committerAnton Khirnov <anton@khirnov.net>2014-10-21 16:41:42 +0200
commitfb540f6436d1116a00b01c62b77b2a21a439cfbc (patch)
tree18b1a763605b87c66b2fb9efb2e094a29f3e79e1
Initial commit.
-rw-r--r--__init__.py0
-rw-r--r--basis.py55
-rw-r--r--series_expansion.py65
3 files changed, 120 insertions, 0 deletions
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/__init__.py
diff --git a/basis.py b/basis.py
new file mode 100644
index 0000000..759ebfe
--- /dev/null
+++ b/basis.py
@@ -0,0 +1,55 @@
+import abc
+import math
+import numpy as np
+
+class ExpansionBasis1D(object):
+ """
+ A family of one-dimensional functions that make up a basis.
+ """
+
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractmethod
+ def eval(self, idx, coord, diff_order = 0):
+ """
+ Evaluate the diff_order-th derivative of the idx-th basis function at
+ the specified coordinates (where zeroth derivative means evaluating the
+ function itself).
+ """
+ pass
+
+ @abc.abstractmethod
+ def colloc_grid(self, order):
+ """
+ Get the coordinates of the optimal collocation grid of the specified
+ order (i.e. exactly order coordinates will be returned).
+ """
+ pass
+
+class CosBasis(ExpansionBasis1D):
+ PARITY_NONE = 0
+ PARITY_EVEN = 1
+ PARITY_ODD = 2
+
+ _diff_fact = [(1, np.cos), (-1, np.sin), (-1, np.cos), (1, np.sin)]
+ _parity = None
+
+ def __init__(self, parity = PARITY_NONE):
+ self._parity = parity
+
+ def eval(self, idx, coord, diff_order = 0):
+ fact, f = self._diff_fact[diff_order % 4]
+
+ if self._parity == self.PARITY_EVEN:
+ idx *= 2
+ elif self._parity == self.PARITY_ODD:
+ idx = 2 * idx + 1
+
+ fact *= idx ** diff_order
+ return fact * f(idx * coord)
+
+ def colloc_grid(self, order):
+ if self._parity == self.PARITY_NONE:
+ return (np.array(range(0, order)) + 1) * 2 * np.pi / (order + 1)
+ else:
+ return (np.array(range(0, order)) + 1) * np.pi / (2 * (order + 1))
diff --git a/series_expansion.py b/series_expansion.py
new file mode 100644
index 0000000..9cd7671
--- /dev/null
+++ b/series_expansion.py
@@ -0,0 +1,65 @@
+import numpy as np
+
+class SeriesExpansion(object):
+ """
+ An N-dimensional function approximated as a series expansion in some chosen
+ basis
+ """
+
+ _dim = None
+ _coeffs = None
+ _basis = None
+
+ def __init__(self, coeffs, basis):
+ self._dim = coeffs.ndim
+ self._coeffs = coeffs
+
+ # for 1D, allow passing just the basis object itself
+ # without being wrapped in an iterable
+ if self._dim == 1:
+ try:
+ basis[0]
+ except TypeError:
+ basis = [basis]
+
+ if len(basis) != coeffs.ndim:
+ raise ValueError('Mismatching number of coefficient and basis functions')
+
+ self._basis = basis
+
+ @property
+ def coeffs(self):
+ return self._coeffs
+
+ @property
+ def basis(self):
+ return self._basis
+
+ def eval(self, coords, diff_order = None):
+ # for 1D, allow passing just the plain array of coords
+ # without being wrapped in an iterable
+ if self._dim == 1:
+ try:
+ coords[0][0]
+ except TypeError:
+ coords = [coords]
+
+ if diff_order is None:
+ diff_order = [0] * len(coords)
+
+ shape = [len(c) for c in coords]
+ ret = np.zeros(shape)
+
+ basis_vals = []
+ for i, (b, c, d) in enumerate(zip(self._basis, coords, diff_order)):
+ val = []
+ for idx in xrange(self._coeffs.shape[i]):
+ val.append(b.eval(idx, c, d))
+ basis_vals.append(val)
+
+ if self._dim == 1:
+ for c, val in zip(self._coeffs, basis_vals[0]):
+ ret += val * c
+ return ret
+ else:
+ raise NotImplementedError('Unsupported number of dimensions')