beamax.transforms
Multiscale wave packet transform (MSWPT) and frame construction utilities.
Key Objects
MSWPT: forward/inverse transforms between spatial/Fourier fields and wave-packet coefficients.compute_frames: construct localized frame atoms in Fourier space.
API Reference
beamax.transforms
MSWPT(dyadic_decomp, redundancy, windowing)
Bases: Module
Multiscale Wave-Packet Transform.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dyadic_decomp
|
DyadicDecomposition
|
Frequency tiling (centres, per-level box lengths). |
required |
redundancy
|
int
|
1 (basis) or 2 (tight frame). Static under JIT. |
required |
windowing
|
(rectangular, rectangular_mirror, none)
|
Windowing for tile filters. Static under JIT. |
"rectangular"
|
Attributes:
| Name | Type | Description |
|---|---|---|
dyadic_decomp |
DyadicDecomposition
|
|
redundancy |
int
|
|
windowing |
str
|
|
complex_dtype |
dtype
|
complex64 unless JAX x64 enabled → complex128. |
sum_gsquare |
(ndarray, shape(*N))
|
Σ_b g_b^2 precomputed. |
boxes_cumsum |
Tuple[int, ...]
|
Cumulative number of boxes per level (static). |
coeff_shapes |
Tuple[Tuple[int, ...], ...]
|
Shape per level: (n_boxes_level, *support_shape). |
coeffs_cumsum |
Tuple[int, ...]
|
Flat coefficient offsets per level (static). |
total_coeffs |
int
|
Total number of flat coefficients (static). |
gfilts_packed |
List[ndarray]
|
Per-level tile |
_support_shapes |
List[Tuple[int, ...]]
|
Per-level support shapes (static). |
_box_shapes |
List[Tuple[int, ...]]
|
Per-level “box lengths” in each axis (static). |
_half_mask |
ndarray
|
Mask selecting positive-frequency half. |
Notes
- Constructor performs shape bookkeeping to keep runtime kernels static-JIT friendly.
- All heavy transforms are pure JAX functions; no side effects.
Build a transform instance with precomputed static metadata.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dyadic_decomp
|
DyadicDecomposition
|
Provides centres, per-level box lengths, and meshgrid. |
required |
redundancy
|
int
|
1 or 2. Controls per-level support sizes and total coeff count. |
required |
windowing
|
str
|
Window type for Gaussian tiles. |
required |
Notes
- Precomputes cumulative box/coeff offsets and per-level support shapes.
- Chooses complex dtype from global JAX precision flag.
- Packs representative per-level
gfilters for later slicing/rolls.
forward(data: Num[Array, '*N'], input_type: str) -> Num[Array, ' total_coeffs']
Forward MSWPT.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
(ndarray, shape(*N), real or complex)
|
Input field in spatial or Fourier domain. |
required |
input_type
|
(spatial, fourier)
|
Declares the domain of |
"spatial"
|
Returns:
| Type | Description |
|---|---|
(ndarray, shape(total_coeffs), complex)
|
Flat coefficient vector. |
Notes
- Converts to Fourier (
utils.unitary_fft) if needed. - Divides by Σ g^2 to form the canonical tight-frame analysis.
- JIT-compiled via
@eqx.filter_jit(donate="all").
inverse(coeffs: Num[Array, ' total_coeffs'], output_type: str) -> Num[Array, '*N']
Fast, exact inverse MSWPT.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coeffs
|
(ndarray, shape(total_coeffs), complex)
|
Flat coefficient vector produced by :meth: |
required |
output_type
|
(spatial, fourier)
|
Domain of the returned array. |
"spatial"
|
Returns:
| Type | Description |
|---|---|
(ndarray, shape(*N), complex)
|
Reconstructed field in the requested domain. |
Notes
The synthesis mirrors the analysis steps in :meth:forward::
forward: patch = extract(F / Σg², centre); c = IFFT(roll(g*patch, +r))
inverse: tmp = FFT(c); add += g * roll(tmp, -r)
where the periodic scatter-add happens at the true box centre.
convert_to_array(coeffs: Num[Array, ' total_coeffs']) -> Num[Array, '*M']
Reshape flat coefficients into a dense tensor arranged by spatial support.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coeffs
|
(ndarray, shape(total_coeffs), complex)
|
Flat vector returned by :meth: |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Dense coefficient tensor with per-level boxes unflattened and placed
at their centred positions. Shape: |
Notes
- Intended for diagnostics/visualization; not required for forward/inverse.
- Uses integer centres and per-level support shapes; pure JAX.
compute_windowed_gaussian(centre: Int[Array, ' d'], meshgrid: Int[Array, '*N d'], box_length, box_aspect_ratio: Union[Int[Array, ' d'], Tuple[int, ...]], domain_length: DomainLength, redundancy: int, windowing: str) -> Float[Array, '*N']
Windowed N-D Gaussian in Fourier index space.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
centre
|
(ndarray, shape(d))
|
Tile centre in Fourier index units. |
required |
meshgrid
|
(ndarray, shape(*N, d))
|
Integer-centred Fourier grid from |
required |
box_length
|
int
|
Smallest-axis tile length for this level. |
required |
box_aspect_ratio
|
(ndarray, shape(d))
|
Per-axis aspect multipliers. Values ≥ 1 with at least one 1. |
required |
domain_length
|
int
|
Smallest side length |
required |
redundancy
|
int
|
1 for basis, 2 for tight frame. |
required |
windowing
|
(none, rectangular, rectangular_mirror)
|
Windowing function applied to the Gaussian. |
"none"
|
Returns:
| Type | Description |
|---|---|
(ndarray, shape(*N))
|
Windowed Gaussian weights. dtype = float32/float64. |
Notes
Pure JAX, JIT-safe. Periodisation handled by modulo arithmetic.
single_filter_idx(centre_idx, meshgrid: Int[Array, '*N d'], dyadic_decomp: DyadicDecomposition, redundancy: int, windowing: str = 'rectangular') -> Float[Array, '*N']
Filter for a single tile (by global box index).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
centre_idx
|
int
|
Global box index in |
required |
meshgrid
|
(ndarray, shape(*N, d))
|
Fourier meshgrid. |
required |
dyadic_decomp
|
DyadicDecomposition
|
Dyadic parameters providing centres and per-level box lengths. |
required |
redundancy
|
int
|
1 (basis) or 2 (frame). |
required |
windowing
|
str
|
Windowing function (see |
'rectangular'
|
Returns:
| Type | Description |
|---|---|
(ndarray, shape(*N))
|
Filter values. |
single_filter_coord(centre: Int[Array, ' d'], level: int, meshgrid: Int[Array, '*N d'], dyadic_decomp: DyadicDecomposition, redundancy: int, windowing: str = 'rectangular') -> Float[Array, '*N']
Filter for a single tile (by (centre, level) pair).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
centre
|
(ndarray, shape(d))
|
Tile centre in Fourier indices. |
required |
level
|
int
|
Dyadic level (0..L-1). |
required |
meshgrid
|
(ndarray, shape(*N, d))
|
Fourier meshgrid. |
required |
dyadic_decomp
|
DyadicDecomposition
|
Decomposition parameters. |
required |
redundancy
|
int
|
1 or 2. |
required |
windowing
|
str
|
Windowing function. |
'rectangular'
|
Returns:
| Type | Description |
|---|---|
(ndarray, shape(*N))
|
Filter values. |
compute_sum_gsquare(dyadic_decomp: DyadicDecomposition, redundancy: int, windowing: str = 'rectangular') -> Float[Array, '*N']
Sum of squares of all tile filters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dyadic_decomp
|
DyadicDecomposition
|
Frequency tiling. |
required |
redundancy
|
int
|
1 (basis) or 2 (frame). |
required |
windowing
|
str
|
Windowing function. |
'rectangular'
|
Returns:
| Type | Description |
|---|---|
(ndarray, shape(*N))
|
Σ_b g_b^2 over all boxes. |
Notes
Implemented with lax.fori_loop to avoid large vmaps.
compute_gh_filters(dyadic_decomp: DyadicDecomposition, redundancy: int, windowing: str = 'rectangular') -> Tuple[Float[Array, B * N], Float[Array, B * N]]
Compute g tiles and their dual h = g / Σ g^2.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dyadic_decomp
|
DyadicDecomposition
|
Frequency tiling. |
required |
redundancy
|
int
|
1 or 2. |
required |
windowing
|
str
|
Windowing function. |
'rectangular'
|
Returns:
| Type | Description |
|---|---|
(ndarray, ndarray)
|
|
compute_frames(dyadic_decomp: DyadicDecomposition, boxidx: int, k: Int[Array, ' d'], fourier_space: Num[Array, '*N d'], redundancy: int, windowing: str = 'rectangular') -> Num[Array, '*N']
Frame atom for box boxidx with plane-wave modulation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dyadic_decomp
|
DyadicDecomposition
|
Frequency tiling. |
required |
boxidx
|
int
|
Box index. |
required |
k
|
(ndarray, shape(d))
|
Wave-vector. |
required |
fourier_space
|
(ndarray, shape(*N, d))
|
Physical Fourier coordinates. |
required |
redundancy
|
int
|
1 or 2. |
required |
windowing
|
str
|
Windowing function. |
'rectangular'
|
Returns:
| Type | Description |
|---|---|
(ndarray, shape(*N))
|
Complex atom, dtype = complex64/complex128. |