Skip to content

Call Wrappers

liblaf.jarp exposes two callable wrappers and a small lax compatibility layer for mixed JAX-and-Python code.

Partition Mixed Call Arguments

filter_jit splits each call into dynamic array leaves and static metadata, rebuilds the original call shape, and partitions the return value again on the way out.

from typing import Any

import jax.numpy as jnp
from jax import Array
from liblaf import jarp


@jarp.filter_jit
def pack(x: Array, label: str = "tag") -> dict[str, Any]:
    return {"x": x + 1, "label": label}


result = pack(jnp.array([1, 2]), label="train")

The wrapper also preserves method binding, so @filter_jit() works on instance methods as well as free functions.

Cache Python Fallbacks By Metadata Shape

fallback_jit starts with the same partitioned call path as filter_jit. If that path raises jax.errors.JAXTypeError or jax.errors.JAXIndexError, jarp logs the exception, marks the current static-metadata signature as unsupported, and reuses the direct Python call path for later calls with the same metadata.

Use it when the same callable sometimes works cleanly with JAX-style inputs but needs a stable eager fallback for particular metadata layouts.

Retry jax.lax Helpers Eagerly

jarp.lax wraps jax.lax.cond, jax.lax.switch, jax.lax.fori_loop, and jax.lax.while_loop. Each wrapper tries the JAX primitive first and reruns eagerly if JAX raises one of the errors handled by LaxWrapper.

from liblaf import jarp


state = jarp.lax.while_loop(
    lambda value: value[0] < 3,
    lambda value: (value[0] + 1, value[1] + [10, 20, 30][value[0]]),
    (0, 0),
)

On the eager fallback path, jarp.lax.switch clamps the branch index into range before dispatch.

Collapse Ordered Conditions To An Index

first_true_index turns an ordered list of scalar or array conditions into a JAX integer array. It returns the first matching condition index at each position, and uses len(condlist) where no condition matches.

import jax.numpy as jnp
from liblaf import jarp


labels = jarp.first_true_index(
    [
        jnp.array([False, True, False, False]),
        jnp.array([True, True, False, False]),
        jnp.array([True, False, True, False]),
    ]
)

Preserve Primitive Metadata

The public jarp.lax helpers are LaxWrapper instances. They keep the wrapped jax.lax primitive available through __wrapped__, preserve the primitive signature for inspection, and cache metadata signatures that should skip directly to the Python fallback after a supported failure.

LaxWrapper copies ordinary function metadata when it is available, but it does not require it. Callable objects with only __call__ still work as the wrapped JAX-side callable.