batch_matmul#
- class saiunit.lax.batch_matmul(x, y, precision=None, **kwargs)#
Batch matrix multiplication.
- Parameters:
x (
Union[saiunit.Quantity,Array,ndarray,number,bool]) – Left input array of shape[..., m, k].y (
Union[saiunit.Quantity,Array,ndarray,number,bool]) – Right input array of shape[..., k, n].precision (
Union[None,str,Precision,tuple[str,str],tuple[Precision,Precision],DotAlgorithm,DotAlgorithmPreset]) – Numerical precision of the computation.
- Returns:
result – The batch matrix product of shape
[..., m, n]. The resulting unit isunit(x) * unit(y).- Return type:
Union[saiunit.Quantity,Array,ndarray,number,bool]
Examples
>>> import saiunit as u >>> import saiunit.lax as sulax >>> import jax.numpy as jnp >>> x = jnp.ones((2, 3, 4)) * u.meter >>> y = jnp.ones((2, 4, 5)) * u.second >>> result = sulax.batch_matmul(x, y) >>> result.mantissa.shape (2, 3, 5)