hyperiax.ChildrenAxis¶
- class hyperiax.ChildrenAxis(flat, segments, num_segments, trailing)[source]¶
Bases:
objectVirtual children-axis proxy for unequal-degree trees.
Backs a flat
(M_total, *trailing)JAX array (every child at the current level concatenated) plus asegmentsarray assigning each row to a parent.num_segmentsis a static Python int (derived from the topology) — required forjax.ops.segment_*to produce a statically-shaped output.The user calls one of
.sum / .prod / .max / .min / .mean (axis=0)to reduce over the children axis. Each dispatches to the matching segment-reduction. Output shape is(num_segments, *trailing), matching the correspondingNodeview at the same level.Non-reduction ops (indexing, broadcasted arithmetic, NumPy coercion) are deliberately rejected: they would either silently produce wrong results on the flat layout or imply a padded dense form the user should request explicitly via
Children.gather()(TODO).Methods
__init__(flat, segments, num_segments, trailing)max([axis])mean([axis])min([axis])prod([axis])sum([axis])Attributes
The backing flat
(M_total, *trailing)array (all children of the level concatenated, in node order).- property flat: Array¶
The backing flat
(M_total, *trailing)array (all children of the level concatenated, in node order). Used byChildren.map()and by the dispatcher’s per-child (writes_children) scatter.