summaryrefslogtreecommitdiff
path: root/series_expansion.py
blob: ca06b40ad11a3fc17ede3e3f0274bc7087ea8d7f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np

class SeriesExpansion(object):
    """
    An N-dimensional function approximated as a series expansion in some chosen
    basis
    """

    _dim    = None
    _coeffs = None
    _basis  = None

    def __init__(self, coeffs, basis):
        self._dim    = coeffs.ndim
        self._coeffs = coeffs

        # for 1D, allow passing just the basis object itself
        # without being wrapped in an iterable
        if self._dim == 1:
            try:
                basis[0]
            except (TypeError, IndexError):
                basis = [basis]

        if len(basis) != coeffs.ndim:
            raise ValueError('Mismatching number of coefficient and basis functions')

        self._basis  = basis

    @property
    def coeffs(self):
        return self._coeffs

    @property
    def basis(self):
        return self._basis

    def eval(self, coords, diff_order = None):
        # for 1D, allow passing just the plain array of coords
        # without being wrapped in an iterable
        if self._dim == 1:
            try:
                coords[0][0]
            except (TypeError, IndexError):
                coords = [coords]

            if diff_order is not None:
                try:
                    diff_order[0]
                except (TypeError, IndexError):
                    diff_order = [diff_order]

        if diff_order is None:
            diff_order = [0] * len(coords)

        shape = [len(c) for c in coords]
        ret = np.zeros(shape)

        basis_vals = []
        for i, (b, c, d) in enumerate(zip(self._basis, coords, diff_order)):
            val = []
            for idx in range(self._coeffs.shape[i]):
                val.append(b.eval(idx, c, d))
            basis_vals.append(val)

        if self._dim == 1:
            for c, val in zip(self._coeffs, basis_vals[0]):
                ret += val * c
            return ret
        elif self._dim == 2:
            for i in range(self._coeffs.shape[0]):
                for j in range(self._coeffs.shape[1]):
                    ret += self._coeffs[i, j] * np.outer(basis_vals[0][i], basis_vals[1][j])
            return ret
        else:
            raise NotImplementedError('Unsupported number of dimensions')