hyperiax.up¶
- hyperiax.up(*, reads=None, reads_children=None, writes=(), writes_children=())[source]¶
Decorator: mark a function as an up-sweep.
The decorated function has signature
(node, children, params)and must return adict[str, Array]whose key set equalswrites ∪ writes_childrenexactly. Values for keys inwriteshave shape(*trailing,)per parent (underjax.vmap()); values for keys inwrites_childrenhave shape(k, *trailing)per parent and are scattered to the children’s slots.At least one of
writesorwrites_childrenmust be non-empty.- Parameters:
reads (
Sequence[str] |None) – Fields the function reads from the current node.None(the default) means “all schema fields at call time.”reads_children (
Sequence[str] |None) – Fields the function reads from each child.Nonedefaults to all schema fields at call time.writes (
Sequence[str]) – Fields written back per parent (canonical case).writes_children (
Sequence[str]) – Fields written per child (one row per outgoing edge). Use when an up sweep also produces per-edge byproducts you want to cache on each child (e.g. per-edge ODE-filter trajectories).
- Return type:
Example:
@hx.up(reads_children=('value',), writes=('value',)) def avg(node, children, params): return {'value': children.value.mean(0)} new_tree = avg(tree)