summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--array_utils.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/array_utils.py b/array_utils.py
new file mode 100644
index 0000000..e16dca1
--- /dev/null
+++ b/array_utils.py
@@ -0,0 +1,57 @@
+import numpy as np
+
+def matrix_invert(mat):
+ """
+ Pointwise matrix inversion.
+
+ Given an array that stores many square matrices (e.g. the values of the
+ metric tensor at spacetime points), compute the array of corresponding
+ inverses.
+
+ :param array_like mat: N-D array with N>2 and each mat[i0, ..., :, :] a
+ square matrix to be inverted.
+ :return: Array (same shape as mat) of inverses.
+ """
+ oldshape = mat.shape
+ newshape = oldshape[:2] + (np.product(mat.shape[2:]),)
+
+ mat = mat.reshape(newshape).transpose((2, 0, 1))
+ inv = np.linalg.inv(mat)
+ return inv.transpose((1, 2, 0)).reshape(oldshape)
+
+def matrix_det(mat):
+ """
+ Pointwise matrix determinant.
+
+ Given an array that stores many square matrices (e.g. the values of the
+ metric tensor at spacetime points), compute the array of corresponding
+ determinants.
+
+ :param array_like mat: N-D array with N>2 and each mat[i0, ..., :, :] a
+ square matrix.
+ :return: Array of determinants.
+ """
+ oldshape = mat.shape
+ newshape = oldshape[:2] + (np.product(mat.shape[2:]),)
+
+ mat = mat.reshape(newshape).transpose((2, 0, 1))
+ return np.linalg.det(mat).reshape(oldshape[2:])
+
+def array_reflect(data, parity = 1.0, axis = -1):
+ """
+ Reflect an N-D array with respect to the specified axis. E.g. input
+ [0, 1, 2, 3] becomes [3, 2, 1, 0, 1, 2, 3] with parity=1.0 and
+ [-3, -2, -1, 0, 1, 2, 3] with parity=-1.0.
+
+ :param array_like data: The array to reflect.
+ :param float parity: The reflected portion is multiplied by this factor,
+ typically 1.0 or -1.0.
+ :param int axis: Index of the axis to reflect along.
+ :return: Reflected array.
+ """
+ slices0 = [slice(None) for _ in data.shape]
+ slices1 = [slice(None) for _ in data.shape]
+ slices0[axis] = slice(None, None, -1)
+ slices1[axis] = slice(1, None)
+
+ return np.concatenate((parity * data[tuple(slices0)], data[tuple(slices1)]), axis)