hyperiax.up

hyperiax.up(*, reads=None, reads_children=None, writes=(), writes_children=())[source]

Decorator: mark a function as an up-sweep.

The decorated function has signature (node, children, params) and must return a dict[str, Array] whose key set equals writes writes_children exactly. Values for keys in writes have shape (*trailing,) per parent (under jax.vmap()); values for keys in writes_children have shape (k, *trailing) per parent and are scattered to the children’s slots.

At least one of writes or writes_children must be non-empty.

Parameters:
  • reads (Sequence[str] | None) – Fields the function reads from the current node. None (the default) means “all schema fields at call time.”

  • reads_children (Sequence[str] | None) – Fields the function reads from each child. None defaults to all schema fields at call time.

  • writes (Sequence[str]) – Fields written back per parent (canonical case).

  • writes_children (Sequence[str]) – Fields written per child (one row per outgoing edge). Use when an up sweep also produces per-edge byproducts you want to cache on each child (e.g. per-edge ODE-filter trajectories).

Return type:

Callable[[Callable], SweepFn]

Example:

@hx.up(reads_children=('value',), writes=('value',))
def avg(node, children, params):
    return {'value': children.value.mean(0)}

new_tree = avg(tree)