NumPy backend#

The NumPy backend runs eagerly on CPU and is always available — numpy is a core dependency. Pick it for interop with the broader scientific Python stack (scipy, pandas, sklearn, matplotlib) when JAX tracing would get in the way.

Installation#

Nothing extra is needed; NumPy ships with every saiunit install.

pip install saiunit

Quick start#

import numpy as np
import saiunit as u

q = u.Quantity(np.array([1.0, 2.0, 3.0]), unit=u.meter)
print(q)
print('backend =', q.backend)
print('(q + q).backend =', (q + q).backend)
[1. 2. 3.] m
backend = numpy
(q + q).backend = numpy

Math, linalg, FFT#

All three subpackages dispatch to NumPy when the mantissa is a NumPy array.

x = u.Quantity(np.linspace(0.0, np.pi, 5), unit=u.UNITLESS)
u.math.sin(x)
array([0.00000000e+00, 7.07106781e-01, 1.00000000e+00, 7.07106781e-01,
       1.22464680e-16])
A = u.Quantity(np.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.meter)
u.linalg.norm(A)
Quantity(5.477226, "m")

Setting NumPy as the default#

Use using_backend('numpy') to keep a block of code on NumPy, including quantities built from Python lists.

with u.using_backend('numpy'):
    q = u.Quantity([1.0, 2.0], unit=u.meter)
    print(type(q.mantissa).__module__, q.backend)
numpy numpy

NumPy ufunc interop#

Standard NumPy ufuncs preserve units and enforce dimensional consistency.

a = u.Quantity(np.array([1.0]), unit=u.meter)
b = u.Quantity(np.array([2.0]), unit=u.meter)
np.add(a, b)
Quantity([3.], "m")
from saiunit import UnitMismatchError

c = u.Quantity(np.array([1.0]), unit=u.second)
try:
    np.add(a, c)
except UnitMismatchError as exc:
    print('expected:', exc)
expected: Cannot convert to a unit with different dimensions. (units are s and m).

JAX-only subpackages#

saiunit.lax, saiunit.autograd, and saiunit.sparse need JAX primitives. Calling them on a NumPy-backed quantity raises BackendError. Convert with q.to_jax() first.

from saiunit import BackendError

q = u.Quantity(np.array([1.0, 2.0, 3.0]), unit=u.meter)
try:
    u.lax.slice(q, (0,), (1,))
except BackendError as exc:
    print('expected:', exc)

u.lax.slice(q.to_jax(), (0,), (1,))
expected: saiunit.lax.slice requires the jax backend; got numpy-backed Quantity. Call .to_jax() on the input first.
Quantity([1.], "m")

See also#