multi_dot#
- class saiunit.math.multi_dot(arrays, *, precision=None, **kwargs)#
Efficiently compute matrix products between a sequence of arrays.
JAX internally uses the opt_einsum library to compute the most efficient operation order. The resulting unit is the product of the units of all input arrays.
- Parameters:
arrays (
Sequence[Union[Array,ndarray,number,bool, saiunit.Quantity]]) – Sequence of arrays or quantities. All must be two-dimensional, except the first and last which may be one-dimensional.precision (
Union[None,str,Precision,tuple[str,str],tuple[Precision,Precision],DotAlgorithm,DotAlgorithmPreset]) – EitherNone(default), or aPrecisionenum value.
- Returns:
output – An array representing the equivalent of
reduce(jnp.matmul, arrays), evaluated in the optimal order. The resulting unit is the product of all input units.- Return type:
Union[Array, saiunit.Quantity]
Examples
>>> import saiunit as u >>> import jax >>> k1, k2 = jax.random.split(jax.random.key(0)) >>> a = jax.random.normal(k1, shape=(3, 4)) * u.meter >>> b = jax.random.normal(k2, shape=(4, 2)) * u.second >>> u.math.multi_dot([a, b]) # unit is meter * second