hyperiax.ChildrenAxis

class hyperiax.ChildrenAxis(flat, segments, num_segments, trailing)[source]

Bases: object

Virtual children-axis proxy for unequal-degree trees.

Backs a flat (M_total, *trailing) JAX array (every child at the current level concatenated) plus a segments array assigning each row to a parent. num_segments is a static Python int (derived from the topology) — required for jax.ops.segment_* to produce a statically-shaped output.

The user calls one of .sum / .prod / .max / .min / .mean (axis=0) to reduce over the children axis. Each dispatches to the matching segment-reduction. Output shape is (num_segments, *trailing), matching the corresponding Node view at the same level.

Non-reduction ops (indexing, broadcasted arithmetic, NumPy coercion) are deliberately rejected: they would either silently produce wrong results on the flat layout or imply a padded dense form the user should request explicitly via Children.gather() (TODO).

Parameters:
__init__(flat, segments, num_segments, trailing)[source]
Parameters:
  • flat (Array)

  • segments (Array)

  • num_segments (int)

  • trailing (tuple)

Methods

__init__(flat, segments, num_segments, trailing)

max([axis])

mean([axis])

min([axis])

prod([axis])

sum([axis])

Attributes

flat

The backing flat (M_total, *trailing) array (all children of the level concatenated, in node order).

property flat: Array

The backing flat (M_total, *trailing) array (all children of the level concatenated, in node order). Used by Children.map() and by the dispatcher’s per-child (writes_children) scatter.

sum(axis=0)[source]
Return type:

Array

Parameters:

axis (int)

prod(axis=0)[source]
Return type:

Array

Parameters:

axis (int)

max(axis=0)[source]
Return type:

Array

Parameters:

axis (int)

min(axis=0)[source]
Return type:

Array

Parameters:

axis (int)

mean(axis=0)[source]
Return type:

Array

Parameters:

axis (int)