Skip to content

beamax.transforms

Multiscale wave packet transform (MSWPT) and frame construction utilities.

Key Objects

  • MSWPT: forward/inverse transforms between spatial/Fourier fields and wave-packet coefficients.
  • compute_frames: construct localized frame atoms in Fourier space.

API Reference

beamax.transforms

MSWPT(dyadic_decomp, redundancy, windowing)

Bases: Module

Multiscale Wave-Packet Transform.

Parameters:

Name Type Description Default
dyadic_decomp DyadicDecomposition

Frequency tiling (centres, per-level box lengths).

required
redundancy int

1 (basis) or 2 (tight frame). Static under JIT.

required
windowing (rectangular, rectangular_mirror, none)

Windowing for tile filters. Static under JIT.

"rectangular"

Attributes:

Name Type Description
dyadic_decomp DyadicDecomposition
redundancy int
windowing str
complex_dtype dtype

complex64 unless JAX x64 enabled → complex128.

sum_gsquare (ndarray, shape(*N))

Σ_b g_b^2 precomputed.

boxes_cumsum Tuple[int, ...]

Cumulative number of boxes per level (static).

coeff_shapes Tuple[Tuple[int, ...], ...]

Shape per level: (n_boxes_level, *support_shape).

coeffs_cumsum Tuple[int, ...]

Flat coefficient offsets per level (static).

total_coeffs int

Total number of flat coefficients (static).

gfilts_packed List[ndarray]

Per-level tile g packed to minimal support boxes.

_support_shapes List[Tuple[int, ...]]

Per-level support shapes (static).

_box_shapes List[Tuple[int, ...]]

Per-level “box lengths” in each axis (static).

_half_mask ndarray

Mask selecting positive-frequency half.

Notes
  • Constructor performs shape bookkeeping to keep runtime kernels static-JIT friendly.
  • All heavy transforms are pure JAX functions; no side effects.

Build a transform instance with precomputed static metadata.

Parameters:

Name Type Description Default
dyadic_decomp DyadicDecomposition

Provides centres, per-level box lengths, and meshgrid.

required
redundancy int

1 or 2. Controls per-level support sizes and total coeff count.

required
windowing str

Window type for Gaussian tiles.

required
Notes
  • Precomputes cumulative box/coeff offsets and per-level support shapes.
  • Chooses complex dtype from global JAX precision flag.
  • Packs representative per-level g filters for later slicing/rolls.
forward(data: Num[Array, '*N'], input_type: str) -> Num[Array, ' total_coeffs']

Forward MSWPT.

Parameters:

Name Type Description Default
data (ndarray, shape(*N), real or complex)

Input field in spatial or Fourier domain.

required
input_type (spatial, fourier)

Declares the domain of data.

"spatial"

Returns:

Type Description
(ndarray, shape(total_coeffs), complex)

Flat coefficient vector.

Notes
  • Converts to Fourier (utils.unitary_fft) if needed.
  • Divides by Σ g^2 to form the canonical tight-frame analysis.
  • JIT-compiled via @eqx.filter_jit(donate="all").
inverse(coeffs: Num[Array, ' total_coeffs'], output_type: str) -> Num[Array, '*N']

Fast, exact inverse MSWPT.

Parameters:

Name Type Description Default
coeffs (ndarray, shape(total_coeffs), complex)

Flat coefficient vector produced by :meth:forward.

required
output_type (spatial, fourier)

Domain of the returned array.

"spatial"

Returns:

Type Description
(ndarray, shape(*N), complex)

Reconstructed field in the requested domain.

Notes

The synthesis mirrors the analysis steps in :meth:forward::

forward:   patch = extract(F / Σg², centre);   c = IFFT(roll(g*patch, +r))
inverse:   tmp = FFT(c);                       add += g * roll(tmp, -r)

where the periodic scatter-add happens at the true box centre.

convert_to_array(coeffs: Num[Array, ' total_coeffs']) -> Num[Array, '*M']

Reshape flat coefficients into a dense tensor arranged by spatial support.

Parameters:

Name Type Description Default
coeffs (ndarray, shape(total_coeffs), complex)

Flat vector returned by :meth:forward.

required

Returns:

Type Description
ndarray

Dense coefficient tensor with per-level boxes unflattened and placed at their centred positions. Shape: (redundancy * N1, redundancy * N2, …), dtype: complex.

Notes
  • Intended for diagnostics/visualization; not required for forward/inverse.
  • Uses integer centres and per-level support shapes; pure JAX.
compute_windowed_gaussian(centre: Int[Array, ' d'], meshgrid: Int[Array, '*N d'], box_length, box_aspect_ratio: Union[Int[Array, ' d'], Tuple[int, ...]], domain_length: DomainLength, redundancy: int, windowing: str) -> Float[Array, '*N']

Windowed N-D Gaussian in Fourier index space.

Parameters:

Name Type Description Default
centre (ndarray, shape(d))

Tile centre in Fourier index units.

required
meshgrid (ndarray, shape(*N, d))

Integer-centred Fourier grid from DyadicDecomposition.fourier_meshgrid.

required
box_length int

Smallest-axis tile length for this level.

required
box_aspect_ratio (ndarray, shape(d))

Per-axis aspect multipliers. Values ≥ 1 with at least one 1.

required
domain_length int

Smallest side length min(N).

required
redundancy int

1 for basis, 2 for tight frame.

required
windowing (none, rectangular, rectangular_mirror)

Windowing function applied to the Gaussian.

"none"

Returns:

Type Description
(ndarray, shape(*N))

Windowed Gaussian weights. dtype = float32/float64.

Notes

Pure JAX, JIT-safe. Periodisation handled by modulo arithmetic.

single_filter_idx(centre_idx, meshgrid: Int[Array, '*N d'], dyadic_decomp: DyadicDecomposition, redundancy: int, windowing: str = 'rectangular') -> Float[Array, '*N']

Filter for a single tile (by global box index).

Parameters:

Name Type Description Default
centre_idx int

Global box index in dyadic_decomp.centres_ndim.

required
meshgrid (ndarray, shape(*N, d))

Fourier meshgrid.

required
dyadic_decomp DyadicDecomposition

Dyadic parameters providing centres and per-level box lengths.

required
redundancy int

1 (basis) or 2 (frame).

required
windowing str

Windowing function (see compute_windowed_gaussian).

'rectangular'

Returns:

Type Description
(ndarray, shape(*N))

Filter values.

single_filter_coord(centre: Int[Array, ' d'], level: int, meshgrid: Int[Array, '*N d'], dyadic_decomp: DyadicDecomposition, redundancy: int, windowing: str = 'rectangular') -> Float[Array, '*N']

Filter for a single tile (by (centre, level) pair).

Parameters:

Name Type Description Default
centre (ndarray, shape(d))

Tile centre in Fourier indices.

required
level int

Dyadic level (0..L-1).

required
meshgrid (ndarray, shape(*N, d))

Fourier meshgrid.

required
dyadic_decomp DyadicDecomposition

Decomposition parameters.

required
redundancy int

1 or 2.

required
windowing str

Windowing function.

'rectangular'

Returns:

Type Description
(ndarray, shape(*N))

Filter values.

compute_sum_gsquare(dyadic_decomp: DyadicDecomposition, redundancy: int, windowing: str = 'rectangular') -> Float[Array, '*N']

Sum of squares of all tile filters.

Parameters:

Name Type Description Default
dyadic_decomp DyadicDecomposition

Frequency tiling.

required
redundancy int

1 (basis) or 2 (frame).

required
windowing str

Windowing function.

'rectangular'

Returns:

Type Description
(ndarray, shape(*N))

Σ_b g_b^2 over all boxes.

Notes

Implemented with lax.fori_loop to avoid large vmaps.

compute_gh_filters(dyadic_decomp: DyadicDecomposition, redundancy: int, windowing: str = 'rectangular') -> Tuple[Float[Array, B * N], Float[Array, B * N]]

Compute g tiles and their dual h = g / Σ g^2.

Parameters:

Name Type Description Default
dyadic_decomp DyadicDecomposition

Frequency tiling.

required
redundancy int

1 or 2.

required
windowing str

Windowing function.

'rectangular'

Returns:

Type Description
(ndarray, ndarray)

(gfilt, hfilt) each with shape (num_boxes, *N).

compute_frames(dyadic_decomp: DyadicDecomposition, boxidx: int, k: Int[Array, ' d'], fourier_space: Num[Array, '*N d'], redundancy: int, windowing: str = 'rectangular') -> Num[Array, '*N']

Frame atom for box boxidx with plane-wave modulation.

Parameters:

Name Type Description Default
dyadic_decomp DyadicDecomposition

Frequency tiling.

required
boxidx int

Box index.

required
k (ndarray, shape(d))

Wave-vector.

required
fourier_space (ndarray, shape(*N, d))

Physical Fourier coordinates.

required
redundancy int

1 or 2.

required
windowing str

Windowing function.

'rectangular'

Returns:

Type Description
(ndarray, shape(*N))

Complex atom, dtype = complex64/complex128.