diff options
author | Anton Khirnov <anton@khirnov.net> | 2022-08-19 15:18:43 +0200 |
---|---|---|
committer | Anton Khirnov <anton@khirnov.net> | 2022-08-19 15:18:43 +0200 |
commit | 0ef7a659a13b4d357aa29f9741258eb28df2b6f4 (patch) | |
tree | 3b7cfa79f4c2bd7b03cad0e569ace84192391dfd | |
parent | 208a3e985b9c82d23fa7c50059fd545b6a1a67f5 (diff) |
Contains array utility functions moved from nr_analysis_axi package:
- matrix_invert()
- matrix_det()
- array_reflect()
-rw-r--r-- | array_utils.py | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/array_utils.py b/array_utils.py new file mode 100644 index 0000000..e16dca1 --- /dev/null +++ b/array_utils.py @@ -0,0 +1,57 @@ +import numpy as np + +def matrix_invert(mat): + """ + Pointwise matrix inversion. + + Given an array that stores many square matrices (e.g. the values of the + metric tensor at spacetime points), compute the array of corresponding + inverses. + + :param array_like mat: N-D array with N>2 and each mat[i0, ..., :, :] a + square matrix to be inverted. + :return: Array (same shape as mat) of inverses. + """ + oldshape = mat.shape + newshape = oldshape[:2] + (np.product(mat.shape[2:]),) + + mat = mat.reshape(newshape).transpose((2, 0, 1)) + inv = np.linalg.inv(mat) + return inv.transpose((1, 2, 0)).reshape(oldshape) + +def matrix_det(mat): + """ + Pointwise matrix determinant. + + Given an array that stores many square matrices (e.g. the values of the + metric tensor at spacetime points), compute the array of corresponding + determinants. + + :param array_like mat: N-D array with N>2 and each mat[i0, ..., :, :] a + square matrix. + :return: Array of determinants. + """ + oldshape = mat.shape + newshape = oldshape[:2] + (np.product(mat.shape[2:]),) + + mat = mat.reshape(newshape).transpose((2, 0, 1)) + return np.linalg.det(mat).reshape(oldshape[2:]) + +def array_reflect(data, parity = 1.0, axis = -1): + """ + Reflect an N-D array with respect to the specified axis. E.g. input + [0, 1, 2, 3] becomes [3, 2, 1, 0, 1, 2, 3] with parity=1.0 and + [-3, -2, -1, 0, 1, 2, 3] with parity=-1.0. + + :param array_like data: The array to reflect. + :param float parity: The reflected portion is multiplied by this factor, + typically 1.0 or -1.0. + :param int axis: Index of the axis to reflect along. + :return: Reflected array. + """ + slices0 = [slice(None) for _ in data.shape] + slices1 = [slice(None) for _ in data.shape] + slices0[axis] = slice(None, None, -1) + slices1[axis] = slice(1, None) + + return np.concatenate((parity * data[tuple(slices0)], data[tuple(slices1)]), axis) |