Skip to content

beamax.solvers.hybrid_solver_utils

Supporting utilities for hybrid solver preprocessing and postprocessing.

Scope

  • Frequency-space partition helpers for LF/HF splitting.
  • Resampling/interpolation and windowing utilities used by HybridSolver.

API Reference

beamax.solvers.hybrid_solver_utils

gh_lowpass_filter(p0: jnp.ndarray, input_type: str, wpt: MSWPT, boxes_include: jnp.ndarray, windowing: str = 'rectangular', gh: Optional[jnp.ndarray] = None) -> Tuple[jnp.ndarray, jnp.ndarray]

Split into LF/HF via g,h frame filters in Fourier domain.

Parameters:

Name Type Description Default
p0 ndarray
required
input_type (spatial, fourier)
"spatial","fourier"
wpt MSWPT
required
boxes_include ndarray

Indices of low-frequency boxes to include.

required
windowing str
'rectangular'
gh ndarray

Precomputed LF-projection filter from :func:compute_gh_filter. Pass this when the same (wpt, boxes_include, windowing) is reused across many inputs to skip the (data-independent) filter computation.

None

Returns:

Type Description
(p0_HF_ft, p0_LF_ft) : Tuple[jnp.ndarray, jnp.ndarray]

Fourier-domain HF and LF components.

compute_gh_filter(wpt: MSWPT, boxes_include: jnp.ndarray, windowing: str = 'rectangular') -> jnp.ndarray

Compute the LF-projection filter gh = (Σ_{b∈LF} g_b^2) / Σ_b g_b^2.

This is the data-independent piece of :func:gh_lowpass_filter and is therefore cacheable across calls that share (wpt, boxes_include, windowing).

Implementation

Uses lax.fori_loop so that only one (*N,)-shape filter is materialised at a time, mirroring the pattern used by :meth:MSWPT.sum_gsquare. Avoids both the per-iter eager-trace overhead of a Python for loop and the (num_boxes, *N) peak memory of a pure vmap.

Parameters:

Name Type Description Default
wpt MSWPT
required
boxes_include jnp.ndarray of int32

LF box indices.

required
windowing str
'rectangular'

Returns:

Type Description
jnp.ndarray, shape (*N,), real dtype
get_indices_between_two_opposing_corners(centers: jnp.ndarray, corner1_idx: int, corner2_idx: int) -> jnp.ndarray

Get the indices in the box defined by the two corners.

Parameters:

Name Type Description Default
centers (ndarray, shape(num_centers, ndim))

Box centre coordinates.

required
corner1_idx int

Index of the first corner.

required
corner2_idx int

Index of the opposing corner.

required

Returns:

Type Description
ndarray

Indices of centres inside the closed axis-aligned box.

Raises:

Type Description
ValueError

If both corner indices refer to the same centre.

get_indices_with_norm_less_than(centers: jnp.ndarray, norm: float, inclusive: bool = True) -> jnp.ndarray

Get the indices of the boxes with a norm less than (or equal to) the given value.

Parameters:

Name Type Description Default
centers (ndarray, shape(num_centers, ndim))

Box centre coordinates.

required
norm float

L-infinity norm threshold.

required
inclusive bool

If True, use <=; otherwise use <.

True

Returns:

Type Description
ndarray

Indices whose centre norm satisfies the threshold.

find_bounding_corner_indices(centers: jnp.ndarray, idx_box: jnp.ndarray) -> Tuple[int, int]

Find actual corner indices from a set of selected frequency indices.

Given a set of selected center indices, finds two opposing corners that exist in the centers array (rather than computing component-wise min/max which may not correspond to actual centers).

Parameters:

Name Type Description Default
centers (ndarray, shape(num_centers, ndim))

All centre coordinates.

required
idx_box ndarray

Indices of selected centres.

required

Returns:

Name Type Description
corner1_idx int

Index of one selected bounding corner.

corner2_idx int

Index of the opposing selected bounding corner.

Raises:

Type Description
ValueError

If idx_box is empty.

are_opposing(corner1: int, corner2: int) -> bool

Check two corners are opposing.

Parameters:

Name Type Description Default
corner1 int

First corner index.

required
corner2 int

Second corner index.

required

Returns:

Type Description
bool

Whether the corner indices differ.

get_bounds(dyadic_decomp: DyadicDecomposition, domain: Domain, corner1: int, corner2: int) -> jnp.ndarray

Get the bounds of the filter banks required, using the opposite corners of the box.

Parameters:

Name Type Description Default
dyadic_decomp DyadicDecomposition

Dyadic decomposition.

required
domain Domain

Physical domain.

required
corner1 int

First corner index.

required
corner2 int

Opposing corner index.

required

Returns:

Type Description
(ndarray, shape(ndim, 2))

Inclusive/exclusive bounds in grid coordinates.

closest_power_of_two(size: int, max_size: int) -> int

Returns the closest power of two to the given size.

Parameters:

Name Type Description Default
size int

Input size.

required
max_size int

Upper bound for the returned size.

required

Returns:

Type Description
int

Smallest power of two greater than or equal to size, capped at max_size.

downsample_p0(p0_LF: jnp.ndarray, bd: jnp.ndarray, use_power_of_two: bool = False) -> jnp.ndarray

Downsample p0 after applying the low pass filter to it.

Parameters:

Name Type Description Default
p0_LF ndarray

Low-pass filtered field.

required
bd (ndarray, shape(ndim, 2))

Bounds of the selected low-frequency support.

required
use_power_of_two bool

Whether to round the crop size up to a power of two.

False

Returns:

Type Description
ndarray

Centred crop of p0_LF.

downsample_domain(domain: Domain, p0_LF_downsampled: jnp.ndarray) -> Domain

Downsample the domain after downsampling the p0.

Parameters:

Name Type Description Default
domain Domain

Original domain.

required
p0_LF_downsampled ndarray

Downsampled low-frequency field.

required

Returns:

Type Description
Domain

Domain with shape and spacing adjusted to p0_LF_downsampled.

Notes

The domain is downsampled to match the downsampled field.

split_frequency_components(p0: jnp.ndarray, sensors_mask: jnp.ndarray, input_type: str, output_type: str, wpt: MSWPT, box_corners: Optional[jnp.ndarray], windowing: str, domain: Domain, cutoff_freq: Optional[float] = None, downsample: bool = False, use_pow2: bool = False, gh: Optional[jnp.ndarray] = None) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, Domain]

Split input into high- and low-frequency components.

Parameters:

Name Type Description Default
p0 ndarray

Input field.

required
sensors_mask ndarray

Sensor mask aligned with p0.

required
input_type (spatial, fourier)

Domain of p0.

"spatial"
output_type (spatial, fourier)

Domain for returned components.

"spatial"
wpt MSWPT

Wave-packet transform defining the dyadic boxes.

required
box_corners ndarray

Pair of box indices defining the low-frequency region.

required
windowing str

Windowing type passed to the filter construction.

required
domain Domain

Physical domain.

required
cutoff_freq float

Frequency-radius alternative to box_corners.

None
downsample bool

Whether to downsample the low-frequency part.

False
use_pow2 bool

Whether downsampled sizes should be powers of two.

False
gh ndarray

Precomputed low-pass filter.

None

Returns:

Name Type Description
p0_HF ndarray

High-frequency component in output_type space.

p0_LF ndarray

Low-frequency component in output_type space.

sensors_mask_ds ndarray

Possibly-downsampled sensors mask (same shape as p0_LF/p0_HF).

dom_downsample Domain

Possibly-downsampled domain matching p0_LF.

Notes

If the low-frequency index set is empty (no bins fall inside the requested box / cutoff), we: - return p0_LF = 0 (in output_type), - return p0_HF = p0 (in output_type), - leave sensors_mask and domain unchanged, - skip downsampling entirely.

This avoids shape, interpolation, and domain-consistency errors.

oversample_window(array: jnp.ndarray, dt_oversample: int = 0, axis: int = 0, window_type: str = 'cos2') -> jnp.ndarray

Apply a windowing function to the array and oversample it in the temporal domain.

Parameters:

Name Type Description Default
array ndarray

Input array to be windowed

required
dt_oversample int

Number of points to oversample

0
axis int

Axis along which to apply the window

0
window_type str

Type of window to apply. Options: 'cos2', 'hann', 'hamming', 'blackman'

'cos2'

Returns:

Type Description
ndarray

Windowed array

interpolate_LF_soln(lf_downsampled: jnp.ndarray, target_size: Tuple, interpolation_method: str = 'spline', interp_window: str = 'cos2', dt_oversample: int = 0, spline_order: int = 3) -> jnp.ndarray

Interpolates a downsampled solution from a LF wave solver, to match the desired size.

Parameters:

Name Type Description Default
lf_downsampled ndarray

Downsampled low-frequency solver output.

required
target_size Tuple[int, ...]

Desired output shape.

required
interpolation_method (spline, fourier)

Interpolation method.

"spline"
interp_window (cos2, hann, hamming, blackman)

Temporal taper to apply before interpolation.

"cos2"
dt_oversample int

Number of oversampled time steps in the taper region.

0
spline_order int

Spline order for scipy.ndimage.zoom.

3

Returns:

Type Description
ndarray

Interpolated low-frequency solution.