Skip to content

beamax.gb.gb_solvers

ODE integrators for Gaussian beam state evolution.

Key Objects

  • SolverConfig: tolerances/method controls for ODE integration.
  • Solver functions such as solve_ODE_base and TR-specific variants for batched or per-beam time handling.

API Reference

beamax.gb.gb_solvers

SolverFn

Bases: Protocol

Protocol for Gaussian beam ODE integrators.

Implementations integrate beam initial data and return beam positions, momenta, Hessians, and amplitudes over time.

Notes

Expected signature is (x0, p0, M0, a0, mode, ts, c, *args, **kwargs) returning (xt, pt, Mt, At). Standard shapes are:

  • x0, p0: (b, d)
  • M0: (b, d, d)
  • a0, mode: (b,)
  • ts: (Nt,)
  • xt, pt: (b, Nt, d)
  • Mt: (b, Nt, d, d)
  • At: (b, Nt, 1)
SolverConfig(solver: diffrax.AbstractSolver = diffrax.Tsit5(), max_steps: int = 4096, rtol: float = 1e-07, atol: float = 1e-09, pcoeff: float = 0.0, icoeff: float = 1.0, dcoeff: float = 0.0, dt0: float | None = None) dataclass

Configuration for ODE solver settings.

from_precision(use_x64: Optional[bool] = None, solver: Optional[diffrax.AbstractSolver] = None, **overrides) classmethod

Create config with precision-appropriate tolerances.

Parameters:

Name Type Description Default
use_x64 bool | None

If None, auto-detects from jax.config.x64_enabled

None
solver AbstractSolver | None

Override default solver (Tsit5)

None
**overrides

Override any other config fields (max_steps, rtol, etc.)

{}

Examples:

Auto-detect precision, use defaults

config = SolverConfig.from_precision()

Force float32 tolerances, but increase max_steps

config = SolverConfig.from_precision(use_x64=False, max_steps=8192)

Auto precision, custom solver and tolerances

config = SolverConfig.from_precision( solver=diffrax.Dopri5(), rtol=1e-5, max_steps=10000 )

Everything custom

config = SolverConfig.from_precision( use_x64=False, solver=diffrax.Dopri8(), max_steps=2048, rtol=1e-3, atol=1e-5, pcoeff=0.3 )

create_p_perp(p0: jnp.ndarray, normp_sq: jnp.ndarray, eye: jnp.ndarray) -> jnp.ndarray

Projector perpendicular to p0.

Parameters:

Name Type Description Default
p0 (ndarray, shape(..., d, 1))
required
normp_sq (ndarray, shape(..., 1, 1))
required
eye (ndarray, shape(..., d, d))
required

Returns:

Type Description
(ndarray, shape(..., d, d))

I - p pᵀ / ||p||².

compute_amp_hom_gen(p0: jnp.ndarray, m0: jnp.ndarray, c0: jnp.ndarray, ts: jnp.ndarray, a0: jnp.ndarray) -> jnp.ndarray

Amplitude for general homogeneous GB (no diagonal assumption).

Parameters:

Name Type Description Default
p0 (b, d)
required
m0 (b, d, d)
required
c0 (b, 1)
required
ts (Nt,)
required
a0 (b,)
required

Returns:

Type Description
(ndarray, shape(b, Nt, 1))

a(t) = a0 / sqrt(det(I + c0 t P_perp M0 / ||p||)).

compute_m_hom_gen(p0: jnp.ndarray, m0: jnp.ndarray, c0: jnp.ndarray, ts: jnp.ndarray) -> jnp.ndarray

M(t) for general homogeneous GB.

Parameters:

Name Type Description Default
p0 (b, d)
required
m0 (b, d, d)
required
c0 (b, 1)
required
ts (Nt,)
required

Returns:

Type Description
(ndarray, shape(b, Nt, d, d))

M(t) = M0 (I + c0 t P_perp M0 / ||p||)^(-1).

compute_amp_hom_diag_2d(p0: jnp.ndarray, normp: jnp.ndarray, alpha0: jnp.ndarray, c0: jnp.ndarray, ts: jnp.ndarray, a0: jnp.ndarray) -> jnp.ndarray

Amplitude for diagonal M0 in 2D anisotropy.

Parameters:

Name Type Description Default
p0 (b, 2)
required
normp (b, 1)
required
alpha0 (b, 2)
required
c0 (b, 1)
required
ts (Nt,)
required
a0 (b,)
required

Returns:

Type Description
(ndarray, shape(b, Nt))

a(t) = a0 / (1 + c0 t * / ||p||³)^{(d-1)/2}.

compute_amp_hom_diag_3d(p0: jnp.ndarray, normp: jnp.ndarray, alpha0: jnp.ndarray, c0: jnp.ndarray, ts: jnp.ndarray, a0: jnp.ndarray) -> jnp.ndarray

Amplitude for diagonal M0 in 3D anisotropy.

Parameters:

Name Type Description Default
p0 (b, 3)
required
normp as above
required
alpha0 as above
required
c0 as above
required
ts as above
required
a0 as above
required

Returns:

Type Description
(ndarray, shape(b, Nt))

Closed-form 3D expression combining axis terms (see source).

compute_amp_hom_diag(p0: jnp.ndarray, normp: jnp.ndarray, alpha0: jnp.ndarray, c0: jnp.ndarray, ts: jnp.ndarray, a0: jnp.ndarray) -> jnp.ndarray

Dispatch amplitude formula for diagonal M0 (2D/3D).

Parameters:

Name Type Description Default
p0 (ndarray, shape(b, d))

Initial momenta.

required
normp (ndarray, shape(b, 1))

Momentum norms.

required
alpha0 (ndarray, shape(b, d))

Diagonal entries of M0.

required
c0 (ndarray, shape(b, 1))

Signed homogeneous sound speed.

required
ts (ndarray, shape(Nt))

Time grid.

required
a0 (ndarray, shape(b))

Initial amplitudes.

required

Returns:

Type Description
(ndarray, shape(b, Nt))
compute_m_hom_diag(p0: jnp.ndarray, normp: jnp.ndarray, alpha0: jnp.ndarray, c0: jnp.ndarray, ts: jnp.ndarray) -> jnp.ndarray

M(t) with diagonal M0 via Sherman–Morrison.

Parameters:

Name Type Description Default
p0 (b, d)
required
normp (b, 1)
required
alpha0 (b, d)
required
c0 (b, 1)
required
ts (Nt,)
required

Returns:

Type Description
(ndarray, shape(b, Nt, d, d))

N0 (A + u vᵀ)^(-1) with diagonal A, rank-1 update from ray direction.

solve_hom_diag(x0: jnp.ndarray, p0: jnp.ndarray, M0: jnp.ndarray, a0: jnp.ndarray, mode: jnp.ndarray, ts: jnp.ndarray, c: Callable, lam=None, config=None) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]

Solver for homogeneous media with simplified equations.

Parameters:

Name Type Description Default
x0 (ndarray, shape(b, d))

Initial beam positions.

required
p0 (ndarray, shape(b, d))

Initial momenta.

required
M0 (ndarray, shape(b, d, d))

Initial Hessian matrices. Only diagonal entries are used.

required
a0 (ndarray, shape(b))

Initial amplitudes.

required
mode (ndarray, shape(b))

Hamiltonian branch signs.

required
ts (ndarray, shape(Nt))

Time grid.

required
c Callable

Homogeneous sound-speed function.

required
lam Any

Ignored compatibility argument.

None
config Any

Ignored compatibility argument.

None

Returns:

Name Type Description
xt (ndarray, shape(b, Nt, d))

Beam positions over time.

pt (ndarray, shape(b, Nt, d))

Beam momenta over time.

Mt (ndarray, shape(b, Nt, d, d))

Beam Hessians over time.

At (ndarray, shape(b, Nt, 1))

Beam amplitudes over time.

Notes

Assumes c(x) is homogeneous, M0 is diagonal, and d is 1, 2, or 3. The diagonal and dimensionality assumptions are relaxed by :func:solve_hom_general.

solve_hom_general(x0: jnp.ndarray, p0: jnp.ndarray, m0: jnp.ndarray, a0: jnp.ndarray, mode: jnp.ndarray, ts: jnp.ndarray, c: Callable, lam=None, config=None) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]

Solver for homogeneous media with simplified equations.

Parameters:

Name Type Description Default
x0 (ndarray, shape(b, d))

Initial beam positions.

required
p0 (ndarray, shape(b, d))

Initial momenta.

required
m0 (ndarray, shape(b, d, d))

Initial Hessian matrices.

required
a0 (ndarray, shape(b))

Initial amplitudes.

required
mode (ndarray, shape(b))

Hamiltonian branch signs.

required
ts (ndarray, shape(Nt))

Time grid.

required
c Callable

Homogeneous sound-speed function.

required
lam Any

Ignored compatibility argument.

None
config Any

Ignored compatibility argument.

None

Returns:

Name Type Description
xt (ndarray, shape(b, Nt, d))

Beam positions over time.

pt (ndarray, shape(b, Nt, d))

Beam momenta over time.

Mt (ndarray, shape(b, Nt, d, d))

Beam Hessians over time.

At (ndarray, shape(b, Nt, 1))

Beam amplitudes over time.

solve_hom_TR(xT: jnp.ndarray, pT: jnp.ndarray, mT: jnp.ndarray, aT: jnp.ndarray, mode: jnp.ndarray, ts: jnp.ndarray, c: Callable, lam=None, config=None) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]

Time-reversal solver for homogeneous media.

Parameters:

Name Type Description Default
xT (ndarray, shape(b, d))

Beam positions at the reference final time.

required
pT (ndarray, shape(b, d))

Beam momenta at the reference final time.

required
mT (ndarray, shape(b, d, d))

Beam Hessians at the reference final time.

required
aT (ndarray, shape(b) or (b, 1))

Beam amplitudes at the reference final time.

required
mode (ndarray, shape(b) or (b, 1))

Hamiltonian branch signs.

required
ts ndarray

Per-beam or shared time grid.

required
c Callable

Homogeneous sound-speed function.

required
lam Any

Ignored compatibility argument.

None
config Any

Ignored compatibility argument.

None

Returns:

Name Type Description
x0 (ndarray, shape(b, Nt, d))

Time-reversed beam positions.

p0_time (ndarray, shape(b, Nt, d))

Time-reversed beam momenta.

m0 (ndarray, shape(b, Nt, d, d))

Time-reversed Hessians.

a0 (ndarray, shape(b, Nt, 1))

Time-reversed amplitudes.

ode_solver_setup(coupled_rhs: Callable, y0: jnp.ndarray, t0, t1, dt0, ts: jnp.ndarray, args: Tuple, config: Optional[SolverConfig] = None, cond_fn: Optional[Callable] = None, saveat: Optional[diffrax.SaveAt] = None)

Setup the ODE solver for the coupled system of ODEs for the GB motion.

Parameters:

Name Type Description Default
coupled_rhs Callable

Right-hand-side function passed to :class:diffrax.ODETerm.

required
y0 ndarray

Initial state vector.

required
t0 float

Initial time.

required
t1 float

Final time.

required
dt0 float

Initial time step.

required
ts ndarray

Save times.

required
args Tuple

Extra ODE arguments.

required
config SolverConfig

Numerical solver configuration.

None
cond_fn Callable

Event condition function.

None
saveat SaveAt

Custom save specification. Defaults to SaveAt(ts=ts).

None

Returns:

Type Description
Solution

Diffrax solution object.

riccati_rhs(M, x, p, mode, c)

Riccati equation for Hessian evolution Ṁ.

Parameters:

Name Type Description Default
M (ndarray, shape(d, d))
required
x (ndarray, shape(d))
required
p (ndarray, shape(d))
required
mode scalar
required
c Callable
required

Returns:

Type Description
(ndarray, shape(d, d))

Ṁ = -(Gxx + Gxp M + M Gxpᵀ + M Gpp M).

With (Gxp)_ij = ∂²G/(∂x_i ∂p_j), this is the standard textbook form (Berra–de Hoop–Romero 2017 eq 2.13; Červený 2007 eq 66).

coupled_rhs_absorption(t, y, args) -> jnp.ndarray

Full GB ODE system with absorption lam.

State layout

y = concat(x (d), p (d), vec(M) (d²), A (1))

Parameters:

Name Type Description Default
t float
required
y jnp.ndarray, shape (d+d+d²+1,)
required
args Tuple[mode, c, d, lam]
required

Returns:

Type Description
jnp.ndarray, same shape as `y`
coupled_rhs(t, y, args) -> jnp.ndarray

GB ODE system without absorption.

Parameters:

Name Type Description Default
t float
required
y jnp.ndarray, shape (d+d+d²+1,)
required
args Tuple[mode, c, d]
required

Returns:

Type Description
jnp.ndarray, same shape as `y`
format_solution(ys, d)

Format the solution of the ODEs.

Parameters:

Name Type Description Default
ys (ndarray, shape(Nt, d + d + d ** 2 + 1))

Flat ODE state trajectory.

required
d int

Spatial dimension.

required

Returns:

Name Type Description
xt (ndarray, shape(Nt, d))

Beam positions.

pt (ndarray, shape(Nt, d))

Beam momenta.

Mt (ndarray, shape(Nt, d, d))

Beam Hessians.

At (ndarray, shape(Nt, 1))

Beam amplitudes.

solve_ODE_base(x0: jnp.ndarray, p0: jnp.ndarray, M0: jnp.ndarray, a0: jnp.ndarray, mode: jnp.ndarray, ts: jnp.ndarray, c: Callable, lam: float = 0.0, solver_config: Optional[SolverConfig] = None) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]

Solve the coupled system of ODEs for the GB with configurable solver settings.

Parameters:

Name Type Description Default
x0 (ndarray, shape(d))

Initial beam position for one vmapped beam.

required
p0 (ndarray, shape(d))

Initial momentum for one vmapped beam.

required
M0 (ndarray, shape(d, d))

Initial Hessian for one vmapped beam.

required
a0 (ndarray, shape(1) or scalar)

Initial amplitude.

required
mode ndarray

Hamiltonian branch sign.

required
ts (ndarray, shape(Nt))

Time grid.

required
c Callable

Sound-speed function.

required
lam float

Absorption coefficient.

0.0
solver_config SolverConfig

Numerical solver configuration.

None

Returns:

Name Type Description
xt (ndarray, shape(Nt, d))

Beam positions.

pt (ndarray, shape(Nt, d))

Beam momenta.

Mt (ndarray, shape(Nt, d, d))

Beam Hessians.

At (ndarray, shape(Nt, 1))

Beam amplitudes.

solve_ODE_batch_t(x0: jnp.ndarray, p0: jnp.ndarray, M0: jnp.ndarray, A0: jnp.ndarray, mode: jnp.ndarray, ts: jnp.ndarray, c: Callable, lam: Optional[float] = None, solver_config: Optional[SolverConfig] = None) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]

Solve the coupled system of ODEs for the GB motion with per-batch time points.

Parameters:

Name Type Description Default
x0 (ndarray, shape(b, d))

Initial beam positions.

required
p0 (ndarray, shape(b, d))

Initial momenta.

required
M0 (ndarray, shape(b, d, d))

Initial Hessian matrices.

required
A0 (ndarray, shape(b))

Initial amplitudes.

required
mode (ndarray, shape(b))

Hamiltonian branch signs.

required
ts (ndarray, shape(b, Nt))

Per-beam time grids, commonly (t0, t1) intervals.

required
c Callable

Sound-speed function.

required
lam float

Absorption coefficient. Currently unused by this no-absorption RHS.

None
solver_config SolverConfig

Numerical solver configuration.

None

Returns:

Name Type Description
xt (ndarray, shape(b, Nt, d))

Beam positions.

pt (ndarray, shape(b, Nt, d))

Beam momenta.

Mt (ndarray, shape(b, Nt, d, d))

Beam Hessians.

At (ndarray, shape(b, Nt, 1))

Beam amplitudes.

solve_ODE_intersection(x0: jnp.ndarray, p0: jnp.ndarray, M0: jnp.ndarray, a0: jnp.ndarray, mode: jnp.ndarray, ts: jnp.ndarray, c: Callable, lam: float, surface: Callable, solver_config: Optional[SolverConfig] = None) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]

Solve the ODE for the Gaussian beam and find the intersection time with the surface.

Parameters:

Name Type Description Default
x0 (ndarray, shape(d))

Initial beam position for one vmapped beam.

required
p0 (ndarray, shape(d))

Initial momentum for one vmapped beam.

required
M0 (ndarray, shape(d, d))

Initial Hessian.

required
a0 ndarray

Initial amplitude.

required
mode ndarray

Hamiltonian branch sign.

required
ts (ndarray, shape(Nt))

Time grid.

required
c Callable

Sound-speed function.

required
lam float

Absorption coefficient.

required
surface Callable

Implicit surface function whose zero defines the target surface.

required
solver_config SolverConfig

Numerical solver configuration.

None

Returns:

Name Type Description
xt (ndarray, shape(Nt, d))

Beam positions.

pt (ndarray, shape(Nt, d))

Beam momenta.

Mt (ndarray, shape(Nt, d, d))

Beam Hessians.

At (ndarray, shape(Nt, 1))

Beam amplitudes.

t_int ndarray

Intersection time, or inf when the root solve fails.

Notes

First solves the beam ODE with dense output, then solves a scalar root problem for surface(x(t)).

solve_ODE_first_hit(x0: jnp.ndarray, p0: jnp.ndarray, M0: jnp.ndarray, a0: jnp.ndarray, mode: jnp.ndarray, ts: jnp.ndarray, c: Callable, lam: float, surface: Callable[[jnp.ndarray], float], solver_config: Optional[SolverConfig] = None) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, float, bool]

Integrate a single beam until it hits surface(x)=0 (or reaches t1).

Returns the state at the first hit time (beam axis first, then time axis).

Parameters:

Name Type Description Default
x0 ndarray

Initial GB parameters for one beam.

required
p0 ndarray

Initial GB parameters for one beam.

required
M0 ndarray

Initial GB parameters for one beam.

required
a0 ndarray

Initial GB parameters for one beam.

required
mode ndarray

Polarisation (+/-1) for this beam (shape (1,) or scalar).

required
ts ndarray

Global time grid; assumed uniform. Integration stops at ts[-1] if no hit.

required
c Callable

Sound speed function.

required
lam float

Absorption parameter.

required
surface Callable[[ndarray], float]

Implicit surface function; root at zero triggers a hit.

required
solver_config SolverConfig | None
None

Returns:

Type Description
(xt, pt, Mt, At, t_hit, hit)

xt, pt : (1, 1, d) Mt : (1, 1, d, d) At : (1, 1, 1) t_hit : float hit : bool (True if event occurred before ts[-1])

coupled_rhs_QP_absorption(t, y, args) -> jnp.ndarray

GB ODE system in (x, p, Q, P, A) coordinates with absorption.

State layout

y = concat(x (d), p (d), vec(Q) (d²), vec(P) (d²), A (1))

Parameters:

Name Type Description Default
t float
required
y jnp.ndarray, shape (d + d + d² + d² + 1,)
required
args Tuple[mode, c, d, lam]
required

Returns:

Type Description
jnp.ndarray, same shape as `y`
format_solution_QP(ys, d)

Format the solution of the ODEs in (x, p, Q, P, A) into (x, p, M, A).

Parameters:

Name Type Description Default
ys (ndarray, shape(Nt, d + d + d ** 2 + d ** 2 + 1))

Flat ODE state trajectory.

required
d int

Spatial dimension.

required

Returns:

Name Type Description
xt (ndarray, shape(Nt, d))

Beam positions.

pt (ndarray, shape(Nt, d))

Beam momenta.

Mt (ndarray, shape(Nt, d, d))

Hessians reconstructed as P @ inv(Q).

At (ndarray, shape(Nt, 1))

Beam amplitudes.

solve_ODE_QP_base(x0: jnp.ndarray, p0: jnp.ndarray, M0: jnp.ndarray, a0: jnp.ndarray, mode: jnp.ndarray, ts: jnp.ndarray, c: Callable, lam: float = 0.0, solver_config: Optional[SolverConfig] = None) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]

Solve the GB ODEs using (Q,P) instead of M directly.

Parameters:

Name Type Description Default
x0 (ndarray, shape(d))

Initial beam position for one vmapped beam.

required
p0 (ndarray, shape(d))

Initial beam momentum.

required
M0 (ndarray, shape(d, d))

Initial Hessian matrix.

required
a0 ndarray

Initial amplitude.

required
mode ndarray

Hamiltonian branch sign.

required
ts (ndarray, shape(Nt))

Time grid.

required
c Callable

Sound-speed function.

required
lam float

Absorption coefficient.

0.0
solver_config SolverConfig

Numerical solver configuration.

None

Returns:

Name Type Description
xt (ndarray, shape(Nt, d))

Beam positions.

pt (ndarray, shape(Nt, d))

Beam momenta.

Mt (ndarray, shape(Nt, d, d))

Reconstructed Hessian matrices.

At (ndarray, shape(Nt, 1))

Beam amplitudes.

Notes

Uses initial condition Q(0) = I and P(0) = M0 so that M(0) = M0.