Quantity#

class saiunit.Quantity(mantissa, unit=Unit('1'), dtype=None)#

A numerical value paired with a physical unit.

Quantity is the central data structure in saiunit. It stores a mantissa (the raw numerical data, typically a JAX array) together with a Unit that describes the physical dimensions and scale. Arithmetic on Quantity objects automatically tracks and checks units, raising UnitMismatchError when incompatible quantities are combined.

Quantity is registered as a JAX pytree, so it works transparently with jax.jit, jax.grad, jax.vmap, and other JAX transformations.

Parameters:
  • mantissa (Union[Any, saiunit.Unit]) – The numerical value(s). If a Unit is passed, the mantissa is set to 1.0 and that unit is adopted. If a Quantity is passed, its mantissa and unit are used (converted to unit when given).

  • unit (Union[saiunit.Unit, str, None]) – The physical unit. Defaults to UNITLESS.

  • dtype (Union[str, type[Any], dtype, SupportsDType, None]) – If provided, the mantissa is cast to this dtype on construction.

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> # Scalar with unit
>>> q = u.Quantity(3.0, unit=u.mV)
>>> q
Quantity(3., "mV")
>>> # Array with unit via multiplication shorthand
>>> arr = jnp.array([1.0, 2.0, 3.0]) * u.mV
>>> arr.shape
(3,)
>>> # From a Unit object directly
>>> u.Quantity(u.metre)
Quantity(1., "m")

See also

Unit

Represents a physical unit (dimension + scale).

astype(dtype)[source]#

Return a copy of this quantity with the mantissa cast to dtype.

Parameters:

dtype (Union[str, type[Any], dtype, SupportsDType]) – Target data type (e.g. jnp.float64).

Returns:

A new quantity with the converted dtype.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.astype(jnp.float64).dtype
float64
property at#

Helper property for index update functionality.

The at property provides a functionally pure equivalent of in-place array modifications.

In particular:

Alternate syntax

Equivalent In-place expression

x = x.at[idx].set(y)

x[idx] = y

x = x.at[idx].add(y)

x[idx] += y

x = x.at[idx].multiply(y)

x[idx] *= y

x = x.at[idx].divide(y)

x[idx] /= y

x = x.at[idx].power(y)

x[idx] **= y

x = x.at[idx].min(y)

x[idx] = minimum(x[idx], y)

x = x.at[idx].max(y)

x[idx] = maximum(x[idx], y)

x = x.at[idx].apply(ufunc)

ufunc.at(x, idx)

x = x.at[idx].get()

x = x[idx]

None of the x.at expressions modify the original x; instead they return a modified copy of x. However, inside a jit() compiled function, expressions like x = x.at[idx].set(y) are guaranteed to be applied in-place.

Quantity.at is multi-backend: it works for numpy, jax, cupy, torch, and dask mantissas. The ndonnx backend cannot represent functional in-place updates in its symbolic graph and raises saiunit.BackendError — call .to_numpy() first. On dask the update is expressed via da.where so the task graph stays lazy; only slice / scalar-int / 1-D integer / boolean-mask indices are supported on dask (multi-dim fancy-integer indexing raises NotImplementedError — use .to_numpy() for that case).

Repeated-index semantics differ across backends. When multiple indices refer to the same location, JAX applies all updates (NumPy in-place x[idx] += y would apply only the last). The summary:

Backend

add

multiply / divide / min / max / apply

jax

accumulates

accumulates

numpy

accumulates (np.add.at)

accumulates (np.<op>.at)

cupy

accumulates

accumulates

torch

accumulates (index_put_(accumulate))

last-write-wins

dask

last-write-wins via mask

last-write-wins via mask

ndonnx

raises BackendError

raises BackendError

By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound semantics can be specified via the mode parameter (see below). On non-JAX backends mode and fill_value are emulated for scalar-int and 1-D integer-array indices; for slice / boolean / ellipsis / tuple indices the kwarg is silently ignored because out-of-bounds cannot occur for same-shape sources. indices_are_sorted and unique_indices are hints only — silently ignored outside JAX.

Parameters:
  • mode (str) –

    Specify out-of-bound indexing mode. Options are:

    • "promise_in_bounds": (default) The user promises that indices are in bounds. No additional checking will be performed. In practice, this means that out-of-bounds indices in get() will be clipped, and out-of-bounds indices in set(), add(), etc. will be dropped.

    • "clip": clamp out of bounds indices into valid range.

    • "drop": ignore out-of-bound indices.

    • "fill": alias for "drop". For get(), the optional fill_value argument specifies the value that will be returned.

  • indices_are_sorted (bool) – If True, the implementation will assume that the indices passed to at[] are sorted in ascending order, which can lead to more efficient execution on some backends.

  • unique_indices (bool) – If True, the implementation will assume that the indices passed to at[] are unique, which can result in more efficient execution on some backends.

  • fill_value (Any) – Only applies to the get() method: the fill value to return for out-of-bounds slices when mode is 'fill'. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans.

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> x = jnp.arange(5.0) * u.mV
>>> x.at[2].add(10 * u.mV)
Quantity([ 0.  1. 12.  3.  4.], "mV")
>>> x.at[2].get()
Quantity(2., "mV")
property backend: str#

one of 'numpy', 'jax', 'cupy', 'torch', 'dask', 'ndonnx'.

Type:

The backend of the underlying mantissa

clip(min=None, max=None)[source]#

Clip (limit) the values in the array to [min, max].

At least one of min or max must be given. Both must be compatible with the unit of self.

Parameters:
  • min (Quantity | Array | ndarray | number | bool | None) – Minimum value.

  • max (Quantity | Array | ndarray | number | bool | None) – Maximum value.

Returns:

The clipped quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.clip(min=u.Quantity(1.5, unit=u.mV), max=u.Quantity(2.5, unit=u.mV))
Quantity([1.5 2.  2.5], "mV")
clone()[source]#

Return a copy of this quantity (PyTorch-style alias for copy()).

Returns:

An independent copy.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> q = u.Quantity(3.0, unit=u.mV)
>>> q.clone()
Quantity(3., "mV")
conj()[source]#

Return the complex conjugate, element-wise, preserving units.

Returns:

The conjugated quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> q = u.Quantity(1.0 + 2.0j, unit=u.mV)
>>> q.conj()
Quantity((1-2j), "mV")
conjugate()[source]#

Return the complex conjugate, element-wise.

Alias for conj().

Returns:

The conjugated quantity.

Return type:

Quantity

copy()[source]#

Return a deep copy of this quantity.

Returns:

An independent copy with the same mantissa and unit.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> q = u.Quantity(3.0, unit=u.mV)
>>> q2 = q.copy()
>>> q2
Quantity(3., "mV")
cross(b, axisa=-1, axisb=-1, axisc=-1, axis=None)[source]#

Cross product of two arrays.

The resulting unit is self.unit * b.unit.

Parameters:
  • b (Quantity) – Second operand.

  • axisa (int) – Axis of self that defines the vector(s) (default -1).

  • axisb (int) – Axis of b that defines the vector(s) (default -1).

  • axisc (int) – Axis of the result containing the cross product (default -1).

  • axis (int | None) – Overrides axisa, axisb, and axisc simultaneously.

Returns:

The cross product.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = u.Quantity(jnp.array([1.0, 0.0, 0.0]), unit=u.mV)
>>> b = u.Quantity(jnp.array([0.0, 1.0, 0.0]), unit=u.second)
>>> a.cross(b)
Quantity([0. 0. 1.], "mV * s")
cumprod(*args, **kwds)[source]#

Return the cumulative product of elements along a given axis.

Because each position in the result corresponds to a different number of multiplied elements, the unit exponent varies across the output. This is only representable when the quantity is dimensionless.

Returns:

The cumulative product.

Return type:

Quantity

Raises:

TypeError – If the quantity is not dimensionless.

diagonal(offset=0, axis1=0, axis2=1)[source]#

Return specified diagonals, preserving units.

Parameters:
  • offset (int) – Offset from the main diagonal (default 0).

  • axis1 (int) – First axis (default 0).

  • axis2 (int) – Second axis (default 1).

Returns:

The diagonal elements.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.diagonal()
Quantity([1. 4.], "mV")
property dim: saiunit.Dimension#

The physical dimension of this quantity (e.g. length, mass, time).

The dimension is independent of scale (metres vs kilometres both have the length dimension).

Returns:

The physical dimension object.

Return type:

Dimension

Examples

>>> import saiunit as u
>>> q = u.Quantity(5.0, unit=u.metre)
>>> q.dim
m

See also

unit

The full unit (dimension + scale).

dot(b)[source]#

Dot product of two arrays.

The resulting unit is self.unit * b.unit.

Parameters:

b (Quantity or array_like) – Second operand.

Returns:

The dot product.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> b = u.Quantity(jnp.array([1.0, 1.0, 1.0]), unit=u.mV)
>>> a.dot(b)
Quantity(6., "mV^2")
property dtype#

The data type of the mantissa.

Returns:

The JAX/NumPy dtype of the underlying array.

Return type:

dtype

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.dtype
float32
expand_as(array)[source]#

Expand an array to a shape of another array.

Parameters:

array (Quantity | Array | ndarray | number | bool)

Returns:

expanded – A readonly view on the original array with the given shape of array. It is typically not contiguous. Furthermore, more than one element of a expanded array may refer to a single memory location.

Return type:

Quantity

expand_dims(axis)[source]#

Insert new axes at the given positions.

Parameters:

axis (int | Sequence[int]) – Position(s) where the new axis (axes) are placed.

Returns:

The expanded quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.expand_dims(0).shape
(1, 2)
factorless()[source]#

Return an equivalent quantity whose unit has factor == 1.0.

If the unit already has no extra factor the original object is returned unchanged.

Returns:

A quantity with the factor folded into the mantissa.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> q = u.Quantity(3.0, unit=u.mV)
>>> q.factorless()
Quantity(3., "mV")
fill(value)[source]#

Fill the array with a scalar mantissa.

Return type:

Quantity

property flat#

1-D iterator over the mantissa elements, unit preserved.

flatten()[source]#

Return a 1-D copy of this quantity.

Returns:

Flattened quantity with the same unit.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.flatten()
Quantity([1. 2. 3. 4.], "mV")
has_same_unit(other)[source]#

Check whether this quantity shares the same physical dimension as other.

Two quantities that differ only in scale (e.g. mV vs V) are considered to have the same unit dimension.

Parameters:

other (Quantity or Unit) – The object to compare with.

Returns:

True if both have identical physical dimensions.

Return type:

bool

Examples

>>> import saiunit as u
>>> a = u.Quantity(1.0, unit=u.mV)
>>> b = u.Quantity(2.0, unit=u.volt)
>>> a.has_same_unit(b)
True
>>> c = u.Quantity(1.0, unit=u.second)
>>> a.has_same_unit(c)
False
in_unit(unit, err_msg=None)[source]#

Convert this quantity to a compatible unit.

Behaves identically to to(); kept for API compatibility.

Parameters:
  • unit (saiunit.Unit) – Target unit. Must share the same dimension as self.unit.

  • err_msg (str | None) – Custom error message used when the dimensions do not match.

Returns:

A new Quantity expressed in unit.

Return type:

Quantity

Raises:

UnitMismatchError – If unit has a different dimension.

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.in_unit(u.volt)
Quantity([0.001 0.002 0.003], "V")
property is_unitless: bool#

True if this quantity is dimensionless (has no physical unit).

Returns:

Whether the quantity is unitless.

Return type:

bool

Examples

>>> import saiunit as u
>>> u.Quantity(5.0).is_unitless
True
>>> u.Quantity(5.0, unit=u.mV).is_unitless
False
item(*args)[source]#

Extract a single element as a scalar Quantity.

Parameters:

*args (int) – Index into the flat array.

Returns:

A 0-D Quantity containing the selected element.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([10.0, 20.0]), unit=u.mV)
>>> q.item(0)
Quantity(10., "mV")
property itemsize: int#

Length (in bytes) of one array element.

property mT: saiunit.Quantity#

Matrix transpose of the last two dimensions, preserving units.

The array must be at least 2-D.

Returns:

The matrix-transposed quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.mT.shape
(2, 2)
property magnitude: Array | ndarray | number | bool#

Alias for mantissa.

Returns:

The raw numerical data of this quantity.

Return type:

array_like

Examples

>>> import saiunit as u
>>> q = u.Quantity(5.0, unit=u.metre)
>>> q.magnitude
5.0

See also

mantissa

Primary accessor for the numerical data.

property mantissa: Array | ndarray | number | bool#

The raw numerical data of this quantity (without the unit).

In scientific notation \(x = a \times 10^{b}\), the mantissa is the coefficient \(a\). For a Quantity, it is the underlying JAX/NumPy array (or Python scalar) that stores the numeric value.

Returns:

The mantissa array or scalar.

Return type:

array_like

Examples

>>> import saiunit as u
>>> q = u.Quantity(3.0, unit=u.mV)
>>> q.mantissa
3.0

See also

magnitude

Alias for mantissa.

unit

The physical unit attached to this quantity.

nancumprod(*args, **kwds)[source]#

Return the cumulative product of elements along a given axis, treating NaNs as ones.

Because each position in the result corresponds to a different number of multiplied elements, the unit exponent varies across the output. This is only representable when the quantity is dimensionless.

Returns:

The cumulative product (NaNs treated as ones).

Return type:

Quantity

Raises:

TypeError – If the quantity is not dimensionless.

nanprod(*args, **kwds)[source]#

Return the product of array elements over a given axis treating Not a Numbers (NaNs) as ones.

When reducing along a specific axis, the number of non-NaN elements must be the same for every position in the result so that a single unit exponent can be assigned. If the non-NaN counts differ and the quantity is not dimensionless, a ValueError is raised.

Returns:

The product (NaNs treated as ones).

Return type:

Quantity

Raises:

ValueError – If the non-NaN counts are not uniform along the reduction axis for a non-dimensionless quantity.

property nbytes: int#

Total bytes consumed by the mantissa array.

outer(b)[source]#

Outer product of two 1-D arrays.

The resulting unit is self.unit * b.unit.

Parameters:

b (Quantity) – Second operand.

Returns:

The outer product matrix.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> b = u.Quantity(jnp.array([3.0, 4.0]), unit=u.second)
>>> a.outer(b).shape
(2, 2)
pow(oc)[source]#

Raise this quantity to the power oc.

The exponent must be dimensionless. The resulting unit is self.unit ** oc.

Parameters:

oc (int, float, or dimensionless Quantity) – The exponent.

Returns:

self ** oc.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> q = u.Quantity(2.0, unit=u.mV)
>>> q.pow(2)
Quantity(4., "mV^2")
prod(*args, **kwds)[source]#

Return the product of array elements over the given axis.

The unit of the result is self.unit ** n where n is the number of elements multiplied together.

Returns:

The product.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([2.0, 3.0]), unit=u.mV)
>>> q.prod()
Quantity(6., "mV^2")
put(indices, values)[source]#

Replaces specified elements of an array with given values.

Parameters:
  • indices (array_like) – Target indices, interpreted as integers.

  • values (array_like) – Values to place in the array at target indices.

Return type:

Quantity

repeat(repeats, axis=None)[source]#

Repeat elements of the array.

Parameters:
  • repeats (int or array of ints) – Number of repetitions for each element.

  • axis (int, optional) – Axis along which to repeat.

Returns:

The repeated quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.repeat(2)
Quantity([1. 1. 2. 2.], "mV")
repr_in_unit(precision=None)[source]#

Return a human-readable string of this quantity in its current unit.

The format is "<value> <unit>", e.g. "3. mV" or "[1. 2. 3.] mV".

Parameters:

precision (int | None) – Number of significant digits. When None the value from numpy.get_printoptions is used.

Returns:

The formatted string.

Return type:

str

Examples

>>> import saiunit as u
>>> x = u.Quantity(25.0, unit=u.mV)
>>> x.repr_in_unit()
'25. mV'
>>> x.to(u.volt).repr_in_unit(3)
'0.025 V'
reshape(shape, order='C')[source]#

Return a quantity with the same data but a new shape.

Parameters:
  • shape (int or tuple of ints) – New shape.

  • order ({'C', 'F'}, optional) – Memory layout order (default 'C').

Returns:

Reshaped quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.reshape((3, 1)).shape
(3, 1)
resize(new_shape)[source]#

Change shape and size of array in-place.

Return type:

Quantity

round(decimals=0)[source]#

Evenly round the mantissa to the given number of decimals.

Parameters:

decimals (int) – Number of decimal places (default 0). Negative values round to positions left of the decimal point.

Returns:

A new quantity with the rounded mantissa.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> q = u.Quantity(1.567, unit=u.mV)
>>> q.round(1)
Quantity(1.6, "mV")
scatter_add(index, value)[source]#

Return a copy with value added at index.

Parameters:
  • index (Array | ndarray | number | bool) – Target index (indices).

  • value (Quantity | Array | ndarray | number | bool) – The value to add. Must have the same unit dimension.

Returns:

A new quantity with the update applied.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_add(0, u.Quantity(10.0, unit=u.mV))
Quantity([11.  2.  3.], "mV")
scatter_div(index, value)[source]#

Return a copy with the element at index divided by value.

value must be dimensionless (a pure scale factor).

Parameters:
  • index (Array | ndarray | number | bool) – Target index (indices).

  • value (Quantity | Array | ndarray | number | bool) – Dimensionless scale factor.

Returns:

A new quantity with the update applied.

Return type:

Quantity

Raises:

TypeError – If value is not dimensionless.

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_div(0, u.Quantity(2.0))
Quantity([0.5 2.  3. ], "mV")
scatter_max(index, value)[source]#

Return a copy where the element at index is the maximum of the current value and value.

Parameters:
  • index (Array | ndarray | number | bool) – Target index (indices).

  • value (Quantity | Array | ndarray | number | bool) – The comparison value. Must have the same unit dimension.

Returns:

A new quantity with the update applied.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_max(0, u.Quantity(10.0, unit=u.mV))
Quantity([10.  2.  3.], "mV")
scatter_min(index, value)[source]#

Return a copy where the element at index is the minimum of the current value and value.

Parameters:
  • index (Array | ndarray | number | bool) – Target index (indices).

  • value (Quantity | Array | ndarray | number | bool) – The comparison value. Must have the same unit dimension.

Returns:

A new quantity with the update applied.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_min(0, u.Quantity(0.5, unit=u.mV))
Quantity([0.5 2.  3. ], "mV")
scatter_mul(index, value)[source]#

Return a copy with the element at index multiplied by value.

value must be dimensionless (a pure scale factor).

Parameters:
  • index (Array | ndarray | number | bool) – Target index (indices).

  • value (Quantity | Array | ndarray | number | bool) – Dimensionless scale factor.

Returns:

A new quantity with the update applied.

Return type:

Quantity

Raises:

TypeError – If value is not dimensionless.

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_mul(0, u.Quantity(10.0))
Quantity([10.  2.  3.], "mV")
scatter_sub(index, value)[source]#

Return a copy with value subtracted at index.

Parameters:
  • index (Array | ndarray | number | bool) – Target index (indices).

  • value (Quantity | Array | ndarray | number | bool) – The value to subtract. Must have the same unit dimension.

Returns:

A new quantity with the update applied.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_sub(0, u.Quantity(1.0, unit=u.mV))
Quantity([0. 2. 3.], "mV")
searchsorted(v, side='left', sorter=None)[source]#

Find indices where elements should be inserted to maintain order.

Return type:

Array

property shape: tuple[int, ...]#

The shape of the mantissa array.

Returns:

Shape tuple, identical to jnp.shape(self.mantissa).

Return type:

tuple of int

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.shape
(2, 2)
sort(axis=-1, stable=True, order=None)[source]#

Sort the array in-place along the given axis.

Parameters:
  • axis (int, optional) – Axis along which to sort (default -1).

  • stable (bool, optional) – Whether to use a stable sort (default True).

  • order (str or list of str, optional) – Field ordering for structured arrays.

Returns:

self, with the mantissa sorted in-place.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([3.0, 1.0, 2.0]), unit=u.mV)
>>> q.sort()
Quantity([1. 2. 3.], "mV")
split(indices_or_sections, axis=0)[source]#

Split the array into multiple sub-arrays.

Parameters:
  • indices_or_sections (int or 1-D array) – If an integer N, the array is divided into N equal parts. If a sorted 1-D array of indices, the entries indicate split points along axis.

  • axis (int, optional) – Axis along which to split (default 0).

Returns:

Sub-arrays, each carrying the same unit.

Return type:

list[Quantity]

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> parts = q.split(3)
>>> len(parts)
3
squeeze(axis=None)[source]#

Remove length-one axes from the array.

Parameters:

axis (int or tuple of ints, optional) – Axes to remove. If None, all length-one axes are removed.

Returns:

The squeezed quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[[1.0]]]), unit=u.mV)
>>> q.squeeze().shape
()
property strides#

Tuple of byte-steps in each dimension (mirrors numpy.ndarray.strides).

swapaxes(axis1, axis2)[source]#

Interchange two axes of the array.

Parameters:
  • axis1 (int) – First axis.

  • axis2 (int) – Second axis.

Returns:

The quantity with axes swapped.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.swapaxes(0, 1).shape
(2, 2)
take(indices, axis=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[source]#

Select elements from the array at the given indices.

Parameters:
  • indices (array_like) – Indices of the values to extract.

  • axis (int, optional) – Axis along which to take (default flattened).

  • mode (str, optional) – Out-of-bounds index handling.

  • unique_indices (bool, optional) – Hint that indices are unique.

  • indices_are_sorted (bool, optional) – Hint that indices are sorted.

  • fill_value (Quantity or scalar, optional) – Value for out-of-bounds positions when mode is 'fill'.

Returns:

The selected elements.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([10.0, 20.0, 30.0]), unit=u.mV)
>>> q.take(jnp.array([0, 2]))
Quantity([10. 30.], "mV")
tile(reps)[source]#

Construct an array by repeating this quantity.

Parameters:

reps (int or array_like) – Number of repetitions along each axis.

Returns:

The tiled quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.tile(2)
Quantity([1. 2. 1. 2.], "mV")
to(new_unit)[source]#

Convert this quantity to a different (compatible) unit.

The mantissa is rescaled so that the physical value stays the same, and the returned Quantity carries new_unit.

Parameters:

new_unit (saiunit.Unit) – Target unit. Must have the same dimension as self.unit.

Returns:

A new Quantity expressed in new_unit.

Return type:

Quantity

Raises:

UnitMismatchError – If new_unit has a different dimension.

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.to(u.volt)
Quantity([0.001 0.002 0.003], "V")

See also

in_unit

Identical behaviour (to delegates to in_unit).

to_decimal

Convert to a plain number in the target unit.

to_cupy(*, device=None)[source]#

Return a new Quantity with mantissa converted to a cupy.ndarray.

No-op (returns self) if the mantissa is already a CuPy array and no device was specified.

Return type:

Quantity

to_dask(*, chunks='auto')[source]#

Return a new Quantity with mantissa converted to a dask.array.Array.

No-op (returns self) if the mantissa is already a dask array and no chunks was specified.

Return type:

Quantity

to_decimal(unit=Unit('1'))[source]#

Return the numerical value expressed in the given unit, without wrapping the result in a Quantity.

This is useful when you need a plain JAX array for downstream computation that does not support units.

Parameters:

unit (saiunit.Unit) – The reference unit. Defaults to UNITLESS.

Returns:

A plain number or JAX array representing the quantity in unit.

Return type:

Array | ndarray | number | bool

Raises:

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.to_decimal(u.volt)
Array([0.001, 0.002, 0.003], dtype=float32)

See also

to

Convert while keeping the Quantity wrapper.

to_jax()[source]#

Return a new Quantity with mantissa converted to Array.

No-op (returns self) if the mantissa is already a JAX array.

Return type:

Quantity

to_ndonnx()[source]#

Return a new Quantity with mantissa converted to an ndonnx.Array.

No-op (returns self) if the mantissa is already an ndonnx array. ndonnx arrays are symbolic — operations build an ONNX graph rather than eagerly computing.

Return type:

Quantity

to_numpy()[source]#

Return a new Quantity with mantissa converted to numpy.ndarray.

No-op (returns self) if the mantissa is already a NumPy array.

Return type:

Quantity

to_torch(*, device=None, dtype=None)[source]#

Return a new Quantity with mantissa converted to a torch.Tensor.

No-op (returns self) if the mantissa is already a torch tensor and no device/dtype was specified. dtype accepts either a torch dtype (e.g. torch.float32) or a numpy dtype (e.g. np.float32).

Return type:

Quantity

tolist()[source]#

Convert the array to a (nested) Python list of Quantity scalars.

Each leaf element is a 0-D Quantity with the same unit.

Returns:

A nested list of scalar Quantity objects, or a single Quantity for 0-D arrays.

Return type:

list or Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.tolist()
[Quantity(1., "mV"), Quantity(2., "mV")]
trace(offset=0, axis1=0, axis2=1)[source]#

Sum along diagonals of the array, preserving units.

Parameters:
  • offset (int) – Offset of the diagonal from the main diagonal (default 0).

  • axis1 (int) – First axis of the 2-D sub-arrays (default 0).

  • axis2 (int) – Second axis of the 2-D sub-arrays (default 1).

Returns:

The trace value(s).

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.eye(3), unit=u.mV)
>>> q.trace()
Quantity(3., "mV")
transpose(*axes)[source]#

Return the array with axes transposed.

For a 2-D array this is the standard matrix transpose.

Parameters:

*axes (None, tuple of ints, or n ints) – If omitted, axes are reversed. Otherwise specifies the permutation.

Returns:

Transposed quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.transpose().shape
(2, 2)
tree_flatten()[source]#

Tree flattens the data.

Return type:

tuple[tuple[Array | ndarray | number | bool], saiunit.Unit]

Returns:

The data and the dimension.

classmethod tree_unflatten(unit, values)[source]#

Tree unflattens the data.

Parameters:
  • unit – The unit.

  • values – The data.

Return type:

Quantity

Returns:

The Quantity object.

property unit: saiunit.Unit#

The Unit attached to this quantity.

The unit carries both the physical dimension and the scale factor (e.g. mV has dimension voltage with scale 1e-3).

Returns:

The unit of this quantity.

Return type:

Unit

Examples

>>> import saiunit as u
>>> q = u.Quantity(5.0, unit=u.mV)
>>> q.unit
mV

See also

dim

The physical dimension without scale information.

mantissa

The numerical value.

unsqueeze(axis)[source]#

Insert a length-one axis (PyTorch-style alias for expand_dims()).

Parameters:

axis (int) – Position where the new axis is inserted.

Returns:

The quantity with an extra dimension.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.unsqueeze(0).shape
(1, 2)
update_mantissa(mantissa)[source]#

Replace the mantissa in-place, keeping the same unit.

The new mantissa must have the same shape and dtype as the current one.

Parameters:

mantissa (Any) – The new numerical data. Must not be a Quantity.

Raises:

ValueError – If mantissa is a Quantity, or if shape/dtype do not match.

Return type:

None

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.update_mantissa(jnp.array([4.0, 5.0, 6.0]))
>>> q
Quantity([4. 5. 6.], "mV")
view(*args, dtype=None)[source]#

New view of array with the same data.

This function is compatible with pytorch syntax.

Returns a new tensor with the same data as the self tensor but of a different shape.

The returned tensor shares the same data and must have the same number of elements, but may have a different size. For a tensor to be viewed, the new view size must be compatible with its original size and stride, i.e., each new view dimension must either be a subspace of an original dimension, or only span across original dimensions \(d, d+1, \dots, d+k\) that satisfy the following contiguity-like condition that \(\forall i = d, \dots, d+k-1\),

\[\text{stride}[i] = \text{stride}[i+1] \times \text{size}[i+1]\]

Otherwise, it will not be possible to view self tensor as shape without copying it (e.g., via contiguous()). When it is unclear whether a view() can be performed, it is advisable to use reshape(), which returns a view if the shapes are compatible, and copies (equivalent to calling contiguous()) otherwise.

Parameters:

shape (int...) – the desired size

Return type:

Quantity

Example:

>>> import jax.numpy as jnp, saiunit
>>> x = saiunit.Quantity(jnp.ones((4, 4)))
>>> x.shape
(4, 4)
>>> y = x.view(16)
>>> y.shape
(16,)
>>> z = x.view(2, 8)
>>> z.shape
(2, 8)
view(dtype) Tensor[source]

Returns a new tensor with the same data as the self tensor but of a different dtype.

If the element size of dtype is different than that of self.dtype, then the size of the last dimension of the output will be scaled proportionally. For instance, if dtype element size is twice that of self.dtype, then each pair of elements in the last dimension of self will be combined, and the size of the last dimension of the output will be half that of self. If dtype element size is half that of self.dtype, then each element in the last dimension of self will be split in two, and the size of the last dimension of the output will be double that of self. For this to be possible, the following conditions must be true:

  • self.dim() must be greater than 0.

  • self.stride(-1) must be 1.

Additionally, if the element size of dtype is greater than that of self.dtype, the following conditions must be true as well:

  • self.size(-1) must be divisible by the ratio between the element sizes of the dtypes.

  • self.storage_offset() must be divisible by the ratio between the element sizes of the dtypes.

  • The strides of all dimensions, except the last dimension, must be divisible by the ratio between the element sizes of the dtypes.

If any of the above conditions are not met, an error is thrown.

Parameters:

dtype (dtype) – the desired dtype

Example:

>>> x = brainstate.random.randn(4, 4)
>>> x
Array([[ 0.9482, -0.0310,  1.4999, -0.5316],
        [-0.1520,  0.7472,  0.5617, -0.8649],
        [-2.4724, -0.0334, -0.2976, -0.8499],
        [-0.2109,  1.9913, -0.9607, -0.6123]])
>>> x.dtype
brainstate.math.float32

>>> y = x.view(numpy.int32)
>>> y
tensor([[ 1064483442, -1124191867,  1069546515, -1089989247],
        [-1105482831,  1061112040,  1057999968, -1084397505],
        [-1071760287, -1123489973, -1097310419, -1084649136],
        [-1101533110,  1073668768, -1082790149, -1088634448]],
    dtype=numpy.int32)
>>> y[0, 0] = 1000000000
>>> x
tensor([[ 0.0047, -0.0310,  1.4999, -0.5316],
        [-0.1520,  0.7472,  0.5617, -0.8649],
        [-2.4724, -0.0334, -0.2976, -0.8499],
        [-0.2109,  1.9913, -0.9607, -0.6123]])

>>> x.view(numpy.complex64)
tensor([[ 0.0047-0.0310j,  1.4999-0.5316j],
        [-0.1520+0.7472j,  0.5617-0.8649j],
        [-2.4724-0.0334j, -0.2976-0.8499j],
        [-0.2109+1.9913j, -0.9607-0.6123j]])
>>> x.view(numpy.complex64).size
[4, 2]

>>> x.view(numpy.uint8)
tensor([[  0, 202, 154,  59, 182, 243, 253, 188, 185, 252, 191,  63, 240,  22,
             8, 191],
        [227, 165,  27, 190, 128,  72,  63,  63, 146, 203,  15,  63,  22, 106,
            93, 191],
        [205,  59,  30, 192, 112, 206,   8, 189,   7,  95, 152, 190,  12, 147,
            89, 191],
        [ 43, 246,  87, 190, 235, 226, 254,  63, 111, 240, 117, 191, 177, 191,
            28, 191]], dtype=uint8)
>>> x.view(numpy.uint8).size
[4, 16]
static with_unit(mantissa, unit)[source]#

Create a Quantity from a raw value and a unit.

This is a convenience factory that reads more naturally in some contexts than the standard constructor.

Parameters:
  • mantissa (Any) – The numerical value(s).

  • unit (saiunit.Unit) – The physical unit.

Returns:

A new Quantity with the given mantissa and unit.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> u.Quantity.with_unit(2.0, unit=u.metre)
Quantity(2., "m")