Source code for saiunit._backend

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Backend dispatch for NumPy vs JAX array operations.

This module centralizes the rules for choosing between NumPy and JAX
namespaces. Internal saiunit code should call ``get_backend(*xs)`` to obtain
an ``xp`` namespace and then call array operations through it
(e.g. ``xp.sin(x)`` instead of ``jnp.sin(x)``).

The NumPy namespace is provided by ``array_api_compat.numpy`` (a thin wrapper
that exposes the array-API standard surface on top of plain NumPy). The JAX
namespace is plain ``jax.numpy`` — JAX 0.9+ is already array-API-compatible
and ``array_api_compat`` returns it unmodified.
"""

from __future__ import annotations

import functools
import importlib
from contextlib import contextmanager
from contextvars import ContextVar
from types import ModuleType
from typing import Iterator, Literal, Optional

import array_api_compat.numpy as _numpy_xp
import numpy as np

from saiunit._exceptions import BackendError
from saiunit._jax_compat import HAS_JAX, jax, jnp
from saiunit._typing import Array

# Local alias kept for backwards compatibility with callers that reference
# ``_jax_xp`` directly. ``None`` when JAX is not installed.
_jax_xp = jnp


@functools.lru_cache(maxsize=None)
def _try_import(module_name: str):
    """Import ``module_name`` and return it, or ``None`` on ImportError.

    Results are cached so failed imports aren't retried on every call.
    Never raises.
    """
    try:
        return importlib.import_module(module_name)
    except ImportError:
        return None

__all__ = [
    "get_backend",
    "get_default_backend",
    "set_default_backend",
    "using_backend",
    "is_jax_array",
    "is_numpy_array",
    "is_cupy_array",
    "is_torch_array",
    "is_dask_array",
    "is_ndonnx_array",
    "to_backend",
]

BackendName = Literal["numpy", "jax", "cupy", "torch", "dask", "ndonnx"]

_default_backend: ContextVar[Optional[BackendName]] = ContextVar(
    "saiunit_default_backend", default=("jax" if HAS_JAX else "numpy")
)


[docs] def is_numpy_array(x) -> bool: """Return True if ``x`` is a NumPy array or scalar (and not a JAX array). Includes ``numpy.ndarray`` as well as numpy scalar types (``numpy.float64``, ``numpy.int32``, …) — reductions like ``np.linalg.norm([3., 4.])`` return a numpy scalar, and for backend-routing purposes that's NumPy too. """ if not isinstance(x, (np.ndarray, np.generic)): return False if HAS_JAX and isinstance(x, Array): return False return True
[docs] def is_jax_array(x) -> bool: """Return True if ``x`` is an ``Array``. False if JAX is not installed.""" if not HAS_JAX: return False return isinstance(x, Array)
[docs] def is_cupy_array(x) -> bool: """Return True if ``x`` is a CuPy ndarray. False if CuPy is not installed.""" cupy = _try_import("cupy") if cupy is None: return False return isinstance(x, cupy.ndarray)
[docs] def is_torch_array(x) -> bool: """Return True if ``x`` is a PyTorch tensor. False if PyTorch is not installed.""" torch = _try_import("torch") if torch is None: return False return isinstance(x, torch.Tensor)
[docs] def is_dask_array(x) -> bool: """Return True if ``x`` is a dask Array. False if dask is not installed.""" da = _try_import("dask.array") if da is None: return False return isinstance(x, da.Array)
[docs] def is_ndonnx_array(x) -> bool: """Return True if ``x`` is an ndonnx Array. False if ndonnx is not installed.""" ndonnx = _try_import("ndonnx") if ndonnx is None: return False return isinstance(x, ndonnx.Array)
[docs] def get_default_backend() -> Optional[BackendName]: """Return the currently configured default backend, or None if unset.""" return _default_backend.get()
[docs] def set_default_backend(name: Optional[BackendName]) -> None: """Set the default backend used when input backend is ambiguous. Parameters ---------- name : {'numpy', 'jax', 'cupy', 'torch', None} Pass ``None`` to clear the default. With no default, the tie-breaker prefers JAX when installed and falls back to NumPy otherwise. """ if name not in ("numpy", "jax", "cupy", "torch", "dask", "ndonnx", None): raise ValueError( f"default backend must be 'numpy', 'jax', 'cupy', 'torch', 'dask', " f"'ndonnx', or None; got {name!r}" ) if name == "jax" and not HAS_JAX: raise BackendError( "jax backend requested but jax is not installed. " "Install with: pip install saiunit[jax]" ) _default_backend.set(name)
[docs] @contextmanager def using_backend(name: BackendName) -> Iterator[None]: """Context manager that temporarily sets the default backend.""" if name not in ("numpy", "jax", "cupy", "torch", "dask", "ndonnx"): raise ValueError( f"backend must be 'numpy', 'jax', 'cupy', 'torch', 'dask', or 'ndonnx'; " f"got {name!r}" ) if name == "jax" and not HAS_JAX: raise BackendError( "jax backend requested but jax is not installed. " "Install with: pip install saiunit[jax]" ) token = _default_backend.set(name) try: yield finally: _default_backend.reset(token)
_XP_CACHE: dict[str, ModuleType] = {} def _xp_for(name: BackendName) -> ModuleType: """Return (and cache) the xp namespace for ``name``.""" cached = _XP_CACHE.get(name) if cached is not None: return cached mod: ModuleType if name == "numpy": mod = _numpy_xp elif name == "jax": if not HAS_JAX: raise BackendError( "jax backend requested but jax is not installed. " "Install with: pip install saiunit[jax]" ) mod = _jax_xp elif name == "cupy": if _try_import("cupy") is None: raise BackendError( "cupy backend requested but cupy is not installed. " "Install with: pip install saiunit[cupy]" ) import array_api_compat.cupy as _cupy_xp mod = _cupy_xp elif name == "torch": if _try_import("torch") is None: raise BackendError( "torch backend requested but torch is not installed. " "Install with: pip install saiunit[torch]" ) import array_api_compat.torch as _torch_xp mod = _torch_xp elif name == "dask": if _try_import("dask.array") is None: raise BackendError( "dask backend requested but dask is not installed. " "Install with: pip install saiunit[dask]" ) import array_api_compat.dask.array as _dask_xp mod = _dask_xp elif name == "ndonnx": ndonnx = _try_import("ndonnx") if ndonnx is None: raise BackendError( "ndonnx backend requested but ndonnx is not installed. " "Install with: pip install saiunit[ndonnx]" ) mod = ndonnx # ndonnx is itself array-API-compatible else: raise ValueError(f"unknown backend: {name!r}") _XP_CACHE[name] = mod return mod def _name_to_xp(name: BackendName) -> ModuleType: """Deprecated alias retained for any external callers; prefer ``_xp_for``.""" return _xp_for(name) def get_backend(*arrays_or_quantities) -> ModuleType: """Return the ``xp`` namespace appropriate for the given inputs. Detection order: numpy, jax, cupy, torch. On mixed inputs or no arrays, consults ``get_default_backend()``; falls back to jax. """ from saiunit._base_quantity import Quantity # local import to avoid cycle mantissas = [a.mantissa if isinstance(a, Quantity) else a for a in arrays_or_quantities] has_numpy = any(is_numpy_array(x) for x in mantissas) has_jax = any(is_jax_array(x) for x in mantissas) has_cupy = any(is_cupy_array(x) for x in mantissas) has_torch = any(is_torch_array(x) for x in mantissas) has_dask = any(is_dask_array(x) for x in mantissas) has_ndonnx = any(is_ndonnx_array(x) for x in mantissas) kinds: list[BackendName] = [ name for name, has in # type: ignore[misc] [("numpy", has_numpy), ("jax", has_jax), ("cupy", has_cupy), ("torch", has_torch), ("dask", has_dask), ("ndonnx", has_ndonnx)] if has ] if len(kinds) == 1: return _xp_for(kinds[0]) default = _default_backend.get() if default is not None: return _xp_for(default) # Tie-breaker: prefer JAX when installed, fall back to NumPy otherwise so # that the package remains usable without the [jax] extra. return _xp_for("jax" if HAS_JAX else "numpy") _NUMPY_TO_TORCH_DTYPE = { "float16": "float16", "float32": "float32", "float64": "float64", "int8": "int8", "int16": "int16", "int32": "int32", "int64": "int64", "uint8": "uint8", "bool": "bool", "complex64": "complex64", "complex128": "complex128", } def _numpy_to_torch_dtype(np_dtype, torch_mod): """Translate a numpy dtype (or np.dtype-like) to a torch dtype.""" name = np.dtype(np_dtype).name torch_name = _NUMPY_TO_TORCH_DTYPE.get(name) if torch_name is None: raise TypeError(f"no torch dtype mapping for numpy dtype {name!r}") return getattr(torch_mod, torch_name) def _translate_dtype(dtype, xp): """Translate ``dtype`` (often a numpy dtype) to the equivalent on ``xp``. torch and ndonnx reject numpy dtype objects: torch's ``Tensor.to`` raises ``TypeError: received an invalid combination of arguments`` and ndonnx raises ``AttributeError: type object 'numpy.float32' has no attribute '__ndx_cast_from__'``. Both expose array-API dtype attributes (``xp.float32``, ``xp.int64``, …) that match numpy's dtype names, so the safest cross-backend bridge is ``getattr(xp, np.dtype(dtype).name)``. Returns ``dtype`` unchanged if it isn't a numpy dtype-like value or if ``xp`` doesn't expose the named attribute. """ if dtype is None: return None try: name = np.dtype(dtype).name except (TypeError, ValueError): return dtype return getattr(xp, name, dtype) def to_backend(x, name: BackendName, **kwargs): """Convert ``x`` to the given backend; no-op if already there. Backend-specific kwargs: - cupy: device - torch: device, dtype Other backends raise TypeError on any kwargs. """ if name == "numpy": if kwargs: raise TypeError(f"to_backend(name='numpy') does not accept kwargs; got {sorted(kwargs)}") if is_numpy_array(x): return x # ndonnx requires explicit materialization; np.asarray returns a 0-d # object wrapper instead of evaluating the symbolic graph. if is_ndonnx_array(x): return x.unwrap_numpy() return np.asarray(x) if name == "jax": if not HAS_JAX: raise BackendError( "jax backend requested but jax is not installed. " "Install with: pip install saiunit[jax]" ) if kwargs: raise TypeError(f"to_backend(name='jax') does not accept kwargs; got {sorted(kwargs)}") if is_jax_array(x): return x return jnp.asarray(x) if name == "cupy": cupy = _try_import("cupy") if cupy is None: raise BackendError( "cupy backend requested but cupy is not installed. " "Install with: pip install saiunit[cupy]" ) unknown = set(kwargs) - {"device"} if unknown: raise TypeError(f"to_backend(name='cupy') does not accept {sorted(unknown)}") if is_cupy_array(x) and "device" not in kwargs: return x device = kwargs.get("device") if device is not None: with cupy.cuda.Device(device): return cupy.asarray(x) return cupy.asarray(x) if name == "torch": torch = _try_import("torch") if torch is None: raise BackendError( "torch backend requested but torch is not installed. " "Install with: pip install saiunit[torch]" ) unknown = set(kwargs) - {"device", "dtype"} if unknown: raise TypeError(f"to_backend(name='torch') does not accept {sorted(unknown)}") # Translate numpy dtype to torch dtype if needed. dtype = kwargs.get("dtype") if dtype is not None and not isinstance(dtype, torch.dtype): dtype = _numpy_to_torch_dtype(dtype, torch) device = kwargs.get("device") if is_torch_array(x) and not kwargs: return x # torch.as_tensor shares memory where possible; we accept that. return torch.as_tensor(x, device=device, dtype=dtype) if name == "dask": da = _try_import("dask.array") if da is None: raise BackendError( "dask backend requested but dask is not installed. " "Install with: pip install saiunit[dask]" ) unknown = set(kwargs) - {"chunks"} if unknown: raise TypeError(f"to_backend(name='dask') does not accept {sorted(unknown)}") if is_dask_array(x) and "chunks" not in kwargs: return x chunks = kwargs.get("chunks", "auto") return da.from_array(x, chunks=chunks) if name == "ndonnx": ndonnx = _try_import("ndonnx") if ndonnx is None: raise BackendError( "ndonnx backend requested but ndonnx is not installed. " "Install with: pip install saiunit[ndonnx]" ) if kwargs: raise TypeError(f"to_backend(name='ndonnx') does not accept kwargs; got {sorted(kwargs)}") if is_ndonnx_array(x): return x return ndonnx.asarray(x) raise ValueError( f"backend must be one of 'numpy', 'jax', 'cupy', 'torch', 'dask', 'ndonnx'; " f"got {name!r}" )