hyperiax.Children

class hyperiax.Children(fields)[source]

Bases: _FieldsView

View of the children of all parents at the current up-sweep level.

Each attribute is a ChildrenAxis proxy backed by a flat (M_total, *trailing) block (every child at the level concatenated) plus segment ids. The proxy exposes:

  • reductions .sum / .max / .min / .prod / .mean (axis=0) — dispatched to jax.ops.segment_*, yielding per-parent results;

  • map() — a segment-preserving per-child jax.vmap, so a non-linear per-child transform feeds the same reduction surface.

Direct array ops (indexing, arithmetic, __array__) are deliberately rejected: use children.map(fn) for any non-reduction per-child work.

Parameters:

fields (Mapping[str, jax.Array])

__init__(fields)
Parameters:

fields (Mapping[str, Array])

Methods

__init__(fields)

map(fn)

Apply fn to each child, preserving the children axis.

map(fn)[source]

Apply fn to each child, preserving the children axis.

fn receives a per-child Node view (the children fields, one child at a time) and returns a dict[str, jax.Array]. The result is a new Children view of the per-child outputs whose fields are ChildrenAxis proxies sharing the input segments, so trailing reductions (msgs.field.sum(0) / .mean(0) / …) dispatch to jax.ops.segment_*.

This lets a sweep apply a non-linear per-child transform and then fuse with the same body regardless of degree (equal trees are just the special case where all segments have the same size).