Rays and autodiff
Static ray-tracing examples for Gaussian beam trajectories and differentiable ray objectives.
2D ray bending
Trace a small fan of 2D rays through a smooth analytic speed field. The script
overlays ray paths on c(x) and reports the observed lateral displacement and
direction change.
#!/usr/bin/env python
"""
Trace a small fan of 2D rays through a smooth speed field.
The Gaussian beam ray equations bend trajectories toward gradients in the
Hamiltonian `G(x, p) = c(x) |p|`. This example solves those ODEs for a few
parallel rays, overlays the paths on the speed map, and reports compact
diagnostics for the amount of bending.
"""
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from beamax import utils
from beamax.gb import gb_solvers
from beamax.plotter import use_beamax_style
jax.config.update("jax_enable_x64", True)
def speed_field(x: jnp.ndarray) -> jnp.ndarray:
"""Smooth 2D sound-speed map used by the ray example."""
lens_center = jnp.array([0.46, 0.52])
lens = jnp.exp(-35.0 * jnp.sum((x - lens_center) ** 2, axis=-1))
vertical_gradient = 0.18 * (x[..., 1] - 0.5)
return 1.05 + vertical_gradient - 0.28 * lens
def solve_rays() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Solve a compact bundle of initially parallel rays."""
n_rays = 60
ts = jnp.linspace(0.0, 0.75, 48)
y0 = jnp.linspace(0.18, 0.82, n_rays)
x0 = jnp.stack([jnp.full((n_rays,), 0.08), y0], axis=-1)
p0 = jnp.tile(jnp.array([1.0, 0.0]), (n_rays, 1))
m0 = 1j * jnp.eye(2)[None, :, :].repeat(n_rays, axis=0)
a0 = jnp.ones((n_rays,))
mode = jnp.ones((n_rays,))
solver_config = gb_solvers.SolverConfig(
rtol=1e-4,
atol=1e-6,
max_steps=1024,
)
xt, pt, _, _ = gb_solvers.solve_ODE_base(
x0,
p0,
m0,
a0,
mode,
ts,
speed_field,
0.0,
solver_config,
)
return xt, pt, x0, ts
def main() -> None:
plot_dir = utils.example_plot_dir(__file__)
use_beamax_style()
xt, pt, x0, ts = solve_rays()
final = xt[:, -1, :]
direction_angles = jnp.arctan2(pt[:, -1, 1], pt[:, -1, 0])
initial_angles = jnp.arctan2(pt[:, 0, 1], pt[:, 0, 0])
grid_n = 128
x = jnp.linspace(0.0, 1.0, grid_n)
y = jnp.linspace(0.0, 1.0, grid_n)
xy = jnp.stack(jnp.meshgrid(x, y, indexing="ij"), axis=-1)
c_values = speed_field(xy)
fig, ax = plt.subplots(figsize=(6.5, 5.2))
im = ax.imshow(
np.asarray(c_values.T),
extent=[0.0, 1.0, 0.0, 1.0],
origin="lower",
cmap="viridis",
aspect="equal",
)
fig.colorbar(im, ax=ax, label="c(x)")
ray_color = "#d94801"
for ray in np.asarray(xt):
ax.plot(ray[:, 0], ray[:, 1], color=ray_color, lw=1.6, alpha=0.88)
ax.scatter(
np.asarray(x0[:, 0]),
np.asarray(x0[:, 1]),
s=26,
color="white",
edgecolor="black",
zorder=3,
)
ax.scatter(
np.asarray(final[:, 0]),
np.asarray(final[:, 1]),
s=24,
color=ray_color,
edgecolor="black",
zorder=3,
)
ax.set_xlim(0.0, 1.0)
ax.set_ylim(0.0, 1.0)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_title("2D ray bending in a smooth speed field")
fig.tight_layout()
out_path = plot_dir / "2d_ray_bending.png"
fig.savefig(out_path, dpi=180, bbox_inches="tight")
plt.close(fig)
mean_lateral_shift = float(jnp.mean(jnp.abs(final[:, 1] - x0[:, 1])))
max_angle_change = float(jnp.max(jnp.abs(direction_angles - initial_angles)))
print(f"Rays solved: {xt.shape[0]}, time samples: {ts.shape[0]}")
print(
f"Speed range on plot grid: [{float(c_values.min()):.3f}, {float(c_values.max()):.3f}]"
)
print(f"Mean lateral displacement: {mean_lateral_shift:.3f}")
print(f"Max direction change: {max_angle_change:.3f} rad")
print(f"Saved ray-bending plot to {out_path}")
if __name__ == "__main__":
main()
2D rays autodiff
Port the thesis ray-focusing example: represent c(x) with a small neural
field, optimize it with autodiff through the Gaussian beam ray ODE, and save
the before/after rays, loss curve, and \(\Delta c\) panels. This optional
example requires beamax[viz-mpl,autodiff] for Optax.
#!/usr/bin/env python
"""
Differentiate through 2D Gaussian beam rays.
This example ports the thesis ray-focusing setup to the public gallery. A
small neural field represents `c(x)`, and autodiff through the Gaussian beam
ray ODE optimizes the medium so a fan of rays focuses at a target point.
Example category: Rays and autodiff
Example extras: viz-mpl,autodiff
Example smoke: false
Requires Optax. Install with `pip install "beamax[viz-mpl,autodiff]"`, or
from a checkout with `pip install -e ".[viz-mpl,autodiff]"`.
"""
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
try:
import optax
except ModuleNotFoundError as exc:
print(
"Skipping optional example: Optax is not installed "
'(`pip install "beamax[viz-mpl,autodiff]"`).'
)
raise SystemExit(0) from exc
from beamax import utils
from beamax.gb import gb_solvers
from beamax.plotter import use_beamax_style
jax.config.update("jax_enable_x64", True)
XMIN, XMAX = -10.0, 10.0
YMIN, YMAX = -10.0, 10.0
EXTENT = [XMIN, XMAX, YMIN, YMAX]
class NeuralC:
"""Small MLP parametrization of the sound-speed field."""
def __init__(self, hidden_dim: int = 32, base_c: float = 1.0):
self.base_c = base_c
key = jax.random.PRNGKey(42)
k1, k2, k3 = jax.random.split(key, 3)
self.params = {
"w1": 0.1 * jax.random.normal(k1, (2, hidden_dim)),
"b1": jnp.zeros(hidden_dim),
"w2": 0.1 * jax.random.normal(k2, (hidden_dim, hidden_dim)),
"b2": jnp.zeros(hidden_dim),
"w3": 0.1 * jax.random.normal(k3, (hidden_dim, 1)),
"b3": jnp.zeros(1),
}
def __call__(self, x: jnp.ndarray, params: dict[str, jnp.ndarray]) -> jnp.ndarray:
x_norm = jnp.stack(
[
2.0 * (x[..., 0] - XMIN) / (XMAX - XMIN) - 1.0,
2.0 * (x[..., 1] - YMIN) / (YMAX - YMIN) - 1.0,
],
axis=-1,
)
h = jnp.tanh(x_norm @ params["w1"] + params["b1"])
h = jnp.tanh(h @ params["w2"] + params["b2"])
delta_c = 0.3 * jnp.tanh(h @ params["w3"] + params["b3"])[..., 0]
return self.base_c + delta_c
def ray_setup():
"""Build the source line, upward launches, target focus, and time grid."""
n_rays = 20
source_x = jnp.linspace(-5.0, 5.0, n_rays)
source_y = -7.0
x0 = jnp.stack([source_x, jnp.full(n_rays, source_y)], axis=-1)
p0 = jnp.stack([jnp.zeros(n_rays), jnp.ones(n_rays)], axis=-1)
focus = jnp.array([0.0, 5.0])
d = 2
alpha0 = jnp.ones((n_rays, d))
m0 = 1j * jnp.einsum("bd,dj->bdj", alpha0, jnp.eye(d))
a0 = jnp.ones((n_rays, 1))
mode = jnp.ones((n_rays, 1))
ts = jnp.linspace(0.0, 15.0, 300)
return x0, p0, m0, a0, mode, ts, focus
def solve_rays(
params: dict[str, jnp.ndarray],
param_c: NeuralC,
x0: jnp.ndarray,
p0: jnp.ndarray,
m0: jnp.ndarray,
a0: jnp.ndarray,
mode: jnp.ndarray,
ts: jnp.ndarray,
) -> jnp.ndarray:
def c_fn(x):
return param_c(x, params)
xt, _, _, _ = gb_solvers.solve_ODE_base(x0, p0, m0, a0, mode, ts, c_fn, 0.0, None)
return xt
def focusing_loss(
params: dict[str, jnp.ndarray],
param_c: NeuralC,
x0: jnp.ndarray,
p0: jnp.ndarray,
m0: jnp.ndarray,
a0: jnp.ndarray,
mode: jnp.ndarray,
ts: jnp.ndarray,
focus: jnp.ndarray,
):
"""Penalize distance to the focus, spread at focus time, and field roughness."""
xt = solve_rays(params, param_c, x0, p0, m0, a0, mode, ts)
dist_to_focus = jnp.linalg.norm(xt - focus[None, None, :], axis=-1)
min_dist = jnp.min(dist_to_focus, axis=1)
focus_loss = jnp.mean(min_dist)
mean_dist = jnp.mean(dist_to_focus, axis=0)
focus_time_idx = jnp.argmin(mean_dist)
rays_at_focus = xt[:, focus_time_idx, :]
spread_loss = jnp.mean(jnp.std(rays_at_focus, axis=0))
smooth_loss = 0.01 * sum(jnp.mean(value**2) for value in params.values())
total_loss = focus_loss + 0.5 * spread_loss + smooth_loss
return total_loss, {
"focus": focus_loss,
"spread": spread_loss,
"smooth": smooth_loss,
"xt": xt,
}
def speed_map(
param_c: NeuralC,
params: dict[str, jnp.ndarray],
nx: int = 200,
ny: int = 200,
) -> jnp.ndarray:
xg = jnp.linspace(XMIN, XMAX, nx)
yg = jnp.linspace(YMIN, YMAX, ny)
xx, yy = jnp.meshgrid(xg, yg, indexing="xy")
grid_points = jnp.stack([xx, yy], axis=-1)
return param_c(grid_points, params)
def plot_rays_before_after(
out_path,
c_init_map,
c_opt_map,
xt_init,
xt_final,
x0,
focus,
) -> None:
vmin = float(jnp.minimum(jnp.min(c_init_map), jnp.min(c_opt_map)))
vmax = float(jnp.maximum(jnp.max(c_init_map), jnp.max(c_opt_map)))
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
panels = (
(ax1, c_init_map, xt_init, r"$c_{\mathrm{init}}(\mathbf{x})$"),
(ax2, c_opt_map, xt_final, r"$c_{\mathrm{opt}}(\mathbf{x})$"),
)
for ax, c_map, xt, title in panels:
im = ax.imshow(
np.asarray(c_map.T),
extent=EXTENT,
origin="lower",
cmap="viridis",
aspect="equal",
vmin=vmin,
vmax=vmax,
)
for ray in np.asarray(xt):
ax.plot(ray[:, 0], ray[:, 1], "w-", lw=1.0, alpha=0.5)
ax.scatter(
np.asarray(x0[:, 0]),
np.asarray(x0[:, 1]),
s=20,
c="red",
marker="o",
label="sources",
)
ax.scatter(
float(focus[0]),
float(focus[1]),
s=100,
c="red",
marker="*",
label="focus",
)
ax.set_title(title)
ax.legend(frameon=True, fancybox=True, loc="lower right")
ax.set_xticks([])
ax.set_yticks([])
plt.colorbar(im, ax=ax2)
plt.tight_layout()
fig.savefig(out_path, dpi=300, bbox_inches="tight")
plt.close(fig)
def plot_loss(out_path, loss_history: list[float]) -> None:
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(loss_history)
ax.set_xlabel("iteration")
ax.set_ylabel(r"$\mathcal{L}$")
ax.grid(True, alpha=0.3)
plt.tight_layout()
fig.savefig(out_path, dpi=300, bbox_inches="tight")
plt.close(fig)
def plot_speed_delta(out_path, c_init_map, c_opt_map, focus) -> None:
diff_map = c_opt_map - c_init_map
delta_max = float(jnp.max(jnp.abs(diff_map)))
fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(
np.asarray(diff_map.T),
extent=EXTENT,
origin="lower",
cmap="RdBu_r",
aspect="equal",
vmin=-delta_max,
vmax=delta_max,
)
ax.scatter(float(focus[0]), float(focus[1]), s=100, c="black", marker="*")
ax.set_title(r"$\Delta c(\mathbf{x})$")
ax.set_xticks([])
ax.set_yticks([])
plt.colorbar(im, ax=ax)
plt.tight_layout()
fig.savefig(out_path, dpi=300, bbox_inches="tight")
plt.close(fig)
def main() -> None:
plot_dir = utils.example_plot_dir(__file__)
use_beamax_style()
x0, p0, m0, a0, mode, ts, focus = ray_setup()
param_c = NeuralC(hidden_dim=32)
params = param_c.params
loss_grad = jax.jit(
jax.value_and_grad(
lambda current_params: focusing_loss(
current_params,
param_c,
x0,
p0,
m0,
a0,
mode,
ts,
focus,
),
has_aux=True,
)
)
xt_init = solve_rays(params, param_c, x0, p0, m0, a0, mode, ts)
lr_schedule = optax.exponential_decay(
init_value=0.02,
transition_steps=50,
decay_rate=0.9,
)
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(lr_schedule))
opt_state = optimizer.init(params)
loss_history: list[float] = []
best_loss = float("inf")
best_params = params
best_xt = xt_init
num_iters = 300
print("Optimizing speed of sound field for ray focusing...")
for step in range(num_iters):
(loss_value, aux), grads = loss_grad(params)
loss_float = float(loss_value)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
loss_history.append(loss_float)
# Mirrors the thesis script: the best ray trajectory is the one used
# to evaluate the loss, while the displayed medium is after the update.
if loss_float < best_loss:
best_loss = loss_float
best_params = params
best_xt = aux["xt"]
if step % 50 == 0:
print(
f"iter {step:03d} | loss {loss_float:.6f} | "
f"focus {float(aux['focus']):.4f} | spread {float(aux['spread']):.4f}"
)
c_init_map = speed_map(param_c, param_c.params)
c_opt_map = speed_map(param_c, best_params)
rays_path = plot_dir / "focusing_rays_before_after.png"
loss_path = plot_dir / "focusing_loss_convergence.png"
delta_path = plot_dir / "focusing_sound_speed_delta.png"
plot_rays_before_after(
rays_path,
c_init_map,
c_opt_map,
xt_init,
best_xt,
x0,
focus,
)
plot_loss(loss_path, loss_history)
plot_speed_delta(delta_path, c_init_map, c_opt_map, focus)
print("Optimization complete.")
print(f"Final focusing loss: {best_loss:.6f}")
print(f"Saved figures to {plot_dir.resolve()}")
if __name__ == "__main__":
main()