hyperiax — JAX tree traversals

hyperiax is a small, pure-JAX library for message passing on phylogenetic / rooted trees. Trees are immutable JAX pytrees; sweeps are decorator-style Tree -> Tree transforms that compose cleanly under @jax.jit and jax.lax.scan.

import jax, jax.numpy as jnp
import hyperiax as hx

topo = hx.symmetric_topology(depth=3, degree=2)
tree = hx.Tree.empty(topo, {"value": (2,)})
tree = tree.at[topo.is_leaf].set(value=jnp.ones((8, 2)))

@hx.up(reads_children=("value",), writes=("value",))
def avg(node, children, params):
    return {"value": children.value.mean(0)}

root_value = avg(tree)["value"][0]

What’s in the box

  • core: Topology, Tree, Schema, the @up / @down sweep decorators, the unified segment-based dispatch engine, and Newick read/write (from_newick() / to_newick()).

  • prebuilt: BFFG for discrete Gaussian and continuous SDE transitions (van der Meulen & Sommer 2025), plus weighted phylo-mean.

Installation

pip install hyperiax                  # core, CPU JAX (jax + jaxlib + numpy)
pip install 'hyperiax[gpu]'           # + CUDA 12 JAX
pip install 'hyperiax[notebook]'      # tutorial notebooks

Contents

Indices