summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2024-01-15 15:53:55 +0100
committerAnton Khirnov <anton@khirnov.net>2024-01-15 15:53:55 +0100
commitca0a119c693ee4105738e87ddcf4769c0b2f83fd (patch)
treed98a9fe888ee949bb3eebed7e814b41e49c45503
parent29aa09f7a08ae5fc44fb6b99e3eca1d9c7143e21 (diff)
doublenull: returns solutions as OdeSolution rather than X(t) array
Similar to what is done in null, allows working with results more conveniently.
-rw-r--r--doublenull.py62
-rw-r--r--test/test_doublenull.npzbin103064 -> 103064 bytes
2 files changed, 29 insertions, 33 deletions
diff --git a/doublenull.py b/doublenull.py
index 1892c79..c0f79ab 100644
--- a/doublenull.py
+++ b/doublenull.py
@@ -2,7 +2,7 @@ from enum import Enum, auto
import numpy as np
from scipy.interpolate import RectBivariateSpline, interp1d
-from scipy.integrate import solve_ivp
+from scipy.integrate import solve_ivp, OdeSolution
from . import interp
@@ -33,48 +33,39 @@ def _events_bnd(spatial_coords):
return [event_x_bound_upper, event_x_bound_lower]
def _kernel_time(times, origin, sign, gXX, gXt, gtt, events):
- idx = np.where(times == origin)[0][0]
+ dt = times[1] - times[0]
- t_fwd = times[idx:]
- t_back = times[:idx + 1][::-1]
-
- ray_fwd = [0.0]
- if len(t_fwd) > 1:
- sol = solve_ivp(_photon_dXdt, (t_fwd[0], t_fwd[-1]), (0,),
- method = 'RK45', t_eval = t_fwd,
+ sol_fwd = None
+ if times[-1] - origin > dt:
+ sol = solve_ivp(_photon_dXdt, (origin, times[-1]), (0,),
+ method = 'RK45',
args = (sign, gXX, gtt, gXt),
dense_output = True, events = events,
rtol = 1e-6, atol = 1e-8)
- ray_fwd = sol.y[0]
- ray_fwd = np.concatenate((ray_fwd, [ray_fwd[-1]] * (len(t_fwd) - len(ray_fwd))))
+ sol_fwd = sol.sol
- ray_back = []
- if len(t_back) > 1:
- sol = solve_ivp(_photon_dXdt, (t_back[0], t_back[-1]), (0,),
- method = 'RK45', t_eval = t_back,
+ sol_back = None
+ if origin - times[0] > dt:
+ sol = solve_ivp(_photon_dXdt, (origin, times[0]), (0,),
+ method = 'RK45',
args = (sign, gXX, gtt, gXt),
dense_output = True, events = events,
rtol = 1e-6, atol = 1e-8)
- ray_back = sol.y[0]
- ray_back = np.concatenate((ray_back, [ray_back[-1]] * (len(t_back) - len(ray_back))))
+ sol_back = sol.sol
- return np.concatenate((ray_back[::-1][:-1], ray_fwd))
+ if sol_fwd is not None and sol_back is not None:
+ # combine forward and backward solutions
+ return OdeSolution(np.concatenate((sol_back.ts[::-1][:-1], sol_fwd.ts)),
+ sol_back.interpolants[::-1] + sol_fwd.interpolants)
+ elif sol_fwd is not None:
+ return sol_fwd
+ return sol_back
def _kernel_space(times, origin, sign, gXX, gXt, gtt, events):
sol = solve_ivp(_photon_dXdt, (times[0], times[-1]), (origin,),
- method = 'RK45', t_eval = times, args = (sign, gXX, gtt, gXt),
+ method = 'RK45', args = (sign, gXX, gtt, gXt),
dense_output = True, events = events, rtol = 1e-6, atol = 1e-8)
-
- t, x = sol.t, sol.y[0]
-
- if len(t) < len(times):
- x_ext = np.empty_like(times)
- x_ext[:x.shape[0]] = x
- x_ext[x.shape[0]:] = x[-1] + sign * (times[x.shape[0]:] - t[-1])
-
- x = x_ext
-
- return x
+ return sol.sol
class Curves(Enum):
SPATIAL_FORWARD = auto()
@@ -138,9 +129,14 @@ def null_curves(times, spatial_coords, gXX, gXt, gtt,
rays_neg = np.empty_like(rays_pos)
for tgt, sign in ((rays_pos, 1.0), (rays_neg, -1.0)):
for i, origin in enumerate(origins):
- tgt[i] = kernel(ray_times, origin, sign,
- gXX_interp, gXt_interp, gtt_interp,
- events)
+ sol = kernel(ray_times, origin, sign,
+ gXX_interp, gXt_interp, gtt_interp,
+ events)
+ tgt[i] = sol(ray_times)
+
+ # do not extrapolate beyond solution range
+ tgt[i, ray_times < sol.t_min] = np.nan
+ tgt[i, ray_times > sol.t_max] = np.nan
return (ray_times, rays_pos, rays_neg)
diff --git a/test/test_doublenull.npz b/test/test_doublenull.npz
index c668973..82aa642 100644
--- a/test/test_doublenull.npz
+++ b/test/test_doublenull.npz
Binary files differ