hyperiax — JAX tree traversals¶
hyperiax is a small, pure-JAX library for message passing on phylogenetic
/ rooted trees. Trees are immutable JAX pytrees; sweeps are decorator-style
Tree -> Tree transforms that compose cleanly under @jax.jit and
jax.lax.scan.
import jax, jax.numpy as jnp
import hyperiax as hx
topo = hx.symmetric_topology(depth=3, degree=2)
tree = hx.Tree.empty(topo, {"value": (2,)})
tree = tree.at[topo.is_leaf].set(value=jnp.ones((8, 2)))
@hx.up(reads_children=("value",), writes=("value",))
def avg(node, children, params):
return {"value": children.value.mean(0)}
root_value = avg(tree)["value"][0]
What’s in the box¶
core:
Topology,Tree,Schema, the@up/@downsweep decorators, the unified segment-based dispatch engine, and Newick read/write (from_newick()/to_newick()).prebuilt: BFFG for discrete Gaussian and continuous SDE transitions (van der Meulen & Sommer 2025), plus weighted phylo-mean.
Installation¶
pip install hyperiax # core, CPU JAX (jax + jaxlib + numpy)
pip install 'hyperiax[gpu]' # + CUDA 12 JAX
pip install 'hyperiax[notebook]' # tutorial notebooks
Contents¶
Tutorials
API reference