Skip to content

beamax.solvers.msgb_solvers.msgb_solver

Primary high-level Multiscale Gaussian Beam solver interface.

Key Objects

  • MSGBSolver: forward/time-reversal/adjoint entry points using wave-packet coefficients and beam propagation.
  • ShardingStrategy: optional device-mesh strategy for distributing beam parameters.

API Reference

beamax.solvers.msgb_solvers.msgb_solver

ShardingStrategy(mesh: Mesh, beam_axis: str = 'x') dataclass

Strategy for sharding beam parameters across devices.

Attributes:

Name Type Description
mesh Mesh

JAX device mesh for multi-device parallelization.

beam_axis str

Mesh axis used to shard beams.

shard_beam_params(p0: jnp.ndarray, M0: jnp.ndarray, x0: jnp.ndarray, omega: jnp.ndarray, a0: jnp.ndarray, modes: jnp.ndarray) -> Tuple[jnp.ndarray, ...]

Shard forward beam parameters along the beam dimension.

Parameters:

Name Type Description Default
p0 ndarray

Beam momenta.

required
M0 ndarray

Beam Hessian matrices.

required
x0 ndarray

Beam positions.

required
omega ndarray

Beam frequencies.

required
a0 ndarray

Beam amplitudes.

required
modes ndarray

Beam branch signs.

required

Returns:

Type Description
Tuple[ndarray, ...]

Device-placed arrays with sharding specifications applied.

shard_tr_params(pts: jnp.ndarray, Mts: jnp.ndarray, xts: jnp.ndarray, omega_ts: jnp.ndarray, ats: jnp.ndarray, signum: jnp.ndarray, ts: jnp.ndarray) -> Tuple[jnp.ndarray, ...]

Shard time-reversal beam parameters along the beam dimension.

Parameters:

Name Type Description Default
pts ndarray

Beam momenta at the boundary.

required
Mts ndarray

Beam Hessians at the boundary.

required
xts ndarray

Boundary positions.

required
omega_ts ndarray

Beam frequencies.

required
ats ndarray

Beam amplitudes.

required
signum ndarray

Beam branch signs.

required
ts ndarray

Per-beam time intervals.

required

Returns:

Type Description
Tuple[ndarray, ...]

Device-placed arrays with sharding specifications applied.

MSGBSolver(thr: Union[int, float], thr_strat: str, batch_size: int, input_type: str, ode_solver: SolverFn, sum_method: str, tr_ode_solver: Optional[SolverFn] = None, sharding: Optional[ShardingStrategy] = None, ode_config: Optional[SolverConfig] = None)

Bases: Module

Multiscale Gaussian Beam solver for the linear wave equation.

Implements forward, time-reversal, and adjoint operators by:

  1. Decomposing the initial pressure into wave-packet coefficients via :class:beamax.transforms.MSWPT.
  2. Thresholding to retain only significant coefficients.
  3. Integrating a small Hamiltonian ODE per retained beam.
  4. Summing (or scanning) the beam contributions at the sensor positions.

Parameters:

Name Type Description Default
thr int or float

Threshold value for coefficient selection. Semantics depend on thr_strat (e.g. absolute magnitude, percentile, top-k count).

required
thr_strat str

Thresholding strategy; one of "hard", "top_n", "percentile", "hard_reassign", "bao_energy", or "perc_max_abs".

required
batch_size int

Batch size along the beam axis for ODE integration. Tune to fit device memory; larger values amortise kernel launches.

required
input_type (spatial, fourier)

Domain the caller provides p0 in.

"spatial"
ode_solver SolverFn

Forward-time ODE integrator for beam dynamics (typically one of :mod:beamax.gb.gb_solvers).

required
sum_method str

Method for summing beam contributions. One of "all_real", "scan_real", "vmap_real", "all_complex", "scan_complex", or "vmap_complex".

required
tr_ode_solver SolverFn

ODE integrator for the time-reversal dynamics. Falls back to ode_solver when None.

None
sharding ShardingStrategy

Multi-device sharding strategy. None runs on a single device.

None
ode_config SolverConfig

Numerical configuration passed through to the ODE integrator. Falls back to SolverConfig.from_precision().

None

Initialize the MSGB solver.

Parameters:

Name Type Description Default
thr int or float

Threshold value for coefficient selection.

required
thr_strat str

Thresholding strategy name.

required
batch_size int

Number of beams per batch for scan/vmap aggregation.

required
input_type (spatial, fourier)

Domain of inputs supplied to the solver.

"spatial"
ode_solver SolverFn

Forward ODE solver.

required
sum_method str

Aggregation mode string. Must be one of the values listed in the class-level parameter documentation.

required
tr_ode_solver SolverFn

ODE solver for time reversal. Defaults to ode_solver.

None
sharding ShardingStrategy

Multi-device sharding strategy.

None
ode_config SolverConfig

Numerical ODE solver configuration. Defaults to SolverConfig.from_precision().

None
forward(p0: jnp.ndarray, domain: Domain, sensors: Union[Sensor, jnp.ndarray], ts: jnp.ndarray, wpt: MSWPT, *, dpdt: Optional[jnp.ndarray] = None) -> jnp.ndarray

Solve the forward wave equation u_tt - c²∇²u = 0 with MSGB.

Initial conditions are u(0, x) = p0 and u_t(0, x) = dpdt (zero for standard photoacoustic tomography).

Parameters:

Name Type Description Default
p0 (ndarray, shape(*N))

Initial pressure field. Real or complex; dtype selects the underlying beam formulation.

required
domain Domain

Computational domain and medium.

required
sensors Sensor or ndarray

Sensor geometry. Either a :class:Sensor or an array of positions in physical units, shape (Ns, ndim).

required
ts (ndarray, shape(Nt))

Time grid.

required
wpt MSWPT

Wave-packet transform used to build the beam decomposition.

required
dpdt ndarray

Initial time derivative. Defaults to zeros (standard PAT).

None

Returns:

Type Description
(ndarray, shape(Nt, Ns))

Pressure at each sensor over time.

forward_with_params(p0: jnp.ndarray, domain: Domain, sensors: Union[Sensor, jnp.ndarray], ts: jnp.ndarray, wpt: MSWPT, *, dpdt: Optional[jnp.ndarray] = None) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, ...]]

Forward MSGB solve plus diagnostic beam parameters.

This is the explicit diagnostic variant of :meth:forward. Most users should call :meth:forward, which returns only sensor data.

Returns:

Name Type Description
sensor_data (ndarray, shape(Nt, Ns))

Pressure at each sensor over time.

params tuple of jnp.ndarray

Beam parameters used in the solve: (p0s, M0s, x0s, omegas, a0s, modes).

time_reversal(data: jnp.ndarray, domain: Domain, sensors: Sensor, sources: Sensor, ts, data_domain: Domain, data_wpt: MSWPT) -> jnp.ndarray

MSGB time-reversal reconstruction.

Parameters:

Name Type Description Default
data (ndarray, shape(Nt, Ns))

Sensor time series to time-reverse.

required
domain Domain

Reconstruction domain. Must have periodic all False — time-reversal here assumes free-space boundaries.

required
sensors Sensor

Sensor geometry corresponding to data.

required
sources Sensor

Source positions used to seed the TR beams (often the same boundary as sensors).

required
ts (ndarray, shape(Nt))

Time grid corresponding to data.

required
data_domain Domain

Domain on which data was acquired (may differ from domain under downsampling).

required
data_wpt MSWPT

Wave-packet transform on data_domain used to analyse data.

required

Returns:

Type Description
(ndarray, shape(*N))

Reconstructed initial pressure. Scaled by 2 to match the standard full-field time-reversal convention.

Raises:

Type Description
ValueError

If any axis of domain is periodic.

Notes

Time reversal here uses per-beam time intervals (via :func:beamax.gb.solve_ODE_batch_t) regardless of the forward integrator, to accommodate the variable emission time of each beam.

time_reversal_with_params(data: jnp.ndarray, domain: Domain, sensors: Sensor, sources: Sensor, ts, data_domain: Domain, data_wpt: MSWPT) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, ...]]

Time-reversal MSGB solve plus diagnostic beam parameters.

This is the explicit diagnostic variant of :meth:time_reversal. Most users should call :meth:time_reversal, which returns only the reconstructed field.

Returns:

Name Type Description
p0_recon (ndarray, shape(*N))

Reconstructed initial pressure.

params tuple of jnp.ndarray

Beam parameters used in the solve.

solve_ivp(p0: jnp.ndarray, dpdt: jnp.ndarray, domain: Domain, wpt: MSWPT, sensors: Union[Sensor, jnp.ndarray], ts: jnp.ndarray) -> jnp.ndarray

Solve the wave-equation IVP with non-zero initial velocity.

Equivalent to :meth:forward but requires an explicit dpdt. Use this when u_t(0, x) ≠ 0 (Cauchy data); use :meth:forward for standard photoacoustic settings where dpdt = 0.

Parameters:

Name Type Description Default
p0 (ndarray, shape(*N))

Initial pressure field.

required
dpdt (ndarray, shape(*N))

Initial time derivative of the pressure.

required
domain Domain

Computational domain.

required
wpt MSWPT

Wave-packet transform for the beam decomposition.

required
sensors Sensor or ndarray

Sensor geometry.

required
ts (ndarray, shape(Nt))

Time grid.

required

Returns:

Type Description
ndarray

Sensor time series. Equivalent to :meth:forward with explicit dpdt.

solve_ivp_with_params(p0: jnp.ndarray, dpdt: jnp.ndarray, domain: Domain, wpt: MSWPT, sensors: Union[Sensor, jnp.ndarray], ts: jnp.ndarray) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, ...]]

IVP solve plus diagnostic beam parameters.

This is the explicit diagnostic variant of :meth:solve_ivp.

adjoint(data: jnp.ndarray, domain: Domain, sensors: Union[Sensor, jnp.ndarray], sources: Sensor, ts: jnp.ndarray, data_domain: Domain, data_wpt: MSWPT) -> jnp.ndarray

MSGB adjoint solve (Arridge-style): F = w ∂_t r(T - t), then B^{-1}F + TR.

Parameters:

Name Type Description Default
data ndarray

Boundary measurement r(t, x_s) on Gamma, or (if use_raw_source=True) an already-formed adjoint source F(t, x_s). Shape (Nt, Ns) or (Nt,) with time along axis 0.

required
domain Domain

Reconstruction (image) domain where we want q(T, x).

required
sensors Sensor or ndarray

Locations at which to evaluate the adjoint field. For image reconstruction this is typically domain.grid (so we get q_T on the full grid).

required
sources Sensor

Source geometry on Gamma, used to construct the boundary beam parameters (same role as in time reversal).

required
ts ndarray

Time grid, shape (Nt,). Currently not used directly, but kept for interface symmetry and possible future extensions.

required
data_domain Domain

Domain describing the (t, x_s) grid of the boundary data. Its dx[0] is used as the time step dt.

required
data_wpt MSWPT

MSWPT instance for analysing the boundary data / source.

required
Keyword Parameters

use_raw_source : bool, default False If False (default), data is interpreted as a boundary measurement r(s, x_s) in the original acquisition variable and we internally form the adjoint source

    F(t, x_s) = w(x_s) ∂_t r(T - t, x_s),

represented on the original-time grid as -w(x_s) ∂_s r(s, x_s),
using a simple finite-difference in time and unit weights w≡1.
If True, `data` is assumed to already be F(t, x_s) and is
passed to `_prepare_adj_params` unchanged.

Returns:

Type Description
ndarray

Adjoint field q(T, x) on the reconstruction domain (same shape as a forward initial condition).

adjoint_with_params(data: jnp.ndarray, domain: Domain, sensors: Union[Sensor, jnp.ndarray], sources: Sensor, ts: jnp.ndarray, data_domain: Domain, data_wpt: MSWPT) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, ...]]

Adjoint MSGB solve plus diagnostic beam parameters.

This is the explicit diagnostic variant of :meth:adjoint. Most users should call :meth:adjoint, which returns only the adjoint field.

Returns:

Name Type Description
q_T ndarray

Adjoint field q(T, x) on the reconstruction domain.

params tuple of jnp.ndarray

Beam parameters used internally.