hyperiax.Node

class hyperiax.Node(fields)[source]

Bases: _FieldsView

Sliced per-node fields for one level (or one selected subset).

Each attribute is a JAX array of shape (scope_size, *trailing) where scope_size is the number of nodes in this dispatch scope (typically a level’s non-leaves).

Parameters:

fields (Mapping[str, jax.Array])

__init__(fields)
Parameters:

fields (Mapping[str, Array])

Methods

__init__(fields)