liblaf.jarp.tree
¶
Helpers for defining, flattening, and transforming JAX PyTrees.
Most users start with define,
frozen, field specifiers such as
array and static, and
ravel. Lower-level partitioning, registration, and
code-generation helpers remain available for custom integrations. Importing
this package also registers JAX adapters for bound methods and warp.array.
Modules:
-
attrs–attrshelpers for classes that should behave like JAX PyTrees. -
codegen–Code-generation helpers for high-performance PyTree registrations.
-
prelude–PyTree-aware wrappers for callables and transparent object proxies.
Classes:
-
AuxData–Store the static part of a partitioned PyTree.
-
FieldType–Describe how a field participates in PyTree flattening.
-
Partial–Store a partially applied callable as a PyTree-aware proxy.
-
PyTreeProxy–Wrap an arbitrary object and flatten the wrapped value as a PyTree.
-
PyTreeType–Choose how a class should participate in JAX PyTree flattening.
-
Structure–Record how to flatten and rebuild a PyTree's dynamic leaves.
Functions:
-
array–Create a data field whose default is normalized to a JAX array.
-
auto–Create a field whose PyTree role is chosen from the runtime value.
-
codegen_pytree_functions–Generate flatten and unflatten callbacks for a class.
-
combine–Rebuild a PyTree from dynamic leaves and recorded metadata.
-
combine_leaves–Merge dynamic leaves back together with their static counterparts.
-
define–Define an
attrsclass and optionally register it as a PyTree. -
field–Create an
attrsfield using jarp'sstaticmetadata convention. -
frozen–Define a frozen
attrsclass and register it as a data PyTree. -
frozen_static–Define a frozen
attrsclass and register it as a static PyTree. -
is_data–Return whether a value stays on the dynamic side of a partition.
-
is_leaf–Return whether a leaf contributes numeric data to a flat vector.
-
partial–Partially apply a callable and keep bound values visible to JAX trees.
-
partition–Split a PyTree into dynamic leaves and static metadata.
-
partition_leaves–Separate raw tree leaves into dynamic leaves and metadata leaves.
-
ravel–Flatten a PyTree's dynamic leaves into one vector.
-
register_fieldz–Register an
attrsclass with JAX using field metadata. -
register_generic–Register a class as a PyTree using explicit field groups.
-
select–Select among matching PyTrees with
jax.numpy.select. -
static–Create a field that is always treated as static metadata.
-
where–Choose between matching PyTrees with
jax.numpy.where.
AuxData
¶
FieldType
¶
Partial
¶
Bases: PartialCallableObjectProxy
flowchart TD
liblaf.jarp.tree.Partial[Partial]
click liblaf.jarp.tree.Partial href "" "liblaf.jarp.tree.Partial"
Store a partially applied callable as a PyTree-aware proxy.
Bound arguments and keyword arguments flatten as PyTree children, while the wrapped callable itself is partitioned between dynamic data and static metadata when needed.
Examples:
>>> import jax
>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> def add(left, right):
... return left + right
>>> part = jarp.partial(add, jnp.array([1, 2]))
>>> leaves, _treedef = jax.tree.flatten(part)
>>> [leaf.tolist() for leaf in leaves]
[[1, 2]]
>>> part(jnp.array([3, 4])).tolist()
[4, 6]
Methods:
-
__call__–
Attributes:
-
__wrapped__(Callable[..., T]) –
Source code in src/liblaf/jarp/tree/prelude/_partial.py
PyTreeProxy
¶
Bases: BaseObjectProxy
flowchart TD
liblaf.jarp.tree.PyTreeProxy[PyTreeProxy]
click liblaf.jarp.tree.PyTreeProxy href "" "liblaf.jarp.tree.PyTreeProxy"
Wrap an arbitrary object and flatten the wrapped value as a PyTree.
The proxy itself stays transparent while JAX sees the wrapped object's PyTree structure.
Attributes:
-
__wrapped__(T) –
PyTreeType
¶
Bases: StrEnum
flowchart TD
liblaf.jarp.tree.PyTreeType[PyTreeType]
click liblaf.jarp.tree.PyTreeType href "" "liblaf.jarp.tree.PyTreeType"
Choose how a class should participate in JAX PyTree flattening.
Attributes:
Structure
¶
Record how to flatten and rebuild a PyTree's dynamic leaves.
Instances are returned by ravel and capture the
original tree definition, the static leaves that were removed from the flat
vector, and the offsets needed to reconstruct each dynamic leaf.
Parameters:
-
dtype(str | type[Any] | dtype | SupportsDType) – -
meta_leaves(tuple[Any, ...]) – -
offsets(tuple[int, ...]) – -
shapes(tuple[Shape | None, ...]) – -
treedef(PyTreeDef) –
Methods:
-
ravel–Flatten a compatible tree or flatten an array directly.
-
unravel–Rebuild the original tree shape from a flat vector.
Attributes:
-
dtype(DTypeLike) – -
is_leaf(bool) –Return whether the recorded tree was a single leaf.
-
meta_leaves(tuple[Any, ...]) – -
offsets(tuple[int, ...]) – -
shapes(tuple[Shape | None, ...]) – -
treedef(PyTreeDef) –
ravel
¶
Flatten a compatible tree or flatten an array directly.
Parameters:
-
tree(T | Array) –A tree with the same structure and static leaves used to build this
Structure, or a JAX array that should be flattened directly.
Returns:
-
Array1D–A one-dimensional array containing the dynamic leaves.
Source code in src/liblaf/jarp/tree/_ravel.py
unravel
¶
Rebuild the original tree shape from a flat vector.
Parameters:
-
flat(T | Array) –One-dimensional data produced by
ravel, or a tree that already matches the recorded structure. -
dtype(DTypeLike | None, default:None) –Optional dtype override applied to the flat array before it is split and reshaped.
Returns:
-
T–A tree with the same structure and static metadata as the original
-
T–input to
ravel.
Source code in src/liblaf/jarp/tree/_ravel.py
array
¶
array(
*,
default: T = ...,
validator: _ValidatorArgType[T] | None = ...,
repr: _ReprArgType = ...,
hash: bool | None = ...,
init: bool = ...,
metadata: Mapping[Any, Any] | None = ...,
converter: _ConverterType
| list[_ConverterType]
| tuple[_ConverterType, ...]
| None = ...,
factory: Callable[[], T] | None = ...,
kw_only: bool | None = ...,
eq: _EqOrderType | None = ...,
order: _EqOrderType | None = ...,
on_setattr: _OnSetAttrArgType | None = ...,
alias: str | None = ...,
type: type | None = ...,
static: FieldType | bool | None = ...,
) -> Array
Create a data field whose default is normalized to a JAX array.
When default is a concrete array-like value, array rewrites it into
a factory so each instance receives its own array object.
Parameters:
-
default(T, default:...) – -
validator(_ValidatorArgType[T] | None, default:...) – -
repr(_ReprArgType, default:...) – -
hash(bool | None, default:...) – -
init(bool, default:...) – -
metadata(Mapping[Any, Any] | None, default:...) – -
converter(_ConverterType | list[_ConverterType] | tuple[_ConverterType, ...] | None, default:...) – -
factory(Callable[[], T] | None, default:...) – -
kw_only(bool | None, default:...) – -
eq(_EqOrderType | None, default:...) – -
order(_EqOrderType | None, default:...) – -
on_setattr(_OnSetAttrArgType | None, default:...) – -
alias(str | None, default:...) – -
type(type | None, default:...) – -
static(FieldType | bool | None, default:...) –
Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
codegen_pytree_functions
¶
codegen_pytree_functions(
cls: type,
data_fields: Sequence[str] = (),
meta_fields: Sequence[str] = (),
auto_fields: Sequence[str] = (),
*,
filter_spec: Callable[[Any], bool] = is_data,
bypass_setattr: bool | None = None,
) -> PyTreeFunctions
Generate flatten and unflatten callbacks for a class.
Parameters:
-
cls(type) –Class whose instances should become PyTree nodes.
-
data_fields(Sequence[str], default:()) –Field names that are always emitted as dynamic children.
-
meta_fields(Sequence[str], default:()) –Field names that are always emitted as auxiliary metadata.
-
auto_fields(Sequence[str], default:()) –Field names filtered at runtime with
filter_spec. -
filter_spec(Callable[[Any], bool], default:is_data) –Predicate used to split
auto_fieldsinto dynamic data or metadata. -
bypass_setattr(bool | None, default:None) –Whether generated unflattening code should use
object.__setattr__instead of normal attribute assignment.
Returns:
-
PyTreeFunctions–A
PyTreeFunctionstuple -
PyTreeFunctions–containing
flatten,unflatten, andflatten_with_keyscallables.
Source code in src/liblaf/jarp/tree/codegen/_compile.py
combine
¶
Rebuild a PyTree from dynamic leaves and recorded metadata.
Source code in src/liblaf/jarp/tree/_filters.py
combine_leaves
¶
Merge dynamic leaves back together with their static counterparts.
Source code in src/liblaf/jarp/tree/_filters.py
define
¶
Define an attrs class and optionally register it as a PyTree.
Parameters:
-
maybe_cls(T | None, default:None) –Class being decorated. When omitted, return a configured decorator.
-
**kwargs(Any, default:{}) –Options forwarded to
attrs.define, pluspytreeto control JAX registration.pytree="data"registers fields withfieldzsemantics,"static"registers the whole instance as a static value, and"none"leaves the class unregistered.
Returns:
-
Any–The decorated class or a class decorator.
Source code in src/liblaf/jarp/tree/attrs/_define.py
field
¶
field(**kwargs) -> Any
Create an attrs field using jarp's static metadata convention.
Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
frozen
¶
Define a frozen attrs class and register it as a data PyTree.
This is the common choice for immutable structures whose array fields should participate in JAX transformations.
Source code in src/liblaf/jarp/tree/attrs/_define.py
frozen_static
¶
Define a frozen attrs class and register it as a static PyTree.
Use this for immutable helper objects that should be treated as static metadata instead of flattening into JAX leaves.
Source code in src/liblaf/jarp/tree/attrs/_define.py
is_data
¶
Return whether a value stays on the dynamic side of a partition.
Dynamic values include JAX arrays, None placeholders, and objects whose
type already has a JAX PyTree registration. Everything else is treated as
static metadata by partition.
Source code in src/liblaf/jarp/tree/_filters.py
is_leaf
¶
Return whether a leaf contributes numeric data to a flat vector.
This is intentionally narrower than is_data:
only arrays and None participate in the flat-vector protocol used by
liblaf.jarp.ravel.
Source code in src/liblaf/jarp/tree/_filters.py
partial
¶
Partially apply a callable and keep bound values visible to JAX trees.
partition
¶
Split a PyTree into dynamic leaves and static metadata.
The returned leaf list preserves tree order. Non-dynamic positions become
None in the data list and are stored in the accompanying
AuxData.
Source code in src/liblaf/jarp/tree/_filters.py
partition_leaves
¶
Separate raw tree leaves into dynamic leaves and metadata leaves.
Source code in src/liblaf/jarp/tree/_filters.py
ravel
¶
Flatten a PyTree's dynamic leaves into one vector.
Non-array leaves are treated as static metadata and preserved in the
returned Structure instead of being
concatenated into the flat array.
Examples:
>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> flat, structure = jarp.ravel({"x": jnp.array([1.0, 2.0]), "tag": "train"})
>>> flat.tolist()
[1.0, 2.0]
>>> rebuilt = structure.unravel(jnp.array([3.0, 4.0]))
>>> rebuilt["x"].tolist(), rebuilt["tag"]
([3.0, 4.0], 'train')
Parameters:
-
tree(T) –PyTree to flatten.
Returns:
-
Array–A tuple of
(flat, structure)whereflatis a one-dimensional JAX -
Structure[T]–array and
structurecan rebuild compatible trees later.
Source code in src/liblaf/jarp/tree/_ravel.py
register_fieldz
¶
register_fieldz[T: type](
cls: T,
data_fields: Sequence[str] | None = None,
meta_fields: Sequence[str] | None = None,
auto_fields: Sequence[str] | None = None,
*,
filter_spec: Callable[[Any], bool] = is_data,
bypass_setattr: bool | None = None,
) -> T
Register an attrs class with JAX using field metadata.
Field groups default to the metadata written by
array, auto, and
static. Pass explicit field lists when you
need to register a class that was not declared with liblaf.jarp field
helpers.
Parameters:
-
cls(T) –Class to register.
-
data_fields(Sequence[str] | None, default:None) –Field names that are always treated as dynamic children.
-
meta_fields(Sequence[str] | None, default:None) –Field names that are always treated as static metadata.
-
auto_fields(Sequence[str] | None, default:None) –Field names filtered at runtime with
filter_spec. -
filter_spec(Callable[[Any], bool], default:is_data) –Predicate used to split
auto_fieldsinto dynamic data or metadata. -
bypass_setattr(bool | None, default:None) –Whether generated unflattening code should use
object.__setattr__instead of normal attribute assignment.
Returns:
-
T–The same class object, for decorator-style usage.
Source code in src/liblaf/jarp/tree/attrs/_register.py
register_generic
¶
register_generic(
cls: type,
data_fields: Sequence[str] = (),
meta_fields: Sequence[str] = (),
auto_fields: Sequence[str] = (),
*,
filter_spec: Callable[[Any], bool] = is_data,
bypass_setattr: bool | None = None,
) -> None
Register a class as a PyTree using explicit field groups.
Use this lower-level helper when you want to control the flattening layout directly instead of relying on attrs metadata.
Parameters:
-
cls(type) –Class to register.
-
data_fields(Sequence[str], default:()) –Field names that are always emitted as dynamic children.
-
meta_fields(Sequence[str], default:()) –Field names that are always emitted as auxiliary metadata.
-
auto_fields(Sequence[str], default:()) –Field names filtered at runtime with
filter_spec. -
filter_spec(Callable[[Any], bool], default:is_data) –Predicate used to split
auto_fieldsinto dynamic data or metadata. -
bypass_setattr(bool | None, default:None) –Whether generated unflattening code should use
object.__setattr__instead of normal attribute assignment.
Source code in src/liblaf/jarp/tree/codegen/_compile.py
select
¶
Select among matching PyTrees with jax.numpy.select.
Each leaf is selected independently with ordered conditions. The first true
condition at each position selects the corresponding choice leaf; default
supplies the leaf where no condition is true.
Parameters:
-
condlist(Sequence[Bool[ArrayLike, ' ...']]) –Non-empty sequence of boolean scalar or array-like conditions.
-
choicelist(Sequence[T]) –PyTrees to choose from. It must have the same length as
condlist, and every choice must have the same tree structure asdefault. -
default(T) –PyTree returned where no condition is true.
Returns:
-
T–A PyTree with the same structure as
default.
Raises:
-
ValueError–If
condlistis empty or its length does not matchchoicelist.
Examples:
>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> result = jarp.tree.select(
... [jnp.array([False, True, False]), jnp.array([True, True, False])],
... [{"value": jnp.array([1, 1, 1])}, {"value": jnp.array([2, 2, 2])}],
... {"value": jnp.array([9, 9, 9])},
... )
>>> result["value"].tolist()
[2, 1, 9]
Source code in src/liblaf/jarp/tree/_ops.py
static
¶
static(**kwargs) -> Any
Create a field that is always treated as static metadata.
Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
where
¶
Choose between matching PyTrees with jax.numpy.where.
Parameters:
-
condition(Bool[ArrayLike, ' ...']) –Boolean scalar or array-like condition.
-
x(T) –PyTree used where
conditionis true. -
y(T) –PyTree used where
conditionis false.
Returns:
-
T–A PyTree with the same structure as
xandy.
Examples:
>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> result = jarp.tree.where(
... jnp.array([True, False]),
... {"value": jnp.array([1, 2])},
... {"value": jnp.array([3, 4])},
... )
>>> result["value"].tolist()
[1, 4]