hyperiax.Children¶
- class hyperiax.Children(fields)[source]¶
Bases:
_FieldsViewView of the children of all parents at the current up-sweep level.
Each attribute is a
ChildrenAxisproxy backed by a flat(M_total, *trailing)block (every child at the level concatenated) plus segment ids. The proxy exposes:reductions
.sum / .max / .min / .prod / .mean (axis=0)— dispatched tojax.ops.segment_*, yielding per-parent results;map()— a segment-preserving per-childjax.vmap, so a non-linear per-child transform feeds the same reduction surface.
Direct array ops (indexing, arithmetic,
__array__) are deliberately rejected: usechildren.map(fn)for any non-reduction per-child work.Methods
- map(fn)[source]¶
Apply
fnto each child, preserving the children axis.fnreceives a per-childNodeview (the children fields, one child at a time) and returns adict[str, jax.Array]. The result is a newChildrenview of the per-child outputs whose fields areChildrenAxisproxies sharing the input segments, so trailing reductions (msgs.field.sum(0)/.mean(0)/ …) dispatch tojax.ops.segment_*.This lets a sweep apply a non-linear per-child transform and then fuse with the same body regardless of degree (equal trees are just the special case where all segments have the same size).