NumPy backend#
The NumPy backend runs eagerly on CPU and is always available — numpy is
a core dependency. Pick it for interop with the broader scientific Python
stack (scipy, pandas, sklearn, matplotlib) when JAX tracing would get in
the way.
Installation#
Nothing extra is needed; NumPy ships with every saiunit install.
pip install saiunit
Quick start#
import numpy as np
import saiunit as u
q = u.Quantity(np.array([1.0, 2.0, 3.0]), unit=u.meter)
print(q)
print('backend =', q.backend)
print('(q + q).backend =', (q + q).backend)
[1. 2. 3.] m
backend = numpy
(q + q).backend = numpy
Math, linalg, FFT#
All three subpackages dispatch to NumPy when the mantissa is a NumPy array.
x = u.Quantity(np.linspace(0.0, np.pi, 5), unit=u.UNITLESS)
u.math.sin(x)
array([0.00000000e+00, 7.07106781e-01, 1.00000000e+00, 7.07106781e-01,
1.22464680e-16])
A = u.Quantity(np.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.meter)
u.linalg.norm(A)
Quantity(5.477226, "m")
Setting NumPy as the default#
Use using_backend('numpy') to keep a block of code on NumPy, including
quantities built from Python lists.
with u.using_backend('numpy'):
q = u.Quantity([1.0, 2.0], unit=u.meter)
print(type(q.mantissa).__module__, q.backend)
numpy numpy
NumPy ufunc interop#
Standard NumPy ufuncs preserve units and enforce dimensional consistency.
a = u.Quantity(np.array([1.0]), unit=u.meter)
b = u.Quantity(np.array([2.0]), unit=u.meter)
np.add(a, b)
Quantity([3.], "m")
from saiunit import UnitMismatchError
c = u.Quantity(np.array([1.0]), unit=u.second)
try:
np.add(a, c)
except UnitMismatchError as exc:
print('expected:', exc)
expected: Cannot convert to a unit with different dimensions. (units are s and m).
JAX-only subpackages#
saiunit.lax, saiunit.autograd, and saiunit.sparse need JAX primitives.
Calling them on a NumPy-backed quantity raises BackendError. Convert with
q.to_jax() first.
from saiunit import BackendError
q = u.Quantity(np.array([1.0, 2.0, 3.0]), unit=u.meter)
try:
u.lax.slice(q, (0,), (1,))
except BackendError as exc:
print('expected:', exc)
u.lax.slice(q.to_jax(), (0,), (1,))
expected: saiunit.lax.slice requires the jax backend; got numpy-backed Quantity. Call .to_jax() on the input first.
Quantity([1.], "m")
See also#
Backends overview — supported backends and selection.
JAX backend — the default backend.