hyperiax.down

hyperiax.down(*, reads=None, reads_parent=None, writes)[source]

Decorator: mark a function as a down-sweep.

The decorated function has signature (node, parent, params) and must return a dict[str, Array] whose key set equals writes exactly. Each returned value must have shape (*trailing,) matching the schema for that field (per-node under jax.vmap()).

The root is never visited — it has no parent. Seed its data with tree.set(...) before calling the sweep.

Parameters:
  • reads (Sequence[str] | None) – Fields the function reads from the current node. None defaults to all schema fields at call time.

  • reads_parent (Sequence[str] | None) – Fields the function reads from each node’s parent. None defaults to all schema fields at call time.

  • writes (Sequence[str]) – Required. The fields the function writes back into the tree. Must be non-empty.

Return type:

Callable[[Callable], SweepFn]

Example:

@hx.down(reads=('delta',), reads_parent=('value',), writes=('value',))
def propagate(node, parent, params):
    return {'value': parent.value + node.delta}

new_tree = propagate(tree)