liblaf.jarp.warp
¶
Interop helpers between JAX arrays and NVIDIA Warp.
Use to_warp for array conversion,
jax_callable and
jax_kernel to expose Warp functions through
JAX tracing, struct for dtype-specialized Warp
struct declarations, and liblaf.jarp.warp.types
for dtypes that follow JAX's active precision mode.
Modules:
-
types–Convenience accessors for Warp scalar, vector, and matrix dtypes.
Classes:
-
FfiCallableProtocol–Callable interface returned by
jax_callable. -
FfiKernelProtocol–Callable interface returned by
jax_kernel. -
JaxCallableCallOptions– -
JaxCallableOptions– -
JaxKernelCallOptions– -
JaxKernelOptions–
Functions:
-
jax_callable–Wrap
warp.jax_experimental.jax_callablewith optional dtype dispatch. -
jax_kernel–Wrap
warp.jax_experimental.jax_kernelwith optional overload lookup. -
struct–Decorate a class as a Warp struct.
-
to_warp–Convert a supported array object into a [
warp.array][].
FfiCallableProtocol
¶
Bases: Protocol
flowchart TD
liblaf.jarp.warp.FfiCallableProtocol[FfiCallableProtocol]
click liblaf.jarp.warp.FfiCallableProtocol href "" "liblaf.jarp.warp.FfiCallableProtocol"
Callable interface returned by jax_callable.
Methods:
-
__call__–
FfiKernelProtocol
¶
Bases: Protocol
flowchart TD
liblaf.jarp.warp.FfiKernelProtocol[FfiKernelProtocol]
click liblaf.jarp.warp.FfiKernelProtocol href "" "liblaf.jarp.warp.FfiKernelProtocol"
Callable interface returned by jax_kernel.
Methods:
-
__call__–
__call__
¶
__call__(
*args: Array,
output_dims: ShapeLike
| dict[str, ShapeLike]
| None = ...,
launch_dims: ShapeLike | None = ...,
vmap_method: VmapMethod | None = ...,
) -> Sequence[Array]
JaxCallableCallOptions
typed-dict
¶
JaxCallableCallOptions(
*,
output_dims: ShapeLike
| dict[str, ShapeLike]
| None = ...,
vmap_method: VmapMethod | None = ...,
)
Bases: TypedDict
flowchart TD
liblaf.jarp.warp.JaxCallableCallOptions[JaxCallableCallOptions]
click liblaf.jarp.warp.JaxCallableCallOptions href "" "liblaf.jarp.warp.JaxCallableCallOptions"
Parameters:
Parameters:
-
output_dims(ShapeLike | dict[str, ShapeLike] | None, default:...) – -
vmap_method(VmapMethod | None, default:...) –
JaxCallableOptions
typed-dict
¶
JaxCallableOptions(
*,
num_outputs: int = ...,
graph_mode: GraphMode = ...,
vmap_method: VmapMethod | None = ...,
output_dims: dict[str, ShapeLike] | None = ...,
in_out_argnames: Iterable[str] = ...,
stage_in_argnames: Iterable[str] = ...,
stage_out_argnames: Iterable[str] = ...,
graph_cache_max: int | None = ...,
module_preload_mode: ModulePreloadMode = ...,
has_side_effect: bool = ...,
)
Bases: TypedDict
flowchart TD
liblaf.jarp.warp.JaxCallableOptions[JaxCallableOptions]
click liblaf.jarp.warp.JaxCallableOptions href "" "liblaf.jarp.warp.JaxCallableOptions"
Parameters:
-
num_outputs(int) – -
graph_mode(GraphMode) – -
vmap_method(VmapMethod | None) – -
output_dims(dict[str, ShapeLike] | None) – -
in_out_argnames(Iterable[str]) – -
stage_in_argnames(Iterable[str]) – -
stage_out_argnames(Iterable[str]) – -
graph_cache_max(int | None) – -
module_preload_mode(ModulePreloadMode) – -
has_side_effect(bool) –
Parameters:
-
num_outputs(int, default:...) – -
graph_mode(GraphMode, default:...) – -
vmap_method(VmapMethod | None, default:...) – -
output_dims(dict[str, ShapeLike] | None, default:...) – -
in_out_argnames(Iterable[str], default:...) – -
stage_in_argnames(Iterable[str], default:...) – -
stage_out_argnames(Iterable[str], default:...) – -
graph_cache_max(int | None, default:...) – -
module_preload_mode(ModulePreloadMode, default:...) – -
has_side_effect(bool, default:...) –
JaxKernelCallOptions
typed-dict
¶
JaxKernelCallOptions(
*,
output_dims: ShapeLike
| dict[str, ShapeLike]
| None = ...,
launch_dims: ShapeLike | None = ...,
vmap_method: VmapMethod | None = ...,
)
Bases: TypedDict
flowchart TD
liblaf.jarp.warp.JaxKernelCallOptions[JaxKernelCallOptions]
click liblaf.jarp.warp.JaxKernelCallOptions href "" "liblaf.jarp.warp.JaxKernelCallOptions"
Parameters:
-
output_dims(ShapeLike | dict[str, ShapeLike] | None) – -
launch_dims(ShapeLike | None) – -
vmap_method(VmapMethod | None) –
Parameters:
-
output_dims(ShapeLike | dict[str, ShapeLike] | None, default:...) – -
launch_dims(ShapeLike | None, default:...) – -
vmap_method(VmapMethod | None, default:...) –
JaxKernelOptions
typed-dict
¶
JaxKernelOptions(
*,
num_outputs: int = ...,
vmap_method: VmapMethod = ...,
launch_dims: ShapeLike | None = ...,
output_dims: ShapeLike
| dict[str, ShapeLike]
| None = ...,
in_out_argnames: Iterable[str] = ...,
module_preload_mode: ModulePreloadMode = ...,
enable_backward: bool = ...,
)
Bases: TypedDict
flowchart TD
liblaf.jarp.warp.JaxKernelOptions[JaxKernelOptions]
click liblaf.jarp.warp.JaxKernelOptions href "" "liblaf.jarp.warp.JaxKernelOptions"
Parameters:
-
num_outputs(int) – -
vmap_method(VmapMethod) – -
launch_dims(ShapeLike | None) – -
output_dims(ShapeLike | dict[str, ShapeLike] | None) – -
in_out_argnames(Iterable[str]) – -
module_preload_mode(ModulePreloadMode) – -
enable_backward(bool) –
Parameters:
-
num_outputs(int, default:...) – -
vmap_method(VmapMethod, default:...) – -
launch_dims(ShapeLike | None, default:...) – -
output_dims(ShapeLike | dict[str, ShapeLike] | None, default:...) – -
in_out_argnames(Iterable[str], default:...) – -
module_preload_mode(ModulePreloadMode, default:...) – -
enable_backward(bool, default:...) –
jax_callable
¶
jax_callable(
func: _FfiCallableFunction,
*,
generic: Literal[False] = False,
**kwargs: Unpack[JaxCallableOptions],
) -> FfiCallableProtocol
jax_callable(
*,
generic: Literal[False] = False,
**kwargs: Unpack[JaxCallableOptions],
) -> Callable[[_FfiCallableFunction], FfiCallableProtocol]
jax_callable(
func: _FfiCallableFactory,
*,
generic: Literal[True],
**kwargs: Unpack[JaxCallableOptions],
) -> _FfiCallable
jax_callable(
*,
generic: Literal[True],
**kwargs: Unpack[JaxCallableOptions],
) -> Callable[[_FfiCallableFactory], _FfiCallable]
Wrap warp.jax_experimental.jax_callable with optional dtype dispatch.
When generic=True, func is treated as a factory keyed by the Warp
scalar dtypes inferred from the runtime JAX arguments. The factory output is
cached, so repeated calls with the same dtype signature reuse the same Warp
callable.
Parameters:
-
func(Callable | None, default:None) –Warp callable function or factory. When omitted, return a decorator.
-
generic(bool, default:False) –When true,
funcis treated as a factory that receives Warp scalar dtypes inferred from the runtime JAX arguments and returns a concrete Warp callable implementation. -
num_outputs(int, default:...) – -
graph_mode(GraphMode, default:...) – -
vmap_method(VmapMethod | None, default:...) – -
output_dims(dict[str, ShapeLike] | None, default:...) – -
in_out_argnames(Iterable[str], default:...) – -
stage_in_argnames(Iterable[str], default:...) – -
stage_out_argnames(Iterable[str], default:...) – -
graph_cache_max(int | None, default:...) – -
module_preload_mode(ModulePreloadMode, default:...) – -
has_side_effect(bool, default:...) –
Returns:
-
Any–A callable compatible with JAX tracing, or a decorator producing one.
-
Any–The callable returns the output arrays produced by Warp's FFI wrapper.
Source code in src/liblaf/jarp/warp/_jax_callable.py
jax_kernel
¶
jax_kernel(
*,
arg_types_factory: Callable[[WarpScalarDType], ArgTypes]
| None = None,
**kwargs: Unpack[JaxKernelOptions],
) -> Callable[[Callable], FfiKernelProtocol]
jax_kernel(
kernel: Callable,
*,
arg_types_factory: Callable[[WarpScalarDType], ArgTypes]
| None = None,
**kwargs: Unpack[JaxKernelOptions],
) -> FfiKernelProtocol
Wrap warp.jax_experimental.jax_kernel with optional overload lookup.
When arg_types_factory is provided, the wrapper infers Warp scalar
dtypes from the runtime JAX arguments, builds an overload signature, and
resolves the corresponding Warp kernel before dispatch.
Parameters:
-
kernel(Callable | None, default:None) –Warp kernel to expose to JAX. When omitted, return a decorator.
-
arg_types_factory(Callable[[WarpScalarDType], ArgTypes] | None, default:None) –Optional callback that maps runtime Warp scalar dtypes to the overloaded kernel argument types expected by [warp.overload][].
-
num_outputs(int, default:...) – -
vmap_method(VmapMethod, default:...) – -
launch_dims(ShapeLike | None, default:...) – -
output_dims(ShapeLike | dict[str, ShapeLike] | None, default:...) – -
in_out_argnames(Iterable[str], default:...) – -
module_preload_mode(ModulePreloadMode, default:...) – -
enable_backward(bool, default:...) –
Returns:
-
Any–A callable compatible with JAX tracing, or a decorator producing one.
-
Any–The callable returns the output arrays produced by Warp's FFI wrapper.
Source code in src/liblaf/jarp/warp/_jax_kernel.py
struct
¶
struct[T: type](cls: T) -> T
Decorate a class as a Warp struct.
Plain classes are forwarded to warp.struct. Classes that define
__annotations_factory__(dtype) stay generic: MyStruct[wp.float64]
builds and caches a specialized Warp struct from the factory annotations,
while MyStruct() instantiates MyStruct[liblaf.jarp.warp.types.floating]
so the default follows JAX's active precision mode.
Parameters:
-
cls(T) –Class to decorate.
Returns:
-
T–The Warp struct for plain classes, or the original generic class with
-
T–dtype subscription and default construction hooks installed.
Source code in src/liblaf/jarp/warp/_struct.py
to_warp
¶
Convert a supported array object into a [warp.array][].
The dispatcher supports existing Warp arrays, NumPy arrays, and JAX arrays.
A dtype hint may be a concrete Warp dtype or a tuple that describes a
vector or matrix dtype inferred from the trailing dimensions of arr.
Use (-1, Any) for vector inference and (-1, -1, Any) for matrix
inference when the element type should follow the source array.
Parameters:
-
arr(array | ndarray | Array) –Array object to convert.
-
*_args(Any, default:()) –Reserved for singledispatch compatibility.
-
**_kwargs(Any, default:{}) –Reserved for singledispatch compatibility.
Returns:
-
array–A Warp array view or converted array, depending on the source type.
Raises:
-
TypeError–If
arruses an unsupported type.