beamax.utils
General utilities shared across decomposition, transforms, and solvers.
Typical Use Cases
- FFT/interpolation helpers for resampling and signal manipulation.
- Batching/index utilities for coefficient and beam pipelines.
- Miscellaneous helpers for device checks, synthetic data generation, and numeric convenience operations.
API Reference
beamax.utils
Utilities public API (explicit). Heavy things must be imported inside call-sites.
Interpolator(grid_points: List[jnp.ndarray], values: jnp.ndarray, **_)
Thin wrapper around make_c_function_from_grid using axis vectors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
grid_points
|
List[ndarray]
|
1D axis arrays (length d), each strictly increasing. |
required |
values
|
ndarray
|
Grid values shaped to match |
required |
Methods:
| Name | Description |
|---|---|
__call__ |
Evaluate interpolant at |
grad |
Gradient |
hessian |
Hessian |
Construct an interpolator from axis vectors and grid values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
grid_points
|
List[ndarray]
|
One strictly increasing one-dimensional coordinate vector per axis. |
required |
values
|
ndarray
|
Grid values with dimensionality matching |
required |
**_
|
dict
|
Ignored compatibility keyword arguments. |
{}
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If axis count and value dimensionality disagree, or if any axis is not one-dimensional with at least two points. |
grad(x: jnp.ndarray) -> jnp.ndarray
Evaluate the gradient of the interpolant.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
(ndarray, shape(d))
|
Physical query coordinate. |
required |
Returns:
| Type | Description |
|---|---|
(ndarray, shape(d))
|
Gradient at |
hessian(x: jnp.ndarray) -> jnp.ndarray
Evaluate the Hessian of the interpolant.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
(ndarray, shape(d))
|
Physical query coordinate. |
required |
Returns:
| Type | Description |
|---|---|
(ndarray, shape(d, d))
|
Hessian at |
get_devices()
Inspect available JAX devices.
Returns:
| Type | Description |
|---|---|
Tuple[bool, bool, bool]
|
Presence flags for (CPU, GPU, TPU). Prints the detected devices. |
memory_estimate(dims: jnp.ndarray, dtype: jnp.dtype) -> str
Estimate memory footprint for an array shape/dtype.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dims
|
ndarray
|
Shape tuple or array of dimensions. |
required |
dtype
|
dtype
|
|
required |
Returns:
| Type | Description |
|---|---|
str
|
Estimated memory usage. |
memory_str(x: jnp.ndarray) -> str
Memory usage of an existing array.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
ndarray
|
|
required |
Returns:
| Type | Description |
|---|---|
str
|
Human-friendly memory string. |
array_str(x: Union[jnp.ndarray, None]) -> Union[str, None]
Short descriptor for an array (shape/dtype/memory).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
ndarray | None
|
|
required |
Returns:
| Type | Description |
|---|---|
str | None
|
|
detect_root() -> Path
Locate the repository root used by examples for output files.
Priority
1) BEAMAX_ROOT environment variable
2) Current working directory upward search
3) Package source location upward search
4) Current working directory
Returns:
| Type | Description |
|---|---|
Path
|
|
example_plot_dir(example_file: str | os.PathLike[str]) -> Path
Return the plot output directory for a public example script.
Examples mirror their first directory under examples/:
examples/forward/2d_forward.py-><root>/plots/forwardexamples/rays/2d_ray_bending.py-><root>/plots/rays
If example_file is outside the detected checkout's examples tree,
the file's immediate parent directory name is used as the category.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
example_file
|
str or PathLike
|
Usually |
required |
Returns:
| Type | Description |
|---|---|
Path
|
Existing output directory for plots from that example category. |
unitary_fft(arr: jnp.ndarray) -> jnp.ndarray
Unitary N-D FFT with centred zero-frequency component.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
arr
|
ndarray
|
Real or complex array in the spatial domain. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Fourier transform with |
unitary_ifft(arr: jnp.ndarray) -> jnp.ndarray
Unitary inverse FFT matching :func:unitary_fft.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
arr
|
ndarray
|
Fourier-domain array (already shifted). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Spatial-domain array with unitary scaling. |
convert_space(array: jnp.ndarray, input_space: str, target_space: str) -> jnp.ndarray
Convert an array between spatial and Fourier domains.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
array
|
ndarray
|
|
required |
input_space
|
(spatial, fourier)
|
Declares the domain of |
"spatial"
|
target_space
|
(spatial, fourier)
|
Desired output domain. |
"spatial"
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Array in the requested domain. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If an unsupported conversion is requested. |
make_c_function_from_grid(c_map: jnp.ndarray, spacing: Optional[Tuple[float, ...]] = None, origin: Optional[Tuple[float, ...]] = None)
Build a JAX-differentiable n-linear interpolant over a rectilinear grid.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
c_map
|
(ndarray, shape(*N))
|
Grid values. |
required |
spacing
|
Tuple[float, ...] | None
|
Per-axis spacing. Defaults to 1.0. |
None
|
origin
|
Tuple[float, ...] | None
|
Per-axis origin. Defaults to 0.0. |
None
|
Returns:
| Type | Description |
|---|---|
Callable[[ndarray], ndarray]
|
Function |
Notes
- Piecewise-linear; differentiable a.e.; gradients via
jax.grad. - Index clamping at grid boundaries.
interpolate_nearest(array: jnp.ndarray, new_shape: Tuple) -> jnp.ndarray
Nearest-neighbour resampling to new_shape.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
array
|
ndarray
|
|
required |
new_shape
|
Tuple[int, ...]
|
|
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Resampled array. |
pad_array(array: jnp.ndarray, desired_size: Tuple[int, ...], mode: str = 'constant') -> jnp.ndarray
Centered pad per axis to reach desired_size.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
array
|
ndarray
|
|
required |
desired_size
|
Tuple[int, ...]
|
|
required |
mode
|
('constant', 'edge')
|
|
"constant"
|
Returns:
| Type | Description |
|---|---|
ndarray
|
|
pad_zero(array: jnp.ndarray, desired_size: Tuple) -> jnp.ndarray
Zero-pad to centered target size.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
array
|
ndarray
|
|
required |
desired_size
|
Tuple[int, ...]
|
|
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
|
pad_edge(array: jnp.ndarray, desired_size: Tuple) -> jnp.ndarray
Edge-pad (replicate border) to centered target size.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
array
|
ndarray
|
|
required |
desired_size
|
Tuple[int, ...]
|
|
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
|
crop_centered(array: jnp.ndarray, desired_size: Tuple[int, ...]) -> jnp.ndarray
Centered crop per axis. No-op if any target dim exceeds current.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
array
|
ndarray
|
|
required |
desired_size
|
Tuple[int, ...]
|
|
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
|
interpolate_fourier(array: jnp.ndarray, desired_size: Tuple[int, ...], input_type: str, output_type: str) -> jnp.ndarray
Unitary FFT-based resampling (pad when upsampling, crop when downsampling).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
array
|
ndarray
|
|
required |
desired_size
|
Tuple[int, ...]
|
|
required |
input_type
|
('spatial', 'fourier')
|
|
"spatial"
|
output_type
|
('spatial', 'fourier')
|
|
"spatial"
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Resampled array in |
Notes
Assumes periodic boundaries. Uses pad_array then crop_centered in frequency.
extract_centered_box(arr, box_shape_tuple, center)
Wrap-around extraction of a centered N-D box (JIT/static-friendly).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
arr
|
jnp.ndarray, shape S = (N1, ..., Nd)
|
|
required |
box_shape_tuple
|
Tuple[int, ...]
|
Static Python tuple of ints (box sizes per axis). |
required |
center
|
(ndarray, shape(d))
|
Center index in the same index space as |
required |
Returns:
| Type | Description |
|---|---|
jnp.ndarray, shape `box_shape_tuple`
|
|
rel_l2(a, b)
Compute relative L2 error between two arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a
|
array - like
|
Reference array. |
required |
b
|
array - like
|
Comparison array. |
required |
Returns:
| Type | Description |
|---|---|
float
|
|
batch_data(*args, batch_size, zero_padded_args=())
Batch leading beam axis into (num_batches, batch_size, ...).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
Tuple[ndarray, ...]
|
Arrays with leading dimension |
()
|
batch_size
|
int
|
|
required |
zero_padded_args
|
Tuple[int, ...]
|
Indices into |
()
|
Returns:
| Type | Description |
|---|---|
Tuple[ndarray, ...]
|
Batched arrays (with padding if needed). |
find_level(dyadic_decomp: DyadicDecomposition, box_num: int) -> Int[Array, '']
Map global box index → dyadic level.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dyadic_decomp
|
DyadicDecomposition
|
|
required |
box_num
|
int
|
|
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Scalar (0-D) array with the level |
find_tensor_and_multiindex(flat_indices: jnp.ndarray, shapes: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]
Decode flat indices over concatenated tensors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
flat_indices
|
(ndarray, shape(m))
|
|
required |
shapes
|
(ndarray, shape(L, k))
|
Shapes of each tensor (per level). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
array_indices |
(ndarray, shape(m))
|
Which tensor each flat index belongs to. |
multidimensional_indices |
(ndarray, shape(m, k))
|
Unravelled indices within that tensor. |
compute_coeff_shapes(dyadic_decomp: DyadicDecomposition, redundancy: int, level) -> jnp.ndarray
Per-level coefficient tensor shapes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dyadic_decomp
|
DyadicDecomposition
|
|
required |
redundancy
|
int
|
1 (basis) or 2 (frame). |
required |
level
|
int
|
Vectorized: function is |
required |
Returns:
| Type | Description |
|---|---|
(ndarray, shape(k))
|
|