beamax.solvers.msgb_solvers.msgb_solver
Primary high-level Multiscale Gaussian Beam solver interface.
Key Objects
MSGBSolver: forward/time-reversal/adjoint entry points using wave-packet coefficients and beam propagation.ShardingStrategy: optional device-mesh strategy for distributing beam parameters.
API Reference
beamax.solvers.msgb_solvers.msgb_solver
ShardingStrategy(mesh: Mesh, beam_axis: str = 'x')
dataclass
Strategy for sharding beam parameters across devices.
Attributes:
| Name | Type | Description |
|---|---|---|
mesh |
Mesh
|
JAX device mesh for multi-device parallelization. |
beam_axis |
str
|
Mesh axis used to shard beams. |
shard_beam_params(p0: jnp.ndarray, M0: jnp.ndarray, x0: jnp.ndarray, omega: jnp.ndarray, a0: jnp.ndarray, modes: jnp.ndarray) -> Tuple[jnp.ndarray, ...]
Shard forward beam parameters along the beam dimension.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
p0
|
ndarray
|
Beam momenta. |
required |
M0
|
ndarray
|
Beam Hessian matrices. |
required |
x0
|
ndarray
|
Beam positions. |
required |
omega
|
ndarray
|
Beam frequencies. |
required |
a0
|
ndarray
|
Beam amplitudes. |
required |
modes
|
ndarray
|
Beam branch signs. |
required |
Returns:
| Type | Description |
|---|---|
Tuple[ndarray, ...]
|
Device-placed arrays with sharding specifications applied. |
shard_tr_params(pts: jnp.ndarray, Mts: jnp.ndarray, xts: jnp.ndarray, omega_ts: jnp.ndarray, ats: jnp.ndarray, signum: jnp.ndarray, ts: jnp.ndarray) -> Tuple[jnp.ndarray, ...]
Shard time-reversal beam parameters along the beam dimension.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pts
|
ndarray
|
Beam momenta at the boundary. |
required |
Mts
|
ndarray
|
Beam Hessians at the boundary. |
required |
xts
|
ndarray
|
Boundary positions. |
required |
omega_ts
|
ndarray
|
Beam frequencies. |
required |
ats
|
ndarray
|
Beam amplitudes. |
required |
signum
|
ndarray
|
Beam branch signs. |
required |
ts
|
ndarray
|
Per-beam time intervals. |
required |
Returns:
| Type | Description |
|---|---|
Tuple[ndarray, ...]
|
Device-placed arrays with sharding specifications applied. |
MSGBSolver(thr: Union[int, float], thr_strat: str, batch_size: int, input_type: str, ode_solver: SolverFn, sum_method: str, tr_ode_solver: Optional[SolverFn] = None, sharding: Optional[ShardingStrategy] = None, ode_config: Optional[SolverConfig] = None)
Bases: Module
Multiscale Gaussian Beam solver for the linear wave equation.
Implements forward, time-reversal, and adjoint operators by:
- Decomposing the initial pressure into wave-packet coefficients via
:class:
beamax.transforms.MSWPT. - Thresholding to retain only significant coefficients.
- Integrating a small Hamiltonian ODE per retained beam.
- Summing (or scanning) the beam contributions at the sensor positions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
thr
|
int or float
|
Threshold value for coefficient selection. Semantics depend on
|
required |
thr_strat
|
str
|
Thresholding strategy; one of |
required |
batch_size
|
int
|
Batch size along the beam axis for ODE integration. Tune to fit device memory; larger values amortise kernel launches. |
required |
input_type
|
(spatial, fourier)
|
Domain the caller provides |
"spatial"
|
ode_solver
|
SolverFn
|
Forward-time ODE integrator for beam dynamics (typically one of
:mod: |
required |
sum_method
|
str
|
Method for summing beam contributions. One of |
required |
tr_ode_solver
|
SolverFn
|
ODE integrator for the time-reversal dynamics. Falls back to
|
None
|
sharding
|
ShardingStrategy
|
Multi-device sharding strategy. |
None
|
ode_config
|
SolverConfig
|
Numerical configuration passed through to the ODE integrator. Falls
back to |
None
|
Initialize the MSGB solver.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
thr
|
int or float
|
Threshold value for coefficient selection. |
required |
thr_strat
|
str
|
Thresholding strategy name. |
required |
batch_size
|
int
|
Number of beams per batch for scan/vmap aggregation. |
required |
input_type
|
(spatial, fourier)
|
Domain of inputs supplied to the solver. |
"spatial"
|
ode_solver
|
SolverFn
|
Forward ODE solver. |
required |
sum_method
|
str
|
Aggregation mode string. Must be one of the values listed in the class-level parameter documentation. |
required |
tr_ode_solver
|
SolverFn
|
ODE solver for time reversal. Defaults to |
None
|
sharding
|
ShardingStrategy
|
Multi-device sharding strategy. |
None
|
ode_config
|
SolverConfig
|
Numerical ODE solver configuration. Defaults to
|
None
|
forward(p0: jnp.ndarray, domain: Domain, sensors: Union[Sensor, jnp.ndarray], ts: jnp.ndarray, wpt: MSWPT, *, dpdt: Optional[jnp.ndarray] = None) -> jnp.ndarray
Solve the forward wave equation u_tt - c²∇²u = 0 with MSGB.
Initial conditions are u(0, x) = p0 and u_t(0, x) = dpdt (zero
for standard photoacoustic tomography).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
p0
|
(ndarray, shape(*N))
|
Initial pressure field. Real or complex; dtype selects the underlying beam formulation. |
required |
domain
|
Domain
|
Computational domain and medium. |
required |
sensors
|
Sensor or ndarray
|
Sensor geometry. Either a :class: |
required |
ts
|
(ndarray, shape(Nt))
|
Time grid. |
required |
wpt
|
MSWPT
|
Wave-packet transform used to build the beam decomposition. |
required |
dpdt
|
ndarray
|
Initial time derivative. Defaults to zeros (standard PAT). |
None
|
Returns:
| Type | Description |
|---|---|
(ndarray, shape(Nt, Ns))
|
Pressure at each sensor over time. |
forward_with_params(p0: jnp.ndarray, domain: Domain, sensors: Union[Sensor, jnp.ndarray], ts: jnp.ndarray, wpt: MSWPT, *, dpdt: Optional[jnp.ndarray] = None) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, ...]]
Forward MSGB solve plus diagnostic beam parameters.
This is the explicit diagnostic variant of :meth:forward. Most users
should call :meth:forward, which returns only sensor data.
Returns:
| Name | Type | Description |
|---|---|---|
sensor_data |
(ndarray, shape(Nt, Ns))
|
Pressure at each sensor over time. |
params |
tuple of jnp.ndarray
|
Beam parameters used in the solve:
|
time_reversal(data: jnp.ndarray, domain: Domain, sensors: Sensor, sources: Sensor, ts, data_domain: Domain, data_wpt: MSWPT) -> jnp.ndarray
MSGB time-reversal reconstruction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
(ndarray, shape(Nt, Ns))
|
Sensor time series to time-reverse. |
required |
domain
|
Domain
|
Reconstruction domain. Must have |
required |
sensors
|
Sensor
|
Sensor geometry corresponding to |
required |
sources
|
Sensor
|
Source positions used to seed the TR beams (often the same
boundary as |
required |
ts
|
(ndarray, shape(Nt))
|
Time grid corresponding to |
required |
data_domain
|
Domain
|
Domain on which |
required |
data_wpt
|
MSWPT
|
Wave-packet transform on |
required |
Returns:
| Type | Description |
|---|---|
(ndarray, shape(*N))
|
Reconstructed initial pressure. Scaled by 2 to match the standard full-field time-reversal convention. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If any axis of |
Notes
Time reversal here uses per-beam time intervals (via
:func:beamax.gb.solve_ODE_batch_t) regardless of the forward
integrator, to accommodate the variable emission time of each beam.
time_reversal_with_params(data: jnp.ndarray, domain: Domain, sensors: Sensor, sources: Sensor, ts, data_domain: Domain, data_wpt: MSWPT) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, ...]]
Time-reversal MSGB solve plus diagnostic beam parameters.
This is the explicit diagnostic variant of :meth:time_reversal. Most
users should call :meth:time_reversal, which returns only the
reconstructed field.
Returns:
| Name | Type | Description |
|---|---|---|
p0_recon |
(ndarray, shape(*N))
|
Reconstructed initial pressure. |
params |
tuple of jnp.ndarray
|
Beam parameters used in the solve. |
solve_ivp(p0: jnp.ndarray, dpdt: jnp.ndarray, domain: Domain, wpt: MSWPT, sensors: Union[Sensor, jnp.ndarray], ts: jnp.ndarray) -> jnp.ndarray
Solve the wave-equation IVP with non-zero initial velocity.
Equivalent to :meth:forward but requires an explicit dpdt.
Use this when u_t(0, x) ≠ 0 (Cauchy data); use :meth:forward
for standard photoacoustic settings where dpdt = 0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
p0
|
(ndarray, shape(*N))
|
Initial pressure field. |
required |
dpdt
|
(ndarray, shape(*N))
|
Initial time derivative of the pressure. |
required |
domain
|
Domain
|
Computational domain. |
required |
wpt
|
MSWPT
|
Wave-packet transform for the beam decomposition. |
required |
sensors
|
Sensor or ndarray
|
Sensor geometry. |
required |
ts
|
(ndarray, shape(Nt))
|
Time grid. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Sensor time series. Equivalent to :meth: |
solve_ivp_with_params(p0: jnp.ndarray, dpdt: jnp.ndarray, domain: Domain, wpt: MSWPT, sensors: Union[Sensor, jnp.ndarray], ts: jnp.ndarray) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, ...]]
IVP solve plus diagnostic beam parameters.
This is the explicit diagnostic variant of :meth:solve_ivp.
adjoint(data: jnp.ndarray, domain: Domain, sensors: Union[Sensor, jnp.ndarray], sources: Sensor, ts: jnp.ndarray, data_domain: Domain, data_wpt: MSWPT) -> jnp.ndarray
MSGB adjoint solve (Arridge-style): F = w ∂_t r(T - t), then B^{-1}F + TR.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
ndarray
|
Boundary measurement r(t, x_s) on Gamma, or (if use_raw_source=True) an already-formed adjoint source F(t, x_s). Shape (Nt, Ns) or (Nt,) with time along axis 0. |
required |
domain
|
Domain
|
Reconstruction (image) domain where we want q(T, x). |
required |
sensors
|
Sensor or ndarray
|
Locations at which to evaluate the adjoint field. For image
reconstruction this is typically |
required |
sources
|
Sensor
|
Source geometry on Gamma, used to construct the boundary beam parameters (same role as in time reversal). |
required |
ts
|
ndarray
|
Time grid, shape (Nt,). Currently not used directly, but kept for interface symmetry and possible future extensions. |
required |
data_domain
|
Domain
|
Domain describing the (t, x_s) grid of the boundary data.
Its |
required |
data_wpt
|
MSWPT
|
MSWPT instance for analysing the boundary data / source. |
required |
Keyword Parameters
use_raw_source : bool, default False
If False (default), data is interpreted as a boundary
measurement r(s, x_s) in the original acquisition variable and
we internally form the adjoint source
F(t, x_s) = w(x_s) ∂_t r(T - t, x_s),
represented on the original-time grid as -w(x_s) ∂_s r(s, x_s),
using a simple finite-difference in time and unit weights w≡1.
If True, `data` is assumed to already be F(t, x_s) and is
passed to `_prepare_adj_params` unchanged.
Returns:
| Type | Description |
|---|---|
ndarray
|
Adjoint field q(T, x) on the reconstruction domain (same shape as a forward initial condition). |
adjoint_with_params(data: jnp.ndarray, domain: Domain, sensors: Union[Sensor, jnp.ndarray], sources: Sensor, ts: jnp.ndarray, data_domain: Domain, data_wpt: MSWPT) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, ...]]
Adjoint MSGB solve plus diagnostic beam parameters.
This is the explicit diagnostic variant of :meth:adjoint. Most users
should call :meth:adjoint, which returns only the adjoint field.
Returns:
| Name | Type | Description |
|---|---|---|
q_T |
ndarray
|
Adjoint field q(T, x) on the reconstruction domain. |
params |
tuple of jnp.ndarray
|
Beam parameters used internally. |