PyTorch backend#
The PyTorch backend lets you embed unit-aware computations inside an
existing PyTorch model. PyTorch’s own autograd is preserved through
saiunit operations, so tensor.backward() works on a quantity-derived
loss.
Installation#
pip install saiunit[torch]
The extra pins torch>=2.0. Pick a CPU or CUDA wheel via PyTorch’s own
install matrix if you need a specific accelerator build.
Graceful import#
import saiunit as u
try:
import torch
HAVE_TORCH = True
except ImportError:
HAVE_TORCH = False
print('torch not installed; install with: pip install saiunit[torch]')
torch not installed; install with: pip install saiunit[torch]
Quick start#
if HAVE_TORCH:
q = u.Quantity(torch.tensor([1.0, 2.0, 3.0]), unit=u.meter)
print(q)
print('backend =', q.backend)
print('(q + q).backend =', (q + q).backend)
Conversion#
Quantity.to_torch(device=..., dtype=...) accepts either a torch dtype
(torch.float32) or a numpy dtype (np.float32) — saiunit translates the
latter automatically.
if HAVE_TORCH:
import numpy as np
q_cpu = u.Quantity([1.0, 2.0], unit=u.meter)
q_f64 = q_cpu.to_torch(dtype=np.float32)
print(q_f64.mantissa.dtype)
Gradients with PyTorch autograd#
saiunit.autograd.grad is JAX-only. For PyTorch use torch.autograd.grad
(or .backward()) on the mantissa of the result. Units propagate through
saiunit operations even though the gradient itself is computed by torch.
if HAVE_TORCH:
x = torch.tensor([1.0, 2.0], requires_grad=True)
q = u.Quantity(x, unit=u.meter) ** 2 # area
loss = q.mantissa.sum()
grads, = torch.autograd.grad(loss, x)
print('grads:', grads) # 2 * x
Requesting the backend explicitly#
Same BackendError semantics as the other optional backends.
from saiunit import BackendError
if not HAVE_TORCH:
try:
with u.using_backend('torch'):
u.Quantity([1.0], unit=u.meter)
except BackendError as exc:
print('expected without torch:', exc)
expected without torch: torch backend requested but torch is not installed. Install with: pip install saiunit[torch]
Limitations#
saiunit.autograd,saiunit.lax,saiunit.sparseare JAX-only.Convert tensors with
.to_jax()/.to_numpy()when you need those.