Skip to content

liblaf.jarp

Utilities for mixed JAX PyTrees and NVIDIA Warp interop.

The top-level package re-exports the filtered call wrappers filter_jit and fallback_jit, selected liblaf.jarp.lax helpers, the most common helpers from liblaf.jarp.tree, and Warp integration utilities such as to_warp, struct, jax_callable, and jax_kernel. Import liblaf.jarp.lax, liblaf.jarp.tree, or liblaf.jarp.warp directly when you need the larger submodule surfaces.

Modules:

  • lax

    Control-flow wrappers and ordered condition helpers.

  • tree

    Helpers for defining, flattening, and transforming JAX PyTrees.

  • warp

    Interop helpers between JAX arrays and NVIDIA Warp.

Classes:

  • Enum

    JAX-compatible enum base class with traceable integer values.

  • Partial

    Store a partially applied callable as a PyTree-aware proxy.

  • PyTreeProxy

    Wrap an arbitrary object and flatten the wrapped value as a PyTree.

  • Structure

    Record how to flatten and rebuild a PyTree's dynamic leaves.

Functions:

  • array

    Create a data field whose default is normalized to a JAX array.

  • auto

    Create a field whose PyTree role is chosen from the runtime value.

  • cond

    Choose between two branches, then retry eagerly if JAX rejects them.

  • define

    Define an attrs class and optionally register it as a PyTree.

  • fallback_jit

    Wrap a callable and cache Python fallbacks for failing metadata shapes.

  • field

    Create an attrs field using jarp's static metadata convention.

  • filter_jit

    Wrap a callable with liblaf.jarp data-versus-metadata partitioning.

  • first_true_index

    Return the index of the first true condition.

  • fori_loop

    Run a counted loop, then retry in Python if JAX rejects the body.

  • frozen

    Define a frozen attrs class and register it as a data PyTree.

  • frozen_static

    Define a frozen attrs class and register it as a static PyTree.

  • jax_callable

    Wrap warp.jax_experimental.jax_callable with optional dtype dispatch.

  • jax_kernel

    Wrap warp.jax_experimental.jax_kernel with optional overload lookup.

  • partial

    Partially apply a callable and keep bound values visible to JAX trees.

  • ravel

    Flatten a PyTree's dynamic leaves into one vector.

  • static

    Create a field that is always treated as static metadata.

  • struct

    Decorate a class as a Warp struct.

  • switch

    Choose one branch by index, then retry eagerly if JAX rejects it.

  • to_warp

    Convert a supported array object into a [warp.array][].

  • while_loop

    Run a loop, then retry in Python if JAX rejects the callbacks.

Attributes:

__commit_id__ module-attribute

__commit_id__: str | None = None

__version__ module-attribute

__version__: str = '0.2.2.dev2+g759ba685d'

__version_tuple__ module-attribute

__version_tuple__: tuple[int | str, ...] = (
    0,
    2,
    2,
    "dev2",
    "g759ba685d",
)

Enum

Bases: Enum


              flowchart TD
              liblaf.jarp.Enum[Enum]

              

              click liblaf.jarp.Enum href "" "liblaf.jarp.Enum"
            

JAX-compatible enum base class with traceable integer values.

Enum behaves like enum.Enum, but its value is a JAX array leaf. Subclasses are registered as PyTrees, so enum state can travel through jax.jit and jax.lax loops as dynamic data.

Array-valued results can represent several members at once. In that case the enum object's name is "<unknown>", while value remains the traceable integer array that JAX operates on.

Examples:

>>> import enum
>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> class Phase(jarp.Enum):
...     START = enum.auto()
...     RUNNING = enum.auto()
...     DONE = enum.auto()
>>> int(Phase.RUNNING.value)
1
>>> Phase.where(
...     jnp.array([True, False]), Phase.START, Phase.RUNNING
... ).value.tolist()
[0, 1]

Methods:

  • __eq__

    Compare members from the same enum class by value.

  • __hash__

    Return a hash that keeps enum classes distinct.

  • __init_subclass__

    Register each subclass as a keyed JAX PyTree.

  • select

    Select among enum-bearing PyTrees with ordered conditions.

  • tree_flatten

    Flatten the enum value as the only dynamic child.

  • tree_flatten_with_keys

    Flatten the enum value with a stable value path key.

  • tree_unflatten

    Rebuild an enum object from its flattened integer value.

  • where

    Choose between enum-bearing PyTrees leaf by leaf.

Attributes:

  • value (Integer[Array, '']) –

    Return the enum's dynamic integer value as an int32 JAX array.

value property

value: Integer[Array, '']

Return the enum's dynamic integer value as an int32 JAX array.

__eq__

__eq__(other: object) -> bool

Compare members from the same enum class by value.

Source code in src/liblaf/jarp/_enum.py
def __eq__(self, other: object) -> bool:
    """Compare members from the same enum class by value."""
    if not isinstance(other, type(self)):
        return NotImplemented
    return cast("bool", self.value == other.value)

__hash__

__hash__() -> int

Return a hash that keeps enum classes distinct.

Source code in src/liblaf/jarp/_enum.py
def __hash__(self) -> int:
    """Return a hash that keeps enum classes distinct."""
    return hash((type(self), int(self._value_)))

__init_subclass__

__init_subclass__(**kwargs) -> None

Register each subclass as a keyed JAX PyTree.

Source code in src/liblaf/jarp/_enum.py
def __init_subclass__(cls, **kwargs) -> None:
    """Register each subclass as a keyed JAX PyTree."""
    super().__init_subclass__(**kwargs)
    jtu.register_pytree_with_keys_class(cls)

select staticmethod

select[T](
    condlist: Sequence[Bool[ArrayLike, " ..."]],
    choicelist: Sequence[T],
    default: T,
) -> T

Select among enum-bearing PyTrees with ordered conditions.

This delegates to tree.select. Conditions follow jax.numpy.select semantics: the first true condition at each position selects the corresponding choice, and default is used where no condition is true.

Parameters:

  • condlist (Sequence[Bool[ArrayLike, ' ...']]) –

    Non-empty sequence of boolean scalar or array-like conditions.

  • choicelist (Sequence[T]) –

    PyTrees to choose from. It must have the same length as condlist, and every choice must have the same tree structure as default.

  • default (T) –

    PyTree returned where no condition is true.

Returns:

  • T

    A PyTree with the same structure as default.

Raises:

  • ValueError

    If condlist is empty or its length does not match choicelist.

Examples:

>>> import enum
>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> class Phase(jarp.Enum):
...     START = enum.auto()
...     RUNNING = enum.auto()
...     DONE = enum.auto()
>>> result = Phase.select(
...     [jnp.array([False, True]), jnp.array([True, True])],
...     [Phase.START, Phase.RUNNING],
...     default=Phase.DONE,
... )
>>> result.value.tolist()
[1, 0]
Source code in src/liblaf/jarp/_enum.py
@staticmethod
def select[T](
    condlist: Sequence[Bool[ArrayLike, " ..."]], choicelist: Sequence[T], default: T
) -> T:
    """Select among enum-bearing PyTrees with ordered conditions.

    This delegates to [`tree.select`][liblaf.jarp.tree.select]. Conditions
    follow [`jax.numpy.select`][jax.numpy.select] semantics: the first true
    condition at each position selects the corresponding choice, and
    `default` is used where no condition is true.

    Args:
        condlist: Non-empty sequence of boolean scalar or array-like
            conditions.
        choicelist: PyTrees to choose from. It must have the same length as
            `condlist`, and every choice must have the same tree structure
            as `default`.
        default: PyTree returned where no condition is true.

    Returns:
        A PyTree with the same structure as `default`.

    Raises:
        ValueError: If `condlist` is empty or its length does not match
            `choicelist`.

    Examples:
        >>> import enum
        >>> import jax.numpy as jnp
        >>> from liblaf import jarp
        >>> class Phase(jarp.Enum):
        ...     START = enum.auto()
        ...     RUNNING = enum.auto()
        ...     DONE = enum.auto()
        >>> result = Phase.select(
        ...     [jnp.array([False, True]), jnp.array([True, True])],
        ...     [Phase.START, Phase.RUNNING],
        ...     default=Phase.DONE,
        ... )
        >>> result.value.tolist()
        [1, 0]
    """
    return tree.select(condlist, choicelist, default)

tree_flatten

tree_flatten() -> tuple[tuple[Integer[Array, '']], None]

Flatten the enum value as the only dynamic child.

Source code in src/liblaf/jarp/_enum.py
def tree_flatten(self) -> tuple[tuple[Integer[Array, ""]], None]:
    """Flatten the enum value as the only dynamic child."""
    child: Integer[Array, ""] = jnp.asarray(self.value, jnp.int32)
    return (child,), None

tree_flatten_with_keys

tree_flatten_with_keys() -> tuple[
    tuple[tuple[GetAttrKey, Integer[Array, ""]]], None
]

Flatten the enum value with a stable value path key.

Source code in src/liblaf/jarp/_enum.py
def tree_flatten_with_keys(
    self,
) -> tuple[tuple[tuple[jtu.GetAttrKey, Integer[Array, ""]]], None]:
    """Flatten the enum value with a stable `value` path key."""
    key: jtu.GetAttrKey = jtu.GetAttrKey("value")
    child: Integer[Array, ""] = jnp.asarray(self.value, jnp.int32)
    return ((key, child),), None

tree_unflatten classmethod

tree_unflatten(
    meta: None, data: tuple[Integer[Array, ""]]
) -> Self

Rebuild an enum object from its flattened integer value.

Source code in src/liblaf/jarp/_enum.py
@classmethod
def tree_unflatten(cls, meta: None, data: tuple[Integer[Array, ""]]) -> Self:
    """Rebuild an enum object from its flattened integer value."""
    del meta
    (value,) = data
    return cls._missing_(value)

where staticmethod

where[T](
    condition: Bool[ArrayLike, " ..."], x: T, y: T
) -> T

Choose between enum-bearing PyTrees leaf by leaf.

This delegates to tree.where. It applies jax.numpy.where to each matching pair of leaves in x and y.

Parameters:

  • condition (Bool[ArrayLike, ' ...']) –

    Boolean scalar or array-like condition.

  • x (T) –

    PyTree used where condition is true.

  • y (T) –

    PyTree used where condition is false.

Returns:

  • T

    A PyTree with the same structure as x and y.

Examples:

>>> import enum
>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> class Phase(jarp.Enum):
...     START = enum.auto()
...     RUNNING = enum.auto()
>>> Phase.where(
...     jnp.array([True, False]), Phase.START, Phase.RUNNING
... ).value.tolist()
[0, 1]
Source code in src/liblaf/jarp/_enum.py
@staticmethod
def where[T](condition: Bool[ArrayLike, " ..."], x: T, y: T) -> T:
    """Choose between enum-bearing PyTrees leaf by leaf.

    This delegates to [`tree.where`][liblaf.jarp.tree.where]. It applies
    [`jax.numpy.where`][jax.numpy.where] to each matching pair of leaves in
    `x` and `y`.

    Args:
        condition: Boolean scalar or array-like condition.
        x: PyTree used where `condition` is true.
        y: PyTree used where `condition` is false.

    Returns:
        A PyTree with the same structure as `x` and `y`.

    Examples:
        >>> import enum
        >>> import jax.numpy as jnp
        >>> from liblaf import jarp
        >>> class Phase(jarp.Enum):
        ...     START = enum.auto()
        ...     RUNNING = enum.auto()
        >>> Phase.where(
        ...     jnp.array([True, False]), Phase.START, Phase.RUNNING
        ... ).value.tolist()
        [0, 1]
    """
    return tree.where(condition, x, y)

Partial

Partial(
    func: Callable[..., T], /, *args: Any, **kwargs: Any
)

Bases: PartialCallableObjectProxy


              flowchart TD
              liblaf.jarp.Partial[Partial]

              

              click liblaf.jarp.Partial href "" "liblaf.jarp.Partial"
            

Store a partially applied callable as a PyTree-aware proxy.

Bound arguments and keyword arguments flatten as PyTree children, while the wrapped callable itself is partitioned between dynamic data and static metadata when needed.

Examples:

>>> import jax
>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> def add(left, right):
...     return left + right
>>> part = jarp.partial(add, jnp.array([1, 2]))
>>> leaves, _treedef = jax.tree.flatten(part)
>>> [leaf.tolist() for leaf in leaves]
[[1, 2]]
>>> part(jnp.array([3, 4])).tolist()
[4, 6]

Methods:

Attributes:

Source code in src/liblaf/jarp/tree/prelude/_partial.py
def __init__(self, func: Callable[..., T], /, *args: Any, **kwargs: Any) -> None:
    """Create a proxy that records bound arguments for PyTree flattening."""
    super().__init__(func, *args, **kwargs)
    self._self_args = args
    self._self_kwargs = kwargs

__wrapped__ instance-attribute

__wrapped__: Callable[..., T]

__call__

__call__(*args: P.args, **kwargs: P.kwargs) -> T
Source code in src/liblaf/jarp/tree/prelude/_partial.py
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...

PyTreeProxy

Bases: BaseObjectProxy


              flowchart TD
              liblaf.jarp.PyTreeProxy[PyTreeProxy]

              

              click liblaf.jarp.PyTreeProxy href "" "liblaf.jarp.PyTreeProxy"
            

Wrap an arbitrary object and flatten the wrapped value as a PyTree.

The proxy itself stays transparent while JAX sees the wrapped object's PyTree structure.

Attributes:

__wrapped__ instance-attribute

__wrapped__: T

Structure

Record how to flatten and rebuild a PyTree's dynamic leaves.

Instances are returned by ravel and capture the original tree definition, the static leaves that were removed from the flat vector, and the offsets needed to reconstruct each dynamic leaf.

Parameters:

Methods:

  • ravel

    Flatten a compatible tree or flatten an array directly.

  • unravel

    Rebuild the original tree shape from a flat vector.

Attributes:

dtype instance-attribute

dtype: DTypeLike

is_leaf property

is_leaf: bool

Return whether the recorded tree was a single leaf.

meta_leaves instance-attribute

meta_leaves: tuple[Any, ...]

offsets instance-attribute

offsets: tuple[int, ...]

shapes instance-attribute

shapes: tuple[Shape | None, ...]

treedef instance-attribute

treedef: PyTreeDef

ravel

ravel(tree: T | Array) -> Array1D

Flatten a compatible tree or flatten an array directly.

Parameters:

  • tree (T | Array) –

    A tree with the same structure and static leaves used to build this Structure, or a JAX array that should be flattened directly.

Returns:

  • Array1D

    A one-dimensional array containing the dynamic leaves.

Source code in src/liblaf/jarp/tree/_ravel.py
def ravel(self, tree: T | Array) -> Array1D:
    """Flatten a compatible tree or flatten an array directly.

    Args:
        tree: A tree with the same structure and static leaves used to
            build this [`Structure`][liblaf.jarp.tree.Structure], or a JAX
            array that should be flattened directly.

    Returns:
        A one-dimensional array containing the dynamic leaves.
    """
    if isinstance(tree, Array):
        # do not flatten if already flat
        return jnp.ravel(tree)
    leaves, treedef = jax.tree.flatten(tree)
    assert treedef == self.treedef
    data_leaves, meta_leaves = partition_leaves(leaves)
    assert tuple(meta_leaves) == self.meta_leaves
    return _ravel(data_leaves)

unravel

unravel(
    flat: T | Array, dtype: DTypeLike | None = None
) -> T

Rebuild the original tree shape from a flat vector.

Parameters:

  • flat (T | Array) –

    One-dimensional data produced by ravel, or a tree that already matches the recorded structure.

  • dtype (DTypeLike | None, default: None ) –

    Optional dtype override applied to the flat array before it is split and reshaped.

Returns:

  • T

    A tree with the same structure and static metadata as the original

  • T

    input to ravel.

Source code in src/liblaf/jarp/tree/_ravel.py
def unravel(self, flat: T | Array, dtype: DTypeLike | None = None) -> T:
    """Rebuild the original tree shape from a flat vector.

    Args:
        flat: One-dimensional data produced by
            [`ravel`][liblaf.jarp.tree.Structure.ravel], or a tree that
            already matches the recorded structure.
        dtype: Optional dtype override applied to the flat array before it
            is split and reshaped.

    Returns:
        A tree with the same structure and static metadata as the original
        input to [`ravel`][liblaf.jarp.tree.ravel].
    """
    if not isinstance(flat, Array):
        # do not unravel if already a pytree
        assert jax.tree.structure(flat) == self.treedef
        return cast("T", flat)
    flat: Array = jnp.asarray(flat, self.dtype if dtype is None else dtype)
    if self.is_leaf:
        if self.shapes[0] is None:
            assert jnp.size(flat) == 0
            return cast("T", self.meta_leaves[0])
        return cast("T", jnp.reshape(flat, self.shapes[0]))
    data_leaves: list[Array | None] = _unravel(flat, self.offsets, self.shapes)
    leaves: list[Any] = combine_leaves(data_leaves, self.meta_leaves)
    return jax.tree.unflatten(self.treedef, leaves)

array

array(
    *,
    default: T = ...,
    validator: _ValidatorArgType[T] | None = ...,
    repr: _ReprArgType = ...,
    hash: bool | None = ...,
    init: bool = ...,
    metadata: Mapping[Any, Any] | None = ...,
    converter: _ConverterType
    | list[_ConverterType]
    | tuple[_ConverterType, ...]
    | None = ...,
    factory: Callable[[], T] | None = ...,
    kw_only: bool | None = ...,
    eq: _EqOrderType | None = ...,
    order: _EqOrderType | None = ...,
    on_setattr: _OnSetAttrArgType | None = ...,
    alias: str | None = ...,
    type: type | None = ...,
    static: FieldType | bool | None = ...,
) -> Array

Create a data field whose default is normalized to a JAX array.

When default is a concrete array-like value, array rewrites it into a factory so each instance receives its own array object.

Parameters:

  • default (T, default: ... ) –
  • validator (_ValidatorArgType[T] | None, default: ... ) –
  • repr (_ReprArgType, default: ... ) –
  • hash (bool | None, default: ... ) –
  • init (bool, default: ... ) –
  • metadata (Mapping[Any, Any] | None, default: ... ) –
  • converter (_ConverterType | list[_ConverterType] | tuple[_ConverterType, ...] | None, default: ... ) –
  • factory (Callable[[], T] | None, default: ... ) –
  • kw_only (bool | None, default: ... ) –
  • eq (_EqOrderType | None, default: ... ) –
  • order (_EqOrderType | None, default: ... ) –
  • on_setattr (_OnSetAttrArgType | None, default: ... ) –
  • alias (str | None, default: ... ) –
  • type (type | None, default: ... ) –
  • static (FieldType | bool | None, default: ... ) –
Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
def array(**kwargs: Unpack[FieldOptions[Any]]) -> Array:
    """Create a data field whose default is normalized to a JAX array.

    When `default` is a concrete array-like value, `array` rewrites it into
    a factory so each instance receives its own array object.
    """
    if "default" in kwargs and "factory" not in kwargs:
        default: Any = kwargs["default"]
        if not (default is None or isinstance(default, attrs.Factory)):  # ty:ignore[invalid-argument-type]
            default: Array = jnp.asarray(default)
            kwargs.pop("default")
            kwargs["factory"] = lambda: default
    return field(**kwargs)  # ty:ignore[no-matching-overload]

auto

auto(**kwargs) -> Any

Create a field whose PyTree role is chosen from the runtime value.

Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def auto(**kwargs) -> Any:
    """Create a field whose PyTree role is chosen from the runtime value."""
    kwargs.setdefault("static", FieldType.AUTO)
    return field(**kwargs)

cond

cond[*Ts, T](
    pred: ScalarLike,
    true_fun: Callable[[*Ts], T],
    false_fun: Callable[[*Ts], T],
    *operands: *Ts,
) -> T

Choose between two branches, then retry eagerly if JAX rejects them.

The wrapper first calls jax.lax.cond. If that raises jax.errors.JAXTypeError or jax.errors.JAXIndexError, it logs the exception and reruns the selected branch in plain Python.

Parameters:

  • pred (ScalarLike) –

    Scalar predicate. Python truthiness decides which branch runs on the fallback path.

  • true_fun (Callable[[*Ts], T]) –

    Branch evaluated when pred is true.

  • false_fun (Callable[[*Ts], T]) –

    Branch evaluated when pred is false.

  • *operands (*Ts, default: () ) –

    Positional operands forwarded to the selected branch.

Returns:

  • T

    The value returned by the selected branch.

Source code in src/liblaf/jarp/lax/_control.py
@lax_wrapper(jax.lax.cond)
def cond[*Ts, T](
    pred: ScalarLike,
    true_fun: Callable[[*Ts], T],
    false_fun: Callable[[*Ts], T],
    *operands: *Ts,
) -> T:
    """Choose between two branches, then retry eagerly if JAX rejects them.

    The wrapper first calls [`jax.lax.cond`][jax.lax.cond]. If that raises
    [`jax.errors.JAXTypeError`][jax.errors.JAXTypeError] or
    [`jax.errors.JAXIndexError`][jax.errors.JAXIndexError], it logs the
    exception and reruns the selected branch in plain Python.

    Args:
        pred: Scalar predicate. Python truthiness decides which branch runs on
            the fallback path.
        true_fun: Branch evaluated when `pred` is true.
        false_fun: Branch evaluated when `pred` is false.
        *operands: Positional operands forwarded to the selected branch.

    Returns:
        The value returned by the selected branch.
    """
    if pred:
        return true_fun(*operands)
    return false_fun(*operands)

define

define[T: type](
    cls: T, /, **kwargs: Unpack[DefineOptions]
) -> T
define[T: type](
    cls: None = None, **kwargs: Unpack[DefineOptions]
) -> Callable[[T], T]

Define an attrs class and optionally register it as a PyTree.

Parameters:

  • maybe_cls (T | None, default: None ) –

    Class being decorated. When omitted, return a configured decorator.

  • **kwargs (Any, default: {} ) –

    Options forwarded to attrs.define, plus pytree to control JAX registration. pytree="data" registers fields with fieldz semantics, "static" registers the whole instance as a static value, and "none" leaves the class unregistered.

Returns:

  • Any

    The decorated class or a class decorator.

Source code in src/liblaf/jarp/tree/attrs/_define.py
def define[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define an `attrs` class and optionally register it as a PyTree.

    Args:
        maybe_cls: Class being decorated. When omitted, return a configured
            decorator.
        **kwargs: Options forwarded to [`attrs.define`][attrs.define], plus
            `pytree` to control JAX registration. `pytree="data"`
            registers fields with `fieldz` semantics, `"static"` registers
            the whole instance as a static value, and `"none"` leaves the
            class unregistered.

    Returns:
        The decorated class or a class decorator.
    """
    if maybe_cls is None:
        return functools.partial(define, **kwargs)
    pytree: PyTreeType = PyTreeType(kwargs.pop("pytree", None))
    frozen: bool = kwargs.get("frozen", False)
    if pytree is PyTreeType.STATIC and not frozen:
        warnings.warn(
            "Defining a static class that is not frozen may lead to unexpected behavior.",
            stacklevel=2,
        )
    cls: T = attrs.define(maybe_cls, **kwargs)  # ty:ignore[invalid-assignment]
    match pytree:
        case PyTreeType.DATA:
            register_fieldz(cls)
        case PyTreeType.STATIC:
            jtu.register_static(cls)
    return cls

fallback_jit

fallback_jit[F: Callable[..., Any]](
    fun: F, **kwargs: Unpack[FilterJitOptions]
) -> F
fallback_jit(
    fun: None = None, **kwargs: Unpack[FilterJitOptions]
) -> IdentityFunction

Wrap a callable and cache Python fallbacks for failing metadata shapes.

The wrapper first uses the same partitioned call path as filter_jit. If that path raises jax.errors.JAXTypeError or jax.errors.JAXIndexError, the exception is logged, the current static-metadata signature is marked as unsupported, and the original callable is invoked directly in Python. Later calls with the same static metadata skip the partitioned path and reuse the Python fallback immediately.

Examples:

>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> @jarp.fallback_jit
... def add_one(value):
...     return value + 1
>>> int(add_one(jnp.array(2)))
3

Parameters:

  • fun (Callable[P, T] | None, default: None ) –

    Callable to wrap. When omitted, return a configured decorator.

  • keep_unused (bool, default: ... ) –
  • device (Any | None, default: ... ) –
  • backend (str | None, default: ... ) –
  • inline (bool, default: ... ) –

Returns:

  • Callable

    The wrapped callable, or a decorator that produces one.

Source code in src/liblaf/jarp/_jit/_fallback_jit.py
def fallback_jit[**P, T](
    fun: Callable[P, T] | None = None, **kwargs: Unpack[FilterJitOptions]
) -> Callable:
    """Wrap a callable and cache Python fallbacks for failing metadata shapes.

    The wrapper first uses the same partitioned call path as
    [`filter_jit`][liblaf.jarp.filter_jit]. If that path raises
    [`jax.errors.JAXTypeError`][jax.errors.JAXTypeError] or
    [`jax.errors.JAXIndexError`][jax.errors.JAXIndexError], the exception is
    logged, the current static-metadata signature is marked as unsupported,
    and the original callable is invoked directly in Python. Later calls with
    the same static metadata skip the partitioned path and reuse the Python
    fallback immediately.

    Examples:
        >>> import jax.numpy as jnp
        >>> from liblaf import jarp
        >>> @jarp.fallback_jit
        ... def add_one(value):
        ...     return value + 1
        >>> int(add_one(jnp.array(2)))
        3

    Args:
        fun: Callable to wrap. When omitted, return a configured decorator.
        **kwargs: Options forwarded to [`jax.jit`][jax.jit] for the inner
            filtered callable.

    Returns:
        The wrapped callable, or a decorator that produces one.
    """
    if fun is None:
        return functools.partial(fallback_jit, **kwargs)
    fun_data, fun_meta = tree.partition(fun)
    inner: FilterInner = FilterInner(fun_meta=fun_meta)
    inner_jit: jax.stages.Wrapped = jax.jit(inner, **kwargs)
    outer: FallbackOuter = FallbackOuter(
        fun_data=fun_data, fun_meta=fun_meta, inner=inner_jit
    )
    functools.update_wrapper(outer, fun)
    return outer

field

field(**kwargs) -> Any

Create an attrs field using jarp's static metadata convention.

Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def field(**kwargs) -> Any:
    """Create an `attrs` field using jarp's `static` metadata convention."""
    if "static" in kwargs:
        kwargs["metadata"] = {
            "static": kwargs.pop("static"),
            **(kwargs.get("metadata") or {}),
        }
    return attrs.field(**kwargs)

filter_jit

filter_jit[F: Callable[..., Any]](
    fun: F, **kwargs: Unpack[FilterJitOptions]
) -> F
filter_jit(
    fun: None = None, **kwargs: Unpack[FilterJitOptions]
) -> IdentityFunction

Wrap a callable with liblaf.jarp data-versus-metadata partitioning.

The wrapper partitions the callable and each invocation's arguments with partition, rebuilds the original call shape, and partitions the return value again before handing it back. This keeps JAX arrays on the dynamic side of the partition while preserving ordinary Python metadata such as strings, bound methods, or configuration objects.

Examples:

>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> @jarp.filter_jit
... def scale(value, *, label):
...     assert label == "active"
...     return value * 2
>>> scale(jnp.array([1, 2]), label="active").tolist()
[2, 4]

Parameters:

  • fun (F | None, default: None ) –

    Callable to wrap. When omitted, return a configured decorator.

  • keep_unused (bool, default: ... ) –
  • device (Any | None, default: ... ) –
  • backend (str | None, default: ... ) –
  • inline (bool, default: ... ) –

Returns:

  • Callable[..., Any]

    The wrapped callable, or a decorator that produces one.

Source code in src/liblaf/jarp/_jit/_filter_jit.py
def filter_jit[F: Callable[..., Any]](
    fun: F | None = None, **kwargs: Unpack[FilterJitOptions]
) -> Callable[..., Any]:
    """Wrap a callable with `liblaf.jarp` data-versus-metadata partitioning.

    The wrapper partitions the callable and each invocation's arguments with
    [`partition`][liblaf.jarp.tree.partition], rebuilds the original call shape,
    and partitions the return value again before handing it back. This keeps
    JAX arrays on the dynamic side of the partition while preserving ordinary
    Python metadata such as strings, bound methods, or configuration objects.

    Examples:
        >>> import jax.numpy as jnp
        >>> from liblaf import jarp
        >>> @jarp.filter_jit
        ... def scale(value, *, label):
        ...     assert label == "active"
        ...     return value * 2
        >>> scale(jnp.array([1, 2]), label="active").tolist()
        [2, 4]

    Args:
        fun: Callable to wrap. When omitted, return a configured decorator.
        **kwargs: Options forwarded to [`jax.jit`][jax.jit] for the inner
            filtered callable.

    Returns:
        The wrapped callable, or a decorator that produces one.
    """
    if fun is None:
        return functools.partial(filter_jit, **kwargs)
    fun_data, fun_meta = tree.partition(fun)
    inner: FilterInner = FilterInner(fun_meta=fun_meta)
    inner_jit: InnerLike = jax.jit(inner, **kwargs)
    outer: FilterOuter = FilterOuter(fun_data=fun_data, inner=inner_jit)
    functools.update_wrapper(outer, fun)
    return outer

first_true_index

first_true_index(
    condlist: Sequence[ArrayLike],
) -> Integer[Array, "*shape"]

Return the index of the first true condition.

This is a small jax.numpy.select wrapper for cases where an ordered condition list should become integer labels. Each result value is the zero-based index of the first true condition at that position. When no condition is true, the result is len(condlist).

Parameters:

  • condlist (Sequence[ArrayLike]) –

    Non-empty ordered sequence of scalar or array-like boolean conditions. Array conditions follow jax.numpy.select broadcasting rules.

Returns:

  • Integer[Array, '*shape']

    A JAX integer array with the broadcast condition shape. Scalar

  • Integer[Array, '*shape']

    conditions return a zero-dimensional array.

Raises:

Examples:

>>> from liblaf.jarp.lax import first_true_index
>>> int(first_true_index([False, True, True]))
1
>>> int(first_true_index([False, False]))
2

Array conditions are evaluated elementwise.

>>> import jax.numpy as jnp
>>> first_true_index(
...     [
...         jnp.array([False, True, False, False]),
...         jnp.array([True, True, False, False]),
...         jnp.array([True, False, True, False]),
...     ]
... ).tolist()
[1, 0, 2, 3]
Source code in src/liblaf/jarp/lax/_extras.py
def first_true_index(condlist: Sequence[ArrayLike]) -> Integer[Array, "*shape"]:
    """Return the index of the first true condition.

    This is a small [`jax.numpy.select`][jax.numpy.select] wrapper for cases
    where an ordered condition list should become integer labels. Each result
    value is the zero-based index of the first true condition at that position.
    When no condition is true, the result is `len(condlist)`.

    Args:
        condlist: Non-empty ordered sequence of scalar or array-like boolean
            conditions. Array conditions follow `jax.numpy.select` broadcasting
            rules.

    Returns:
        A JAX integer array with the broadcast condition shape. Scalar
        conditions return a zero-dimensional array.

    Raises:
        ValueError: If `condlist` is empty.

    Examples:
        >>> from liblaf.jarp.lax import first_true_index
        >>> int(first_true_index([False, True, True]))
        1
        >>> int(first_true_index([False, False]))
        2

        Array conditions are evaluated elementwise.

        >>> import jax.numpy as jnp
        >>> first_true_index(
        ...     [
        ...         jnp.array([False, True, False, False]),
        ...         jnp.array([True, True, False, False]),
        ...         jnp.array([True, False, True, False]),
        ...     ]
        ... ).tolist()
        [1, 0, 2, 3]
    """
    return jnp.select(condlist, range(len(condlist)), default=len(condlist))

fori_loop

fori_loop[T](
    lower: int,
    upper: int,
    body_fun: Callable[[int, T], T],
    init_val: T,
    **kwargs: Any,
) -> T

Run a counted loop, then retry in Python if JAX rejects the body.

The wrapper first calls jax.lax.fori_loop. If that raises jax.errors.JAXTypeError or jax.errors.JAXIndexError, it logs the exception and runs an ordinary Python for loop instead.

Parameters:

  • lower (int) –

    Inclusive loop lower bound.

  • upper (int) –

    Exclusive loop upper bound.

  • body_fun (Callable[[int, T], T]) –

    Callback that receives the iteration index and current loop value, then returns the next loop value.

  • init_val (T) –

    Initial loop value.

  • **kwargs (Any, default: {} ) –

    Extra keyword arguments forwarded to jax.lax.fori_loop on the first attempt. They are ignored on the Python fallback path.

Returns:

  • T

    The final loop value.

Source code in src/liblaf/jarp/lax/_control.py
@lax_wrapper(jax.lax.fori_loop)
def fori_loop[T](
    lower: int,
    upper: int,
    body_fun: Callable[[int, T], T],
    init_val: T,
    **kwargs: Any,
) -> T:
    """Run a counted loop, then retry in Python if JAX rejects the body.

    The wrapper first calls [`jax.lax.fori_loop`][jax.lax.fori_loop]. If that
    raises [`jax.errors.JAXTypeError`][jax.errors.JAXTypeError] or
    [`jax.errors.JAXIndexError`][jax.errors.JAXIndexError], it logs the
    exception and runs an ordinary Python `for` loop instead.

    Args:
        lower: Inclusive loop lower bound.
        upper: Exclusive loop upper bound.
        body_fun: Callback that receives the iteration index and current loop
            value, then returns the next loop value.
        init_val: Initial loop value.
        **kwargs: Extra keyword arguments forwarded to
            [`jax.lax.fori_loop`][jax.lax.fori_loop] on the first attempt.
            They are ignored on the Python fallback path.

    Returns:
        The final loop value.
    """
    del kwargs
    val: T = init_val
    for i in range(lower, upper):
        val: T = body_fun(i, val)
    return val

frozen

frozen[T: type](
    cls: T, /, **kwargs: Unpack[DefineOptions]
) -> T
frozen[T: type](
    cls: None = None, /, **kwargs: Unpack[DefineOptions]
) -> Callable[[T], T]

Define a frozen attrs class and register it as a data PyTree.

This is the common choice for immutable structures whose array fields should participate in JAX transformations.

Source code in src/liblaf/jarp/tree/attrs/_define.py
def frozen[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define a frozen `attrs` class and register it as a data PyTree.

    This is the common choice for immutable structures whose array fields
    should participate in JAX transformations.
    """
    _warnings_hide = True
    if maybe_cls is None:
        return functools.partial(frozen, **kwargs)
    kwargs.setdefault("frozen", True)
    return define(maybe_cls, **kwargs)

frozen_static

frozen_static[T: type](
    cls: T, /, **kwargs: Unpack[DefineOptions]
) -> T
frozen_static[T: type](
    cls: None = None, /, **kwargs: Unpack[DefineOptions]
) -> Callable[[T], T]

Define a frozen attrs class and register it as a static PyTree.

Use this for immutable helper objects that should be treated as static metadata instead of flattening into JAX leaves.

Source code in src/liblaf/jarp/tree/attrs/_define.py
def frozen_static[T: type](maybe_cls: T | None = None, **kwargs: Any) -> Any:
    """Define a frozen `attrs` class and register it as a static PyTree.

    Use this for immutable helper objects that should be treated as static
    metadata instead of flattening into JAX leaves.
    """
    _warnings_hide = True
    if maybe_cls is None:
        return functools.partial(frozen_static, **kwargs)
    kwargs.setdefault("frozen", True)
    kwargs.setdefault("pytree", PyTreeType.STATIC)
    return define(maybe_cls, **kwargs)

jax_callable

jax_callable(
    func: _FfiCallableFunction,
    *,
    generic: Literal[False] = False,
    **kwargs: Unpack[JaxCallableOptions],
) -> FfiCallableProtocol
jax_callable(
    *,
    generic: Literal[False] = False,
    **kwargs: Unpack[JaxCallableOptions],
) -> Callable[[_FfiCallableFunction], FfiCallableProtocol]
jax_callable(
    func: _FfiCallableFactory,
    *,
    generic: Literal[True],
    **kwargs: Unpack[JaxCallableOptions],
) -> _FfiCallable
jax_callable(
    *,
    generic: Literal[True],
    **kwargs: Unpack[JaxCallableOptions],
) -> Callable[[_FfiCallableFactory], _FfiCallable]

Wrap warp.jax_experimental.jax_callable with optional dtype dispatch.

When generic=True, func is treated as a factory keyed by the Warp scalar dtypes inferred from the runtime JAX arguments. The factory output is cached, so repeated calls with the same dtype signature reuse the same Warp callable.

Parameters:

  • func (Callable | None, default: None ) –

    Warp callable function or factory. When omitted, return a decorator.

  • generic (bool, default: False ) –

    When true, func is treated as a factory that receives Warp scalar dtypes inferred from the runtime JAX arguments and returns a concrete Warp callable implementation.

  • num_outputs (int, default: ... ) –
  • graph_mode (GraphMode, default: ... ) –
  • vmap_method (VmapMethod | None, default: ... ) –
  • output_dims (dict[str, ShapeLike] | None, default: ... ) –
  • in_out_argnames (Iterable[str], default: ... ) –
  • stage_in_argnames (Iterable[str], default: ... ) –
  • stage_out_argnames (Iterable[str], default: ... ) –
  • graph_cache_max (int | None, default: ... ) –
  • module_preload_mode (ModulePreloadMode, default: ... ) –
  • has_side_effect (bool, default: ... ) –

Returns:

  • Any

    A callable compatible with JAX tracing, or a decorator producing one.

  • Any

    The callable returns the output arrays produced by Warp's FFI wrapper.

Source code in src/liblaf/jarp/warp/_jax_callable.py
def jax_callable(
    func: Callable | None = None,
    *,
    generic: bool = False,
    **kwargs: Unpack[JaxCallableOptions],
) -> Any:
    """Wrap `warp.jax_experimental.jax_callable` with optional dtype dispatch.

    When `generic=True`, `func` is treated as a factory keyed by the Warp
    scalar dtypes inferred from the runtime JAX arguments. The factory output is
    cached, so repeated calls with the same dtype signature reuse the same Warp
    callable.

    Args:
        func: Warp callable function or factory. When omitted, return a
            decorator.
        generic: When true, `func` is treated as a factory that receives Warp
            scalar dtypes inferred from the runtime JAX arguments and returns a
            concrete Warp callable implementation.
        **kwargs: Options forwarded to Warp's JAX callable adapter.

    Returns:
        A callable compatible with JAX tracing, or a decorator producing one.
        The callable returns the output arrays produced by Warp's FFI wrapper.
    """
    if func is None:
        return functools.partial(jax_callable, generic=generic, **kwargs)
    if not generic:
        return warp.jax_experimental.jax_callable(func, **kwargs)
    factory: _FfiCallableFactory = functools.lru_cache(func)
    return _FfiCallable(factory=factory, options=kwargs)

jax_kernel

jax_kernel(
    *,
    arg_types_factory: Callable[[WarpScalarDType], ArgTypes]
    | None = None,
    **kwargs: Unpack[JaxKernelOptions],
) -> Callable[[Callable], FfiKernelProtocol]
jax_kernel(
    kernel: Callable,
    *,
    arg_types_factory: Callable[[WarpScalarDType], ArgTypes]
    | None = None,
    **kwargs: Unpack[JaxKernelOptions],
) -> FfiKernelProtocol

Wrap warp.jax_experimental.jax_kernel with optional overload lookup.

When arg_types_factory is provided, the wrapper infers Warp scalar dtypes from the runtime JAX arguments, builds an overload signature, and resolves the corresponding Warp kernel before dispatch.

Parameters:

  • kernel (Callable | None, default: None ) –

    Warp kernel to expose to JAX. When omitted, return a decorator.

  • arg_types_factory (Callable[[WarpScalarDType], ArgTypes] | None, default: None ) –

    Optional callback that maps runtime Warp scalar dtypes to the overloaded kernel argument types expected by [warp.overload][].

  • num_outputs (int, default: ... ) –
  • vmap_method (VmapMethod, default: ... ) –
  • launch_dims (ShapeLike | None, default: ... ) –
  • output_dims (ShapeLike | dict[str, ShapeLike] | None, default: ... ) –
  • in_out_argnames (Iterable[str], default: ... ) –
  • module_preload_mode (ModulePreloadMode, default: ... ) –
  • enable_backward (bool, default: ... ) –

Returns:

  • Any

    A callable compatible with JAX tracing, or a decorator producing one.

  • Any

    The callable returns the output arrays produced by Warp's FFI wrapper.

Source code in src/liblaf/jarp/warp/_jax_kernel.py
def jax_kernel(
    kernel: Callable | None = None,
    *,
    arg_types_factory: Callable[[WarpScalarDType], ArgTypes] | None = None,
    **kwargs: Unpack[JaxKernelOptions],
) -> Any:
    """Wrap `warp.jax_experimental.jax_kernel` with optional overload lookup.

    When `arg_types_factory` is provided, the wrapper infers Warp scalar
    dtypes from the runtime JAX arguments, builds an overload signature, and
    resolves the corresponding Warp kernel before dispatch.

    Args:
        kernel: Warp kernel to expose to JAX. When omitted, return a decorator.
        arg_types_factory: Optional callback that maps runtime Warp scalar dtypes
            to the overloaded kernel argument types expected by
            [warp.overload][].
        **kwargs: Options forwarded to Warp's JAX kernel adapter.

    Returns:
        A callable compatible with JAX tracing, or a decorator producing one.
        The callable returns the output arrays produced by Warp's FFI wrapper.
    """
    if kernel is None:
        return functools.partial(
            jax_kernel, arg_types_factory=arg_types_factory, **kwargs
        )
    if arg_types_factory is None:
        return warp.jax_experimental.jax_kernel(kernel, **kwargs)
    return _FfiKernel(
        kernel=cast("wp.Kernel", kernel),
        options=kwargs,
        arg_types_factory=arg_types_factory,
    )

partial

partial[T](
    func: Callable[..., T], /, *args: Any, **kwargs: Any
) -> Partial[..., T]

Partially apply a callable and keep bound values visible to JAX trees.

Source code in src/liblaf/jarp/tree/prelude/_partial.py
def partial[T](func: Callable[..., T], /, *args: Any, **kwargs: Any) -> Partial[..., T]:
    """Partially apply a callable and keep bound values visible to JAX trees."""
    return Partial(func, *args, **kwargs)

ravel

ravel[T](tree: T) -> tuple[Array, Structure[T]]

Flatten a PyTree's dynamic leaves into one vector.

Non-array leaves are treated as static metadata and preserved in the returned Structure instead of being concatenated into the flat array.

Examples:

>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> flat, structure = jarp.ravel({"x": jnp.array([1.0, 2.0]), "tag": "train"})
>>> flat.tolist()
[1.0, 2.0]
>>> rebuilt = structure.unravel(jnp.array([3.0, 4.0]))
>>> rebuilt["x"].tolist(), rebuilt["tag"]
([3.0, 4.0], 'train')

Parameters:

  • tree (T) –

    PyTree to flatten.

Returns:

  • Array

    A tuple of (flat, structure) where flat is a one-dimensional JAX

  • Structure[T]

    array and structure can rebuild compatible trees later.

Source code in src/liblaf/jarp/tree/_ravel.py
def ravel[T](tree: T) -> tuple[Array, Structure[T]]:
    """Flatten a PyTree's dynamic leaves into one vector.

    Non-array leaves are treated as static metadata and preserved in the
    returned [`Structure`][liblaf.jarp.tree.Structure] instead of being
    concatenated into the flat array.

    Examples:
        >>> import jax.numpy as jnp
        >>> from liblaf import jarp
        >>> flat, structure = jarp.ravel({"x": jnp.array([1.0, 2.0]), "tag": "train"})
        >>> flat.tolist()
        [1.0, 2.0]
        >>> rebuilt = structure.unravel(jnp.array([3.0, 4.0]))
        >>> rebuilt["x"].tolist(), rebuilt["tag"]
        ([3.0, 4.0], 'train')

    Args:
        tree: PyTree to flatten.

    Returns:
        A tuple of `(flat, structure)` where `flat` is a one-dimensional JAX
        array and `structure` can rebuild compatible trees later.
    """
    leaves, treedef = jax.tree.flatten(tree)
    dynamic_leaves, static_leaves = partition_leaves(leaves)
    flat: Array = _ravel(dynamic_leaves)
    structure: Structure[T] = Structure(
        offsets=_offsets_from_leaves(dynamic_leaves),
        shapes=_shapes_from_leaves(dynamic_leaves),
        meta_leaves=tuple(static_leaves),
        treedef=treedef,
        dtype=flat.dtype,
    )
    return flat, structure

static

static(**kwargs) -> Any

Create a field that is always treated as static metadata.

Source code in src/liblaf/jarp/tree/attrs/_field_specifiers.py
@_wraps(attrs.field)
def static(**kwargs) -> Any:
    """Create a field that is always treated as static metadata."""
    # for consistency with `jax.tree_util.register_dataclass`
    kwargs.setdefault("static", True)
    return field(**kwargs)

struct

struct[T: type](cls: T) -> T

Decorate a class as a Warp struct.

Plain classes are forwarded to warp.struct. Classes that define __annotations_factory__(dtype) stay generic: MyStruct[wp.float64] builds and caches a specialized Warp struct from the factory annotations, while MyStruct() instantiates MyStruct[liblaf.jarp.warp.types.floating] so the default follows JAX's active precision mode.

Parameters:

  • cls (T) –

    Class to decorate.

Returns:

  • T

    The Warp struct for plain classes, or the original generic class with

  • T

    dtype subscription and default construction hooks installed.

Source code in src/liblaf/jarp/warp/_struct.py
def struct[T: type](cls: T) -> T:
    """Decorate a class as a Warp struct.

    Plain classes are forwarded to `warp.struct`. Classes that define
    `__annotations_factory__(dtype)` stay generic: `MyStruct[wp.float64]`
    builds and caches a specialized Warp struct from the factory annotations,
    while `MyStruct()` instantiates `MyStruct[liblaf.jarp.warp.types.floating]`
    so the default follows JAX's active precision mode.

    Args:
        cls: Class to decorate.

    Returns:
        The Warp struct for plain classes, or the original generic class with
        dtype subscription and default construction hooks installed.
    """
    if not hasattr(cls, "__annotations_factory__"):
        return cast("T", wp.struct(cls))

    @functools.cache
    def __class_getitem__(cls: T, key: Any) -> T:  # noqa: N807
        c: type = type(
            cls.__name__,
            (cls,),
            {
                "__module__": cls.__module__,
                "__qualname__": cls.__qualname__,
                "__annotations__": cls.__annotations_factory__(key),  # ty:ignore[unresolved-attribute]
            },
        )
        return cast("T", wp.struct(c, module="unique"))

    def __new__(owner: type) -> object:  # noqa: N807
        if owner is cls:
            return __class_getitem__(cls, wpt.floating)()
        return object.__new__(owner)

    cls.__class_getitem__ = classmethod(__class_getitem__)  # ty:ignore[invalid-assignment]
    cls.__new__ = staticmethod(__new__)  # ty:ignore[invalid-assignment]
    return cls

switch

switch[*Ts, T](
    index: ArrayLike,
    branches: Sequence[Callable[[*Ts], T]],
    *operands: *Ts,
) -> T

Choose one branch by index, then retry eagerly if JAX rejects it.

The wrapper first calls jax.lax.switch. If that raises jax.errors.JAXTypeError or jax.errors.JAXIndexError, it logs the exception, clamps index into the valid range, and dispatches in plain Python.

Parameters:

  • index (ArrayLike) –

    Branch index. The fallback path clamps the value into the valid range before dispatch.

  • branches (Sequence[Callable[[*Ts], T]]) –

    Candidate branch functions.

  • *operands (*Ts, default: () ) –

    Positional operands forwarded to the selected branch.

Returns:

  • T

    The value returned by the selected branch.

Source code in src/liblaf/jarp/lax/_control.py
@lax_wrapper(jax.lax.switch)
def switch[*Ts, T](
    index: ArrayLike, branches: Sequence[Callable[[*Ts], T]], *operands: *Ts
) -> T:
    """Choose one branch by index, then retry eagerly if JAX rejects it.

    The wrapper first calls [`jax.lax.switch`][jax.lax.switch]. If that raises
    [`jax.errors.JAXTypeError`][jax.errors.JAXTypeError] or
    [`jax.errors.JAXIndexError`][jax.errors.JAXIndexError], it logs the
    exception, clamps `index` into the valid range, and dispatches in plain
    Python.

    Args:
        index: Branch index. The fallback path clamps the value into the valid
            range before dispatch.
        branches: Candidate branch functions.
        *operands: Positional operands forwarded to the selected branch.

    Returns:
        The value returned by the selected branch.
    """
    index: Array = jax.lax.clamp(index, 0, len(branches) - 1)
    return branches[cast("int", index)](*operands)

to_warp

to_warp(
    arr: array | ndarray | Array,
    *_args: Any,
    **_kwargs: Any,
) -> array

Convert a supported array object into a [warp.array][].

The dispatcher supports existing Warp arrays, NumPy arrays, and JAX arrays. A dtype hint may be a concrete Warp dtype or a tuple that describes a vector or matrix dtype inferred from the trailing dimensions of arr. Use (-1, Any) for vector inference and (-1, -1, Any) for matrix inference when the element type should follow the source array.

Parameters:

  • arr (array | ndarray | Array) –

    Array object to convert.

  • *_args (Any, default: () ) –

    Reserved for singledispatch compatibility.

  • **_kwargs (Any, default: {} ) –

    Reserved for singledispatch compatibility.

Returns:

  • array

    A Warp array view or converted array, depending on the source type.

Raises:

  • TypeError

    If arr uses an unsupported type.

Source code in src/liblaf/jarp/warp/_to_warp.py
@functools.singledispatch
def to_warp(arr: Any, *_args: Any, **_kwargs: Any) -> wp.array:
    """Convert a supported array object into a [`warp.array`][warp.array].

    The dispatcher supports existing Warp arrays, NumPy arrays, and JAX arrays.
    A `dtype` hint may be a concrete Warp dtype or a tuple that describes a
    vector or matrix dtype inferred from the trailing dimensions of `arr`.
    Use `(-1, Any)` for vector inference and `(-1, -1, Any)` for matrix
    inference when the element type should follow the source array.

    Args:
        arr: Array object to convert.
        *_args: Reserved for singledispatch compatibility.
        **_kwargs: Reserved for singledispatch compatibility.

    Returns:
        A Warp array view or converted array, depending on the source type.

    Raises:
        TypeError: If `arr` uses an unsupported type.
    """
    raise TypeError(arr)

while_loop

while_loop[T](
    cond_fun: Callable[[T], BooleanNumeric],
    body_fun: Callable[[T], T],
    init_val: T,
) -> T

Run a loop, then retry in Python if JAX rejects the callbacks.

The wrapper first calls jax.lax.while_loop. If that raises jax.errors.JAXTypeError or jax.errors.JAXIndexError, it logs the exception and reruns the loop eagerly in Python.

Parameters:

  • cond_fun (Callable[[T], BooleanNumeric]) –

    Predicate evaluated on the loop state.

  • body_fun (Callable[[T], T]) –

    Function that produces the next loop state.

  • init_val (T) –

    Initial loop state.

Returns:

  • T

    The final loop state.

Source code in src/liblaf/jarp/lax/_control.py
@lax_wrapper(jax.lax.while_loop)
def while_loop[T](
    cond_fun: Callable[[T], BooleanNumeric], body_fun: Callable[[T], T], init_val: T
) -> T:
    """Run a loop, then retry in Python if JAX rejects the callbacks.

    The wrapper first calls [`jax.lax.while_loop`][jax.lax.while_loop]. If
    that raises [`jax.errors.JAXTypeError`][jax.errors.JAXTypeError] or
    [`jax.errors.JAXIndexError`][jax.errors.JAXIndexError], it logs the
    exception and reruns the loop eagerly in Python.

    Args:
        cond_fun: Predicate evaluated on the loop state.
        body_fun: Function that produces the next loop state.
        init_val: Initial loop state.

    Returns:
        The final loop state.
    """
    val: T = init_val
    while cond_fun(val):
        val: T = body_fun(val)
    return val