liblaf.jarp.tree.attrs
¶
attrs helpers for classes that should behave like JAX PyTrees.
The decorators and field specifiers in liblaf.jarp.tree.attrs wrap attrs
while recording which fields should flatten as dynamic data, remain static
metadata, or be decided from the runtime value.
Classes:
-
FieldType–Describe how a field participates in PyTree flattening.
-
PyTreeType–Choose how a class should participate in JAX PyTree flattening.
Functions:
-
array–Create a data field whose default is normalized to a JAX array.
-
auto–Create a field whose PyTree role is chosen from the runtime value.
-
define–Define an
attrsclass and optionally register it as a PyTree. -
field–Create an
attrsfield using jarp'sstaticmetadata convention. -
frozen–Define a frozen
attrsclass and register it as a data PyTree. -
frozen_static–Define a frozen
attrsclass and register it as a static PyTree. -
register_fieldz–Register an
attrsclass with JAX using field metadata. -
static–Create a field that is always treated as static metadata.
FieldType
¶
PyTreeType
¶
Bases: StrEnum
flowchart TD
liblaf.jarp.tree.attrs.PyTreeType[PyTreeType]
click liblaf.jarp.tree.attrs.PyTreeType href "" "liblaf.jarp.tree.attrs.PyTreeType"
Choose how a class should participate in JAX PyTree flattening.
Attributes:
array
¶
array(
*,
default: T = ...,
validator: _ValidatorArgType[T] | None = ...,
repr: _ReprArgType = ...,
hash: bool | None = ...,
init: bool = ...,
metadata: Mapping[Any, Any] | None = ...,
converter: _ConverterType
| list[_ConverterType]
| tuple[_ConverterType, ...]
| None = ...,
factory: Callable[[], T] | None = ...,
kw_only: bool | None = ...,
eq: _EqOrderType | None = ...,
order: _EqOrderType | None = ...,
on_setattr: _OnSetAttrArgType | None = ...,
alias: str | None = ...,
type: type | None = ...,
static: FieldType | bool | None = ...,
) -> Array
Create a data field whose default is normalized to a JAX array.
When default is a concrete array-like value, array rewrites it into
a factory so each instance receives its own array object.
Parameters:
-
default(T, default:...) – -
validator(_ValidatorArgType[T] | None, default:...) – -
repr(_ReprArgType, default:...) – -
hash(bool | None, default:...) – -
init(bool, default:...) – -
metadata(Mapping[Any, Any] | None, default:...) – -
converter(_ConverterType | list[_ConverterType] | tuple[_ConverterType, ...] | None, default:...) – -
factory(Callable[[], T] | None, default:...) – -
kw_only(bool | None, default:...) – -
eq(_EqOrderType | None, default:...) – -
order(_EqOrderType | None, default:...) – -
on_setattr(_OnSetAttrArgType | None, default:...) – -
alias(str | None, default:...) – -
type(type | None, default:...) – -
static(FieldType | bool | None, default:...) –
Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
define
¶
Define an attrs class and optionally register it as a PyTree.
Parameters:
-
maybe_cls(T | None, default:None) –Class being decorated. When omitted, return a configured decorator.
-
**kwargs(Any, default:{}) –Options forwarded to
attrs.define, pluspytreeto control JAX registration.pytree="data"registers fields withfieldzsemantics,"static"registers the whole instance as a static value, and"none"leaves the class unregistered.
Returns:
-
Any–The decorated class or a class decorator.
Source code in src/liblaf/jarp/tree/attrs/_define.py
field
¶
field(**kwargs) -> Any
Create an attrs field using jarp's static metadata convention.
Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
frozen
¶
Define a frozen attrs class and register it as a data PyTree.
This is the common choice for immutable structures whose array fields should participate in JAX transformations.
Source code in src/liblaf/jarp/tree/attrs/_define.py
frozen_static
¶
Define a frozen attrs class and register it as a static PyTree.
Use this for immutable helper objects that should be treated as static metadata instead of flattening into JAX leaves.
Source code in src/liblaf/jarp/tree/attrs/_define.py
register_fieldz
¶
register_fieldz[T: type](
cls: T,
data_fields: Sequence[str] | None = None,
meta_fields: Sequence[str] | None = None,
auto_fields: Sequence[str] | None = None,
*,
filter_spec: Callable[[Any], bool] = is_data,
bypass_setattr: bool | None = None,
) -> T
Register an attrs class with JAX using field metadata.
Field groups default to the metadata written by
array, auto, and
static. Pass explicit field lists when you
need to register a class that was not declared with liblaf.jarp field
helpers.
Parameters:
-
cls(T) –Class to register.
-
data_fields(Sequence[str] | None, default:None) –Field names that are always treated as dynamic children.
-
meta_fields(Sequence[str] | None, default:None) –Field names that are always treated as static metadata.
-
auto_fields(Sequence[str] | None, default:None) –Field names filtered at runtime with
filter_spec. -
filter_spec(Callable[[Any], bool], default:is_data) –Predicate used to split
auto_fieldsinto dynamic data or metadata. -
bypass_setattr(bool | None, default:None) –Whether generated unflattening code should use
object.__setattr__instead of normal attribute assignment.
Returns:
-
T–The same class object, for decorator-style usage.