From ca0a119c693ee4105738e87ddcf4769c0b2f83fd Mon Sep 17 00:00:00 2001 From: Anton Khirnov Date: Mon, 15 Jan 2024 15:53:55 +0100 Subject: doublenull: returns solutions as OdeSolution rather than X(t) array Similar to what is done in null, allows working with results more conveniently. --- doublenull.py | 62 ++++++++++++++++++++++------------------------- test/test_doublenull.npz | Bin 103064 -> 103064 bytes 2 files changed, 29 insertions(+), 33 deletions(-) diff --git a/doublenull.py b/doublenull.py index 1892c79..c0f79ab 100644 --- a/doublenull.py +++ b/doublenull.py @@ -2,7 +2,7 @@ from enum import Enum, auto import numpy as np from scipy.interpolate import RectBivariateSpline, interp1d -from scipy.integrate import solve_ivp +from scipy.integrate import solve_ivp, OdeSolution from . import interp @@ -33,48 +33,39 @@ def _events_bnd(spatial_coords): return [event_x_bound_upper, event_x_bound_lower] def _kernel_time(times, origin, sign, gXX, gXt, gtt, events): - idx = np.where(times == origin)[0][0] + dt = times[1] - times[0] - t_fwd = times[idx:] - t_back = times[:idx + 1][::-1] - - ray_fwd = [0.0] - if len(t_fwd) > 1: - sol = solve_ivp(_photon_dXdt, (t_fwd[0], t_fwd[-1]), (0,), - method = 'RK45', t_eval = t_fwd, + sol_fwd = None + if times[-1] - origin > dt: + sol = solve_ivp(_photon_dXdt, (origin, times[-1]), (0,), + method = 'RK45', args = (sign, gXX, gtt, gXt), dense_output = True, events = events, rtol = 1e-6, atol = 1e-8) - ray_fwd = sol.y[0] - ray_fwd = np.concatenate((ray_fwd, [ray_fwd[-1]] * (len(t_fwd) - len(ray_fwd)))) + sol_fwd = sol.sol - ray_back = [] - if len(t_back) > 1: - sol = solve_ivp(_photon_dXdt, (t_back[0], t_back[-1]), (0,), - method = 'RK45', t_eval = t_back, + sol_back = None + if origin - times[0] > dt: + sol = solve_ivp(_photon_dXdt, (origin, times[0]), (0,), + method = 'RK45', args = (sign, gXX, gtt, gXt), dense_output = True, events = events, rtol = 1e-6, atol = 1e-8) - ray_back = sol.y[0] - ray_back = np.concatenate((ray_back, [ray_back[-1]] * (len(t_back) - len(ray_back)))) + sol_back = sol.sol - return np.concatenate((ray_back[::-1][:-1], ray_fwd)) + if sol_fwd is not None and sol_back is not None: + # combine forward and backward solutions + return OdeSolution(np.concatenate((sol_back.ts[::-1][:-1], sol_fwd.ts)), + sol_back.interpolants[::-1] + sol_fwd.interpolants) + elif sol_fwd is not None: + return sol_fwd + return sol_back def _kernel_space(times, origin, sign, gXX, gXt, gtt, events): sol = solve_ivp(_photon_dXdt, (times[0], times[-1]), (origin,), - method = 'RK45', t_eval = times, args = (sign, gXX, gtt, gXt), + method = 'RK45', args = (sign, gXX, gtt, gXt), dense_output = True, events = events, rtol = 1e-6, atol = 1e-8) - - t, x = sol.t, sol.y[0] - - if len(t) < len(times): - x_ext = np.empty_like(times) - x_ext[:x.shape[0]] = x - x_ext[x.shape[0]:] = x[-1] + sign * (times[x.shape[0]:] - t[-1]) - - x = x_ext - - return x + return sol.sol class Curves(Enum): SPATIAL_FORWARD = auto() @@ -138,9 +129,14 @@ def null_curves(times, spatial_coords, gXX, gXt, gtt, rays_neg = np.empty_like(rays_pos) for tgt, sign in ((rays_pos, 1.0), (rays_neg, -1.0)): for i, origin in enumerate(origins): - tgt[i] = kernel(ray_times, origin, sign, - gXX_interp, gXt_interp, gtt_interp, - events) + sol = kernel(ray_times, origin, sign, + gXX_interp, gXt_interp, gtt_interp, + events) + tgt[i] = sol(ray_times) + + # do not extrapolate beyond solution range + tgt[i, ray_times < sol.t_min] = np.nan + tgt[i, ray_times > sol.t_max] = np.nan return (ray_times, rays_pos, rays_neg) diff --git a/test/test_doublenull.npz b/test/test_doublenull.npz index c668973..82aa642 100644 Binary files a/test/test_doublenull.npz and b/test/test_doublenull.npz differ -- cgit v1.2.3