summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnton Khirnov <anton@khirnov.net>2024-01-15 22:52:34 +0100
committerAnton Khirnov <anton@khirnov.net>2024-01-15 22:52:34 +0100
commit19b3a9b733e6b0636a3041d26a8cb9c03cb1103d (patch)
tree0d408ba11cefa422fa5faffe60ae65dc27c9924c
parentca0a119c693ee4105738e87ddcf4769c0b2f83fd (diff)
doublenull: split off integration into a separate function
-rw-r--r--doublenull.py78
1 files changed, 44 insertions, 34 deletions
diff --git a/doublenull.py b/doublenull.py
index c0f79ab..6fe84ef 100644
--- a/doublenull.py
+++ b/doublenull.py
@@ -19,25 +19,23 @@ def _photon_dXdt(t, coord, sign, gXX, gtt, gXt):
return ((-gXt_val + sign * np.sqrt((gXt_val ** 2) - gtt_val * gXX_val)) / gXX_val).flatten()[0]
# terminate integration on reaching the outer boundaries
-def _events_bnd(spatial_coords):
+def _events_bnd(X_span):
def event_x_bound_upper(t, x, sign, *args):
- return x[0] - spatial_coords[-1]
+ return x[0] - X_span[1]
event_x_bound_upper.terminal = True
event_x_bound_upper.direction = 1.0
def event_x_bound_lower(t, x, sign, *args):
- return x[0] - spatial_coords[0]
+ return x[0] - X_span[0]
event_x_bound_lower.terminal = True
event_x_bound_lower.direction = -1.0
return [event_x_bound_upper, event_x_bound_lower]
-def _kernel_time(times, origin, sign, gXX, gXt, gtt, events):
- dt = times[1] - times[0]
-
+def _kernel_time(t_span, dt, origin, sign, gXX, gXt, gtt, events):
sol_fwd = None
- if times[-1] - origin > dt:
- sol = solve_ivp(_photon_dXdt, (origin, times[-1]), (0,),
+ if t_span[1] - origin > dt:
+ sol = solve_ivp(_photon_dXdt, (origin, t_span[1]), (0,),
method = 'RK45',
args = (sign, gXX, gtt, gXt),
dense_output = True, events = events,
@@ -45,8 +43,8 @@ def _kernel_time(times, origin, sign, gXX, gXt, gtt, events):
sol_fwd = sol.sol
sol_back = None
- if origin - times[0] > dt:
- sol = solve_ivp(_photon_dXdt, (origin, times[0]), (0,),
+ if origin - t_span[0] > dt:
+ sol = solve_ivp(_photon_dXdt, (origin, t_span[0]), (0,),
method = 'RK45',
args = (sign, gXX, gtt, gXt),
dense_output = True, events = events,
@@ -61,8 +59,8 @@ def _kernel_time(times, origin, sign, gXX, gXt, gtt, events):
return sol_fwd
return sol_back
-def _kernel_space(times, origin, sign, gXX, gXt, gtt, events):
- sol = solve_ivp(_photon_dXdt, (times[0], times[-1]), (origin,),
+def _kernel_space(t_span, dt, origin, sign, gXX, gXt, gtt, events):
+ sol = solve_ivp(_photon_dXdt, t_span, (origin,),
method = 'RK45', args = (sign, gXX, gtt, gXt),
dense_output = True, events = events, rtol = 1e-6, atol = 1e-8)
return sol.sol
@@ -81,6 +79,26 @@ class Curves(Enum):
to form the resulting curve
"""
+def _null_curves(t_span, dt, X_span, dX, origins, gXX, gXt, gtt,
+ kind, interp_order):
+ gXX_interp = interp.Interp2D_C([min(t_span), X_span[0]], [dt, dX], gXX, interp_order)
+ gXt_interp = interp.Interp2D_C([min(t_span), X_span[0]], [dt, dX], gXt, interp_order) \
+ if gXt is not None else lambda t, X: 0.0
+ gtt_interp = interp.Interp2D_C([min(t_span), X_span[0]], [dt, dX], gtt, interp_order)
+
+ kernel = _kernel_time if kind == Curves.TEMPORAL else _kernel_space
+ events = _events_bnd(X_span)
+
+ rays_pos = []
+ rays_neg = []
+ for tgt, sign in ((rays_pos, 1.0), (rays_neg, -1.0)):
+ for i, origin in enumerate(origins):
+ tgt.append(kernel(t_span, dt, origin, sign,
+ gXX_interp, gXt_interp, gtt_interp,
+ events))
+
+ return (rays_pos, rays_neg)
+
def null_curves(times, spatial_coords, gXX, gXt, gtt,
kind = Curves.SPATIAL_FORWARD, interp_order = 6):
"""
@@ -106,37 +124,29 @@ def null_curves(times, spatial_coords, gXX, gXt, gtt,
X-coordinate of the ray shot from (t=ray_times[0],
X=spatial_coords[i]) at time t=ray_times[j].
"""
+ origins = times if kind == Curves.TEMPORAL else spatial_coords
+ ray_times = times[::-1] if kind == Curves.SPATIAL_BACK else times
+
+ t_span = [ray_times[0], ray_times[-1]]
+ X_span = [spatial_coords[0], spatial_coords[-1]]
dt = times[1] - times[0]
dX = spatial_coords[1] - spatial_coords[0]
- gXX_interp = interp.Interp2D_C([times[0], spatial_coords[0]], [dt, dX], gXX, interp_order)
- gXt_interp = interp.Interp2D_C([times[0], spatial_coords[0]], [dt, dX], gXt, interp_order) \
- if gXt is not None else lambda t, X: 0.0
- gtt_interp = interp.Interp2D_C([times[0], spatial_coords[0]], [dt, dX], gtt, interp_order)
-
- if kind == Curves.TEMPORAL:
- kernel = _kernel_time
- origins = times
- else:
- kernel = _kernel_space
- origins = spatial_coords
-
- ray_times = times[::-1] if kind == Curves.SPATIAL_BACK else times
- events = _events_bnd(spatial_coords)
+ # integrate the null curves, as lists of OdeSolution
+ pos, neg = _null_curves(t_span, dt, X_span, dX, origins, gXX, gXt, gtt,
+ kind, interp_order)
+ # evaluate the null curves on the provided coordinates
rays_pos = np.empty((origins.shape[0], times.shape[0]))
rays_neg = np.empty_like(rays_pos)
- for tgt, sign in ((rays_pos, 1.0), (rays_neg, -1.0)):
- for i, origin in enumerate(origins):
- sol = kernel(ray_times, origin, sign,
- gXX_interp, gXt_interp, gtt_interp,
- events)
- tgt[i] = sol(ray_times)
+ for tgt, curves in ((rays_pos, pos), (rays_neg, neg)):
+ for i, c in enumerate(curves):
+ tgt[i] = c(ray_times)
# do not extrapolate beyond solution range
- tgt[i, ray_times < sol.t_min] = np.nan
- tgt[i, ray_times > sol.t_max] = np.nan
+ tgt[i, ray_times < c.t_min] = np.nan
+ tgt[i, ray_times > c.t_max] = np.nan
return (ray_times, rays_pos, rays_neg)