hyperiax.prebuilt.bffg.continuous_forward_sweep

hyperiax.prebuilt.bffg.continuous_forward_sweep(n_steps, drift_fn, diffusion_fn)[source]

Build the unconditional SDE forward-sampling down-sweep.

For each non-root node, integrates the true SDE \(dX_u = b(u, X_u)\,du + \sigma(u, X_u)\,dW_u\) from the parent’s terminal value parent.vals[-1] over the edge using Euler-Maruyama and the pre-stored increments \(dW = \sqrt{\Delta t}\, z\) derived from node.zs.

Parameters:
  • n_steps – Number of Euler-Maruyama substeps per edge.

  • drift_fn(t, x, params) -> (d,) returning the true drift.

  • diffusion_fn(t, x, params) -> (d, noise) returning the true diffusion matrix \(\sigma(t, x)\).

Return type:

SweepFn

Returns:

A hyperiax.SweepFn that writes the full per-edge trajectory to vals at every non-root node. The root’s vals must be set by the caller (typically via init_continuous_tree()’s root_val).