Skip to content

beamax.utils

General utilities shared across decomposition, transforms, and solvers.

Typical Use Cases

  • FFT/interpolation helpers for resampling and signal manipulation.
  • Batching/index utilities for coefficient and beam pipelines.
  • Miscellaneous helpers for device checks, synthetic data generation, and numeric convenience operations.

API Reference

beamax.utils

Utilities public API (explicit). Heavy things must be imported inside call-sites.

Interpolator(grid_points: List[jnp.ndarray], values: jnp.ndarray, **_)

Thin wrapper around make_c_function_from_grid using axis vectors.

Parameters:

Name Type Description Default
grid_points List[ndarray]

1D axis arrays (length d), each strictly increasing.

required
values ndarray

Grid values shaped to match grid_points.

required

Methods:

Name Description
__call__

Evaluate interpolant at x (shape (..., d)).

grad

Gradient ∇c(x).

hessian

Hessian ∇²c(x).

Construct an interpolator from axis vectors and grid values.

Parameters:

Name Type Description Default
grid_points List[ndarray]

One strictly increasing one-dimensional coordinate vector per axis.

required
values ndarray

Grid values with dimensionality matching grid_points.

required
**_ dict

Ignored compatibility keyword arguments.

{}

Raises:

Type Description
ValueError

If axis count and value dimensionality disagree, or if any axis is not one-dimensional with at least two points.

grad(x: jnp.ndarray) -> jnp.ndarray

Evaluate the gradient of the interpolant.

Parameters:

Name Type Description Default
x (ndarray, shape(d))

Physical query coordinate.

required

Returns:

Type Description
(ndarray, shape(d))

Gradient at x.

hessian(x: jnp.ndarray) -> jnp.ndarray

Evaluate the Hessian of the interpolant.

Parameters:

Name Type Description Default
x (ndarray, shape(d))

Physical query coordinate.

required

Returns:

Type Description
(ndarray, shape(d, d))

Hessian at x.

get_devices()

Inspect available JAX devices.

Returns:

Type Description
Tuple[bool, bool, bool]

Presence flags for (CPU, GPU, TPU). Prints the detected devices.

memory_estimate(dims: jnp.ndarray, dtype: jnp.dtype) -> str

Estimate memory footprint for an array shape/dtype.

Parameters:

Name Type Description Default
dims ndarray

Shape tuple or array of dimensions.

required
dtype dtype
required

Returns:

Type Description
str

Estimated memory usage.

memory_str(x: jnp.ndarray) -> str

Memory usage of an existing array.

Parameters:

Name Type Description Default
x ndarray
required

Returns:

Type Description
str

Human-friendly memory string.

array_str(x: Union[jnp.ndarray, None]) -> Union[str, None]

Short descriptor for an array (shape/dtype/memory).

Parameters:

Name Type Description Default
x ndarray | None
required

Returns:

Type Description
str | None
detect_root() -> Path

Locate the repository root used by examples for output files.

Priority

1) BEAMAX_ROOT environment variable 2) Current working directory upward search 3) Package source location upward search 4) Current working directory

Returns:

Type Description
Path
example_plot_dir(example_file: str | os.PathLike[str]) -> Path

Return the plot output directory for a public example script.

Examples mirror their first directory under examples/:

  • examples/forward/2d_forward.py -> <root>/plots/forward
  • examples/rays/2d_ray_bending.py -> <root>/plots/rays

If example_file is outside the detected checkout's examples tree, the file's immediate parent directory name is used as the category.

Parameters:

Name Type Description Default
example_file str or PathLike

Usually __file__ from an example script.

required

Returns:

Type Description
Path

Existing output directory for plots from that example category.

unitary_fft(arr: jnp.ndarray) -> jnp.ndarray

Unitary N-D FFT with centred zero-frequency component.

Parameters:

Name Type Description Default
arr ndarray

Real or complex array in the spatial domain.

required

Returns:

Type Description
ndarray

Fourier transform with norm="ortho" and fftshift applied.

unitary_ifft(arr: jnp.ndarray) -> jnp.ndarray

Unitary inverse FFT matching :func:unitary_fft.

Parameters:

Name Type Description Default
arr ndarray

Fourier-domain array (already shifted).

required

Returns:

Type Description
ndarray

Spatial-domain array with unitary scaling.

convert_space(array: jnp.ndarray, input_space: str, target_space: str) -> jnp.ndarray

Convert an array between spatial and Fourier domains.

Parameters:

Name Type Description Default
array ndarray
required
input_space (spatial, fourier)

Declares the domain of array.

"spatial"
target_space (spatial, fourier)

Desired output domain.

"spatial"

Returns:

Type Description
ndarray

Array in the requested domain.

Raises:

Type Description
ValueError

If an unsupported conversion is requested.

make_c_function_from_grid(c_map: jnp.ndarray, spacing: Optional[Tuple[float, ...]] = None, origin: Optional[Tuple[float, ...]] = None)

Build a JAX-differentiable n-linear interpolant over a rectilinear grid.

Parameters:

Name Type Description Default
c_map (ndarray, shape(*N))

Grid values.

required
spacing Tuple[float, ...] | None

Per-axis spacing. Defaults to 1.0.

None
origin Tuple[float, ...] | None

Per-axis origin. Defaults to 0.0.

None

Returns:

Type Description
Callable[[ndarray], ndarray]

Function c_fun(x) with x shape (..., d) in physical units.

Notes
  • Piecewise-linear; differentiable a.e.; gradients via jax.grad.
  • Index clamping at grid boundaries.
interpolate_nearest(array: jnp.ndarray, new_shape: Tuple) -> jnp.ndarray

Nearest-neighbour resampling to new_shape.

Parameters:

Name Type Description Default
array ndarray
required
new_shape Tuple[int, ...]
required

Returns:

Type Description
ndarray

Resampled array.

pad_array(array: jnp.ndarray, desired_size: Tuple[int, ...], mode: str = 'constant') -> jnp.ndarray

Centered pad per axis to reach desired_size.

Parameters:

Name Type Description Default
array ndarray
required
desired_size Tuple[int, ...]
required
mode ('constant', 'edge')
"constant"

Returns:

Type Description
ndarray
pad_zero(array: jnp.ndarray, desired_size: Tuple) -> jnp.ndarray

Zero-pad to centered target size.

Parameters:

Name Type Description Default
array ndarray
required
desired_size Tuple[int, ...]
required

Returns:

Type Description
ndarray
pad_edge(array: jnp.ndarray, desired_size: Tuple) -> jnp.ndarray

Edge-pad (replicate border) to centered target size.

Parameters:

Name Type Description Default
array ndarray
required
desired_size Tuple[int, ...]
required

Returns:

Type Description
ndarray
crop_centered(array: jnp.ndarray, desired_size: Tuple[int, ...]) -> jnp.ndarray

Centered crop per axis. No-op if any target dim exceeds current.

Parameters:

Name Type Description Default
array ndarray
required
desired_size Tuple[int, ...]
required

Returns:

Type Description
ndarray
interpolate_fourier(array: jnp.ndarray, desired_size: Tuple[int, ...], input_type: str, output_type: str) -> jnp.ndarray

Unitary FFT-based resampling (pad when upsampling, crop when downsampling).

Parameters:

Name Type Description Default
array ndarray
required
desired_size Tuple[int, ...]
required
input_type ('spatial', 'fourier')
"spatial"
output_type ('spatial', 'fourier')
"spatial"

Returns:

Type Description
ndarray

Resampled array in output_type domain.

Notes

Assumes periodic boundaries. Uses pad_array then crop_centered in frequency.

extract_centered_box(arr, box_shape_tuple, center)

Wrap-around extraction of a centered N-D box (JIT/static-friendly).

Parameters:

Name Type Description Default
arr jnp.ndarray, shape S = (N1, ..., Nd)
required
box_shape_tuple Tuple[int, ...]

Static Python tuple of ints (box sizes per axis).

required
center (ndarray, shape(d))

Center index in the same index space as arr (0..Ni-1).

required

Returns:

Type Description
jnp.ndarray, shape `box_shape_tuple`
rel_l2(a, b)

Compute relative L2 error between two arrays.

Parameters:

Name Type Description Default
a array - like

Reference array.

required
b array - like

Comparison array.

required

Returns:

Type Description
float

||a - b||_2 / (||a||_2 + 1e-30).

batch_data(*args, batch_size, zero_padded_args=())

Batch leading beam axis into (num_batches, batch_size, ...).

Parameters:

Name Type Description Default
*args Tuple[ndarray, ...]

Arrays with leading dimension b.

()
batch_size int
required
zero_padded_args Tuple[int, ...]

Indices into args that should be zero-padded in the last batch; others repeat their last entry.

()

Returns:

Type Description
Tuple[ndarray, ...]

Batched arrays (with padding if needed).

find_level(dyadic_decomp: DyadicDecomposition, box_num: int) -> Int[Array, '']

Map global box index → dyadic level.

Parameters:

Name Type Description Default
dyadic_decomp DyadicDecomposition
required
box_num int
required

Returns:

Type Description
ndarray

Scalar (0-D) array with the level such that cumulative boxes up to exceed box_num. JAX traces it as an integer scalar inside JIT.

find_tensor_and_multiindex(flat_indices: jnp.ndarray, shapes: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]

Decode flat indices over concatenated tensors.

Parameters:

Name Type Description Default
flat_indices (ndarray, shape(m))
required
shapes (ndarray, shape(L, k))

Shapes of each tensor (per level).

required

Returns:

Name Type Description
array_indices (ndarray, shape(m))

Which tensor each flat index belongs to.

multidimensional_indices (ndarray, shape(m, k))

Unravelled indices within that tensor.

compute_coeff_shapes(dyadic_decomp: DyadicDecomposition, redundancy: int, level) -> jnp.ndarray

Per-level coefficient tensor shapes.

Parameters:

Name Type Description Default
dyadic_decomp DyadicDecomposition
required
redundancy int

1 (basis) or 2 (frame).

required
level int

Vectorized: function is vmapped over level.

required

Returns:

Type Description
(ndarray, shape(k))

(num_boxes_level, *support_shape).