liblaf.jarp.lax
¶
Control-flow wrappers and ordered condition helpers.
Use liblaf.jarp.lax when you want to try the corresponding jax.lax
primitive first, but still rerun the same callback structure eagerly if JAX
raises one of the tracing or indexing errors that commonly appear with
Python-only code. Use first_true_index
when an ordered condition list should become a JAX-friendly integer index.
Classes:
-
LaxWrapper–Call a JAX primitive first and cache Python fallback signatures.
Functions:
-
cond–Choose between two branches, then retry eagerly if JAX rejects them.
-
first_true_index–Return the index of the first true condition.
-
fori_loop–Run a counted loop, then retry in Python if JAX rejects the body.
-
lax_wrapper–Decorate an eager fallback with a
LaxWrapper. -
switch–Choose one branch by index, then retry eagerly if JAX rejects it.
-
while_loop–Run a loop, then retry in Python if JAX rejects the callbacks.
LaxWrapper
¶
Call a JAX primitive first and cache Python fallback signatures.
LaxWrapper powers the public helpers in
liblaf.jarp.lax. It preserves wrapper metadata from
the wrapped JAX primitive when that metadata exists, tries that primitive
on each new call shape, and records metadata signatures that should skip
directly to the Python fallback after a supported JAX error. Callable
objects without ordinary function metadata are accepted.
Examples:
>>> from liblaf.jarp.lax import LaxWrapper
>>> class Wrapped:
... def __call__(self, value):
... return value + 1
>>> wrapper = LaxWrapper(Wrapped(), lambda value: value - 1)
>>> wrapper(2)
3
Attributes:
-
__wrapped__(Callable[P, T]) –JAX callable attempted before the fallback.
Parameters:
-
fallback(Callable[ParamSpec, T]) – -
success_cache(dict[AuxData, bool], default:<class 'dict'>) –dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
Methods:
__wrapped__
class-attribute
instance-attribute
¶
success_cache
class-attribute
instance-attribute
¶
__attrs_post_init__
¶
Source code in src/liblaf/jarp/lax/_wrapper.py
__call__
¶
Source code in src/liblaf/jarp/lax/_wrapper.py
cond
¶
cond[*Ts, T](
pred: ScalarLike,
true_fun: Callable[[*Ts], T],
false_fun: Callable[[*Ts], T],
*operands: *Ts,
) -> T
Choose between two branches, then retry eagerly if JAX rejects them.
The wrapper first calls jax.lax.cond. If that raises
jax.errors.JAXTypeError or
jax.errors.JAXIndexError, it logs the
exception and reruns the selected branch in plain Python.
Parameters:
-
pred(ScalarLike) –Scalar predicate. Python truthiness decides which branch runs on the fallback path.
-
true_fun(Callable[[*Ts], T]) –Branch evaluated when
predis true. -
false_fun(Callable[[*Ts], T]) –Branch evaluated when
predis false. -
*operands(*Ts, default:()) –Positional operands forwarded to the selected branch.
Returns:
-
T–The value returned by the selected branch.
Source code in src/liblaf/jarp/lax/_control.py
first_true_index
¶
first_true_index(
condlist: Sequence[ArrayLike],
) -> Integer[Array, "*shape"]
Return the index of the first true condition.
This is a small jax.numpy.select wrapper for cases
where an ordered condition list should become integer labels. Each result
value is the zero-based index of the first true condition at that position.
When no condition is true, the result is len(condlist).
Parameters:
-
condlist(Sequence[ArrayLike]) –Non-empty ordered sequence of scalar or array-like boolean conditions. Array conditions follow
jax.numpy.selectbroadcasting rules.
Returns:
-
Integer[Array, '*shape']–A JAX integer array with the broadcast condition shape. Scalar
-
Integer[Array, '*shape']–conditions return a zero-dimensional array.
Raises:
-
ValueError–If
condlistis empty.
Examples:
>>> from liblaf.jarp.lax import first_true_index
>>> int(first_true_index([False, True, True]))
1
>>> int(first_true_index([False, False]))
2
Array conditions are evaluated elementwise.
>>> import jax.numpy as jnp
>>> first_true_index(
... [
... jnp.array([False, True, False, False]),
... jnp.array([True, True, False, False]),
... jnp.array([True, False, True, False]),
... ]
... ).tolist()
[1, 0, 2, 3]
Source code in src/liblaf/jarp/lax/_extras.py
fori_loop
¶
fori_loop[T](
lower: int,
upper: int,
body_fun: Callable[[int, T], T],
init_val: T,
**kwargs: Any,
) -> T
Run a counted loop, then retry in Python if JAX rejects the body.
The wrapper first calls jax.lax.fori_loop. If that
raises jax.errors.JAXTypeError or
jax.errors.JAXIndexError, it logs the
exception and runs an ordinary Python for loop instead.
Parameters:
-
lower(int) –Inclusive loop lower bound.
-
upper(int) –Exclusive loop upper bound.
-
body_fun(Callable[[int, T], T]) –Callback that receives the iteration index and current loop value, then returns the next loop value.
-
init_val(T) –Initial loop value.
-
**kwargs(Any, default:{}) –Extra keyword arguments forwarded to
jax.lax.fori_loopon the first attempt. They are ignored on the Python fallback path.
Returns:
-
T–The final loop value.
Source code in src/liblaf/jarp/lax/_control.py
lax_wrapper
¶
lax_wrapper[**P, T](
wrapped: Callable[..., Any],
) -> Callable[[Callable[P, T]], LaxWrapper[P, T]]
Decorate an eager fallback with a LaxWrapper.
Parameters:
Returns:
-
Callable[[Callable[P, T]], LaxWrapper[P, T]]–A decorator that turns the fallback function into a
LaxWrapper.
Source code in src/liblaf/jarp/lax/_wrapper.py
switch
¶
Choose one branch by index, then retry eagerly if JAX rejects it.
The wrapper first calls jax.lax.switch. If that raises
jax.errors.JAXTypeError or
jax.errors.JAXIndexError, it logs the
exception, clamps index into the valid range, and dispatches in plain
Python.
Parameters:
-
index(ArrayLike) –Branch index. The fallback path clamps the value into the valid range before dispatch.
-
branches(Sequence[Callable[[*Ts], T]]) –Candidate branch functions.
-
*operands(*Ts, default:()) –Positional operands forwarded to the selected branch.
Returns:
-
T–The value returned by the selected branch.
Source code in src/liblaf/jarp/lax/_control.py
while_loop
¶
while_loop[T](
cond_fun: Callable[[T], BooleanNumeric],
body_fun: Callable[[T], T],
init_val: T,
) -> T
Run a loop, then retry in Python if JAX rejects the callbacks.
The wrapper first calls jax.lax.while_loop. If
that raises jax.errors.JAXTypeError or
jax.errors.JAXIndexError, it logs the
exception and reruns the loop eagerly in Python.
Parameters:
-
cond_fun(Callable[[T], BooleanNumeric]) –Predicate evaluated on the loop state.
-
body_fun(Callable[[T], T]) –Function that produces the next loop state.
-
init_val(T) –Initial loop state.
Returns:
-
T–The final loop state.