PyTorch backend#

The PyTorch backend lets you embed unit-aware computations inside an existing PyTorch model. PyTorch’s own autograd is preserved through saiunit operations, so tensor.backward() works on a quantity-derived loss.

Installation#

pip install saiunit[torch]

The extra pins torch>=2.0. Pick a CPU or CUDA wheel via PyTorch’s own install matrix if you need a specific accelerator build.

Graceful import#

import saiunit as u

try:
    import torch
    HAVE_TORCH = True
except ImportError:
    HAVE_TORCH = False
    print('torch not installed; install with: pip install saiunit[torch]')
torch not installed; install with: pip install saiunit[torch]

Quick start#

if HAVE_TORCH:
    q = u.Quantity(torch.tensor([1.0, 2.0, 3.0]), unit=u.meter)
    print(q)
    print('backend =', q.backend)
    print('(q + q).backend =', (q + q).backend)

Conversion#

Quantity.to_torch(device=..., dtype=...) accepts either a torch dtype (torch.float32) or a numpy dtype (np.float32) — saiunit translates the latter automatically.

if HAVE_TORCH:
    import numpy as np
    q_cpu = u.Quantity([1.0, 2.0], unit=u.meter)
    q_f64 = q_cpu.to_torch(dtype=np.float32)
    print(q_f64.mantissa.dtype)

Gradients with PyTorch autograd#

saiunit.autograd.grad is JAX-only. For PyTorch use torch.autograd.grad (or .backward()) on the mantissa of the result. Units propagate through saiunit operations even though the gradient itself is computed by torch.

if HAVE_TORCH:
    x = torch.tensor([1.0, 2.0], requires_grad=True)
    q = u.Quantity(x, unit=u.meter) ** 2     # area
    loss = q.mantissa.sum()
    grads, = torch.autograd.grad(loss, x)
    print('grads:', grads)                   # 2 * x

Requesting the backend explicitly#

Same BackendError semantics as the other optional backends.

from saiunit import BackendError

if not HAVE_TORCH:
    try:
        with u.using_backend('torch'):
            u.Quantity([1.0], unit=u.meter)
    except BackendError as exc:
        print('expected without torch:', exc)
expected without torch: torch backend requested but torch is not installed. Install with: pip install saiunit[torch]

Limitations#

  • saiunit.autograd, saiunit.lax, saiunit.sparse are JAX-only.

  • Convert tensors with .to_jax() / .to_numpy() when you need those.