Skip to content

beamax.solvers.msgb_solvers.forward_solver_utils

Internal utilities for MSGB forward solves.

Scope

  • Coefficient computation and thresholding.
  • Beam parameter extraction from MSWPT indices.
  • Aggregation strategies (all / vmap / scan) and memory-aware batching support.

API Reference

beamax.solvers.msgb_solvers.forward_solver_utils

threshold_coefficients(coeffs: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], val: float, strategy: str = 'hard', wpt: MSWPT = None)

Apply thresholding to wavelet coefficients.

Parameters:

Name Type Description Default
coeffs ndarray or Tuple[ndarray, ndarray]

Coefficients to threshold.

required
val float

Threshold value.

required
strategy str

Thresholding strategy.

"hard"
wpt MSWPT

Wave-packet transform required by strategies that depend on the dyadic layout.

None

Returns:

Type Description
Tuple[ndarray, ndarray] or Tuple[Tuple[ndarray, ndarray], ...]

Selected indices and values. If coeffs is a tuple, returns one (idx, values) pair for each coefficient vector.

Raises:

Type Description
ValueError

If strategy is unknown.

compute_forward_parameters(significant_coeffs: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], wpt: MSWPT, domain: Domain) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]

Compute Gaussian beam parameters from wavelet coefficients.

Parameters:

Name Type Description Default
significant_coeffs ndarray or Tuple[ndarray, ndarray]

Significant coefficient indices for positive, or positive/negative, modes.

required
wpt MSWPT

Wave-packet transform.

required
domain Domain

Physical domain.

required

Returns:

Name Type Description
p0s ndarray

Initial beam momenta.

M0s ndarray

Initial beam Hessians.

x0s ndarray

Initial beam positions.

ωs : jnp.ndarray

Beam frequencies.

a0s ndarray

Initial beam amplitudes.

modes ndarray

Beam branch signs.

compute_memory_requirements(b: int, N: Tuple, Nt: int) -> str

Estimate memory requirements for Gaussian beam computation.

Parameters:

Name Type Description Default
b int

Number of beams.

required
N Tuple[int, ...]

Grid dimensions.

required
Nt int

Number of time points.

required

Returns:

Type Description
str

Human-readable memory estimate.

compute_forward_result(params: Tuple[jnp.ndarray, ...], c: Callable, lam: float, ts: jnp.ndarray, ode_solver: SolverFn, sensors: jnp.ndarray, domain_size: jnp.ndarray, periodic: jnp.ndarray, use_real: bool = True, aggregate_method: str = 'scan', solver_config: Optional[SolverConfig] = None) -> jnp.ndarray

Compute forward solution to the wave equation using Gaussian beams.

Parameters:

Name Type Description Default
params Tuple[ndarray, ...]

Beam parameters (p0, M0, x0, omega, a0, mode).

required
c Callable

Sound-speed function.

required
lam float

Absorption parameter.

required
ts ndarray

Time points.

required
ode_solver SolverFn

ODE solver.

required
sensors ndarray

Sensor positions.

required
domain_size ndarray

Domain size.

required
periodic ndarray

Boundary periodicity flags.

required
use_real bool

Whether to use real-valued beam computation.

True
aggregate_method (scan, vmap, all)

Beam aggregation method.

"scan"
solver_config SolverConfig

Numerical ODE configuration.

None

Returns:

Type Description
ndarray

Forward solution at sensor locations.

compute_coefficients(p0: jnp.ndarray, dpdt: jnp.ndarray, input_type: str, domain: Domain, wpt: MSWPT, mode: str = 'both') -> Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]

Compute wavelet packet transform coefficients.

Parameters:

Name Type Description Default
p0 ndarray

Initial pressure field.

required
dpdt ndarray

Initial pressure time derivative.

required
input_type (spatial, fourier)

Domain of p0 and dpdt.

"spatial"
domain Domain

Physical domain.

required
wpt MSWPT

Wave-packet transform.

required
mode (both, pos_only)

"both" returns positive and negative frequency coefficients. "pos_only" returns masked positive coefficients.

"both"

Returns:

Type Description
ndarray or Tuple[ndarray, ndarray]

Coefficients (cpos, cneg) if mode="both", otherwise masked cpos.

Raises:

Type Description
ValueError

If mode is invalid.