beamax.solvers.hybrid_solver_utils
Supporting utilities for hybrid solver preprocessing and postprocessing.
Scope
- Frequency-space partition helpers for LF/HF splitting.
- Resampling/interpolation and windowing utilities used by
HybridSolver.
API Reference
beamax.solvers.hybrid_solver_utils
gh_lowpass_filter(p0: jnp.ndarray, input_type: str, wpt: MSWPT, boxes_include: jnp.ndarray, windowing: str = 'rectangular', gh: Optional[jnp.ndarray] = None) -> Tuple[jnp.ndarray, jnp.ndarray]
Split into LF/HF via g,h frame filters in Fourier domain.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
p0
|
ndarray
|
|
required |
input_type
|
(spatial, fourier)
|
|
"spatial","fourier"
|
wpt
|
MSWPT
|
|
required |
boxes_include
|
ndarray
|
Indices of low-frequency boxes to include. |
required |
windowing
|
str
|
|
'rectangular'
|
gh
|
ndarray
|
Precomputed LF-projection filter from :func: |
None
|
Returns:
| Type | Description |
|---|---|
(p0_HF_ft, p0_LF_ft) : Tuple[jnp.ndarray, jnp.ndarray]
|
Fourier-domain HF and LF components. |
compute_gh_filter(wpt: MSWPT, boxes_include: jnp.ndarray, windowing: str = 'rectangular') -> jnp.ndarray
Compute the LF-projection filter gh = (Σ_{b∈LF} g_b^2) / Σ_b g_b^2.
This is the data-independent piece of :func:gh_lowpass_filter and is
therefore cacheable across calls that share (wpt, boxes_include,
windowing).
Implementation
Uses lax.fori_loop so that only one (*N,)-shape filter is
materialised at a time, mirroring the pattern used by
:meth:MSWPT.sum_gsquare. Avoids both the per-iter eager-trace overhead
of a Python for loop and the (num_boxes, *N) peak memory of a
pure vmap.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
wpt
|
MSWPT
|
|
required |
boxes_include
|
jnp.ndarray of int32
|
LF box indices. |
required |
windowing
|
str
|
|
'rectangular'
|
Returns:
| Type | Description |
|---|---|
jnp.ndarray, shape (*N,), real dtype
|
|
get_indices_between_two_opposing_corners(centers: jnp.ndarray, corner1_idx: int, corner2_idx: int) -> jnp.ndarray
Get the indices in the box defined by the two corners.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
centers
|
(ndarray, shape(num_centers, ndim))
|
Box centre coordinates. |
required |
corner1_idx
|
int
|
Index of the first corner. |
required |
corner2_idx
|
int
|
Index of the opposing corner. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Indices of centres inside the closed axis-aligned box. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If both corner indices refer to the same centre. |
get_indices_with_norm_less_than(centers: jnp.ndarray, norm: float, inclusive: bool = True) -> jnp.ndarray
Get the indices of the boxes with a norm less than (or equal to) the given value.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
centers
|
(ndarray, shape(num_centers, ndim))
|
Box centre coordinates. |
required |
norm
|
float
|
L-infinity norm threshold. |
required |
inclusive
|
bool
|
If |
True
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Indices whose centre norm satisfies the threshold. |
find_bounding_corner_indices(centers: jnp.ndarray, idx_box: jnp.ndarray) -> Tuple[int, int]
Find actual corner indices from a set of selected frequency indices.
Given a set of selected center indices, finds two opposing corners that exist in the centers array (rather than computing component-wise min/max which may not correspond to actual centers).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
centers
|
(ndarray, shape(num_centers, ndim))
|
All centre coordinates. |
required |
idx_box
|
ndarray
|
Indices of selected centres. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
corner1_idx |
int
|
Index of one selected bounding corner. |
corner2_idx |
int
|
Index of the opposing selected bounding corner. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
are_opposing(corner1: int, corner2: int) -> bool
Check two corners are opposing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
corner1
|
int
|
First corner index. |
required |
corner2
|
int
|
Second corner index. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
Whether the corner indices differ. |
get_bounds(dyadic_decomp: DyadicDecomposition, domain: Domain, corner1: int, corner2: int) -> jnp.ndarray
Get the bounds of the filter banks required, using the opposite corners of the box.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dyadic_decomp
|
DyadicDecomposition
|
Dyadic decomposition. |
required |
domain
|
Domain
|
Physical domain. |
required |
corner1
|
int
|
First corner index. |
required |
corner2
|
int
|
Opposing corner index. |
required |
Returns:
| Type | Description |
|---|---|
(ndarray, shape(ndim, 2))
|
Inclusive/exclusive bounds in grid coordinates. |
closest_power_of_two(size: int, max_size: int) -> int
Returns the closest power of two to the given size.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
size
|
int
|
Input size. |
required |
max_size
|
int
|
Upper bound for the returned size. |
required |
Returns:
| Type | Description |
|---|---|
int
|
Smallest power of two greater than or equal to |
downsample_p0(p0_LF: jnp.ndarray, bd: jnp.ndarray, use_power_of_two: bool = False) -> jnp.ndarray
Downsample p0 after applying the low pass filter to it.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
p0_LF
|
ndarray
|
Low-pass filtered field. |
required |
bd
|
(ndarray, shape(ndim, 2))
|
Bounds of the selected low-frequency support. |
required |
use_power_of_two
|
bool
|
Whether to round the crop size up to a power of two. |
False
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Centred crop of |
downsample_domain(domain: Domain, p0_LF_downsampled: jnp.ndarray) -> Domain
Downsample the domain after downsampling the p0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
domain
|
Domain
|
Original domain. |
required |
p0_LF_downsampled
|
ndarray
|
Downsampled low-frequency field. |
required |
Returns:
| Type | Description |
|---|---|
Domain
|
Domain with shape and spacing adjusted to |
Notes
The domain is downsampled to match the downsampled field.
split_frequency_components(p0: jnp.ndarray, sensors_mask: jnp.ndarray, input_type: str, output_type: str, wpt: MSWPT, box_corners: Optional[jnp.ndarray], windowing: str, domain: Domain, cutoff_freq: Optional[float] = None, downsample: bool = False, use_pow2: bool = False, gh: Optional[jnp.ndarray] = None) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, Domain]
Split input into high- and low-frequency components.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
p0
|
ndarray
|
Input field. |
required |
sensors_mask
|
ndarray
|
Sensor mask aligned with |
required |
input_type
|
(spatial, fourier)
|
Domain of |
"spatial"
|
output_type
|
(spatial, fourier)
|
Domain for returned components. |
"spatial"
|
wpt
|
MSWPT
|
Wave-packet transform defining the dyadic boxes. |
required |
box_corners
|
ndarray
|
Pair of box indices defining the low-frequency region. |
required |
windowing
|
str
|
Windowing type passed to the filter construction. |
required |
domain
|
Domain
|
Physical domain. |
required |
cutoff_freq
|
float
|
Frequency-radius alternative to |
None
|
downsample
|
bool
|
Whether to downsample the low-frequency part. |
False
|
use_pow2
|
bool
|
Whether downsampled sizes should be powers of two. |
False
|
gh
|
ndarray
|
Precomputed low-pass filter. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
p0_HF |
ndarray
|
High-frequency component in |
p0_LF |
ndarray
|
Low-frequency component in |
sensors_mask_ds |
ndarray
|
Possibly-downsampled sensors mask (same shape as p0_LF/p0_HF). |
dom_downsample |
Domain
|
Possibly-downsampled domain matching p0_LF. |
Notes
If the low-frequency index set is empty (no bins fall inside the requested
box / cutoff), we:
- return p0_LF = 0 (in output_type),
- return p0_HF = p0 (in output_type),
- leave sensors_mask and domain unchanged,
- skip downsampling entirely.
This avoids shape, interpolation, and domain-consistency errors.
oversample_window(array: jnp.ndarray, dt_oversample: int = 0, axis: int = 0, window_type: str = 'cos2') -> jnp.ndarray
Apply a windowing function to the array and oversample it in the temporal domain.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
array
|
ndarray
|
Input array to be windowed |
required |
dt_oversample
|
int
|
Number of points to oversample |
0
|
axis
|
int
|
Axis along which to apply the window |
0
|
window_type
|
str
|
Type of window to apply. Options: 'cos2', 'hann', 'hamming', 'blackman' |
'cos2'
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Windowed array |
interpolate_LF_soln(lf_downsampled: jnp.ndarray, target_size: Tuple, interpolation_method: str = 'spline', interp_window: str = 'cos2', dt_oversample: int = 0, spline_order: int = 3) -> jnp.ndarray
Interpolates a downsampled solution from a LF wave solver, to match the desired size.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lf_downsampled
|
ndarray
|
Downsampled low-frequency solver output. |
required |
target_size
|
Tuple[int, ...]
|
Desired output shape. |
required |
interpolation_method
|
(spline, fourier)
|
Interpolation method. |
"spline"
|
interp_window
|
(cos2, hann, hamming, blackman)
|
Temporal taper to apply before interpolation. |
"cos2"
|
dt_oversample
|
int
|
Number of oversampled time steps in the taper region. |
0
|
spline_order
|
int
|
Spline order for |
3
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Interpolated low-frequency solution. |