summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2022-08-15 16:21:08 +0200
committerAnton Khirnov <anton@khirnov.net>2022-08-15 16:21:08 +0200
commit163c031d84003b055a4c541338e98f61741098ad (patch)
tree863326426a2193f78dae2f8e0a9c83710cb753a2
parent859b256f6fc99829753ad41cbef8c7c12fb178db (diff)
doublenull: switch from odeint to newer solve_ivp
Use its events feature to terminate integration on grid boundaries.
-rw-r--r--doublenull.py57
1 files changed, 33 insertions, 24 deletions
diff --git a/doublenull.py b/doublenull.py
index 01264cf..a3f5caf 100644
--- a/doublenull.py
+++ b/doublenull.py
@@ -1,6 +1,6 @@
import numpy as np
import scipy.interpolate as interp
-from scipy.integrate import odeint
+from scipy.integrate import solve_ivp
def calc_null_curves(times, spatial_coords, gXX, gXt, gtt):
"""
@@ -29,11 +29,24 @@ def calc_null_curves(times, spatial_coords, gXX, gXt, gtt):
gXt_interp = interp.RectBivariateSpline(times, spatial_coords, gXt)
gtt_interp = interp.RectBivariateSpline(times, spatial_coords, gtt)
- def _dXdt(coord, t, sign,
- coord_min = spatial_coords[0], coord_max = spatial_coords[-1],
- gXX = gXX_interp, gtt = gtt_interp, gXt = gXt_interp):
- if coord <= coord_min or coord >= coord_max:
- return float('nan')
+
+ # terminate integration on reaching the outer boundaries
+ def event_x_bound_upper(t, x, sign):
+ return x[0] - spatial_coords[-1]
+ event_x_bound_upper.terminal = True
+ event_x_bound_upper.direction = 1.0
+
+ def event_x_bound_lower(t, x, sign):
+ return x[0] - spatial_coords[0]
+ event_x_bound_lower.terminal = True
+ event_x_bound_lower.direction = -1.0
+
+ events = [event_x_bound_upper, event_x_bound_lower]
+
+ # null geodesic equation RHS:
+ # gXX dX^2 + 2 gXt dX dt - gtt dt^2 = 0
+ # => dx / dt = (-gXt ± √(gXt^2 - gXX gtt)) / gXX
+ def dXdt(t, coord, sign, gXX = gXX_interp, gtt = gtt_interp, gXt = gXt_interp):
gXt_val = gXt(t, coord)
gXX_val = gXX(t, coord)
gtt_val = gtt(t, coord)
@@ -43,24 +56,20 @@ def calc_null_curves(times, spatial_coords, gXX, gXt, gtt):
rays_neg = np.empty_like(rays_pos)
for j, X0 in enumerate(spatial_coords):
for tgt, sign in ((rays_pos, 1.0), (rays_neg, -1.0)):
- ray = odeint(_dXdt, X0, times, (sign,))[:, 0]
-
- # clip the rays to the integration area
- k = 0
- while np.isnan(ray[k]):
- k += 1
- while k > 0:
- ray[k - 1] = ray[k]
- k -= 1
-
- k = ray.shape[0] - 1
- while np.isnan(ray[k]):
- k -= 1
- while k < ray.shape[0] - 1:
- ray[k + 1] = ray[k]
- k += 1
-
- tgt[j] = ray
+ ret = solve_ivp(dXdt, (times[0], times[-1]), (X0,),
+ method = 'RK45', t_eval = times, args = (sign,),
+ dense_output = True, events = events, rtol = 1e-6, atol = 1e-8)
+
+ t, x = ret.t, ret.y[0]
+
+ if len(t) < len(times):
+ x_ext = np.empty_like(times)
+ x_ext[:x.shape[0]] = x
+ x_ext[x.shape[0]:] = x[-1] + sign * (times[x.shape[0]:] - t[-1])
+
+ x = x_ext
+
+ tgt[j] = x
return (rays_pos, rays_neg)