diag_indices_from#
- class saiunit.math.diag_indices_from(arr, **kwargs)#
Return indices for accessing the main diagonal of a given array.
Units are stripped before computing the indices.
- Parameters:
arr (
Union[Array,ndarray,number,bool, saiunit.Quantity]) – Input array. Must be at least 2-D with equal-length dimensions.- Returns:
indices – Index arrays to access the main diagonal.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> arr = jnp.array([[1, 2], [3, 4]]) >>> u.math.diag_indices_from(arr) (Array([0, 1], dtype=int32), Array([0, 1], dtype=int32))