JAX backend (default)#
JAX is saiunit’s default array backend. Every install ships with it because
jax is a required core dependency. The JAX backend is the only one that
supports saiunit.autograd, saiunit.lax, and saiunit.sparse, and it is
the only backend that participates in jax.jit / jax.vmap / jax.pmap.
This notebook shows the JAX backend in isolation. For the multi-backend story see overview.
Installation#
pip install saiunit # core: pulls in jax + numpy
pip install saiunit[cpu] # core + jax[cpu] CPU wheels
pip install saiunit[cuda12] # core + jax[cuda12] for NVIDIA GPU
pip install saiunit[tpu] # core + jax[tpu] for Google TPU
Quick start#
Multiplying a JAX array by a unit produces a JAX-backed Quantity. The
.backend property echoes that back.
import jax
import jax.numpy as jnp
import saiunit as u
q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.meter)
print(q)
print('backend =', q.backend)
print('(q + q).backend =', (q + q).backend)
[1. 2. 3.] m
backend = jax
(q + q).backend = jax
Automatic differentiation#
saiunit.autograd.grad is JAX-only and propagates units through the
derivative. The derivative of x ** 3 w.r.t. a length is an area.
f = lambda x: x ** 3
x = 3.0 * u.meter
u.autograd.grad(f)(x)
Quantity(27., "m^2")
JIT and vmap#
Quantities are registered as JAX pytrees, so they flow through jit and
vmap without manual unpacking. The compiled function sees the mantissa
as the leaf and the unit as static metadata.
@jax.jit
def kinetic_energy(m, v):
return 0.5 * m * v ** 2
m = 1.5 * u.kgram
v = 4.0 * u.meter / u.second
kinetic_energy(m, v)
Quantity(12., "J")
speeds = u.math.arange(0.0 * u.meter / u.second,
5.0 * u.meter / u.second,
1.0 * u.meter / u.second)
jax.vmap(lambda v: kinetic_energy(1.5 * u.kgram, v))(speeds)
Quantity([ 0. 0.75 3. 6.75 12. ], "J")
Setting JAX as the default explicitly#
jax is already the fallback default, but you can be explicit. Useful when
you build a Quantity from a Python list (no array yet) and want to pin the
backend.
with u.using_backend('jax'):
q = u.Quantity([1.0, 2.0], unit=u.meter)
print(type(q.mantissa).__module__, q.backend)
jaxlib._jax jax
Mixed backends#
Mixing a JAX-backed quantity with one from another backend falls through the default-backend tiebreaker. By default the result lands on JAX.
import numpy as np
q_np = u.Quantity(np.array([1.0]), unit=u.meter)
q_jax = u.Quantity(jnp.array([2.0]), unit=u.meter)
print((q_np + q_jax).backend) # 'jax' (default tiebreaker)
with u.using_backend('numpy'):
print((q_np + q_jax).backend) # 'numpy'
jax
jax
See also#
Backends overview — selection rules and capabilities.
saiunit.autograd,saiunit.lax,saiunit.sparse— JAX-only subpackages.