Skip to content

Reconstruction

Reconstruction examples with MSGB. These cover time reversal and adjoint operators for recovering \(p_0\) from sensor data. Examples marked optional require beamax[kwave,viz-mpl] and are skipped by the default smoke suite.


2D time reversal and adjoint

Optional: compare k-Wave time-reversal and adjoint reconstructions for a tiny 2D \(p_0\).

Open In Colab

#!/usr/bin/env python
"""
2D MSGB vs k-Wave reconstruction: time reversal + adjoint.

Runs a compact inverse-comparison workflow on one small 2D problem. Steps:

  1. Build a smooth two-Gaussian $p_0$ and a one-sided boundary sensor line.
  2. Forward-simulate with k-Wave to get the sensor record.
  3. Reconstruct with both k-Wave and MSGB via time reversal AND adjoint
     back-propagation (four reconstructions in total).
  4. Plot a 3-row comparison figure: $p_0$ + 2 TR images on top,
     sensor data + 2 adjoint images in the middle, 1D profile through them on
     the bottom. Print relative-L2 metrics against the truth.

MSGB time-reversal in 2D needs a frequency-cropped data domain and a paired
data-WPT — the helper ``prepare_data_domain_for_msgb`` below keeps that setup
local and explicit.

Example category: Reconstruction
Example extras: kwave,viz-mpl
Example smoke: false
"""

import jax
import jax.numpy as jnp
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np

from beamax import utils
from beamax.decomposition import DyadicDecomposition
from beamax.geometry import Domain, Sensor
from beamax.gb import gb_solvers
from beamax.solvers import MSGBSolver
from beamax.transforms import MSWPT


jax.config.update("jax_enable_x64", True)

INSTALL_HINT = 'pip install -e ".[kwave,viz-mpl]"'


# ---------------------------------------------------------------------------
# Setup helpers
# ---------------------------------------------------------------------------


def load_kwave_solver():
    """Import k-Wave lazily so base beamax installs can still import this file."""
    try:
        from beamax.solvers import KWaveSolver
    except ImportError as exc:
        print(f"Skipping optional example: k-Wave is not installed ({INSTALL_HINT}).")
        raise SystemExit(0) from exc
    return KWaveSolver


def c_homogeneous(x: jnp.ndarray) -> jnp.ndarray:
    return 1500.0 + 0.0 * x[..., 0]


def make_two_gaussian_phantom(domain: Domain) -> jnp.ndarray:
    """Two smooth Gaussian inclusions with zero mean, normalised to peak |p| = 1."""
    lx, ly = domain.grid_size
    x, y = jnp.meshgrid(
        jnp.arange(domain.N[0]) * domain.dx[0],
        jnp.arange(domain.N[1]) * domain.dx[1],
        indexing="ij",
    )
    p0 = jnp.exp(
        -((x - 0.38 * lx) ** 2 + (y - 0.45 * ly) ** 2) / (2.0 * (0.08 * lx) ** 2)
    )
    p0 -= 0.7 * jnp.exp(
        -((x - 0.62 * lx) ** 2 + (y - 0.58 * ly) ** 2) / (2.0 * (0.09 * lx) ** 2)
    )
    p0 = p0 - jnp.mean(p0)
    return p0 / jnp.max(jnp.abs(p0))


def coerce_image(arr: jnp.ndarray, shape: tuple[int, int]) -> np.ndarray:
    """Coerce k-Wave image output to ``shape``; handle the transposed-output case."""
    image = np.asarray(arr)
    if image.shape == shape:
        return image
    if image.T.shape == shape:
        return image.T
    return image.reshape(shape)


def scaled(recon: np.ndarray, truth: np.ndarray) -> tuple[np.ndarray, float]:
    """Best L2 scale of recon onto truth; returns (scaled, rel_l2)."""
    r = np.asarray(recon).real
    t = np.asarray(truth).real
    s = float(np.vdot(r, t) / (np.vdot(r, r) + 1e-30))
    out = s * r
    rel_l2 = float(np.linalg.norm(out - t) / (np.linalg.norm(t) + 1e-30))
    return out, rel_l2


# ---------------------------------------------------------------------------
# MSGB data-domain construction
# ---------------------------------------------------------------------------


def _cut_out_middle(arr: jnp.ndarray, size: int) -> jnp.ndarray:
    """Keep the middle ``size`` samples along axis 0 (after fftshift)."""
    mid = arr.shape[0] // 2
    return arr[mid - size // 2 : mid + size // 2]


def prepare_data_domain_for_msgb(
    sensor_data_kw: jnp.ndarray,
    domain: Domain,
    ts: jnp.ndarray,
    *,
    over_resolve: int = 2,
):
    """
    Build the (Nt', Ns) data domain and its paired MSWPT that MSGB needs for
    TR/adjoint, by Fourier-cropping the k-Wave sensor record in time.

    Returns
    -------
    sensor_data_cropped : (Nt', Ns)
    domain_data : Domain on (Nt', Ns) with dx = (dt', dx_y)
    wpt_data : MSWPT on the data domain
    ts_data : (Nt',) new time grid
    """
    sensor_arr = jnp.asarray(sensor_data_kw)
    if sensor_arr.ndim != 2:
        raise ValueError(f"Expected (Nt, Ns) sensor data; got {sensor_arr.shape}")

    nt_cropped = over_resolve * domain.N[0]
    if sensor_arr.shape[0] < nt_cropped:
        raise ValueError(
            f"Need >= {nt_cropped} time samples; got {sensor_arr.shape[0]}."
        )
    fft = utils.unitary_fft(sensor_arr)
    cropped_fft = _cut_out_middle(fft, nt_cropped)
    sensor_data_cropped = utils.unitary_ifft(cropped_fft).real

    nt_data, ns = sensor_data_cropped.shape
    ts_data = jnp.linspace(float(ts[0]), float(ts[-1]), nt_data)
    dt_data = float(ts_data[1] - ts_data[0])
    dx_y = float(domain.dx[1])
    domain_data = Domain(
        N=(nt_data, ns),
        dx=(dt_data, dx_y),
        c=domain.c,
        periodic=domain.periodic,
        cfl=domain.cfl,
    )

    # Aspect ratio set so the dyadic decomposition matches the rectangular data
    # shape (nt_data is typically over_resolve * ns for over_resolve == 2).
    n_min = min(nt_data, ns)
    box_aspect = (nt_data // n_min, ns // n_min)
    dyadic_data = DyadicDecomposition(
        num_levels=2,
        N=(nt_data, ns),
        num_boxes_levels=(4, 8),
        box_aspect_ratio=box_aspect,
    )
    wpt_data = MSWPT(dyadic_data, redundancy=2, windowing="rectangular_mirror")
    return sensor_data_cropped, domain_data, wpt_data, ts_data


# ---------------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------------


def plot_comparison(
    p0,
    sensor_data,
    tr_kw,
    tr_msgb,
    adj_kw,
    adj_msgb,
    domain,
    sensors,
    *,
    out_path,
):
    """Thesis-style 3-row layout: 2 imshow rows + 1 profile row."""
    arrays = [np.asarray(a).real for a in (p0, tr_kw, tr_msgb, adj_kw, adj_msgb)]
    sensor_arr = np.asarray(sensor_data).real
    if sensor_arr.ndim != 2:
        sensor_arr = sensor_arr.reshape(sensor_arr.shape[0], -1)
    vmax = max(float(np.max(np.abs(a))) for a in arrays)
    sensor_vmax = float(np.percentile(np.abs(sensor_arr), 99.5))
    if sensor_vmax == 0.0:
        sensor_vmax = 1.0
    extent = (0.0, float(domain.grid_size[1]), 0.0, float(domain.grid_size[0]))

    fig = plt.figure(figsize=(12, 9))
    gs = gridspec.GridSpec(
        3,
        3,
        height_ratios=[1.0, 1.0, 0.85],
        hspace=0.25,
        wspace=0.08,
        figure=fig,
    )

    top_titles = [
        r"$p_0$",
        r"$p_{\mathrm{TR}}^{\mathrm{k\!-\!Wave}}$",
        r"$p_{\mathrm{TR}}^{\mathrm{MSGB}}$",
    ]
    mid_titles = [
        r"$p_{\mathrm{Adj}}^{\mathrm{k\!-\!Wave}}$",
        r"$p_{\mathrm{Adj}}^{\mathrm{MSGB}}$",
    ]
    top_arrays = [arrays[0], arrays[1], arrays[2]]
    mid_arrays = [arrays[3], arrays[4]]
    image_axes = []

    for j, (title, arr) in enumerate(zip(top_titles, top_arrays)):
        ax = fig.add_subplot(gs[0, j])
        ax.imshow(
            arr,
            origin="lower",
            extent=extent,
            vmin=-vmax,
            vmax=vmax,
            cmap="RdBu_r",
            aspect="equal",
        )
        ax.set_title(title)
        ax.set_xticks([])
        ax.set_yticks([])
        image_axes.append(ax)

    ax_data = fig.add_subplot(gs[1, 0])
    ax_data.imshow(
        sensor_arr,
        origin="lower",
        aspect="auto",
        cmap="viridis",
        vmin=-sensor_vmax,
        vmax=sensor_vmax,
    )
    ax_data.set_title("sensor data")
    ax_data.set_xlabel(r"$x_s$")
    ax_data.set_ylabel(r"$t$")
    ax_data.set_box_aspect(1)
    ax_data.set_xticks([])
    ax_data.set_yticks([])

    for j, (title, arr) in enumerate(zip(mid_titles, mid_arrays), start=1):
        ax = fig.add_subplot(gs[1, j])
        ax.imshow(
            arr,
            origin="lower",
            extent=extent,
            vmin=-vmax,
            vmax=vmax,
            cmap="RdBu_r",
            aspect="equal",
        )
        ax.set_title(title)
        ax.set_xticks([])
        ax.set_yticks([])
        image_axes.append(ax)

    # Overlay sensor positions on every image panel.
    rr, cc = jnp.where(sensors.binary_mask)
    xs = (np.asarray(cc) + 0.5) * float(domain.dx[1])
    ys = (np.asarray(rr) + 0.5) * float(domain.dx[0])
    for ax in image_axes:
        ax.scatter(
            xs,
            ys,
            s=16,
            c="red",
            marker="^",
            alpha=0.9,
            edgecolors="white",
            linewidths=0.3,
            zorder=10,
        )

    # 1D profile down the middle column of the image.
    ax_prof = fig.add_subplot(gs[2, :])
    idx = arrays[0].shape[1] // 2
    y_axis = np.arange(arrays[0].shape[0]) * float(domain.dx[0])
    ax_prof.plot(y_axis, arrays[0][:, idx], color="black", lw=2.0, label=r"$p_0$")
    ax_prof.plot(y_axis, arrays[1][:, idx], color="C0", lw=1.5, label="TR k-Wave")
    ax_prof.plot(
        y_axis, arrays[2][:, idx], color="C0", lw=1.5, ls="--", label="TR MSGB"
    )
    ax_prof.plot(y_axis, arrays[3][:, idx], color="C3", lw=1.5, label="Adj k-Wave")
    ax_prof.plot(
        y_axis, arrays[4][:, idx], color="C3", lw=1.5, ls="--", label="Adj MSGB"
    )
    ax_prof.set_xlabel("y [m]")
    ax_prof.set_ylabel("pressure")
    ax_prof.set_title(f"profile at x = {idx * float(domain.dx[1]):.1e} m")
    ax_prof.legend(loc="lower center", bbox_to_anchor=(0.5, -0.45), ncol=5)
    ax_prof.axvline(idx * float(domain.dx[0]), color="grey", ls=":", lw=0.8)
    for ax in image_axes:
        ax.axvline(idx * float(domain.dx[1]), color="grey", ls=":", lw=0.8)

    fig.savefig(out_path, dpi=180, bbox_inches="tight")
    plt.close(fig)


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------


def main() -> None:
    KWaveSolver = load_kwave_solver()

    n = (64, 64)
    dx = (1.0e-4, 1.0e-4)
    domain = Domain(
        N=n,
        dx=dx,
        c=c_homogeneous,
        cfl=0.3,
        periodic=(False, False),
    )
    ts = domain.generate_time_domain()
    p0 = make_two_gaussian_phantom(domain)

    # Boundary sensors on the x = 0 row.
    sensor_mask = jnp.zeros(n).at[0, :].set(1.0)
    sensors = Sensor(domain=domain, binary_mask=sensor_mask)
    image_mask = jnp.ones(n)

    # --- k-Wave forward to generate sensor data ---
    kwave = KWaveSolver(
        backend="python",
        device="cpu",
        pml_size=8,
        smooth_p0=False,
        debug=False,
    )
    data = kwave.forward(p0, domain, sensor_mask, ts)

    # --- k-Wave TR and Adjoint ---
    tr_kw = -coerce_image(
        kwave.time_reversal(
            data=data,
            domain=domain,
            sensors=image_mask,
            sources=sensor_mask,
            ts=ts,
            data_layout="nt_ns",
        ),
        n,
    )
    adj_kw = -coerce_image(
        kwave.adjoint(
            data=data,
            domain=domain,
            sensors=image_mask,
            sources=sensor_mask,
            ts=ts,
            data_layout="nt_ns",
        ),
        n,
    )

    # --- MSGB TR and Adjoint (with frequency-cropped data domain) ---
    sensor_cropped, domain_data, wpt_data, _ts_data = prepare_data_domain_for_msgb(
        data,
        domain,
        ts,
    )
    img_dyadic = DyadicDecomposition(
        num_levels=2,
        N=n,
        num_boxes_levels=(4, 8),
        box_aspect_ratio=(1, 1),
    )
    img_wpt = MSWPT(img_dyadic, redundancy=2, windowing="rectangular_mirror")
    msgb = MSGBSolver(
        thr=int(img_wpt.total_coeffs),
        thr_strat="top_n",
        batch_size=64,
        input_type="spatial",
        ode_solver=gb_solvers.solve_ODE_base,
        tr_ode_solver=gb_solvers.solve_ODE_batch_t,
        sum_method="scan_real",
    )
    sensors_eval = Sensor(domain=domain, binary_mask=image_mask)
    tr_msgb_raw = msgb.time_reversal(
        data=sensor_cropped,
        domain=domain,
        sensors=sensors_eval,
        sources=sensors,
        ts=ts,
        data_domain=domain_data,
        data_wpt=wpt_data,
    )
    adj_msgb_raw = msgb.adjoint(
        data=sensor_cropped,
        domain=domain,
        sensors=sensors_eval,
        sources=sensors,
        ts=ts,
        data_domain=domain_data,
        data_wpt=wpt_data,
    )
    tr_msgb = np.asarray(tr_msgb_raw).real.reshape(n)
    adj_msgb = np.asarray(adj_msgb_raw).real.reshape(n)

    # --- Best-L2 scale each reconstruction and print metrics ---
    truth = np.asarray(p0)
    tr_kw_s, tr_kw_l2 = scaled(tr_kw, truth)
    adj_kw_s, adj_kw_l2 = scaled(adj_kw, truth)
    tr_msgb_s, tr_msgb_l2 = scaled(tr_msgb, truth)
    adj_msgb_s, adj_msgb_l2 = scaled(adj_msgb, truth)

    print(f"TR k-Wave   rel L2 = {tr_kw_l2:.3f}")
    print(f"TR MSGB     rel L2 = {tr_msgb_l2:.3f}")
    print(f"Adj k-Wave  rel L2 = {adj_kw_l2:.3f}")
    print(f"Adj MSGB    rel L2 = {adj_msgb_l2:.3f}")

    out_dir = utils.example_plot_dir(__file__)
    out_path = out_dir / "2d_time_reversal_and_adjoint.png"
    plot_comparison(
        truth,
        data,
        tr_kw_s,
        tr_msgb_s,
        adj_kw_s,
        adj_msgb_s,
        domain=domain,
        sensors=sensors,
        out_path=out_path,
    )
    print(f"Saved figure to {out_path}")


if __name__ == "__main__":
    main()