CuPy backend#

CuPy is a near drop-in replacement for NumPy that runs on NVIDIA GPUs via CUDA. Use it when you want GPU acceleration for array-API operations and you don’t need JAX autodiff/JIT.

Installation#

pip install saiunit[cupy]

Requires a working CUDA toolkit; the cupy-cuda12x wheel is pulled in by the extra. If you have CUDA 11, install cupy-cuda11x manually instead.

Graceful import#

If CuPy isn’t installed (most CI runners and laptops without an NVIDIA GPU), the snippets below skip cleanly rather than crashing.

import saiunit as u

try:
    import cupy
    HAVE_CUPY = True
except ImportError:
    HAVE_CUPY = False
    print('cupy not installed; install with: pip install saiunit[cupy]')

print('is_cupy_array on a non-cupy object:', u.is_cupy_array([1, 2, 3]))
cupy not installed; install with: pip install saiunit[cupy]
is_cupy_array on a non-cupy object: False

Quick start#

if HAVE_CUPY:
    q = u.Quantity(cupy.array([1.0, 2.0, 3.0]), unit=u.meter)
    print(q)
    print('backend =', q.backend)
    print('(q + q).backend =', (q + q).backend)

Math operations#

saiunit.math dispatches to array_api_compat.cupy, executing on the GPU.

if HAVE_CUPY:
    x = u.Quantity(cupy.linspace(0.0, cupy.pi, 5), unit=u.UNITLESS)
    print(u.math.sin(x))

Converting between backends#

Quantity.to_cupy(device=...) moves the mantissa to the chosen GPU.

if HAVE_CUPY:
    import numpy as np
    q_cpu = u.Quantity(np.array([1.0, 2.0]), unit=u.meter)
    q_gpu = q_cpu.to_cupy(device=0)
    print('mantissa lives on device', q_gpu.mantissa.device)
    # round-trip back to NumPy
    print(q_gpu.to_numpy())

Requesting the backend explicitly#

If you ask for the CuPy backend when CuPy isn’t installed, saiunit raises BackendError (not a bare ImportError) with the install hint.

from saiunit import BackendError

try:
    with u.using_backend('cupy'):
        u.Quantity([1.0, 2.0], unit=u.meter)
except BackendError as exc:
    print('expected without cupy:', exc)
expected without cupy: cupy backend requested but cupy is not installed. Install with: pip install saiunit[cupy]

Limitations#

  • CuPy has no autograd. saiunit.autograd is JAX-only.

  • saiunit.lax and saiunit.sparse are JAX-only.

  • Move data to NumPy or JAX with .to_numpy() / .to_jax() for those.