JAX backend (default)#

JAX is saiunit’s default array backend. Every install ships with it because jax is a required core dependency. The JAX backend is the only one that supports saiunit.autograd, saiunit.lax, and saiunit.sparse, and it is the only backend that participates in jax.jit / jax.vmap / jax.pmap.

This notebook shows the JAX backend in isolation. For the multi-backend story see overview.

Installation#

pip install saiunit          # core: pulls in jax + numpy
pip install saiunit[cpu]     # core + jax[cpu] CPU wheels
pip install saiunit[cuda12]  # core + jax[cuda12] for NVIDIA GPU
pip install saiunit[tpu]     # core + jax[tpu] for Google TPU

Quick start#

Multiplying a JAX array by a unit produces a JAX-backed Quantity. The .backend property echoes that back.

import jax
import jax.numpy as jnp
import saiunit as u

q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.meter)
print(q)
print('backend =', q.backend)
print('(q + q).backend =', (q + q).backend)
[1. 2. 3.] m
backend = jax
(q + q).backend = jax

Automatic differentiation#

saiunit.autograd.grad is JAX-only and propagates units through the derivative. The derivative of x ** 3 w.r.t. a length is an area.

f = lambda x: x ** 3
x = 3.0 * u.meter
u.autograd.grad(f)(x)
Quantity(27., "m^2")

JIT and vmap#

Quantities are registered as JAX pytrees, so they flow through jit and vmap without manual unpacking. The compiled function sees the mantissa as the leaf and the unit as static metadata.

@jax.jit
def kinetic_energy(m, v):
    return 0.5 * m * v ** 2

m = 1.5 * u.kgram
v = 4.0 * u.meter / u.second
kinetic_energy(m, v)
Quantity(12., "J")
speeds = u.math.arange(0.0 * u.meter / u.second,
                       5.0 * u.meter / u.second,
                       1.0 * u.meter / u.second)
jax.vmap(lambda v: kinetic_energy(1.5 * u.kgram, v))(speeds)
Quantity([ 0.    0.75  3.    6.75 12.  ], "J")

Setting JAX as the default explicitly#

jax is already the fallback default, but you can be explicit. Useful when you build a Quantity from a Python list (no array yet) and want to pin the backend.

with u.using_backend('jax'):
    q = u.Quantity([1.0, 2.0], unit=u.meter)
    print(type(q.mantissa).__module__, q.backend)
jaxlib._jax jax

Mixed backends#

Mixing a JAX-backed quantity with one from another backend falls through the default-backend tiebreaker. By default the result lands on JAX.

import numpy as np

q_np  = u.Quantity(np.array([1.0]), unit=u.meter)
q_jax = u.Quantity(jnp.array([2.0]), unit=u.meter)
print((q_np + q_jax).backend)         # 'jax' (default tiebreaker)

with u.using_backend('numpy'):
    print((q_np + q_jax).backend)     # 'numpy'
jax
jax

See also#

  • Backends overview — selection rules and capabilities.

  • saiunit.autograd, saiunit.lax, saiunit.sparse — JAX-only subpackages.