Source code for saiunit.math._fun_array_creation

# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations

from collections.abc import Sequence
from typing import (Union, Optional, List, Any, Tuple)

from saiunit._jax_compat import jax, jnp, tree as _tree
from saiunit._typing import Array, ArrayLike, DTypeLike
import numpy as np

from saiunit._backend import get_backend, get_default_backend, _xp_for, _translate_dtype
from saiunit._base_dimension import UnitMismatchError
from saiunit._base_unit import UNITLESS, Unit
from saiunit._base_getters import fail_for_unit_mismatch, get_unit, unit_scale_align_to_first
from saiunit._base_quantity import Quantity
from saiunit._misc import set_module_as, maybe_custom_array_tree, maybe_custom_array

import array_api_compat.numpy as _numpy_xp


def _safe_call_xp(fn, args, kwargs):
    """Local mirror of :func:`_fun_keep_unit._dispatch_call` for direct ``xp.fn(...)``
    call sites in this module. Imported lazily to avoid a circular import."""
    from ._fun_keep_unit import _dispatch_call
    return _dispatch_call(fn, args, kwargs)


def _default_xp():
    """Return the backend namespace selected by the current default backend.

    When no default is set, prefer JAX if installed (preserves legacy
    behaviour where ``jnp`` was the fallback), otherwise NumPy. If the
    configured default backend isn't importable — e.g. CI runs that test
    individual backends in isolation without JAX — fall back to NumPy.
    """
    name = get_default_backend()
    if name is None:
        return jnp if jnp is not None else _numpy_xp
    try:
        return _xp_for(name)
    except Exception:
        return _numpy_xp

Shape = Union[int, Sequence[int]]

__all__ = [
    # array creation(given shape)
    'full', 'eye', 'identity', 'tri',
    'empty', 'ones', 'zeros',

    # array creation(given array)
    'full_like', 'diag', 'tril', 'triu',
    'empty_like', 'ones_like', 'zeros_like', 'fill_diagonal',

    # array creation(misc)
    'array', 'asarray', 'arange', 'linspace', 'logspace',
    'meshgrid', 'vander',

    # indexing funcs
    'tril_indices', 'tril_indices_from', 'triu_indices',
    'triu_indices_from',

    # others
    'from_numpy',
    'as_numpy',
    'tree_ones_like',
    'tree_zeros_like',
]


@set_module_as('saiunit.math')
def full(
    shape: Shape,
    fill_value: Union[Quantity, int, float],
    dtype: Optional[DTypeLike] = None,
) -> Union[Array, Quantity]:
    """
    Return a new quantity or array of given shape, filled with ``fill_value``.

    If ``fill_value`` is a :class:`~saiunit.Quantity`, the result is a
    ``Quantity`` carrying the same unit.  Otherwise a plain JAX array is
    returned.

    Parameters
    ----------
    shape : int or sequence of ints
        Shape of the new array, e.g., ``(2, 3)`` or ``2``.
    fill_value : scalar, array_like, or Quantity
        Fill value.  When a ``Quantity`` is given its unit is preserved.
    dtype : data-type, optional
        The desired data-type for the array.  The default, ``None``, means
        the dtype is inferred from ``fill_value``.

    Returns
    -------
    out : Quantity or Array
        Array (or ``Quantity``) of ``fill_value`` with the given shape and
        dtype.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> u.math.full((2, 3), 7.0)
        Array([[7., 7., 7.],
               [7., 7., 7.]], dtype=float32)
        >>> u.math.full((3,), 5.0 * u.meter)
        Quantity([5. 5. 5.], "m")
    """
    fill_value = maybe_custom_array(fill_value)
    if isinstance(fill_value, Quantity):
        xp = get_backend(fill_value.mantissa)
        return Quantity(xp.full(shape, fill_value.mantissa, dtype=dtype), unit=fill_value.unit)
    xp = _default_xp()
    return xp.full(shape, fill_value, dtype=dtype)


@set_module_as('saiunit.math')
def eye(
    N: int,
    M: Optional[int] = None,
    k: int = 0,
    dtype: Optional[DTypeLike] = None,
    unit: Unit = UNITLESS,
) -> Union[Array, Quantity]:
    """
    Return a 2-D identity-like quantity or array with ones on the diagonal.

    Parameters
    ----------
    N : int
        Number of rows in the output.
    M : int, optional
        Number of columns in the output.  If ``None``, defaults to ``N``.
    k : int, optional
        Index of the diagonal: 0 (the default) refers to the main diagonal,
        a positive value refers to an upper diagonal, and a negative value
        to a lower diagonal.
    dtype : data-type, optional
        Data-type of the returned array.
    unit : Unit, optional
        Unit of the returned ``Quantity``.  When ``UNITLESS`` (the default)
        a plain array is returned.

    Returns
    -------
    out : Quantity or Array
        An array of shape ``(N, M)`` where all elements are zero except for
        the ``k``-th diagonal, whose values are one (optionally carrying
        ``unit``).

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> u.math.eye(2)
        Array([[1., 0.],
               [0., 1.]], dtype=float32)
        >>> u.math.eye(2, unit=u.meter)
        Quantity([[1. 0.]
                  [0. 1.]], "m")
    """
    if not isinstance(unit, Unit):
        raise TypeError(f'eye requires "unit" to be a Unit instance, got {type(unit).__name__}: {unit!r}.')
    # ``k`` is keyword-only under the array-API spec (e.g. ``array_api_compat.numpy.eye``).
    arr = _default_xp().eye(N, M, k=k, dtype=dtype)
    if not unit.is_unitless:
        return arr * unit
    return arr


@set_module_as('saiunit.math')
def identity(
    n: int,
    dtype: Optional[DTypeLike] = None,
    unit: Unit = UNITLESS
) -> Union[Array, Quantity]:
    """
    Return the identity quantity or array.

    The identity array is a square array with ones on the main diagonal.

    Parameters
    ----------
    n : int
        Number of rows (and columns) in the ``n x n`` output.
    dtype : data-type, optional
        Data-type of the output.  Defaults to ``float``.
    unit : Unit, optional
        Unit of the returned ``Quantity``.  When ``UNITLESS`` (the default)
        a plain array is returned.

    Returns
    -------
    out : Quantity or Array
        ``n x n`` array with its main diagonal set to one and all other
        elements zero.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> u.math.identity(3)
        Array([[1., 0., 0.],
               [0., 1., 0.],
               [0., 0., 1.]], dtype=float32)
        >>> u.math.identity(2, unit=u.second)
        Quantity([[1. 0.]
                  [0. 1.]], "s")
    """
    if not isinstance(unit, Unit):
        raise TypeError(f'identity requires "unit" to be a Unit instance, got {type(unit).__name__}: {unit!r}.')
    if not unit.is_unitless:
        return _default_xp().identity(n, dtype=dtype) * unit
    else:
        return _default_xp().identity(n, dtype=dtype)


@set_module_as('saiunit.math')
def tri(
    N: int,
    M: Optional[int] = None,
    k: int = 0,
    dtype: Optional[DTypeLike] = None,
    unit: Unit = UNITLESS
) -> Union[Array, Quantity]:
    """
    Return an array with ones at and below the given diagonal and zeros elsewhere.

    Parameters
    ----------
    N : int
        Number of rows in the array.
    M : int, optional
        Number of columns in the array.  By default, ``M`` is taken equal
        to ``N``.
    k : int, optional
        The sub-diagonal at and below which the array is filled.
        ``k = 0`` is the main diagonal, ``k < 0`` is below it, and
        ``k > 0`` is above.  The default is 0.
    dtype : data-type, optional
        Data type of the returned array.  The default is ``float``.
    unit : Unit, optional
        Unit of the returned ``Quantity``.

    Returns
    -------
    out : Quantity or Array
        Array of shape ``(N, M)`` with its lower triangle filled with ones
        and zero elsewhere; i.e. ``T[i, j] == 1`` for ``j <= i + k``,
        0 otherwise.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> u.math.tri(3)
        Array([[1., 0., 0.],
               [1., 1., 0.],
               [1., 1., 1.]], dtype=float32)
        >>> u.math.tri(2, 3, unit=u.meter)
        Quantity([[1. 0. 0.]
                  [1. 1. 0.]], "m")
    """
    if not isinstance(unit, Unit):
        raise TypeError(f'tri requires "unit" to be a Unit instance, got {type(unit).__name__}: {unit!r}.')
    xp = _default_xp()
    # ``dask.array.tri`` rejects ``dtype=None`` ("dtype must be known for auto-chunking"),
    # while numpy/jax default to ``float``. Materialize the default explicitly for portability.
    if dtype is None:
        dtype = xp.float64 if hasattr(xp, "float64") else float
    arr = xp.tri(N, M, k, dtype=dtype)
    if not unit.is_unitless:
        return arr * unit
    return arr


@set_module_as('saiunit.math')
def empty(
    shape: Shape,
    dtype: Optional[DTypeLike] = None,
    unit: Unit = UNITLESS
) -> Union[Array, Quantity]:
    """
    Return a new quantity or array of given shape and type, without initializing entries.

    Parameters
    ----------
    shape : int or sequence of ints
        Shape of the empty quantity or array, e.g., ``(2, 3)`` or ``2``.
    dtype : data-type, optional
        Data-type of the output.  Defaults to ``float``.
    unit : Unit, optional
        Unit of the returned ``Quantity``.  When ``UNITLESS`` (the default)
        a plain array is returned.

    Returns
    -------
    out : Quantity or Array
        Array of uninitialized (arbitrary) data of the given shape and dtype.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> result = u.math.empty((2, 3))
        >>> result.shape
        (2, 3)
        >>> result = u.math.empty((2,), unit=u.meter)
        >>> u.get_unit(result) == u.meter
        True
    """
    if not isinstance(unit, Unit):
        raise TypeError(f'empty requires "unit" to be a Unit instance, got {type(unit).__name__}: {unit!r}.')
    if not unit.is_unitless:
        return _default_xp().empty(shape, dtype=dtype) * unit
    else:
        return _default_xp().empty(shape, dtype=dtype)


@set_module_as('saiunit.math')
def ones(
    shape: Shape,
    dtype: Optional[DTypeLike] = None,
    unit: Unit = UNITLESS
) -> Union[Array, Quantity]:
    """
    Return a new quantity or array of given shape and type, filled with ones.

    Parameters
    ----------
    shape : int or sequence of ints
        Shape of the new quantity or array, e.g., ``(2, 3)`` or ``2``.
    dtype : data-type, optional
        The desired data-type for the array.  Default is ``float``.
    unit : Unit, optional
        Unit of the returned ``Quantity``.  When ``UNITLESS`` (the default)
        a plain array is returned.

    Returns
    -------
    out : Quantity or Array
        Array of ones with the given shape and dtype.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> u.math.ones((3,))
        Array([1., 1., 1.], dtype=float32)
        >>> u.math.ones((2, 2), unit=u.meter)
        Quantity([[1. 1.]
                  [1. 1.]], "m")
    """
    if not isinstance(unit, Unit):
        raise TypeError(f'ones requires "unit" to be a Unit instance, got {type(unit).__name__}: {unit!r}.')
    if not unit.is_unitless:
        return _default_xp().ones(shape, dtype=dtype) * unit
    else:
        return _default_xp().ones(shape, dtype=dtype)


@set_module_as('saiunit.math')
def zeros(
    shape: Shape,
    dtype: Optional[DTypeLike] = None,
    unit: Unit = UNITLESS
) -> Union[Array, Quantity]:
    """
    Return a new quantity or array of given shape and type, filled with zeros.

    Parameters
    ----------
    shape : int or sequence of ints
        Shape of the new quantity or array, e.g., ``(2, 3)`` or ``2``.
    dtype : data-type, optional
        The desired data-type for the array.  Default is ``float``.
    unit : Unit, optional
        Unit of the returned ``Quantity``.  When ``UNITLESS`` (the default)
        a plain array is returned.

    Returns
    -------
    out : Quantity or Array
        Array of zeros with the given shape and dtype.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> u.math.zeros((3,))
        Array([0., 0., 0.], dtype=float32)
        >>> u.math.zeros((2,), unit=u.second)
        Quantity([0. 0.], "s")
    """
    if not isinstance(unit, Unit):
        raise TypeError(f'zeros requires "unit" to be a Unit instance, got {type(unit).__name__}: {unit!r}.')
    if not unit.is_unitless:
        return _default_xp().zeros(shape, dtype=dtype) * unit
    else:
        return _default_xp().zeros(shape, dtype=dtype)


@set_module_as('saiunit.math')
def full_like(
    a: Union[Quantity, ArrayLike],
    fill_value: Union[Quantity, ArrayLike],
    dtype: Optional[DTypeLike] = None,
    shape: Shape | None = None
) -> Union[Quantity, Array]:
    """
    Return a new quantity or array with the same shape and type as a given array, filled with ``fill_value``.

    Parameters
    ----------
    a : Quantity or array_like
        The shape and data-type of ``a`` define these same attributes of the
        returned array.
    fill_value : Quantity or array_like
        Value to fill the new quantity or array with.  When ``a`` is a
        ``Quantity``, ``fill_value`` must have a compatible unit.
    dtype : data-type, optional
        Overrides the data type of the result.
    shape : int or sequence of ints, optional
        Overrides the shape of the result.  If not given, ``a.shape`` is
        used.

    Returns
    -------
    out : Quantity or Array
        New array with the same shape and type as ``a``, filled with
        ``fill_value``.

    Raises
    ------
    TypeError
        If ``fill_value`` carries a unit but ``a`` is a plain array (not
        unitless), or vice-versa.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> import jax.numpy as jnp
        >>> u.math.full_like(jnp.array([1.0, 2.0]), 9.0)
        Array([9., 9.], dtype=float32)
        >>> u.math.full_like(jnp.array([1.0, 2.0]) * u.meter, 9.0 * u.meter)
        Quantity([9. 9.], "m")
    """
    a = maybe_custom_array(a)
    fill_value = maybe_custom_array(fill_value)
    xp = get_backend(a.mantissa if isinstance(a, Quantity) else a)
    if isinstance(fill_value, Quantity):
        if isinstance(a, Quantity):
            fill_value = fill_value.in_unit(a.unit)
            return Quantity(
                _safe_call_xp(xp.full_like, (a.mantissa, fill_value.mantissa), dict(dtype=dtype, shape=shape)),
                unit=a.unit
            )
        else:
            if not fill_value.is_unitless:
                raise TypeError(
                    f'full_like requires "fill_value" to be dimensionless when "a" is a plain array, '
                    f'but got fill_value with unit={fill_value.unit}. '
                    f'Either pass a plain number as fill_value or wrap "a" as a Quantity.'
                )
            return Quantity(
                _safe_call_xp(xp.full_like, (a, fill_value.mantissa), dict(dtype=dtype, shape=shape)),
                unit=fill_value.unit
            )
    else:
        if isinstance(a, Quantity):
            if not a.is_unitless:
                raise TypeError(
                    f'full_like requires "a" to be dimensionless when "fill_value" is a plain value, '
                    f'but got a with unit={a.unit}. '
                    f'Either pass a Quantity as fill_value or use a plain array for "a".'
                )
            return _safe_call_xp(xp.full_like, (a.mantissa, fill_value), dict(dtype=dtype, shape=shape))
        else:
            return _safe_call_xp(xp.full_like, (a, fill_value), dict(dtype=dtype, shape=shape))


@set_module_as('saiunit.math')
def diag(
    v: Union[Quantity, ArrayLike],
    k: int = 0,
    unit: Unit = UNITLESS
) -> Union[Quantity, Array]:
    """
    Extract a diagonal or construct a diagonal array.

    If ``v`` is a 1-D array, ``diag`` constructs a 2-D array with ``v`` on
    the ``k``-th diagonal.  If ``v`` is a 2-D array, ``diag`` extracts the
    ``k``-th diagonal and returns a 1-D array.

    Parameters
    ----------
    v : Quantity or array_like
        Input array.  1-D inputs produce a 2-D diagonal matrix; 2-D inputs
        have their ``k``-th diagonal extracted.
    k : int, optional
        Diagonal in question.  The default is 0.  Use ``k > 0`` for
        diagonals above the main diagonal and ``k < 0`` for diagonals below.
    unit : Unit, optional
        Unit of the returned ``Quantity``.  Ignored when ``v`` already
        carries a unit.

    Returns
    -------
    out : Quantity or Array
        The extracted diagonal or constructed diagonal array.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> import jax.numpy as jnp
        >>> u.math.diag(jnp.array([1.0, 2.0, 3.0]))
        Array([[1., 0., 0.],
               [0., 2., 0.],
               [0., 0., 3.]], dtype=float32)
        >>> u.math.diag(jnp.array([1.0, 2.0]), unit=u.meter)
        Quantity([[1. 0.]
                  [0. 2.]], "m")
    """
    if not isinstance(unit, Unit):
        raise TypeError(f'diag requires "unit" to be a Unit instance, got {type(unit).__name__}: {unit!r}.')
    v = maybe_custom_array(v)
    xp = get_backend(v.mantissa if isinstance(v, Quantity) else v)
    if isinstance(v, Quantity):
        if not unit.is_unitless:
            v = v.in_unit(unit)
        return Quantity(_safe_call_xp(xp.diag, (v.mantissa,), dict(k=k)), unit=v.unit)
    else:
        if not unit.is_unitless:
            return _safe_call_xp(xp.diag, (v,), dict(k=k)) * unit
        else:
            return _safe_call_xp(xp.diag, (v,), dict(k=k))


@set_module_as('saiunit.math')
def tril(
    m: Union[Quantity, ArrayLike],
    k: int = 0,
    unit: Unit = UNITLESS
) -> Union[Quantity, Array]:
    """
    Return the lower triangle of an array.

    Return a copy of a matrix with the elements above the ``k``-th diagonal
    zeroed.  For arrays with ``ndim > 2``, ``tril`` applies to the final two
    axes.

    Parameters
    ----------
    m : Quantity or array_like
        Input array.
    k : int, optional
        Diagonal above which to zero elements.  ``k = 0`` is the main
        diagonal, ``k < 0`` is below it, and ``k > 0`` is above.
    unit : Unit, optional
        Unit of the returned ``Quantity``.

    Returns
    -------
    out : Quantity or Array
        Lower triangle of ``m``, of the same shape and data-type as ``m``.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> import jax.numpy as jnp
        >>> u.math.tril(jnp.ones((3, 3)))
        Array([[1., 0., 0.],
               [1., 1., 0.],
               [1., 1., 1.]], dtype=float32)
    """
    if not isinstance(unit, Unit):
        raise TypeError(f'tril requires "unit" to be a Unit instance, got {type(unit).__name__}: {unit!r}.')
    m = maybe_custom_array(m)
    xp = get_backend(m.mantissa if isinstance(m, Quantity) else m)
    if isinstance(m, Quantity):
        if not unit.is_unitless:
            m = m.in_unit(unit)
        return Quantity(xp.tril(m.mantissa, k=k), unit=m.unit)
    else:
        if not unit.is_unitless:
            return xp.tril(m, k=k) * unit
        else:
            return xp.tril(m, k=k)


@set_module_as('saiunit.math')
def triu(
    m: Union[Quantity, ArrayLike],
    k: int = 0,
    unit: Unit = UNITLESS
) -> Union[Quantity, Array]:
    """
    Return the upper triangle of an array.

    Return a copy of an array with the elements below the ``k``-th diagonal
    zeroed.  For arrays with ``ndim > 2``, ``triu`` applies to the final two
    axes.

    Parameters
    ----------
    m : Quantity or array_like
        Input array.
    k : int, optional
        Diagonal below which to zero elements.  ``k = 0`` is the main
        diagonal, ``k < 0`` is below it, and ``k > 0`` is above.
    unit : Unit, optional
        Unit of the returned ``Quantity``.

    Returns
    -------
    out : Quantity or Array
        Upper triangle of ``m``, of the same shape and data-type as ``m``.

    See Also
    --------
    tril : lower triangle of an array

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> import jax.numpy as jnp
        >>> u.math.triu(jnp.ones((3, 3)))
        Array([[1., 1., 1.],
               [0., 1., 1.],
               [0., 0., 1.]], dtype=float32)
    """
    if not isinstance(unit, Unit):
        raise TypeError(f'triu requires "unit" to be a Unit instance, got {type(unit).__name__}: {unit!r}.')
    m = maybe_custom_array(m)
    xp = get_backend(m.mantissa if isinstance(m, Quantity) else m)
    if isinstance(m, Quantity):
        if not unit.is_unitless:
            m = m.in_unit(unit)
        return Quantity(xp.triu(m.mantissa, k=k), unit=m.unit)
    else:
        if not unit.is_unitless:
            return xp.triu(m, k=k) * unit
        else:
            return xp.triu(m, k=k)


@set_module_as('saiunit.math')
def empty_like(
    prototype: Union[Quantity, ArrayLike],
    dtype: Optional[DTypeLike] = None,
    shape: Shape | None = None,
    unit: Unit = UNITLESS
) -> Union[Quantity, Array]:
    """
    Return a new uninitialized quantity or array with the same shape and type as a given array.

    Parameters
    ----------
    prototype : Quantity or array_like
        The shape and data-type of ``prototype`` define these same attributes
        of the returned array.
    dtype : data-type, optional
        Overrides the data type of the result.
    shape : int or tuple of ints, optional
        Overrides the shape of the result.  If not given,
        ``prototype.shape`` is used.
    unit : Unit, optional
        Unit of the returned ``Quantity``.

    Returns
    -------
    out : Quantity or Array
        Array of uninitialized (arbitrary) data with the same shape and type
        as ``prototype``.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> import jax.numpy as jnp
        >>> result = u.math.empty_like(jnp.array([1.0, 2.0, 3.0]))
        >>> result.shape
        (3,)
    """
    if not isinstance(unit, Unit):
        raise TypeError(f'empty_like requires "unit" to be a Unit instance, got {type(unit).__name__}: {unit!r}.')
    prototype = maybe_custom_array(prototype)
    xp = get_backend(prototype.mantissa if isinstance(prototype, Quantity) else prototype)
    if isinstance(prototype, Quantity):
        if not unit.is_unitless:
            prototype = prototype.in_unit(unit)
        return Quantity(
            _safe_call_xp(xp.empty_like, (prototype.mantissa,), dict(dtype=dtype)),
            unit=prototype.unit,
        )
    else:
        if not unit.is_unitless:
            return _safe_call_xp(xp.empty_like, (prototype,), dict(dtype=dtype, shape=shape)) * unit
        else:
            return _safe_call_xp(xp.empty_like, (prototype,), dict(dtype=dtype, shape=shape))


@set_module_as('saiunit.math')
def ones_like(
    a: Union[Quantity, ArrayLike],
    dtype: Optional[DTypeLike] = None,
    shape: Shape | None = None,
    unit: Unit = UNITLESS
) -> Union[Quantity, Array]:
    """
    Return a quantity or array of ones with the same shape and type as a given array.

    Parameters
    ----------
    a : Quantity or array_like
        The shape and data-type of ``a`` define these same attributes of the
        returned array.
    dtype : data-type, optional
        Overrides the data type of the result.
    shape : int or tuple of ints, optional
        Overrides the shape of the result.  If not given, ``a.shape`` is
        used.
    unit : Unit, optional
        Unit of the returned ``Quantity``.

    Returns
    -------
    out : Quantity or Array
        Array of ones with the same shape and type as ``a``.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> import jax.numpy as jnp
        >>> u.math.ones_like(jnp.array([1.0, 2.0, 3.0]))
        Array([1., 1., 1.], dtype=float32)
        >>> u.math.ones_like(jnp.array([1.0, 2.0]) * u.meter)
        Quantity([1. 1.], "m")
    """
    if not isinstance(unit, Unit):
        raise TypeError(f'ones_like requires "unit" to be a Unit instance, got {type(unit).__name__}: {unit!r}.')
    a = maybe_custom_array(a)
    xp = get_backend(a.mantissa if isinstance(a, Quantity) else a)
    if isinstance(a, Quantity):
        if not unit.is_unitless:
            a = a.in_unit(unit)
        return Quantity(
            _safe_call_xp(xp.ones_like, (a.mantissa,), dict(dtype=dtype, shape=shape)),
            unit=a.unit,
        )
    else:
        if not unit.is_unitless:
            return _safe_call_xp(xp.ones_like, (a,), dict(dtype=dtype, shape=shape)) * unit
        else:
            return _safe_call_xp(xp.ones_like, (a,), dict(dtype=dtype, shape=shape))


@set_module_as('saiunit.math')
def zeros_like(
    a: Union[Quantity, ArrayLike],
    dtype: Optional[DTypeLike] = None,
    shape: Shape | None = None,
    unit: Unit = UNITLESS
) -> Union[Quantity, Array]:
    """
    Return a quantity or array of zeros with the same shape and type as a given array.

    Parameters
    ----------
    a : Quantity or array_like
        The shape and data-type of ``a`` define these same attributes of the
        returned array.
    dtype : data-type, optional
        Overrides the data type of the result.
    shape : int or tuple of ints, optional
        Overrides the shape of the result.  If not given, ``a.shape`` is
        used.
    unit : Unit, optional
        Unit of the returned ``Quantity``.

    Returns
    -------
    out : Quantity or Array
        Array of zeros with the same shape and type as ``a``.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> import jax.numpy as jnp
        >>> u.math.zeros_like(jnp.array([1.0, 2.0, 3.0]))
        Array([0., 0., 0.], dtype=float32)
        >>> u.math.zeros_like(jnp.array([1.0, 2.0]) * u.meter)
        Quantity([0. 0.], "m")
    """
    if not isinstance(unit, Unit):
        raise TypeError(f'zeros_like requires "unit" to be a Unit instance, got {type(unit).__name__}: {unit!r}.')
    a = maybe_custom_array(a)
    xp = get_backend(a.mantissa if isinstance(a, Quantity) else a)
    if isinstance(a, Quantity):
        if not unit.is_unitless:
            a = a.in_unit(unit)
        return Quantity(
            _safe_call_xp(xp.zeros_like, (a.mantissa,), dict(dtype=dtype, shape=shape)),
            unit=a.unit,
        )
    else:
        if not unit.is_unitless:
            return _safe_call_xp(xp.zeros_like, (a,), dict(dtype=dtype, shape=shape)) * unit
        else:
            return _safe_call_xp(xp.zeros_like, (a,), dict(dtype=dtype, shape=shape))


@set_module_as('saiunit.math')
def asarray(
    a: Any,
    dtype: Optional[DTypeLike] = None,
    order: Optional[str] = None,
    unit: Optional[Unit] = None,
) -> Quantity | Array | None:
    """
    Convert the input to a quantity or array.

    If ``unit`` is provided, the input is checked for compatible units and
    converted accordingly.  If ``unit`` is not provided, the unit is inferred
    from the input data.

    The function ``array`` is an alias for ``asarray``.

    Parameters
    ----------
    a : Quantity, array_like, list[Quantity], or list[array_like]
        Input data, in any form that can be converted to an array.  When a
        list of ``Quantity`` objects is given, all elements must share the
        same dimension.
    dtype : data-type, optional
        By default, the data-type is inferred from the input data.
    order : {'C', 'F', 'A', 'K'}, optional
        Memory layout.  Defaults to ``'K'``.
    unit : Unit, optional
        Target unit of the returned ``Quantity``.  When given, all elements
        are converted to this unit.

    Returns
    -------
    out : Quantity or Array
        Array interpretation of ``a``.

    Raises
    ------
    UnitMismatchError
        If elements of ``a`` have incompatible units, or if ``unit`` is
        specified but does not match the dimension of ``a``.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> u.math.array([1, 2, 3])
        Array([1, 2, 3], dtype=int32)
        >>> u.math.array([1, 2, 3] * u.meter)
        Quantity([1 2 3], "m")
        >>> u.math.asarray([1 * u.meter, 2 * u.meter])
        Quantity([1 2], "m")
    """
    if a is None:
        return a
    if isinstance(a, dict):
        raise TypeError(
            f"asarray does not accept dict inputs (got {type(a).__name__}); "
            "pass an array, list, or Quantity."
        )

    # get leaves
    leaves, treedef = _tree.flatten(a, is_leaf=lambda x: isinstance(x, Quantity))
    leaves = unit_scale_align_to_first(*leaves)
    leaf_unit = leaves[0].unit

    # get unit
    if unit is not None and not leaf_unit.is_unitless:
        if not isinstance(unit, Unit):
            raise TypeError(f'asarray requires "unit" to be a Unit instance, got {type(unit).__name__}: {unit!r}.')
        leaves = [leaf.in_unit(unit) for leaf in leaves]
    else:
        unit = leaf_unit

    # reconstruct mantissa
    a = treedef.unflatten([leaf.mantissa for leaf in leaves])  # type: ignore[attr-defined]
    xp = _default_xp()
    # ``order`` is a numpy/jax-only kwarg; torch / dask / ndonnx ``asarray``
    # reject it. Only forward when explicitly provided.
    extra = {}
    if order is not None:
        extra["order"] = order
    a = xp.asarray(a, dtype=dtype, **extra)

    # returns
    if unit.is_unitless:
        return a
    return Quantity(a, unit=unit)


array = asarray


@set_module_as('saiunit.math')
def arange(
    start: Optional[Union[Quantity, ArrayLike]] = None,
    stop: Optional[Union[Quantity, ArrayLike]] = None,
    step: Optional[Union[Quantity, ArrayLike]] = None,
    dtype: Optional[DTypeLike] = None
) -> Union[Quantity, Array]:
    """
    Return evenly spaced values within a given interval.

    Values are generated within the half-open interval ``[start, stop)``
    (in other words, the interval including ``start`` but excluding
    ``stop``).  All of ``start``, ``stop``, and ``step`` must share the
    same unit when any of them is a ``Quantity``.

    Parameters
    ----------
    start : Quantity or array_like, optional
        Start of the interval (inclusive).  The default start value is 0.
    stop : Quantity or array_like
        End of the interval (exclusive).
    step : Quantity or array_like, optional
        Spacing between values.  The default step size is 1.
    dtype : data-type, optional
        The type of the output array.  If not given, the dtype is inferred
        from the other input arguments.

    Returns
    -------
    out : Quantity or Array
        Array of evenly spaced values.

    Raises
    ------
    UnitMismatchError
        If ``start``, ``stop``, and ``step`` do not share the same unit.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> u.math.arange(5)
        Array([0, 1, 2, 3, 4], dtype=int32)
        >>> u.math.arange(0 * u.meter, 3 * u.meter, 1 * u.meter)
        Quantity([0 1 2], "m")
    """
    # apply maybe_custom_array to inputs
    start = maybe_custom_array(start) if start is not None else start
    stop = maybe_custom_array(stop) if stop is not None else stop
    step = maybe_custom_array(step) if step is not None else step

    # checking the dimension of the data
    non_none_data = [d for d in (start, stop, step) if d is not None]
    if len(non_none_data) == 0:
        raise ValueError('arange requires at least one of start, stop, or step to be provided.')
    d1 = non_none_data[0]
    for d2 in non_none_data[1:]:
        fail_for_unit_mismatch(
            d1,
            d2,
            error_message="Start value {d1} and stop value {d2} have to have the same units.",
            d1=d1,
            d2=d2
        )

    # convert to array
    unit = get_unit(d1)
    start = start.in_unit(unit).mantissa if isinstance(start, Quantity) else start
    stop = stop.in_unit(unit).mantissa if isinstance(stop, Quantity) else stop
    step = step.in_unit(unit).mantissa if isinstance(step, Quantity) else step
    # Build positional args without leading/trailing ``None``s. torch / dask /
    # ndonnx reject a ``None`` ``step`` positional that numpy and jax silently
    # treat as the default 1, and ndonnx additionally requires a non-``None``
    # ``stop``. Normalize the single-arg form ``arange(stop)`` to ``(0, stop)``
    # the way numpy does, then drop ``step`` when not given.
    if stop is None:
        head: Tuple[Any, ...] = (0, start)
    else:
        head = (start, stop)
    pos = head if step is None else (*head, step)
    xp = _default_xp()
    kwargs = {} if dtype is None else {"dtype": _translate_dtype(dtype, xp)}
    with jax.ensure_compile_time_eval():
        r = xp.arange(*pos, **kwargs)
    return r if unit.is_unitless else Quantity(r, unit=unit)


@set_module_as('saiunit.math')
def linspace(
    start: Union[Quantity, ArrayLike],
    stop: Union[Quantity, ArrayLike],
    num: int = 50,
    endpoint: Optional[bool] = True,
    retstep: Optional[bool] = False,
    dtype: Optional[DTypeLike] = None
) -> Union[Quantity, Array]:
    """
    Return evenly spaced numbers over a specified interval.

    Returns ``num`` evenly spaced samples, calculated over the interval
    ``[start, stop]``.  The endpoint of the interval can optionally be
    excluded.

    Parameters
    ----------
    start : Quantity or array_like
        The starting value of the sequence.
    stop : Quantity or array_like
        The end value of the sequence.  Must have the same unit as
        ``start`` when either is a ``Quantity``.
    num : int, optional
        Number of samples to generate.  Default is 50.
    endpoint : bool, optional
        If ``True``, ``stop`` is the last sample.  Otherwise, it is not
        included.  Default is ``True``.
    retstep : bool, optional
        If ``True``, return ``(samples, step)``, where ``step`` is the
        spacing between samples.
    dtype : data-type, optional
        The type of the output array.

    Returns
    -------
    samples : Quantity or Array
        ``num`` equally spaced samples in the closed interval
        ``[start, stop]`` or the half-open interval ``[start, stop)``.

    Raises
    ------
    UnitMismatchError
        If ``start`` and ``stop`` do not share the same unit.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> u.math.linspace(0, 10, 5)
        Array([ 0. ,  2.5,  5. ,  7.5, 10. ], dtype=float32)
        >>> u.math.linspace(0 * u.meter, 10 * u.meter, 5)
        Quantity([ 0.   2.5  5.   7.5 10. ], "m")
    """
    start = maybe_custom_array(start)
    stop = maybe_custom_array(stop)
    fail_for_unit_mismatch(
        start,
        stop,
        error_message="Start value {start} and stop value {stop} have to have the same units.",
        start=start,
        stop=stop,
    )
    unit = get_unit(start)
    start = start.in_unit(unit).mantissa if isinstance(start, Quantity) else start
    stop = stop.in_unit(unit).mantissa if isinstance(stop, Quantity) else stop
    with jax.ensure_compile_time_eval():
        xp = _default_xp()
        result = _safe_call_xp(
            xp.linspace, (start, stop),
            dict(num=num, endpoint=endpoint, retstep=retstep, dtype=dtype),
        )
    return result if unit.is_unitless else Quantity(result, unit=unit)


@set_module_as('saiunit.math')
def logspace(
    start: ArrayLike,
    stop: ArrayLike,
    num: Optional[int] = 50,
    endpoint: Optional[bool] = True,
    base: Optional[float] = 10.0,
    dtype: Optional[DTypeLike] = None
):
    """
    Return numbers spaced evenly on a log scale.

    In linear space, the sequence starts at ``base ** start`` and ends with
    ``base ** stop`` in ``num`` steps. Because ``base ** x`` is dimensionless,
    ``start`` and ``stop`` must be dimensionless and the result is a plain
    array (never a :class:`Quantity`).

    Parameters
    ----------
    start : array_like
        ``base ** start`` is the starting value of the sequence. Must be
        dimensionless.
    stop : array_like
        ``base ** stop`` is the final value of the sequence (unless
        ``endpoint`` is ``False``). Must be dimensionless.
    num : int, optional
        Number of samples to generate.  Default is 50.
    endpoint : bool, optional
        If ``True``, ``stop`` is the last sample.  Otherwise, it is not
        included.  Default is ``True``.
    base : float, optional
        The base of the log space.  Default is 10.0.
    dtype : data-type, optional
        The type of the output array.

    Returns
    -------
    samples : Array
        ``num`` samples, equally spaced on a log scale.

    Raises
    ------
    UnitMismatchError
        If ``start`` or ``stop`` carries a non-trivial unit.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> u.math.logspace(0, 2, 4)
        Array([  1.       ,   4.6415887,  21.544348 , 100.       ], dtype=float32)
    """
    start = maybe_custom_array(start)
    stop = maybe_custom_array(stop)
    for argname, value in (("start", start), ("stop", stop)):
        u = get_unit(value)
        if not u.is_unitless:
            raise UnitMismatchError(
                f"logspace requires dimensionless `{argname}`, got unit {u!r}. "
                f"`base ** x` is intrinsically dimensionless; pass a plain "
                f"scalar/array instead.",
                u,
            )
    start = start.mantissa if isinstance(start, Quantity) else start
    stop = stop.mantissa if isinstance(stop, Quantity) else stop
    with jax.ensure_compile_time_eval():
        xp = _default_xp()
        return _safe_call_xp(
            xp.logspace, (start, stop),
            dict(num=num, endpoint=endpoint, base=base, dtype=dtype),
        )


@set_module_as('saiunit.math')
def fill_diagonal(
    a: Union[Quantity, ArrayLike],
    val: Union[Quantity, ArrayLike],
    wrap: Optional[bool] = False,
    inplace: Optional[bool] = False
) -> Union[Quantity, Array]:
    """
    Fill the main diagonal of the given array of any dimensionality.

    For an array ``a`` with ``a.ndim >= 2``, the diagonal is the list of
    locations with indices ``a[i, i, ..., i]`` all identical.

    Parameters
    ----------
    a : Quantity or array_like
        Array in which to fill the diagonal.
    val : Quantity or array_like
        Value to be written on the diagonal.  Its unit must be compatible
        with that of ``a``.
    wrap : bool, optional
        If ``True``, the diagonal is "wrapped" after ``a.shape[1]`` and
        continues in the first column (for tall matrices).  Default is
        ``False``.
    inplace : bool, optional
        If ``True``, the diagonal is filled in-place.  Default is ``False``.

    Returns
    -------
    out : Quantity or Array
        The input array with the diagonal filled.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> import jax.numpy as jnp
        >>> u.math.fill_diagonal(jnp.zeros((3, 3)), 5.0)
        Array([[5., 0., 0.],
               [0., 5., 0.],
               [0., 0., 5.]], dtype=float32)
    """
    a = maybe_custom_array(a)
    val = maybe_custom_array(val)
    xp = get_backend(a.mantissa if isinstance(a, Quantity) else a)
    if isinstance(val, Quantity):
        if isinstance(a, Quantity):
            val = val.in_unit(a.unit)
            return Quantity(
                _safe_call_xp(xp.fill_diagonal, (a.mantissa, val.mantissa, wrap), dict(inplace=inplace)),
                unit=a.unit,
            )
        else:
            return Quantity(
                _safe_call_xp(xp.fill_diagonal, (a, val.mantissa, wrap), dict(inplace=inplace)),
                unit=val.unit,
            )
    else:
        if isinstance(a, Quantity):
            return Quantity(
                _safe_call_xp(xp.fill_diagonal, (a.mantissa, val, wrap), dict(inplace=inplace)),
                unit=a.unit,
            )
        else:
            return _safe_call_xp(xp.fill_diagonal, (a, val, wrap), dict(inplace=inplace))


@set_module_as('saiunit.math')
def meshgrid(
    *xi: Union[Quantity, ArrayLike],
    copy: Optional[bool] = True,
    sparse: Optional[bool] = False,
    indexing: Optional[str] = 'xy'
) -> List[Union[Quantity, Array]]:
    """
    Return coordinate matrices from coordinate vectors.

    Make N-D coordinate arrays for vectorized evaluations of N-D
    scalar/vector fields over N-D grids, given one-dimensional coordinate
    arrays ``x1, x2, ..., xn``.

    Parameters
    ----------
    xi : Quantity or array_like
        1-D arrays representing the coordinates of a grid.
    copy : bool, optional
        Must be ``True`` (the default).  JAX does not support
        ``copy=False``.
    sparse : bool, optional
        If ``True``, return a sparse grid instead of a dense grid.
    indexing : {'xy', 'ij'}, optional
        Cartesian (``'xy'``, default) or matrix (``'ij'``) indexing of
        output.

    Returns
    -------
    X1, X2, ..., XN : list of Quantity or Array
        Coordinate matrices.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> import jax.numpy as jnp
        >>> x, y = u.math.meshgrid(jnp.array([1, 2]), jnp.array([3, 4]))
        >>> x
        Array([[1, 2],
               [1, 2]], dtype=int32)
    """

    # Apply maybe_custom_array to inputs before processing
    xi = tuple(maybe_custom_array(x) for x in xi)
    args = [asarray(x) for x in xi]
    if not copy:
        raise ValueError("jax.numpy.meshgrid only supports copy=True")
    if indexing not in ["xy", "ij"]:
        raise ValueError(f"Valid values for indexing are 'xy' and 'ij', got {indexing}")
    if any(a.ndim != 1 for a in args):
        raise ValueError("Arguments to jax.numpy.meshgrid must be 1D, got shapes "
                         f"{[a.shape for a in args]}")
    if indexing == "xy" and len(args) >= 2:
        args[0], args[1] = args[1], args[0]
    shape = [1 if sparse else a.shape[0] for a in args]
    f_shape = lambda i, a: [*shape[:i], a.shape[0], *shape[i + 1:]] if sparse else shape

    def _broadcast_in_dim(x, target_shape, i):
        xp = get_backend(x)
        reshape_shape = [1] * len(args)
        reshape_shape[i] = x.shape[0]
        return xp.broadcast_to(xp.reshape(x, reshape_shape), target_shape)

    # use ``_tree.map`` to be Quantity-aware (Quantity is a registered pytree
    # when JAX is installed; the fallback ``_tree`` only descends standard
    # containers so plain arrays are passed straight through).
    output = [
        _tree.map(lambda x: _broadcast_in_dim(x, f_shape(i, x), i), a)
        for i, a, in enumerate(args)
    ]
    if indexing == "xy" and len(args) >= 2:
        output[0], output[1] = output[1], output[0]
    return output


@set_module_as('saiunit.math')
def vander(
    x: Union[Quantity, ArrayLike],
    N: Optional[bool] = None,
    increasing: Optional[bool] = False,
    unit: Unit = UNITLESS
) -> Union[Quantity, Array]:
    """
    Generate a Vandermonde matrix.

    The columns of the output matrix are powers of the input vector.

    Parameters
    ----------
    x : Quantity or array_like
        1-D input array.  Must be dimensionless if a ``Quantity``.
    N : int, optional
        Number of columns in the output.  If ``N`` is not specified, a
        square array is returned (``N = len(x)``).
    increasing : bool, optional
        Order of the powers of the columns.  If ``True``, the powers
        increase from left to right; if ``False`` (the default), they are
        reversed.
    unit : Unit, optional
        Unit of the returned ``Quantity``.

    Returns
    -------
    out : Quantity or Array
        Vandermonde matrix.

    Raises
    ------
    TypeError
        If ``x`` carries a non-trivial unit.

    Examples
    --------
    .. code-block:: python

        >>> import saiunit as u
        >>> import jax.numpy as jnp
        >>> u.math.vander(jnp.array([1, 2, 3]), 3)
        Array([[1, 1, 1],
               [4, 2, 1],
               [9, 3, 1]], dtype=int32)
    """
    x = maybe_custom_array(x)
    if isinstance(x, Quantity):
        if not x.is_unitless:
            raise TypeError(
                f'vander requires "x" to be dimensionless, '
                f'but got x with unit={x.unit}. '
                f'Pass "unit_to_scale" or strip the unit before calling vander.'
            )
        x = x.mantissa
    r = get_backend(x).vander(x, N=N, increasing=increasing)
    if not isinstance(unit, Unit):
        raise TypeError(f'vander requires "unit" to be a Unit instance, got {type(unit).__name__}: {unit!r}.')
    if not unit.is_unitless:
        return Quantity(r, unit=unit)
    else:
        return r


# indexing funcs
# --------------

[docs] def tril_indices(n, k=0, m=None): xp = _default_xp() name = getattr(xp, '__name__', '') # torch's binding spells the signature ``tril_indices(row, col, offset)``; # array-API / numpy / jax / dask spell it ``(n, k=0, m=None)``. if 'torch' in name: return xp.tril_indices(n, n if m is None else m, offset=k) # ndonnx has no ``tril_indices``; compute the static index pair on numpy # and wrap with ``xp.asarray`` so the caller receives backend-native arrays. if 'ndonnx' in name: rows, cols = np.tril_indices(n, k=k, m=m) return (xp.asarray(rows), xp.asarray(cols)) return _safe_call_xp(xp.tril_indices, (n,), {'k': k, 'm': m})
@set_module_as('saiunit.math') def tril_indices_from( arr: Union[Quantity, ArrayLike], k: Optional[int] = 0 ) -> Tuple[Array, Array]: """ Return the indices for the lower-triangle of an ``(n, m)`` array. Parameters ---------- arr : Quantity or array_like The array for which the returned indices will be valid. k : int, optional Diagonal offset. ``k = 0`` is the main diagonal, ``k < 0`` is below, and ``k > 0`` is above. Returns ------- out : tuple of Array Row and column indices for the lower triangle. Examples -------- .. code-block:: python >>> import saiunit as u >>> import jax.numpy as jnp >>> row, col = u.math.tril_indices_from(jnp.ones((3, 3))) >>> row Array([0, 1, 1, 2, 2, 2], dtype=int32) """ arr = maybe_custom_array(arr) inner = arr.mantissa if isinstance(arr, Quantity) else arr return get_backend(inner).tril_indices_from(inner, k=k)
[docs] def triu_indices(n, k=0, m=None): xp = _default_xp() name = getattr(xp, '__name__', '') if 'torch' in name: return xp.triu_indices(n, n if m is None else m, offset=k) if 'ndonnx' in name: rows, cols = np.triu_indices(n, k=k, m=m) return (xp.asarray(rows), xp.asarray(cols)) return _safe_call_xp(xp.triu_indices, (n,), {'k': k, 'm': m})
@set_module_as('saiunit.math') def triu_indices_from( arr: Union[Quantity, ArrayLike], k: Optional[int] = 0 ) -> Tuple[Array, Array]: """ Return the indices for the upper-triangle of an ``(n, m)`` array. Parameters ---------- arr : Quantity or array_like The array for which the returned indices will be valid. k : int, optional Diagonal offset. ``k = 0`` is the main diagonal, ``k < 0`` is below, and ``k > 0`` is above. Returns ------- out : tuple of Array Row and column indices for the upper triangle. Examples -------- .. code-block:: python >>> import saiunit as u >>> import jax.numpy as jnp >>> row, col = u.math.triu_indices_from(jnp.ones((3, 3))) >>> row Array([0, 0, 0, 1, 1, 2], dtype=int32) """ arr = maybe_custom_array(arr) inner = arr.mantissa if isinstance(arr, Quantity) else arr return get_backend(inner).triu_indices_from(inner, k=k) # --- others --- @set_module_as('saiunit.math') def from_numpy( x: np.ndarray, unit: Unit = UNITLESS ) -> Array | Quantity: """ Convert a NumPy array to a JAX array, optionally attaching a unit. Parameters ---------- x : numpy.ndarray The NumPy array to convert. unit : Unit, optional Unit of the returned ``Quantity``. When ``UNITLESS`` (the default) a plain JAX array is returned. Returns ------- out : Quantity or Array JAX array (or ``Quantity``) created from ``x``. Examples -------- .. code-block:: python >>> import saiunit as u >>> import numpy as np >>> u.math.from_numpy(np.array([1.0, 2.0]), unit=u.meter) Quantity([1. 2.], "m") """ x = maybe_custom_array(x) if not isinstance(unit, Unit): raise TypeError(f'from_numpy requires "unit" to be a Unit instance, got {type(unit).__name__}: {unit!r}.') xp = _default_xp() if not unit.is_unitless: return xp.asarray(x) * unit return xp.asarray(x) @set_module_as('saiunit.math') def as_numpy(x): """ Convert a JAX array (or ``Quantity``) to a NumPy array. Parameters ---------- x : Quantity or array_like The input to convert. If ``x`` is a ``Quantity``, the underlying mantissa (in current unit scale) is returned as a NumPy array. Returns ------- out : numpy.ndarray NumPy array representation of ``x``. Examples -------- .. code-block:: python >>> import saiunit as u >>> u.math.as_numpy(u.math.ones((3,))) array([1., 1., 1.], dtype=float32) """ x = maybe_custom_array(x) return np.array(x) @set_module_as('saiunit.math') def tree_zeros_like(tree): """ Create a tree with the same structure as the input, but with zeros in each leaf. Parameters ---------- tree : pytree A JAX-compatible pytree (nested dicts, lists, tuples, etc.) whose leaves are arrays or ``Quantity`` objects. Returns ------- out : pytree A tree with the same structure, where every leaf is replaced by a zero-filled array (or ``Quantity``) of the same shape, dtype, and unit. Examples -------- .. code-block:: python >>> import saiunit as u >>> import jax.numpy as jnp >>> tree = {'a': jnp.array([1.0, 2.0]), 'b': jnp.array([3.0])} >>> u.math.tree_zeros_like(tree) {'a': Array([0., 0.], dtype=float32), 'b': Array([0.], dtype=float32)} """ tree = maybe_custom_array_tree(tree) return _tree.map(zeros_like, tree) @set_module_as('saiunit.math') def tree_ones_like(tree): """ Create a tree with the same structure as the input, but with ones in each leaf. Parameters ---------- tree : pytree A JAX-compatible pytree (nested dicts, lists, tuples, etc.) whose leaves are arrays or ``Quantity`` objects. Returns ------- out : pytree A tree with the same structure, where every leaf is replaced by a ones-filled array (or ``Quantity``) of the same shape, dtype, and unit. Examples -------- .. code-block:: python >>> import saiunit as u >>> import jax.numpy as jnp >>> tree = {'a': jnp.array([1.0, 2.0]), 'b': jnp.array([3.0])} >>> u.math.tree_ones_like(tree) {'a': Array([1., 1.], dtype=float32), 'b': Array([1.], dtype=float32)} """ tree = maybe_custom_array_tree(tree) return _tree.map(ones_like, tree)