hyperiax.down¶
- hyperiax.down(*, reads=None, reads_parent=None, writes)[source]¶
Decorator: mark a function as a down-sweep.
The decorated function has signature
(node, parent, params)and must return adict[str, Array]whose key set equalswritesexactly. Each returned value must have shape(*trailing,)matching the schema for that field (per-node underjax.vmap()).The root is never visited — it has no parent. Seed its data with
tree.set(...)before calling the sweep.- Parameters:
reads (
Sequence[str] |None) – Fields the function reads from the current node.Nonedefaults to all schema fields at call time.reads_parent (
Sequence[str] |None) – Fields the function reads from each node’s parent.Nonedefaults to all schema fields at call time.writes (
Sequence[str]) – Required. The fields the function writes back into the tree. Must be non-empty.
- Return type:
Example:
@hx.down(reads=('delta',), reads_parent=('value',), writes=('value',)) def propagate(node, parent, params): return {'value': parent.value + node.delta} new_tree = propagate(tree)