CuPy backend#
CuPy is a near drop-in replacement for NumPy that runs on NVIDIA GPUs via CUDA. Use it when you want GPU acceleration for array-API operations and you don’t need JAX autodiff/JIT.
Installation#
pip install saiunit[cupy]
Requires a working CUDA toolkit; the cupy-cuda12x wheel is pulled in by
the extra. If you have CUDA 11, install cupy-cuda11x manually instead.
Graceful import#
If CuPy isn’t installed (most CI runners and laptops without an NVIDIA GPU), the snippets below skip cleanly rather than crashing.
import saiunit as u
try:
import cupy
HAVE_CUPY = True
except ImportError:
HAVE_CUPY = False
print('cupy not installed; install with: pip install saiunit[cupy]')
print('is_cupy_array on a non-cupy object:', u.is_cupy_array([1, 2, 3]))
cupy not installed; install with: pip install saiunit[cupy]
is_cupy_array on a non-cupy object: False
Quick start#
if HAVE_CUPY:
q = u.Quantity(cupy.array([1.0, 2.0, 3.0]), unit=u.meter)
print(q)
print('backend =', q.backend)
print('(q + q).backend =', (q + q).backend)
Math operations#
saiunit.math dispatches to array_api_compat.cupy, executing on the GPU.
if HAVE_CUPY:
x = u.Quantity(cupy.linspace(0.0, cupy.pi, 5), unit=u.UNITLESS)
print(u.math.sin(x))
Converting between backends#
Quantity.to_cupy(device=...) moves the mantissa to the chosen GPU.
if HAVE_CUPY:
import numpy as np
q_cpu = u.Quantity(np.array([1.0, 2.0]), unit=u.meter)
q_gpu = q_cpu.to_cupy(device=0)
print('mantissa lives on device', q_gpu.mantissa.device)
# round-trip back to NumPy
print(q_gpu.to_numpy())
Requesting the backend explicitly#
If you ask for the CuPy backend when CuPy isn’t installed, saiunit raises
BackendError (not a bare ImportError) with the install hint.
from saiunit import BackendError
try:
with u.using_backend('cupy'):
u.Quantity([1.0, 2.0], unit=u.meter)
except BackendError as exc:
print('expected without cupy:', exc)
expected without cupy: cupy backend requested but cupy is not installed. Install with: pip install saiunit[cupy]
Limitations#
CuPy has no autograd.
saiunit.autogradis JAX-only.saiunit.laxandsaiunit.sparseare JAX-only.Move data to NumPy or JAX with
.to_numpy()/.to_jax()for those.