hyperiax.prebuilt.bffg.continuous_schema

hyperiax.prebuilt.bffg.continuous_schema(d, n_steps)[source]

Field layout for a continuous-edge (SDE) BFFG tree.

Parameters:
  • d (int) – State dimension.

  • n_steps (int) – Number of Euler-Maruyama / RK4 substeps per edge.

Return type:

dict[str, tuple[int, ...]]

Returns:

Mapping of field name to per-node trailing shape. Each non-root node’s edge is discretised into n_steps substeps (n_steps + 1 time points). Fields:

  • edge_len: () — Δt for this node’s incoming edge.

  • vals: (n_steps + 1, d) — forward-sampled trajectory along the edge.

  • zs: (n_steps, d) — i.i.d. standard-normal increments driving the SDE.

  • ptnls: (n_steps + 1, d) / precs: (n_steps + 1, d, d) — per-step canonical (F, H) trajectory along the edge (BF).

  • ptnl_v: (d,) / prec_v: (d, d) — vertex (fused) canonical message at this node; written by the up-sweep, used as the terminal condition for the next-level edge.

  • anchor: (d,) — child-end linearisation point of this node’s edge (t = edge_len); refined toward the BFFG posterior mean.

  • anchor_pa: (d,) — parent-end linearisation point of this node’s edge (t = 0); cached on the child so the up-sweep can read both endpoints inside children.map(...).

  • log_norm: () — canonical-message log-norm at the vertex; sums up the tree so log_norm[root] is the marginal log-evidence.

  • log_corr: () — per-edge importance log-weight from forward guiding (Theorem 23 eq 32 / Remark 24).