summaryrefslogtreecommitdiff
path: root/array_utils.py
blob: e16dca1a2633b3e45134753ff327637c70e3cd4d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np

def matrix_invert(mat):
    """
    Pointwise matrix inversion.

    Given an array that stores many square matrices (e.g. the values of the
    metric tensor at spacetime points), compute the array of corresponding
    inverses.

    :param array_like mat: N-D array with N>2 and each mat[i0, ..., :, :] a
                           square matrix to be inverted.
    :return: Array (same shape as mat) of inverses.
    """
    oldshape = mat.shape
    newshape = oldshape[:2] + (np.product(mat.shape[2:]),)

    mat = mat.reshape(newshape).transpose((2, 0, 1))
    inv = np.linalg.inv(mat)
    return inv.transpose((1, 2, 0)).reshape(oldshape)

def matrix_det(mat):
    """
    Pointwise matrix determinant.

    Given an array that stores many square matrices (e.g. the values of the
    metric tensor at spacetime points), compute the array of corresponding
    determinants.

    :param array_like mat: N-D array with N>2 and each mat[i0, ..., :, :] a
                           square matrix.
    :return: Array of determinants.
    """
    oldshape = mat.shape
    newshape = oldshape[:2] + (np.product(mat.shape[2:]),)

    mat = mat.reshape(newshape).transpose((2, 0, 1))
    return np.linalg.det(mat).reshape(oldshape[2:])

def array_reflect(data, parity = 1.0, axis = -1):
    """
    Reflect an N-D array with respect to the specified axis. E.g. input
    [0, 1, 2, 3] becomes [3, 2, 1, 0, 1, 2, 3] with parity=1.0 and
    [-3, -2, -1, 0, 1, 2, 3] with parity=-1.0.

    :param array_like data: The array to reflect.
    :param float parity: The reflected portion is multiplied by this factor,
                         typically 1.0 or -1.0.
    :param int axis: Index of the axis to reflect along.
    :return: Reflected array.
    """
    slices0 = [slice(None) for _ in data.shape]
    slices1 = [slice(None) for _ in data.shape]
    slices0[axis] = slice(None, None, -1)
    slices1[axis] = slice(1, None)

    return np.concatenate((parity * data[tuple(slices0)], data[tuple(slices1)]), axis)