summaryrefslogtreecommitdiff
path: root/doublenull.py
blob: e90349e313e3e249a12b87325aee549f9600fe85 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
from enum import Enum, auto

import numpy as np
from scipy.interpolate import interp1d
from scipy.integrate import solve_ivp, OdeSolution

from . import interp

def _photon_dXdt(t, coord, sign, gXX, gtt, gXt):
    """
    Null curve equation RHS (not a geodesic - not affinely parametrized).

    gXX dX^2 + 2 gXt dX dt - gtt dt^2 = 0
     => dx / dt = (-gXt ± √(gXt^2 - gXX gtt)) / gXX
    """
    gXt_val = gXt(t, coord)
    gXX_val = gXX(t, coord)
    gtt_val = gtt(t, coord)
    return ((-gXt_val + sign * np.sqrt((gXt_val ** 2) - gtt_val * gXX_val)) / gXX_val).flatten()[0]

# terminate integration on reaching the outer boundaries
def _events_bnd(X_span):
    def event_x_bound_upper(t, x, sign, *args):
        return x[0] - X_span[1]
    event_x_bound_upper.terminal  = True
    event_x_bound_upper.direction = 1.0

    def event_x_bound_lower(t, x, sign, *args):
        return x[0] - X_span[0]
    event_x_bound_lower.terminal  = True
    event_x_bound_lower.direction = -1.0

    return [event_x_bound_upper, event_x_bound_lower]

def _kernel_time(t_span, dt, origin, sign, gXX, gXt, gtt, events):
    sol_fwd = None
    if t_span[1] - origin > dt:
        sol = solve_ivp(_photon_dXdt, (origin, t_span[1]), (0,),
                        method = 'RK45',
                        args = (sign, gXX, gtt, gXt),
                        dense_output = True, events = events,
                        rtol = 1e-6, atol = 1e-8)
        sol_fwd = sol.sol

    sol_back = None
    if origin - t_span[0] > dt:
        sol = solve_ivp(_photon_dXdt, (origin, t_span[0]), (0,),
                            method = 'RK45',
                            args = (sign, gXX, gtt, gXt),
                            dense_output = True, events = events,
                            rtol = 1e-6, atol = 1e-8)
        sol_back = sol.sol

    if sol_fwd is not None and sol_back is not None:
        # combine forward and backward solutions
        return OdeSolution(np.concatenate((sol_back.ts[::-1][:-1], sol_fwd.ts)),
                           sol_back.interpolants[::-1] + sol_fwd.interpolants)
    elif sol_fwd is not None:
        return sol_fwd
    return sol_back

def _kernel_space(t_span, dt, origin, sign, gXX, gXt, gtt, events):
    sol = solve_ivp(_photon_dXdt, t_span, (origin,),
                    method = 'RK45', args = (sign, gXX, gtt, gXt),
                    dense_output = True, events = events, rtol = 1e-6, atol = 1e-8)
    return sol.sol

class Curves(Enum):
    SPATIAL_FORWARD = auto()
    "photons are traced forward in time from (t=times[0], X=spatial_coords[i])"

    SPATIAL_BACK    = auto()
    "photons are traced backward in time from (t=times[-1], X=spatial_coords[i])"

    TEMPORAL        = auto()
    """
    photons are traced both forward and backward in time from
     (t=times[i], X=0); the forward and backward segment are combined
     to form the resulting curve
    """

def _null_curves(t_span, dt, X_span, dX, origins, gXX, gXt, gtt,
                 kind, interp_order):
    gXX_interp = interp.Interp2D_C([min(t_span), X_span[0]], [dt, dX], gXX, interp_order)
    gXt_interp = interp.Interp2D_C([min(t_span), X_span[0]], [dt, dX], gXt, interp_order) \
                 if gXt is not None else lambda t, X: 0.0
    gtt_interp = interp.Interp2D_C([min(t_span), X_span[0]], [dt, dX], gtt, interp_order)

    kernel = _kernel_time if kind == Curves.TEMPORAL else _kernel_space
    events = _events_bnd(X_span)

    rays_pos = []
    rays_neg = []
    for tgt, sign in ((rays_pos, 1.0), (rays_neg, -1.0)):
        for i, origin in enumerate(origins):
            tgt.append(kernel(t_span, dt, origin, sign,
                              gXX_interp, gXt_interp, gtt_interp,
                              events))

    return (rays_pos, rays_neg)

def null_curves(times, spatial_coords, gXX, gXt, gtt,
                kind = Curves.SPATIAL_FORWARD, interp_order = 6):
    """
    Compute null curves along a given axis.

    Shoot a null ray from each point in spatial_coords, in the positive and
    negative spatial direction and compute its trajectory.

    :param array_like times: 1D array of coordinate times at which the spacetime
                             curvature is provided
    :param array_like spatial_coords: 1D array of spatial coordinates
    :param array_like gXX: 2D array containing the values of the XX component
                           of the spacetime metric, where X is the spatial
                           coordinate along which the rays are traced.
                           gXX[i, j] is the value at spacetime point
                           (t=times[i], X=spatial_coords[j]).
    :param array_like gXt: same as gXX, but for the Xt component of the metric
    :param array_like gtt: same as gXX, but for the tt component of the metric
    :param Curves kind:    specifies where to integrate the curves from
    :param int interp_order: Order of interpolation used for metric quantities.

    :return: Tuple of (ray_times, rays_pos, rays_neg). rays_*[i, j] contains the
             X-coordinate of the ray shot from (t=ray_times[0],
             X=spatial_coords[i]) at time t=ray_times[j].
    """
    origins   = times       if kind == Curves.TEMPORAL     else spatial_coords
    ray_times = times[::-1] if kind == Curves.SPATIAL_BACK else times

    t_span = [ray_times[0], ray_times[-1]]
    X_span = [spatial_coords[0], spatial_coords[-1]]

    dt = times[1]          - times[0]
    dX = spatial_coords[1] - spatial_coords[0]

    # integrate the null curves, as lists of OdeSolution
    pos, neg = _null_curves(t_span, dt, X_span, dX, origins, gXX, gXt, gtt,
                            kind, interp_order)

    # evaluate the null curves on the provided coordinates
    rays_pos = np.empty((origins.shape[0], times.shape[0]))
    rays_neg = np.empty_like(rays_pos)
    for tgt, curves in ((rays_pos, pos), (rays_neg, neg)):
        for i, c in enumerate(curves):
            tgt[i] = c(ray_times)

            # do not extrapolate beyond solution range
            tgt[i, ray_times < c.t_min] = np.nan
            tgt[i, ray_times > c.t_max] = np.nan

    return (ray_times, rays_pos, rays_neg)

def null_coordinates(times, spatial_coords, u_rays, v_rays,
                     gXX, gXt, gtt, kind = Curves.SPATIAL_FORWARD):
    """
    Compute double-null coordinates (u, v) as functions of
    position and time.

    :param array_like times: 1D array of coordinate times at which the spacetime
                             curvature is provided
    :param array_like spatial_coords: 1D array of spatial coordinates
    :param array_like u_rays: 1D array assigning the values of u on the initial
                              time slice. u_rays[i] is the value of u at
                              X=spatial_coords[i].
    :param array_like v_rays: same as u_rays, but for v.
    :param array_like gXX: 2D array containing the values of the XX component
                           of the spacetime metric, where X is the spatial
                           coordinate along which the rays are traced.
                           gXX[i, j] is the value at spacetime point
                           (t=times[i], X=spatial_coords[j]).
    :param array_like gXt: same as gXX, but for the Xt component of the metric
    :param array_like gtt: same as gXX, but for the tt component of the metric
    :return: tuple containing two 2D arrays with, respectively, values of u and
             v as functions of t and X. u/v[i, j] is the value of u/v at
             t=times[i], X=spatial_coords[j].
    """
    _, X_of_ut, X_of_vt = null_curves(times, spatial_coords, gXX, gXt, gtt, kind = kind)

    u_of_tx = np.empty((times.shape[0], spatial_coords.shape[0]))
    v_of_tx = np.empty_like(u_of_tx)
    for i, t in enumerate(times):
        Xu = X_of_ut[:, i]
        Xv = X_of_vt[:, i]
        u_of_tx[i] = interp1d(Xu, u_rays, bounds_error = False)(spatial_coords)
        v_of_tx[i] = interp1d(Xv, v_rays, bounds_error = False)(spatial_coords)

    return (u_of_tx, v_of_tx)

def null_coordinates_inv(t, X, uv, gXX, gXt, gtt, *,
                         uv_times = None,
                         interp_order = 6):
    """
    Compute values of (t, X) on a uniform grid in double-null
    coordinates (U, V).

    :param array_like t: 1D array of coordinate times at which the spacetime
                         curvature is provided
    :param array_like X: 1D array of spatial coordinates
    :param array_like uv: 1D array of values of U/V values on the symmetry axis.
                          Every value coresponds to the appropriate element in
                          uv_times, i.e. uv[i] = U(t = uv_times[i], X = 0).
    :param array_like gXX: 2D array containing the values of the XX component
                           of the spacetime metric, where X is the spatial
                           coordinate along which the rays are traced.
                           gXX[i, j] is the value at spacetime point
                           (t=times[i], X=spatial_coords[j]).
    :param array_like gXt: same as gXX, but for the Xt component of the metric
    :param array_like gtt: same as gXX, but for the tt component of the metric
    :param array_like uv_times: 1D array of coordinate times corresponding to
                                uv; may be None, in which case t is used instead

    :return: tuple containing two 2D arrays with, respectively, values of T and
             X as functions of U and V. Tuv[i, j] is the value of T at
             U=uv[i], V=uv[j].
    """
    t_span = [t[0], t[-1]]
    dt     = t[1] - t[0]

    if np.any(np.abs((t[1:] - t[:-1]) - dt) > 1e-6 * dt):
        raise ValueError("Non-uniform t-grid")

    X_span = [X[0], X[-1]]
    dX     = X[1] - X[0]
    if np.any(np.abs((X[1:] - X[:-1]) - dX) > 1e-6 * dX):
        raise ValueError("Non-uniform X-grid")

    uv_times = t if uv_times is None else uv_times
    if uv.shape != uv_times.shape:
        raise ValueError("Shape of uv must match uv_times: %s != %s" % (uv.shape, uv_times.shape))

    pos, neg = _null_curves(t_span, dt, X_span, dX, uv_times,
                            gXX, gXt, gtt,
                            Curves.TEMPORAL, interp_order)

    # interpolator for V(t, X)
    # FIXME: integrates the curves again at different times
    # (uniform in U/V vs uniform in X/t)
    if uv_times is t:
        uv_t = uv
    else:
        uv_t = interp1d(uv_times, uv, bounds_error = False)(t)
    _, vtx_vals = null_coordinates(t, X, uv_t, uv_t, gXX, gXt, gtt,
                                   kind = Curves.TEMPORAL)
    Vtx = interp.Interp2D_C([t[0], X[0]], [dt, dX], vtx_vals, interp_order)

    Xuv = np.empty(uv.shape * 2)
    Tuv = np.empty_like(Xuv)
    for i in range(uv.shape[0]):
        # X(t) at constant U=uv[i]
        xt_u = pos[i]

        # uniform t grid along the curve
        t_uniform = np.linspace(xt_u.t_min, xt_u.t_max, 2 * uv.shape[0])

        # values of V(t, X(t)) at this t-grid along the curve
        v_vals = Vtx(t_uniform, xt_u(t_uniform))

        # finally invert V(t) into t(V) on a uniform V-grid
        Tuv[i] = interp1d(v_vals, t_uniform, bounds_error = False)(uv)
        Xuv[i] = xt_u(Tuv[i])

    return Tuv, Xuv