Skip to content

Rays and autodiff

Static ray-tracing examples for Gaussian beam trajectories and differentiable ray objectives.


2D ray bending

Trace a small fan of 2D rays through a smooth analytic speed field. The script overlays ray paths on c(x) and reports the observed lateral displacement and direction change.

Open In Colab

#!/usr/bin/env python
"""
Trace a small fan of 2D rays through a smooth speed field.

The Gaussian beam ray equations bend trajectories toward gradients in the
Hamiltonian `G(x, p) = c(x) |p|`. This example solves those ODEs for a few
parallel rays, overlays the paths on the speed map, and reports compact
diagnostics for the amount of bending.
"""

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from beamax import utils
from beamax.gb import gb_solvers
from beamax.plotter import use_beamax_style


jax.config.update("jax_enable_x64", True)


def speed_field(x: jnp.ndarray) -> jnp.ndarray:
    """Smooth 2D sound-speed map used by the ray example."""
    lens_center = jnp.array([0.46, 0.52])
    lens = jnp.exp(-35.0 * jnp.sum((x - lens_center) ** 2, axis=-1))
    vertical_gradient = 0.18 * (x[..., 1] - 0.5)
    return 1.05 + vertical_gradient - 0.28 * lens


def solve_rays() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Solve a compact bundle of initially parallel rays."""
    n_rays = 60
    ts = jnp.linspace(0.0, 0.75, 48)
    y0 = jnp.linspace(0.18, 0.82, n_rays)
    x0 = jnp.stack([jnp.full((n_rays,), 0.08), y0], axis=-1)
    p0 = jnp.tile(jnp.array([1.0, 0.0]), (n_rays, 1))
    m0 = 1j * jnp.eye(2)[None, :, :].repeat(n_rays, axis=0)
    a0 = jnp.ones((n_rays,))
    mode = jnp.ones((n_rays,))

    solver_config = gb_solvers.SolverConfig(
        rtol=1e-4,
        atol=1e-6,
        max_steps=1024,
    )
    xt, pt, _, _ = gb_solvers.solve_ODE_base(
        x0,
        p0,
        m0,
        a0,
        mode,
        ts,
        speed_field,
        0.0,
        solver_config,
    )
    return xt, pt, x0, ts


def main() -> None:
    plot_dir = utils.example_plot_dir(__file__)
    use_beamax_style()

    xt, pt, x0, ts = solve_rays()
    final = xt[:, -1, :]
    direction_angles = jnp.arctan2(pt[:, -1, 1], pt[:, -1, 0])
    initial_angles = jnp.arctan2(pt[:, 0, 1], pt[:, 0, 0])

    grid_n = 128
    x = jnp.linspace(0.0, 1.0, grid_n)
    y = jnp.linspace(0.0, 1.0, grid_n)
    xy = jnp.stack(jnp.meshgrid(x, y, indexing="ij"), axis=-1)
    c_values = speed_field(xy)

    fig, ax = plt.subplots(figsize=(6.5, 5.2))
    im = ax.imshow(
        np.asarray(c_values.T),
        extent=[0.0, 1.0, 0.0, 1.0],
        origin="lower",
        cmap="viridis",
        aspect="equal",
    )
    fig.colorbar(im, ax=ax, label="c(x)")

    ray_color = "#d94801"
    for ray in np.asarray(xt):
        ax.plot(ray[:, 0], ray[:, 1], color=ray_color, lw=1.6, alpha=0.88)
    ax.scatter(
        np.asarray(x0[:, 0]),
        np.asarray(x0[:, 1]),
        s=26,
        color="white",
        edgecolor="black",
        zorder=3,
    )
    ax.scatter(
        np.asarray(final[:, 0]),
        np.asarray(final[:, 1]),
        s=24,
        color=ray_color,
        edgecolor="black",
        zorder=3,
    )
    ax.set_xlim(0.0, 1.0)
    ax.set_ylim(0.0, 1.0)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_title("2D ray bending in a smooth speed field")
    fig.tight_layout()

    out_path = plot_dir / "2d_ray_bending.png"
    fig.savefig(out_path, dpi=180, bbox_inches="tight")
    plt.close(fig)

    mean_lateral_shift = float(jnp.mean(jnp.abs(final[:, 1] - x0[:, 1])))
    max_angle_change = float(jnp.max(jnp.abs(direction_angles - initial_angles)))
    print(f"Rays solved: {xt.shape[0]}, time samples: {ts.shape[0]}")
    print(
        f"Speed range on plot grid: [{float(c_values.min()):.3f}, {float(c_values.max()):.3f}]"
    )
    print(f"Mean lateral displacement: {mean_lateral_shift:.3f}")
    print(f"Max direction change: {max_angle_change:.3f} rad")
    print(f"Saved ray-bending plot to {out_path}")


if __name__ == "__main__":
    main()

2D rays autodiff

Port the thesis ray-focusing example: represent c(x) with a small neural field, optimize it with autodiff through the Gaussian beam ray ODE, and save the before/after rays, loss curve, and \(\Delta c\) panels. This optional example requires beamax[viz-mpl,autodiff] for Optax.

Open In Colab

#!/usr/bin/env python
"""
Differentiate through 2D Gaussian beam rays.

This example ports the thesis ray-focusing setup to the public gallery. A
small neural field represents `c(x)`, and autodiff through the Gaussian beam
ray ODE optimizes the medium so a fan of rays focuses at a target point.

Example category: Rays and autodiff
Example extras: viz-mpl,autodiff
Example smoke: false

Requires Optax. Install with `pip install "beamax[viz-mpl,autodiff]"`, or
from a checkout with `pip install -e ".[viz-mpl,autodiff]"`.
"""

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

try:
    import optax
except ModuleNotFoundError as exc:
    print(
        "Skipping optional example: Optax is not installed "
        '(`pip install "beamax[viz-mpl,autodiff]"`).'
    )
    raise SystemExit(0) from exc

from beamax import utils
from beamax.gb import gb_solvers
from beamax.plotter import use_beamax_style


jax.config.update("jax_enable_x64", True)

XMIN, XMAX = -10.0, 10.0
YMIN, YMAX = -10.0, 10.0
EXTENT = [XMIN, XMAX, YMIN, YMAX]


class NeuralC:
    """Small MLP parametrization of the sound-speed field."""

    def __init__(self, hidden_dim: int = 32, base_c: float = 1.0):
        self.base_c = base_c
        key = jax.random.PRNGKey(42)
        k1, k2, k3 = jax.random.split(key, 3)
        self.params = {
            "w1": 0.1 * jax.random.normal(k1, (2, hidden_dim)),
            "b1": jnp.zeros(hidden_dim),
            "w2": 0.1 * jax.random.normal(k2, (hidden_dim, hidden_dim)),
            "b2": jnp.zeros(hidden_dim),
            "w3": 0.1 * jax.random.normal(k3, (hidden_dim, 1)),
            "b3": jnp.zeros(1),
        }

    def __call__(self, x: jnp.ndarray, params: dict[str, jnp.ndarray]) -> jnp.ndarray:
        x_norm = jnp.stack(
            [
                2.0 * (x[..., 0] - XMIN) / (XMAX - XMIN) - 1.0,
                2.0 * (x[..., 1] - YMIN) / (YMAX - YMIN) - 1.0,
            ],
            axis=-1,
        )
        h = jnp.tanh(x_norm @ params["w1"] + params["b1"])
        h = jnp.tanh(h @ params["w2"] + params["b2"])
        delta_c = 0.3 * jnp.tanh(h @ params["w3"] + params["b3"])[..., 0]
        return self.base_c + delta_c


def ray_setup():
    """Build the source line, upward launches, target focus, and time grid."""
    n_rays = 20
    source_x = jnp.linspace(-5.0, 5.0, n_rays)
    source_y = -7.0
    x0 = jnp.stack([source_x, jnp.full(n_rays, source_y)], axis=-1)
    p0 = jnp.stack([jnp.zeros(n_rays), jnp.ones(n_rays)], axis=-1)
    focus = jnp.array([0.0, 5.0])

    d = 2
    alpha0 = jnp.ones((n_rays, d))
    m0 = 1j * jnp.einsum("bd,dj->bdj", alpha0, jnp.eye(d))
    a0 = jnp.ones((n_rays, 1))
    mode = jnp.ones((n_rays, 1))
    ts = jnp.linspace(0.0, 15.0, 300)
    return x0, p0, m0, a0, mode, ts, focus


def solve_rays(
    params: dict[str, jnp.ndarray],
    param_c: NeuralC,
    x0: jnp.ndarray,
    p0: jnp.ndarray,
    m0: jnp.ndarray,
    a0: jnp.ndarray,
    mode: jnp.ndarray,
    ts: jnp.ndarray,
) -> jnp.ndarray:
    def c_fn(x):
        return param_c(x, params)

    xt, _, _, _ = gb_solvers.solve_ODE_base(x0, p0, m0, a0, mode, ts, c_fn, 0.0, None)
    return xt


def focusing_loss(
    params: dict[str, jnp.ndarray],
    param_c: NeuralC,
    x0: jnp.ndarray,
    p0: jnp.ndarray,
    m0: jnp.ndarray,
    a0: jnp.ndarray,
    mode: jnp.ndarray,
    ts: jnp.ndarray,
    focus: jnp.ndarray,
):
    """Penalize distance to the focus, spread at focus time, and field roughness."""
    xt = solve_rays(params, param_c, x0, p0, m0, a0, mode, ts)
    dist_to_focus = jnp.linalg.norm(xt - focus[None, None, :], axis=-1)
    min_dist = jnp.min(dist_to_focus, axis=1)
    focus_loss = jnp.mean(min_dist)

    mean_dist = jnp.mean(dist_to_focus, axis=0)
    focus_time_idx = jnp.argmin(mean_dist)
    rays_at_focus = xt[:, focus_time_idx, :]
    spread_loss = jnp.mean(jnp.std(rays_at_focus, axis=0))
    smooth_loss = 0.01 * sum(jnp.mean(value**2) for value in params.values())

    total_loss = focus_loss + 0.5 * spread_loss + smooth_loss
    return total_loss, {
        "focus": focus_loss,
        "spread": spread_loss,
        "smooth": smooth_loss,
        "xt": xt,
    }


def speed_map(
    param_c: NeuralC,
    params: dict[str, jnp.ndarray],
    nx: int = 200,
    ny: int = 200,
) -> jnp.ndarray:
    xg = jnp.linspace(XMIN, XMAX, nx)
    yg = jnp.linspace(YMIN, YMAX, ny)
    xx, yy = jnp.meshgrid(xg, yg, indexing="xy")
    grid_points = jnp.stack([xx, yy], axis=-1)
    return param_c(grid_points, params)


def plot_rays_before_after(
    out_path,
    c_init_map,
    c_opt_map,
    xt_init,
    xt_final,
    x0,
    focus,
) -> None:
    vmin = float(jnp.minimum(jnp.min(c_init_map), jnp.min(c_opt_map)))
    vmax = float(jnp.maximum(jnp.max(c_init_map), jnp.max(c_opt_map)))

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    panels = (
        (ax1, c_init_map, xt_init, r"$c_{\mathrm{init}}(\mathbf{x})$"),
        (ax2, c_opt_map, xt_final, r"$c_{\mathrm{opt}}(\mathbf{x})$"),
    )
    for ax, c_map, xt, title in panels:
        im = ax.imshow(
            np.asarray(c_map.T),
            extent=EXTENT,
            origin="lower",
            cmap="viridis",
            aspect="equal",
            vmin=vmin,
            vmax=vmax,
        )
        for ray in np.asarray(xt):
            ax.plot(ray[:, 0], ray[:, 1], "w-", lw=1.0, alpha=0.5)
        ax.scatter(
            np.asarray(x0[:, 0]),
            np.asarray(x0[:, 1]),
            s=20,
            c="red",
            marker="o",
            label="sources",
        )
        ax.scatter(
            float(focus[0]),
            float(focus[1]),
            s=100,
            c="red",
            marker="*",
            label="focus",
        )
        ax.set_title(title)
        ax.legend(frameon=True, fancybox=True, loc="lower right")
        ax.set_xticks([])
        ax.set_yticks([])

    plt.colorbar(im, ax=ax2)
    plt.tight_layout()
    fig.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)


def plot_loss(out_path, loss_history: list[float]) -> None:
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.plot(loss_history)
    ax.set_xlabel("iteration")
    ax.set_ylabel(r"$\mathcal{L}$")
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    fig.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)


def plot_speed_delta(out_path, c_init_map, c_opt_map, focus) -> None:
    diff_map = c_opt_map - c_init_map
    delta_max = float(jnp.max(jnp.abs(diff_map)))

    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(
        np.asarray(diff_map.T),
        extent=EXTENT,
        origin="lower",
        cmap="RdBu_r",
        aspect="equal",
        vmin=-delta_max,
        vmax=delta_max,
    )
    ax.scatter(float(focus[0]), float(focus[1]), s=100, c="black", marker="*")
    ax.set_title(r"$\Delta c(\mathbf{x})$")
    ax.set_xticks([])
    ax.set_yticks([])
    plt.colorbar(im, ax=ax)
    plt.tight_layout()
    fig.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)


def main() -> None:
    plot_dir = utils.example_plot_dir(__file__)
    use_beamax_style()

    x0, p0, m0, a0, mode, ts, focus = ray_setup()
    param_c = NeuralC(hidden_dim=32)
    params = param_c.params

    loss_grad = jax.jit(
        jax.value_and_grad(
            lambda current_params: focusing_loss(
                current_params,
                param_c,
                x0,
                p0,
                m0,
                a0,
                mode,
                ts,
                focus,
            ),
            has_aux=True,
        )
    )

    xt_init = solve_rays(params, param_c, x0, p0, m0, a0, mode, ts)
    lr_schedule = optax.exponential_decay(
        init_value=0.02,
        transition_steps=50,
        decay_rate=0.9,
    )
    optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(lr_schedule))
    opt_state = optimizer.init(params)

    loss_history: list[float] = []
    best_loss = float("inf")
    best_params = params
    best_xt = xt_init

    num_iters = 300
    print("Optimizing speed of sound field for ray focusing...")
    for step in range(num_iters):
        (loss_value, aux), grads = loss_grad(params)
        loss_float = float(loss_value)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        loss_history.append(loss_float)

        # Mirrors the thesis script: the best ray trajectory is the one used
        # to evaluate the loss, while the displayed medium is after the update.
        if loss_float < best_loss:
            best_loss = loss_float
            best_params = params
            best_xt = aux["xt"]

        if step % 50 == 0:
            print(
                f"iter {step:03d} | loss {loss_float:.6f} | "
                f"focus {float(aux['focus']):.4f} | spread {float(aux['spread']):.4f}"
            )

    c_init_map = speed_map(param_c, param_c.params)
    c_opt_map = speed_map(param_c, best_params)

    rays_path = plot_dir / "focusing_rays_before_after.png"
    loss_path = plot_dir / "focusing_loss_convergence.png"
    delta_path = plot_dir / "focusing_sound_speed_delta.png"

    plot_rays_before_after(
        rays_path,
        c_init_map,
        c_opt_map,
        xt_init,
        best_xt,
        x0,
        focus,
    )
    plot_loss(loss_path, loss_history)
    plot_speed_delta(delta_path, c_init_map, c_opt_map, focus)

    print("Optimization complete.")
    print(f"Final focusing loss: {best_loss:.6f}")
    print(f"Saved figures to {plot_dir.resolve()}")


if __name__ == "__main__":
    main()