Skip to content

liblaf.jarp.warp

Interop helpers between JAX arrays and NVIDIA Warp.

Use to_warp for array conversion, jax_callable and jax_kernel to expose Warp functions through JAX tracing, struct for dtype-specialized Warp struct declarations, and liblaf.jarp.warp.types for dtypes that follow JAX's active precision mode.

Modules:

  • types

    Convenience accessors for Warp scalar, vector, and matrix dtypes.

Classes:

Functions:

  • jax_callable

    Wrap warp.jax_experimental.jax_callable with optional dtype dispatch.

  • jax_kernel

    Wrap warp.jax_experimental.jax_kernel with optional overload lookup.

  • struct

    Decorate a class as a Warp struct.

  • to_warp

    Convert a supported array object into a [warp.array][].

FfiCallableProtocol

Bases: Protocol


              flowchart TD
              liblaf.jarp.warp.FfiCallableProtocol[FfiCallableProtocol]

              

              click liblaf.jarp.warp.FfiCallableProtocol href "" "liblaf.jarp.warp.FfiCallableProtocol"
            

Callable interface returned by jax_callable.

Methods:

__call__

__call__(
    *args: Array,
    output_dims: ShapeLike
    | dict[str, ShapeLike]
    | None = ...,
    vmap_method: VmapMethod | None = ...,
) -> Sequence[Array]

Parameters:

  • output_dims (ShapeLike | dict[str, ShapeLike] | None, default: ... ) –
  • vmap_method (VmapMethod | None, default: ... ) –
Source code in src/liblaf/jarp/warp/_jax_callable.py
def __call__(
    self, *args: Array, **kwargs: Unpack[JaxCallableCallOptions]
) -> Sequence[Array]: ...

FfiKernelProtocol

Bases: Protocol


              flowchart TD
              liblaf.jarp.warp.FfiKernelProtocol[FfiKernelProtocol]

              

              click liblaf.jarp.warp.FfiKernelProtocol href "" "liblaf.jarp.warp.FfiKernelProtocol"
            

Callable interface returned by jax_kernel.

Methods:

__call__

__call__(
    *args: Array,
    output_dims: ShapeLike
    | dict[str, ShapeLike]
    | None = ...,
    launch_dims: ShapeLike | None = ...,
    vmap_method: VmapMethod | None = ...,
) -> Sequence[Array]

Parameters:

  • output_dims (ShapeLike | dict[str, ShapeLike] | None, default: ... ) –
  • launch_dims (ShapeLike | None, default: ... ) –
  • vmap_method (VmapMethod | None, default: ... ) –
Source code in src/liblaf/jarp/warp/_jax_kernel.py
def __call__(
    self, *args: Array, **kwargs: Unpack[JaxKernelCallOptions]
) -> Sequence[Array]: ...

JaxCallableCallOptions typed-dict

JaxCallableCallOptions(
    *,
    output_dims: ShapeLike
    | dict[str, ShapeLike]
    | None = ...,
    vmap_method: VmapMethod | None = ...,
)

Bases: TypedDict


              flowchart TD
              liblaf.jarp.warp.JaxCallableCallOptions[JaxCallableCallOptions]

              

              click liblaf.jarp.warp.JaxCallableCallOptions href "" "liblaf.jarp.warp.JaxCallableCallOptions"
            

Parameters:

  • output_dims (ShapeLike | dict[str, ShapeLike] | None) –
  • vmap_method (VmapMethod | None) –

Parameters:

  • output_dims (ShapeLike | dict[str, ShapeLike] | None, default: ... ) –
  • vmap_method (VmapMethod | None, default: ... ) –

JaxCallableOptions typed-dict

JaxCallableOptions(
    *,
    num_outputs: int = ...,
    graph_mode: GraphMode = ...,
    vmap_method: VmapMethod | None = ...,
    output_dims: dict[str, ShapeLike] | None = ...,
    in_out_argnames: Iterable[str] = ...,
    stage_in_argnames: Iterable[str] = ...,
    stage_out_argnames: Iterable[str] = ...,
    graph_cache_max: int | None = ...,
    module_preload_mode: ModulePreloadMode = ...,
    has_side_effect: bool = ...,
)

Bases: TypedDict


              flowchart TD
              liblaf.jarp.warp.JaxCallableOptions[JaxCallableOptions]

              

              click liblaf.jarp.warp.JaxCallableOptions href "" "liblaf.jarp.warp.JaxCallableOptions"
            

Parameters:

  • num_outputs (int) –
  • graph_mode (GraphMode) –
  • vmap_method (VmapMethod | None) –
  • output_dims (dict[str, ShapeLike] | None) –
  • in_out_argnames (Iterable[str]) –
  • stage_in_argnames (Iterable[str]) –
  • stage_out_argnames (Iterable[str]) –
  • graph_cache_max (int | None) –
  • module_preload_mode (ModulePreloadMode) –
  • has_side_effect (bool) –

Parameters:

  • num_outputs (int, default: ... ) –
  • graph_mode (GraphMode, default: ... ) –
  • vmap_method (VmapMethod | None, default: ... ) –
  • output_dims (dict[str, ShapeLike] | None, default: ... ) –
  • in_out_argnames (Iterable[str], default: ... ) –
  • stage_in_argnames (Iterable[str], default: ... ) –
  • stage_out_argnames (Iterable[str], default: ... ) –
  • graph_cache_max (int | None, default: ... ) –
  • module_preload_mode (ModulePreloadMode, default: ... ) –
  • has_side_effect (bool, default: ... ) –

JaxKernelCallOptions typed-dict

JaxKernelCallOptions(
    *,
    output_dims: ShapeLike
    | dict[str, ShapeLike]
    | None = ...,
    launch_dims: ShapeLike | None = ...,
    vmap_method: VmapMethod | None = ...,
)

Bases: TypedDict


              flowchart TD
              liblaf.jarp.warp.JaxKernelCallOptions[JaxKernelCallOptions]

              

              click liblaf.jarp.warp.JaxKernelCallOptions href "" "liblaf.jarp.warp.JaxKernelCallOptions"
            

Parameters:

  • output_dims (ShapeLike | dict[str, ShapeLike] | None) –
  • launch_dims (ShapeLike | None) –
  • vmap_method (VmapMethod | None) –

Parameters:

  • output_dims (ShapeLike | dict[str, ShapeLike] | None, default: ... ) –
  • launch_dims (ShapeLike | None, default: ... ) –
  • vmap_method (VmapMethod | None, default: ... ) –

JaxKernelOptions typed-dict

JaxKernelOptions(
    *,
    num_outputs: int = ...,
    vmap_method: VmapMethod = ...,
    launch_dims: ShapeLike | None = ...,
    output_dims: ShapeLike
    | dict[str, ShapeLike]
    | None = ...,
    in_out_argnames: Iterable[str] = ...,
    module_preload_mode: ModulePreloadMode = ...,
    enable_backward: bool = ...,
)

Bases: TypedDict


              flowchart TD
              liblaf.jarp.warp.JaxKernelOptions[JaxKernelOptions]

              

              click liblaf.jarp.warp.JaxKernelOptions href "" "liblaf.jarp.warp.JaxKernelOptions"
            

Parameters:

  • num_outputs (int) –
  • vmap_method (VmapMethod) –
  • launch_dims (ShapeLike | None) –
  • output_dims (ShapeLike | dict[str, ShapeLike] | None) –
  • in_out_argnames (Iterable[str]) –
  • module_preload_mode (ModulePreloadMode) –
  • enable_backward (bool) –

Parameters:

  • num_outputs (int, default: ... ) –
  • vmap_method (VmapMethod, default: ... ) –
  • launch_dims (ShapeLike | None, default: ... ) –
  • output_dims (ShapeLike | dict[str, ShapeLike] | None, default: ... ) –
  • in_out_argnames (Iterable[str], default: ... ) –
  • module_preload_mode (ModulePreloadMode, default: ... ) –
  • enable_backward (bool, default: ... ) –

jax_callable

jax_callable(
    func: _FfiCallableFunction,
    *,
    generic: Literal[False] = False,
    **kwargs: Unpack[JaxCallableOptions],
) -> FfiCallableProtocol
jax_callable(
    *,
    generic: Literal[False] = False,
    **kwargs: Unpack[JaxCallableOptions],
) -> Callable[[_FfiCallableFunction], FfiCallableProtocol]
jax_callable(
    func: _FfiCallableFactory,
    *,
    generic: Literal[True],
    **kwargs: Unpack[JaxCallableOptions],
) -> _FfiCallable
jax_callable(
    *,
    generic: Literal[True],
    **kwargs: Unpack[JaxCallableOptions],
) -> Callable[[_FfiCallableFactory], _FfiCallable]

Wrap warp.jax_experimental.jax_callable with optional dtype dispatch.

When generic=True, func is treated as a factory keyed by the Warp scalar dtypes inferred from the runtime JAX arguments. The factory output is cached, so repeated calls with the same dtype signature reuse the same Warp callable.

Parameters:

  • func (Callable | None, default: None ) –

    Warp callable function or factory. When omitted, return a decorator.

  • generic (bool, default: False ) –

    When true, func is treated as a factory that receives Warp scalar dtypes inferred from the runtime JAX arguments and returns a concrete Warp callable implementation.

  • num_outputs (int, default: ... ) –
  • graph_mode (GraphMode, default: ... ) –
  • vmap_method (VmapMethod | None, default: ... ) –
  • output_dims (dict[str, ShapeLike] | None, default: ... ) –
  • in_out_argnames (Iterable[str], default: ... ) –
  • stage_in_argnames (Iterable[str], default: ... ) –
  • stage_out_argnames (Iterable[str], default: ... ) –
  • graph_cache_max (int | None, default: ... ) –
  • module_preload_mode (ModulePreloadMode, default: ... ) –
  • has_side_effect (bool, default: ... ) –

Returns:

  • Any

    A callable compatible with JAX tracing, or a decorator producing one.

  • Any

    The callable returns the output arrays produced by Warp's FFI wrapper.

Source code in src/liblaf/jarp/warp/_jax_callable.py
def jax_callable(
    func: Callable | None = None,
    *,
    generic: bool = False,
    **kwargs: Unpack[JaxCallableOptions],
) -> Any:
    """Wrap `warp.jax_experimental.jax_callable` with optional dtype dispatch.

    When `generic=True`, `func` is treated as a factory keyed by the Warp
    scalar dtypes inferred from the runtime JAX arguments. The factory output is
    cached, so repeated calls with the same dtype signature reuse the same Warp
    callable.

    Args:
        func: Warp callable function or factory. When omitted, return a
            decorator.
        generic: When true, `func` is treated as a factory that receives Warp
            scalar dtypes inferred from the runtime JAX arguments and returns a
            concrete Warp callable implementation.
        **kwargs: Options forwarded to Warp's JAX callable adapter.

    Returns:
        A callable compatible with JAX tracing, or a decorator producing one.
        The callable returns the output arrays produced by Warp's FFI wrapper.
    """
    if func is None:
        return functools.partial(jax_callable, generic=generic, **kwargs)
    if not generic:
        return warp.jax_experimental.jax_callable(func, **kwargs)
    factory: _FfiCallableFactory = functools.lru_cache(func)
    return _FfiCallable(factory=factory, options=kwargs)

jax_kernel

jax_kernel(
    *,
    arg_types_factory: Callable[[WarpScalarDType], ArgTypes]
    | None = None,
    **kwargs: Unpack[JaxKernelOptions],
) -> Callable[[Callable], FfiKernelProtocol]
jax_kernel(
    kernel: Callable,
    *,
    arg_types_factory: Callable[[WarpScalarDType], ArgTypes]
    | None = None,
    **kwargs: Unpack[JaxKernelOptions],
) -> FfiKernelProtocol

Wrap warp.jax_experimental.jax_kernel with optional overload lookup.

When arg_types_factory is provided, the wrapper infers Warp scalar dtypes from the runtime JAX arguments, builds an overload signature, and resolves the corresponding Warp kernel before dispatch.

Parameters:

  • kernel (Callable | None, default: None ) –

    Warp kernel to expose to JAX. When omitted, return a decorator.

  • arg_types_factory (Callable[[WarpScalarDType], ArgTypes] | None, default: None ) –

    Optional callback that maps runtime Warp scalar dtypes to the overloaded kernel argument types expected by [warp.overload][].

  • num_outputs (int, default: ... ) –
  • vmap_method (VmapMethod, default: ... ) –
  • launch_dims (ShapeLike | None, default: ... ) –
  • output_dims (ShapeLike | dict[str, ShapeLike] | None, default: ... ) –
  • in_out_argnames (Iterable[str], default: ... ) –
  • module_preload_mode (ModulePreloadMode, default: ... ) –
  • enable_backward (bool, default: ... ) –

Returns:

  • Any

    A callable compatible with JAX tracing, or a decorator producing one.

  • Any

    The callable returns the output arrays produced by Warp's FFI wrapper.

Source code in src/liblaf/jarp/warp/_jax_kernel.py
def jax_kernel(
    kernel: Callable | None = None,
    *,
    arg_types_factory: Callable[[WarpScalarDType], ArgTypes] | None = None,
    **kwargs: Unpack[JaxKernelOptions],
) -> Any:
    """Wrap `warp.jax_experimental.jax_kernel` with optional overload lookup.

    When `arg_types_factory` is provided, the wrapper infers Warp scalar
    dtypes from the runtime JAX arguments, builds an overload signature, and
    resolves the corresponding Warp kernel before dispatch.

    Args:
        kernel: Warp kernel to expose to JAX. When omitted, return a decorator.
        arg_types_factory: Optional callback that maps runtime Warp scalar dtypes
            to the overloaded kernel argument types expected by
            [warp.overload][].
        **kwargs: Options forwarded to Warp's JAX kernel adapter.

    Returns:
        A callable compatible with JAX tracing, or a decorator producing one.
        The callable returns the output arrays produced by Warp's FFI wrapper.
    """
    if kernel is None:
        return functools.partial(
            jax_kernel, arg_types_factory=arg_types_factory, **kwargs
        )
    if arg_types_factory is None:
        return warp.jax_experimental.jax_kernel(kernel, **kwargs)
    return _FfiKernel(
        kernel=cast("wp.Kernel", kernel),
        options=kwargs,
        arg_types_factory=arg_types_factory,
    )

struct

struct[T: type](cls: T) -> T

Decorate a class as a Warp struct.

Plain classes are forwarded to warp.struct. Classes that define __annotations_factory__(dtype) stay generic: MyStruct[wp.float64] builds and caches a specialized Warp struct from the factory annotations, while MyStruct() instantiates MyStruct[liblaf.jarp.warp.types.floating] so the default follows JAX's active precision mode.

Parameters:

  • cls (T) –

    Class to decorate.

Returns:

  • T

    The Warp struct for plain classes, or the original generic class with

  • T

    dtype subscription and default construction hooks installed.

Source code in src/liblaf/jarp/warp/_struct.py
def struct[T: type](cls: T) -> T:
    """Decorate a class as a Warp struct.

    Plain classes are forwarded to `warp.struct`. Classes that define
    `__annotations_factory__(dtype)` stay generic: `MyStruct[wp.float64]`
    builds and caches a specialized Warp struct from the factory annotations,
    while `MyStruct()` instantiates `MyStruct[liblaf.jarp.warp.types.floating]`
    so the default follows JAX's active precision mode.

    Args:
        cls: Class to decorate.

    Returns:
        The Warp struct for plain classes, or the original generic class with
        dtype subscription and default construction hooks installed.
    """
    if not hasattr(cls, "__annotations_factory__"):
        return cast("T", wp.struct(cls))

    @functools.cache
    def __class_getitem__(cls: T, key: Any) -> T:  # noqa: N807
        c: type = type(
            cls.__name__,
            (cls,),
            {
                "__module__": cls.__module__,
                "__qualname__": cls.__qualname__,
                "__annotations__": cls.__annotations_factory__(key),  # ty:ignore[unresolved-attribute]
            },
        )
        return cast("T", wp.struct(c, module="unique"))

    def __new__(owner: type) -> object:  # noqa: N807
        if owner is cls:
            return __class_getitem__(cls, wpt.floating)()
        return object.__new__(owner)

    cls.__class_getitem__ = classmethod(__class_getitem__)  # ty:ignore[invalid-assignment]
    cls.__new__ = staticmethod(__new__)  # ty:ignore[invalid-assignment]
    return cls

to_warp

to_warp(
    arr: array | ndarray | Array,
    *_args: Any,
    **_kwargs: Any,
) -> array

Convert a supported array object into a [warp.array][].

The dispatcher supports existing Warp arrays, NumPy arrays, and JAX arrays. A dtype hint may be a concrete Warp dtype or a tuple that describes a vector or matrix dtype inferred from the trailing dimensions of arr. Use (-1, Any) for vector inference and (-1, -1, Any) for matrix inference when the element type should follow the source array.

Parameters:

  • arr (array | ndarray | Array) –

    Array object to convert.

  • *_args (Any, default: () ) –

    Reserved for singledispatch compatibility.

  • **_kwargs (Any, default: {} ) –

    Reserved for singledispatch compatibility.

Returns:

  • array

    A Warp array view or converted array, depending on the source type.

Raises:

  • TypeError

    If arr uses an unsupported type.

Source code in src/liblaf/jarp/warp/_to_warp.py
@functools.singledispatch
def to_warp(arr: Any, *_args: Any, **_kwargs: Any) -> wp.array:
    """Convert a supported array object into a [`warp.array`][warp.array].

    The dispatcher supports existing Warp arrays, NumPy arrays, and JAX arrays.
    A `dtype` hint may be a concrete Warp dtype or a tuple that describes a
    vector or matrix dtype inferred from the trailing dimensions of `arr`.
    Use `(-1, Any)` for vector inference and `(-1, -1, Any)` for matrix
    inference when the element type should follow the source array.

    Args:
        arr: Array object to convert.
        *_args: Reserved for singledispatch compatibility.
        **_kwargs: Reserved for singledispatch compatibility.

    Returns:
        A Warp array view or converted array, depending on the source type.

    Raises:
        TypeError: If `arr` uses an unsupported type.
    """
    raise TypeError(arr)