Quantity#
- class saiunit.Quantity(mantissa, unit=Unit('1'), dtype=None)#
A numerical value paired with a physical unit.
Quantityis the central data structure insaiunit. It stores a mantissa (the raw numerical data, typically a JAX array) together with aUnitthat describes the physical dimensions and scale. Arithmetic onQuantityobjects automatically tracks and checks units, raisingUnitMismatchErrorwhen incompatible quantities are combined.Quantityis registered as a JAX pytree, so it works transparently withjax.jit,jax.grad,jax.vmap, and other JAX transformations.- Parameters:
mantissa (
Union[Any, saiunit.Unit]) – The numerical value(s). If aUnitis passed, the mantissa is set to1.0and that unit is adopted. If aQuantityis passed, its mantissa and unit are used (converted to unit when given).unit (
Union[saiunit.Unit,str,None]) – The physical unit. Defaults toUNITLESS.dtype (
Union[str,type[Any],dtype,SupportsDType,None]) – If provided, the mantissa is cast to this dtype on construction.
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> # Scalar with unit >>> q = u.Quantity(3.0, unit=u.mV) >>> q Quantity(3., "mV") >>> # Array with unit via multiplication shorthand >>> arr = jnp.array([1.0, 2.0, 3.0]) * u.mV >>> arr.shape (3,) >>> # From a Unit object directly >>> u.Quantity(u.metre) Quantity(1., "m")
See also
UnitRepresents a physical unit (dimension + scale).
- astype(dtype)[source]#
Return a copy of this quantity with the mantissa cast to dtype.
- Parameters:
dtype (
Union[str,type[Any],dtype,SupportsDType]) – Target data type (e.g.jnp.float64).- Returns:
A new quantity with the converted dtype.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV) >>> q.astype(jnp.float64).dtype float64
- property at#
Helper property for index update functionality.
The
atproperty provides a functionally pure equivalent of in-place array modifications.In particular:
Alternate syntax
Equivalent In-place expression
x = x.at[idx].set(y)x[idx] = yx = x.at[idx].add(y)x[idx] += yx = x.at[idx].multiply(y)x[idx] *= yx = x.at[idx].divide(y)x[idx] /= yx = x.at[idx].power(y)x[idx] **= yx = x.at[idx].min(y)x[idx] = minimum(x[idx], y)x = x.at[idx].max(y)x[idx] = maximum(x[idx], y)x = x.at[idx].apply(ufunc)ufunc.at(x, idx)x = x.at[idx].get()x = x[idx]None of the
x.atexpressions modify the originalx; instead they return a modified copy ofx. However, inside ajit()compiled function, expressions likex = x.at[idx].set(y)are guaranteed to be applied in-place.Quantity.atis multi-backend: it works fornumpy,jax,cupy,torch, anddaskmantissas. Thendonnxbackend cannot represent functional in-place updates in its symbolic graph and raisessaiunit.BackendError— call.to_numpy()first. Ondaskthe update is expressed viada.whereso the task graph stays lazy; only slice / scalar-int / 1-D integer / boolean-mask indices are supported on dask (multi-dim fancy-integer indexing raisesNotImplementedError— use.to_numpy()for that case).Repeated-index semantics differ across backends. When multiple indices refer to the same location, JAX applies all updates (NumPy in-place
x[idx] += ywould apply only the last). The summary:Backend
addmultiply / divide / min / max / applyjaxaccumulates
accumulates
numpyaccumulates (np.add.at)
accumulates (np.<op>.at)
cupyaccumulates
accumulates
torchaccumulates (index_put_(accumulate))
last-write-wins
dasklast-write-wins via mask
last-write-wins via mask
ndonnxraises BackendError
raises BackendError
By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound semantics can be specified via the
modeparameter (see below). On non-JAX backendsmodeandfill_valueare emulated for scalar-int and 1-D integer-array indices; for slice / boolean / ellipsis / tuple indices the kwarg is silently ignored because out-of-bounds cannot occur for same-shape sources.indices_are_sortedandunique_indicesare hints only — silently ignored outside JAX.- Parameters:
mode (str) –
Specify out-of-bound indexing mode. Options are:
"promise_in_bounds": (default) The user promises that indices are in bounds. No additional checking will be performed. In practice, this means that out-of-bounds indices inget()will be clipped, and out-of-bounds indices inset(),add(), etc. will be dropped."clip": clamp out of bounds indices into valid range."drop": ignore out-of-bound indices."fill": alias for"drop". For get(), the optionalfill_valueargument specifies the value that will be returned.
indices_are_sorted (bool) – If True, the implementation will assume that the indices passed to
at[]are sorted in ascending order, which can lead to more efficient execution on some backends.unique_indices (bool) – If True, the implementation will assume that the indices passed to
at[]are unique, which can result in more efficient execution on some backends.fill_value (Any) – Only applies to the
get()method: the fill value to return for out-of-bounds slices when mode is'fill'. Ignored otherwise. Defaults toNaNfor inexact types, the largest negative value for signed types, the largest positive value for unsigned types, andTruefor booleans.
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> x = jnp.arange(5.0) * u.mV >>> x.at[2].add(10 * u.mV) Quantity([ 0. 1. 12. 3. 4.], "mV") >>> x.at[2].get() Quantity(2., "mV")
- property backend: str#
one of
'numpy','jax','cupy','torch','dask','ndonnx'.- Type:
The backend of the underlying mantissa
- clip(min=None, max=None)[source]#
Clip (limit) the values in the array to
[min, max].At least one of min or max must be given. Both must be compatible with the unit of
self.- Parameters:
- Returns:
The clipped quantity.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.clip(min=u.Quantity(1.5, unit=u.mV), max=u.Quantity(2.5, unit=u.mV)) Quantity([1.5 2. 2.5], "mV")
- clone()[source]#
Return a copy of this quantity (PyTorch-style alias for
copy()).- Returns:
An independent copy.
- Return type:
Examples
>>> import saiunit as u >>> q = u.Quantity(3.0, unit=u.mV) >>> q.clone() Quantity(3., "mV")
- conj()[source]#
Return the complex conjugate, element-wise, preserving units.
- Returns:
The conjugated quantity.
- Return type:
Examples
>>> import saiunit as u >>> q = u.Quantity(1.0 + 2.0j, unit=u.mV) >>> q.conj() Quantity((1-2j), "mV")
- conjugate()[source]#
Return the complex conjugate, element-wise.
Alias for
conj().- Returns:
The conjugated quantity.
- Return type:
- copy()[source]#
Return a deep copy of this quantity.
- Returns:
An independent copy with the same mantissa and unit.
- Return type:
Examples
>>> import saiunit as u >>> q = u.Quantity(3.0, unit=u.mV) >>> q2 = q.copy() >>> q2 Quantity(3., "mV")
- cross(b, axisa=-1, axisb=-1, axisc=-1, axis=None)[source]#
Cross product of two arrays.
The resulting unit is
self.unit * b.unit.- Parameters:
b (
Quantity) – Second operand.axisa (
int) – Axis of self that defines the vector(s) (default-1).axisb (
int) – Axis of b that defines the vector(s) (default-1).axisc (
int) – Axis of the result containing the cross product (default-1).axis (
int|None) – Overrides axisa, axisb, and axisc simultaneously.
- Returns:
The cross product.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> a = u.Quantity(jnp.array([1.0, 0.0, 0.0]), unit=u.mV) >>> b = u.Quantity(jnp.array([0.0, 1.0, 0.0]), unit=u.second) >>> a.cross(b) Quantity([0. 0. 1.], "mV * s")
- cumprod(*args, **kwds)[source]#
Return the cumulative product of elements along a given axis.
Because each position in the result corresponds to a different number of multiplied elements, the unit exponent varies across the output. This is only representable when the quantity is dimensionless.
- diagonal(offset=0, axis1=0, axis2=1)[source]#
Return specified diagonals, preserving units.
- Parameters:
- Returns:
The diagonal elements.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV) >>> q.diagonal() Quantity([1. 4.], "mV")
- property dim: saiunit.Dimension#
The physical dimension of this quantity (e.g. length, mass, time).
The dimension is independent of scale (metres vs kilometres both have the length dimension).
- Returns:
The physical dimension object.
- Return type:
Examples
>>> import saiunit as u >>> q = u.Quantity(5.0, unit=u.metre) >>> q.dim m
See also
unitThe full unit (dimension + scale).
- dot(b)[source]#
Dot product of two arrays.
The resulting unit is
self.unit * b.unit.- Parameters:
b (Quantity or array_like) – Second operand.
- Returns:
The dot product.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> a = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> b = u.Quantity(jnp.array([1.0, 1.0, 1.0]), unit=u.mV) >>> a.dot(b) Quantity(6., "mV^2")
- property dtype#
The data type of the mantissa.
- Returns:
The JAX/NumPy dtype of the underlying array.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV) >>> q.dtype float32
- expand_dims(axis)[source]#
Insert new axes at the given positions.
- Parameters:
axis (
int|Sequence[int]) – Position(s) where the new axis (axes) are placed.- Returns:
The expanded quantity.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV) >>> q.expand_dims(0).shape (1, 2)
- factorless()[source]#
Return an equivalent quantity whose unit has
factor == 1.0.If the unit already has no extra factor the original object is returned unchanged.
- Returns:
A quantity with the factor folded into the mantissa.
- Return type:
Examples
>>> import saiunit as u >>> q = u.Quantity(3.0, unit=u.mV) >>> q.factorless() Quantity(3., "mV")
- property flat#
1-D iterator over the mantissa elements, unit preserved.
- flatten()[source]#
Return a 1-D copy of this quantity.
- Returns:
Flattened quantity with the same unit.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV) >>> q.flatten() Quantity([1. 2. 3. 4.], "mV")
- has_same_unit(other)[source]#
Check whether this quantity shares the same physical dimension as other.
Two quantities that differ only in scale (e.g.
mVvsV) are considered to have the same unit dimension.- Parameters:
- Returns:
Trueif both have identical physical dimensions.- Return type:
Examples
>>> import saiunit as u >>> a = u.Quantity(1.0, unit=u.mV) >>> b = u.Quantity(2.0, unit=u.volt) >>> a.has_same_unit(b) True >>> c = u.Quantity(1.0, unit=u.second) >>> a.has_same_unit(c) False
- in_unit(unit, err_msg=None)[source]#
Convert this quantity to a compatible unit.
Behaves identically to
to(); kept for API compatibility.- Parameters:
- Returns:
A new
Quantityexpressed in unit.- Return type:
- Raises:
UnitMismatchError – If unit has a different dimension.
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.in_unit(u.volt) Quantity([0.001 0.002 0.003], "V")
- property is_unitless: bool#
Trueif this quantity is dimensionless (has no physical unit).- Returns:
Whether the quantity is unitless.
- Return type:
Examples
>>> import saiunit as u >>> u.Quantity(5.0).is_unitless True >>> u.Quantity(5.0, unit=u.mV).is_unitless False
- item(*args)[source]#
Extract a single element as a scalar
Quantity.- Parameters:
*args (int) – Index into the flat array.
- Returns:
A 0-D
Quantitycontaining the selected element.- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([10.0, 20.0]), unit=u.mV) >>> q.item(0) Quantity(10., "mV")
- property mT: saiunit.Quantity#
Matrix transpose of the last two dimensions, preserving units.
The array must be at least 2-D.
- Returns:
The matrix-transposed quantity.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV) >>> q.mT.shape (2, 2)
- property magnitude: Array | ndarray | number | bool#
Alias for
mantissa.- Returns:
The raw numerical data of this quantity.
- Return type:
array_like
Examples
>>> import saiunit as u >>> q = u.Quantity(5.0, unit=u.metre) >>> q.magnitude 5.0
See also
mantissaPrimary accessor for the numerical data.
- property mantissa: Array | ndarray | number | bool#
The raw numerical data of this quantity (without the unit).
In scientific notation \(x = a \times 10^{b}\), the mantissa is the coefficient \(a\). For a
Quantity, it is the underlying JAX/NumPy array (or Python scalar) that stores the numeric value.- Returns:
The mantissa array or scalar.
- Return type:
array_like
Examples
>>> import saiunit as u >>> q = u.Quantity(3.0, unit=u.mV) >>> q.mantissa 3.0
- nancumprod(*args, **kwds)[source]#
Return the cumulative product of elements along a given axis, treating NaNs as ones.
Because each position in the result corresponds to a different number of multiplied elements, the unit exponent varies across the output. This is only representable when the quantity is dimensionless.
- nanprod(*args, **kwds)[source]#
Return the product of array elements over a given axis treating Not a Numbers (NaNs) as ones.
When reducing along a specific axis, the number of non-NaN elements must be the same for every position in the result so that a single unit exponent can be assigned. If the non-NaN counts differ and the quantity is not dimensionless, a
ValueErroris raised.- Returns:
The product (NaNs treated as ones).
- Return type:
- Raises:
ValueError – If the non-NaN counts are not uniform along the reduction axis for a non-dimensionless quantity.
- outer(b)[source]#
Outer product of two 1-D arrays.
The resulting unit is
self.unit * b.unit.Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> a = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV) >>> b = u.Quantity(jnp.array([3.0, 4.0]), unit=u.second) >>> a.outer(b).shape (2, 2)
- pow(oc)[source]#
Raise this quantity to the power oc.
The exponent must be dimensionless. The resulting unit is
self.unit ** oc.- Parameters:
- Returns:
self ** oc.- Return type:
Examples
>>> import saiunit as u >>> q = u.Quantity(2.0, unit=u.mV) >>> q.pow(2) Quantity(4., "mV^2")
- prod(*args, **kwds)[source]#
Return the product of array elements over the given axis.
The unit of the result is
self.unit ** nwhere n is the number of elements multiplied together.- Returns:
The product.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([2.0, 3.0]), unit=u.mV) >>> q.prod() Quantity(6., "mV^2")
- put(indices, values)[source]#
Replaces specified elements of an array with given values.
- Parameters:
indices (array_like) – Target indices, interpreted as integers.
values (array_like) – Values to place in the array at target indices.
- Return type:
- repeat(repeats, axis=None)[source]#
Repeat elements of the array.
- Parameters:
- Returns:
The repeated quantity.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV) >>> q.repeat(2) Quantity([1. 1. 2. 2.], "mV")
- repr_in_unit(precision=None)[source]#
Return a human-readable string of this quantity in its current unit.
The format is
"<value> <unit>", e.g."3. mV"or"[1. 2. 3.] mV".- Parameters:
precision (
int|None) – Number of significant digits. When None the value fromnumpy.get_printoptionsis used.- Returns:
The formatted string.
- Return type:
Examples
>>> import saiunit as u >>> x = u.Quantity(25.0, unit=u.mV) >>> x.repr_in_unit() '25. mV' >>> x.to(u.volt).repr_in_unit(3) '0.025 V'
- reshape(shape, order='C')[source]#
Return a quantity with the same data but a new shape.
- Parameters:
- Returns:
Reshaped quantity.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.reshape((3, 1)).shape (3, 1)
- round(decimals=0)[source]#
Evenly round the mantissa to the given number of decimals.
- Parameters:
decimals (
int) – Number of decimal places (default0). Negative values round to positions left of the decimal point.- Returns:
A new quantity with the rounded mantissa.
- Return type:
Examples
>>> import saiunit as u >>> q = u.Quantity(1.567, unit=u.mV) >>> q.round(1) Quantity(1.6, "mV")
- scatter_add(index, value)[source]#
Return a copy with value added at index.
- Parameters:
index (
Array|ndarray|number|bool) – Target index (indices).value (
Quantity|Array|ndarray|number|bool) – The value to add. Must have the same unit dimension.
- Returns:
A new quantity with the update applied.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.scatter_add(0, u.Quantity(10.0, unit=u.mV)) Quantity([11. 2. 3.], "mV")
- scatter_div(index, value)[source]#
Return a copy with the element at index divided by value.
value must be dimensionless (a pure scale factor).
- Parameters:
index (
Array|ndarray|number|bool) – Target index (indices).value (
Quantity|Array|ndarray|number|bool) – Dimensionless scale factor.
- Returns:
A new quantity with the update applied.
- Return type:
- Raises:
TypeError – If value is not dimensionless.
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.scatter_div(0, u.Quantity(2.0)) Quantity([0.5 2. 3. ], "mV")
- scatter_max(index, value)[source]#
Return a copy where the element at index is the maximum of the current value and value.
- Parameters:
index (
Array|ndarray|number|bool) – Target index (indices).value (
Quantity|Array|ndarray|number|bool) – The comparison value. Must have the same unit dimension.
- Returns:
A new quantity with the update applied.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.scatter_max(0, u.Quantity(10.0, unit=u.mV)) Quantity([10. 2. 3.], "mV")
- scatter_min(index, value)[source]#
Return a copy where the element at index is the minimum of the current value and value.
- Parameters:
index (
Array|ndarray|number|bool) – Target index (indices).value (
Quantity|Array|ndarray|number|bool) – The comparison value. Must have the same unit dimension.
- Returns:
A new quantity with the update applied.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.scatter_min(0, u.Quantity(0.5, unit=u.mV)) Quantity([0.5 2. 3. ], "mV")
- scatter_mul(index, value)[source]#
Return a copy with the element at index multiplied by value.
value must be dimensionless (a pure scale factor).
- Parameters:
index (
Array|ndarray|number|bool) – Target index (indices).value (
Quantity|Array|ndarray|number|bool) – Dimensionless scale factor.
- Returns:
A new quantity with the update applied.
- Return type:
- Raises:
TypeError – If value is not dimensionless.
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.scatter_mul(0, u.Quantity(10.0)) Quantity([10. 2. 3.], "mV")
- scatter_sub(index, value)[source]#
Return a copy with value subtracted at index.
- Parameters:
index (
Array|ndarray|number|bool) – Target index (indices).value (
Quantity|Array|ndarray|number|bool) – The value to subtract. Must have the same unit dimension.
- Returns:
A new quantity with the update applied.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.scatter_sub(0, u.Quantity(1.0, unit=u.mV)) Quantity([0. 2. 3.], "mV")
- searchsorted(v, side='left', sorter=None)[source]#
Find indices where elements should be inserted to maintain order.
- Return type:
Array
- property shape: tuple[int, ...]#
The shape of the mantissa array.
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV) >>> q.shape (2, 2)
- sort(axis=-1, stable=True, order=None)[source]#
Sort the array in-place along the given axis.
- Parameters:
- Returns:
self, with the mantissa sorted in-place.- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([3.0, 1.0, 2.0]), unit=u.mV) >>> q.sort() Quantity([1. 2. 3.], "mV")
- split(indices_or_sections, axis=0)[source]#
Split the array into multiple sub-arrays.
- Parameters:
- Returns:
Sub-arrays, each carrying the same unit.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> parts = q.split(3) >>> len(parts) 3
- squeeze(axis=None)[source]#
Remove length-one axes from the array.
- Parameters:
axis (int or tuple of ints, optional) – Axes to remove. If
None, all length-one axes are removed.- Returns:
The squeezed quantity.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([[[1.0]]]), unit=u.mV) >>> q.squeeze().shape ()
- property strides#
Tuple of byte-steps in each dimension (mirrors numpy.ndarray.strides).
- swapaxes(axis1, axis2)[source]#
Interchange two axes of the array.
- Parameters:
- Returns:
The quantity with axes swapped.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV) >>> q.swapaxes(0, 1).shape (2, 2)
- take(indices, axis=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[source]#
Select elements from the array at the given indices.
- Parameters:
indices (array_like) – Indices of the values to extract.
axis (int, optional) – Axis along which to take (default flattened).
mode (str, optional) – Out-of-bounds index handling.
unique_indices (bool, optional) – Hint that indices are unique.
indices_are_sorted (bool, optional) – Hint that indices are sorted.
fill_value (Quantity or scalar, optional) – Value for out-of-bounds positions when mode is
'fill'.
- Returns:
The selected elements.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([10.0, 20.0, 30.0]), unit=u.mV) >>> q.take(jnp.array([0, 2])) Quantity([10. 30.], "mV")
- tile(reps)[source]#
Construct an array by repeating this quantity.
- Parameters:
reps (int or array_like) – Number of repetitions along each axis.
- Returns:
The tiled quantity.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV) >>> q.tile(2) Quantity([1. 2. 1. 2.], "mV")
- to(new_unit)[source]#
Convert this quantity to a different (compatible) unit.
The mantissa is rescaled so that the physical value stays the same, and the returned
Quantitycarries new_unit.- Parameters:
new_unit (saiunit.Unit) – Target unit. Must have the same dimension as
self.unit.- Returns:
A new
Quantityexpressed in new_unit.- Return type:
- Raises:
UnitMismatchError – If new_unit has a different dimension.
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.to(u.volt) Quantity([0.001 0.002 0.003], "V")
See also
in_unitIdentical behaviour (
todelegates toin_unit).to_decimalConvert to a plain number in the target unit.
- to_cupy(*, device=None)[source]#
Return a new Quantity with mantissa converted to a
cupy.ndarray.No-op (returns
self) if the mantissa is already a CuPy array and nodevicewas specified.- Return type:
- to_dask(*, chunks='auto')[source]#
Return a new Quantity with mantissa converted to a
dask.array.Array.No-op (returns
self) if the mantissa is already a dask array and nochunkswas specified.- Return type:
- to_decimal(unit=Unit('1'))[source]#
Return the numerical value expressed in the given unit, without wrapping the result in a
Quantity.This is useful when you need a plain JAX array for downstream computation that does not support units.
- Parameters:
unit (saiunit.Unit) – The reference unit. Defaults to
UNITLESS.- Returns:
A plain number or JAX array representing the quantity in unit.
- Return type:
Array|ndarray|number|bool- Raises:
UnitMismatchError – If unit has a different dimension than
self.unit.
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.to_decimal(u.volt) Array([0.001, 0.002, 0.003], dtype=float32)
See also
toConvert while keeping the
Quantitywrapper.
- to_jax()[source]#
Return a new Quantity with mantissa converted to
Array.No-op (returns
self) if the mantissa is already a JAX array.- Return type:
- to_ndonnx()[source]#
Return a new Quantity with mantissa converted to an
ndonnx.Array.No-op (returns
self) if the mantissa is already an ndonnx array. ndonnx arrays are symbolic — operations build an ONNX graph rather than eagerly computing.- Return type:
- to_numpy()[source]#
Return a new Quantity with mantissa converted to
numpy.ndarray.No-op (returns
self) if the mantissa is already a NumPy array.- Return type:
- to_torch(*, device=None, dtype=None)[source]#
Return a new Quantity with mantissa converted to a
torch.Tensor.No-op (returns
self) if the mantissa is already a torch tensor and nodevice/dtypewas specified.dtypeaccepts either a torch dtype (e.g.torch.float32) or a numpy dtype (e.g.np.float32).- Return type:
- tolist()[source]#
Convert the array to a (nested) Python list of
Quantityscalars.Each leaf element is a 0-D
Quantitywith the same unit.- Returns:
A nested list of scalar
Quantityobjects, or a singleQuantityfor 0-D arrays.- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV) >>> q.tolist() [Quantity(1., "mV"), Quantity(2., "mV")]
- trace(offset=0, axis1=0, axis2=1)[source]#
Sum along diagonals of the array, preserving units.
- Parameters:
- Returns:
The trace value(s).
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.eye(3), unit=u.mV) >>> q.trace() Quantity(3., "mV")
- transpose(*axes)[source]#
Return the array with axes transposed.
For a 2-D array this is the standard matrix transpose.
- Parameters:
*axes (None, tuple of ints, or n ints) – If omitted, axes are reversed. Otherwise specifies the permutation.
- Returns:
Transposed quantity.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV) >>> q.transpose().shape (2, 2)
- classmethod tree_unflatten(unit, values)[source]#
Tree unflattens the data.
- Parameters:
unit – The unit.
values – The data.
- Return type:
- Returns:
The Quantity object.
- property unit: saiunit.Unit#
The
Unitattached to this quantity.The unit carries both the physical dimension and the scale factor (e.g.
mVhas dimensionvoltagewith scale1e-3).- Returns:
The unit of this quantity.
- Return type:
Examples
>>> import saiunit as u >>> q = u.Quantity(5.0, unit=u.mV) >>> q.unit mV
- unsqueeze(axis)[source]#
Insert a length-one axis (PyTorch-style alias for
expand_dims()).- Parameters:
axis (
int) – Position where the new axis is inserted.- Returns:
The quantity with an extra dimension.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV) >>> q.unsqueeze(0).shape (1, 2)
- update_mantissa(mantissa)[source]#
Replace the mantissa in-place, keeping the same unit.
The new mantissa must have the same shape and dtype as the current one.
- Parameters:
mantissa (
Any) – The new numerical data. Must not be aQuantity.- Raises:
ValueError – If mantissa is a
Quantity, or if shape/dtype do not match.- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.update_mantissa(jnp.array([4.0, 5.0, 6.0])) >>> q Quantity([4. 5. 6.], "mV")
- view(*args, dtype=None)[source]#
New view of array with the same data.
This function is compatible with pytorch syntax.
Returns a new tensor with the same data as the
selftensor but of a differentshape.The returned tensor shares the same data and must have the same number of elements, but may have a different size. For a tensor to be viewed, the new view size must be compatible with its original size and stride, i.e., each new view dimension must either be a subspace of an original dimension, or only span across original dimensions \(d, d+1, \dots, d+k\) that satisfy the following contiguity-like condition that \(\forall i = d, \dots, d+k-1\),
\[\text{stride}[i] = \text{stride}[i+1] \times \text{size}[i+1]\]Otherwise, it will not be possible to view
selftensor asshapewithout copying it (e.g., viacontiguous()). When it is unclear whether aview()can be performed, it is advisable to usereshape(), which returns a view if the shapes are compatible, and copies (equivalent to callingcontiguous()) otherwise.Example:
>>> import jax.numpy as jnp, saiunit >>> x = saiunit.Quantity(jnp.ones((4, 4))) >>> x.shape (4, 4) >>> y = x.view(16) >>> y.shape (16,) >>> z = x.view(2, 8) >>> z.shape (2, 8)
- view(dtype) Tensor[source]
Returns a new tensor with the same data as the
selftensor but of a differentdtype.If the element size of
dtypeis different than that ofself.dtype, then the size of the last dimension of the output will be scaled proportionally. For instance, ifdtypeelement size is twice that ofself.dtype, then each pair of elements in the last dimension ofselfwill be combined, and the size of the last dimension of the output will be half that ofself. Ifdtypeelement size is half that ofself.dtype, then each element in the last dimension ofselfwill be split in two, and the size of the last dimension of the output will be double that ofself. For this to be possible, the following conditions must be true:self.dim()must be greater than 0.self.stride(-1)must be 1.
Additionally, if the element size of
dtypeis greater than that ofself.dtype, the following conditions must be true as well:self.size(-1)must be divisible by the ratio between the element sizes of the dtypes.self.storage_offset()must be divisible by the ratio between the element sizes of the dtypes.The strides of all dimensions, except the last dimension, must be divisible by the ratio between the element sizes of the dtypes.
If any of the above conditions are not met, an error is thrown.
- Parameters:
dtype (
dtype) – the desired dtype
Example:
>>> x = brainstate.random.randn(4, 4) >>> x Array([[ 0.9482, -0.0310, 1.4999, -0.5316], [-0.1520, 0.7472, 0.5617, -0.8649], [-2.4724, -0.0334, -0.2976, -0.8499], [-0.2109, 1.9913, -0.9607, -0.6123]]) >>> x.dtype brainstate.math.float32 >>> y = x.view(numpy.int32) >>> y tensor([[ 1064483442, -1124191867, 1069546515, -1089989247], [-1105482831, 1061112040, 1057999968, -1084397505], [-1071760287, -1123489973, -1097310419, -1084649136], [-1101533110, 1073668768, -1082790149, -1088634448]], dtype=numpy.int32) >>> y[0, 0] = 1000000000 >>> x tensor([[ 0.0047, -0.0310, 1.4999, -0.5316], [-0.1520, 0.7472, 0.5617, -0.8649], [-2.4724, -0.0334, -0.2976, -0.8499], [-0.2109, 1.9913, -0.9607, -0.6123]]) >>> x.view(numpy.complex64) tensor([[ 0.0047-0.0310j, 1.4999-0.5316j], [-0.1520+0.7472j, 0.5617-0.8649j], [-2.4724-0.0334j, -0.2976-0.8499j], [-0.2109+1.9913j, -0.9607-0.6123j]]) >>> x.view(numpy.complex64).size [4, 2] >>> x.view(numpy.uint8) tensor([[ 0, 202, 154, 59, 182, 243, 253, 188, 185, 252, 191, 63, 240, 22, 8, 191], [227, 165, 27, 190, 128, 72, 63, 63, 146, 203, 15, 63, 22, 106, 93, 191], [205, 59, 30, 192, 112, 206, 8, 189, 7, 95, 152, 190, 12, 147, 89, 191], [ 43, 246, 87, 190, 235, 226, 254, 63, 111, 240, 117, 191, 177, 191, 28, 191]], dtype=uint8) >>> x.view(numpy.uint8).size [4, 16]
- static with_unit(mantissa, unit)[source]#
Create a
Quantityfrom a raw value and a unit.This is a convenience factory that reads more naturally in some contexts than the standard constructor.
- Parameters:
mantissa (
Any) – The numerical value(s).unit (saiunit.Unit) – The physical unit.
- Returns:
A new
Quantitywith the given mantissa and unit.- Return type:
Examples
>>> import saiunit as u >>> u.Quantity.with_unit(2.0, unit=u.metre) Quantity(2., "m")