From 19b3a9b733e6b0636a3041d26a8cb9c03cb1103d Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Mon, 15 Jan 2024 22:52:34 +0100 Subject: doublenull: split off integration into a separate function --- doublenull.py | 78 +++++++++++++++++++++++++++++++++-------------------------- 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/doublenull.py b/doublenull.py index c0f79ab..6fe84ef 100644 --- a/doublenull.py +++ b/doublenull.py @@ -19,25 +19,23 @@ def _photon_dXdt(t, coord, sign, gXX, gtt, gXt): return ((-gXt_val + sign * np.sqrt((gXt_val ** 2) - gtt_val * gXX_val)) / gXX_val).flatten()[0] # terminate integration on reaching the outer boundaries -def _events_bnd(spatial_coords): +def _events_bnd(X_span): def event_x_bound_upper(t, x, sign, *args): - return x[0] - spatial_coords[-1] + return x[0] - X_span[1] event_x_bound_upper.terminal = True event_x_bound_upper.direction = 1.0 def event_x_bound_lower(t, x, sign, *args): - return x[0] - spatial_coords[0] + return x[0] - X_span[0] event_x_bound_lower.terminal = True event_x_bound_lower.direction = -1.0 return [event_x_bound_upper, event_x_bound_lower] -def _kernel_time(times, origin, sign, gXX, gXt, gtt, events): - dt = times[1] - times[0] - +def _kernel_time(t_span, dt, origin, sign, gXX, gXt, gtt, events): sol_fwd = None - if times[-1] - origin > dt: - sol = solve_ivp(_photon_dXdt, (origin, times[-1]), (0,), + if t_span[1] - origin > dt: + sol = solve_ivp(_photon_dXdt, (origin, t_span[1]), (0,), method = 'RK45', args = (sign, gXX, gtt, gXt), dense_output = True, events = events, @@ -45,8 +43,8 @@ def _kernel_time(times, origin, sign, gXX, gXt, gtt, events): sol_fwd = sol.sol sol_back = None - if origin - times[0] > dt: - sol = solve_ivp(_photon_dXdt, (origin, times[0]), (0,), + if origin - t_span[0] > dt: + sol = solve_ivp(_photon_dXdt, (origin, t_span[0]), (0,), method = 'RK45', args = (sign, gXX, gtt, gXt), dense_output = True, events = events, @@ -61,8 +59,8 @@ def _kernel_time(times, origin, sign, gXX, gXt, gtt, events): return sol_fwd return sol_back -def _kernel_space(times, origin, sign, gXX, gXt, gtt, events): - sol = solve_ivp(_photon_dXdt, (times[0], times[-1]), (origin,), +def _kernel_space(t_span, dt, origin, sign, gXX, gXt, gtt, events): + sol = solve_ivp(_photon_dXdt, t_span, (origin,), method = 'RK45', args = (sign, gXX, gtt, gXt), dense_output = True, events = events, rtol = 1e-6, atol = 1e-8) return sol.sol @@ -81,6 +79,26 @@ class Curves(Enum): to form the resulting curve """ +def _null_curves(t_span, dt, X_span, dX, origins, gXX, gXt, gtt, + kind, interp_order): + gXX_interp = interp.Interp2D_C([min(t_span), X_span[0]], [dt, dX], gXX, interp_order) + gXt_interp = interp.Interp2D_C([min(t_span), X_span[0]], [dt, dX], gXt, interp_order) \ + if gXt is not None else lambda t, X: 0.0 + gtt_interp = interp.Interp2D_C([min(t_span), X_span[0]], [dt, dX], gtt, interp_order) + + kernel = _kernel_time if kind == Curves.TEMPORAL else _kernel_space + events = _events_bnd(X_span) + + rays_pos = [] + rays_neg = [] + for tgt, sign in ((rays_pos, 1.0), (rays_neg, -1.0)): + for i, origin in enumerate(origins): + tgt.append(kernel(t_span, dt, origin, sign, + gXX_interp, gXt_interp, gtt_interp, + events)) + + return (rays_pos, rays_neg) + def null_curves(times, spatial_coords, gXX, gXt, gtt, kind = Curves.SPATIAL_FORWARD, interp_order = 6): """ @@ -106,37 +124,29 @@ def null_curves(times, spatial_coords, gXX, gXt, gtt, X-coordinate of the ray shot from (t=ray_times[0], X=spatial_coords[i]) at time t=ray_times[j]. """ + origins = times if kind == Curves.TEMPORAL else spatial_coords + ray_times = times[::-1] if kind == Curves.SPATIAL_BACK else times + + t_span = [ray_times[0], ray_times[-1]] + X_span = [spatial_coords[0], spatial_coords[-1]] dt = times[1] - times[0] dX = spatial_coords[1] - spatial_coords[0] - gXX_interp = interp.Interp2D_C([times[0], spatial_coords[0]], [dt, dX], gXX, interp_order) - gXt_interp = interp.Interp2D_C([times[0], spatial_coords[0]], [dt, dX], gXt, interp_order) \ - if gXt is not None else lambda t, X: 0.0 - gtt_interp = interp.Interp2D_C([times[0], spatial_coords[0]], [dt, dX], gtt, interp_order) - - if kind == Curves.TEMPORAL: - kernel = _kernel_time - origins = times - else: - kernel = _kernel_space - origins = spatial_coords - - ray_times = times[::-1] if kind == Curves.SPATIAL_BACK else times - events = _events_bnd(spatial_coords) + # integrate the null curves, as lists of OdeSolution + pos, neg = _null_curves(t_span, dt, X_span, dX, origins, gXX, gXt, gtt, + kind, interp_order) + # evaluate the null curves on the provided coordinates rays_pos = np.empty((origins.shape[0], times.shape[0])) rays_neg = np.empty_like(rays_pos) - for tgt, sign in ((rays_pos, 1.0), (rays_neg, -1.0)): - for i, origin in enumerate(origins): - sol = kernel(ray_times, origin, sign, - gXX_interp, gXt_interp, gtt_interp, - events) - tgt[i] = sol(ray_times) + for tgt, curves in ((rays_pos, pos), (rays_neg, neg)): + for i, c in enumerate(curves): + tgt[i] = c(ray_times) # do not extrapolate beyond solution range - tgt[i, ray_times < sol.t_min] = np.nan - tgt[i, ray_times > sol.t_max] = np.nan + tgt[i, ray_times < c.t_min] = np.nan + tgt[i, ray_times > c.t_max] = np.nan return (ray_times, rays_pos, rays_neg) -- cgit v1.2.3