gather#
- class saiunit.math.gather(input, dim, index, **kwargs)#
Gather values along an axis specified by dim, according to index.
JAX implementation of
torch.gather.- Parameters:
- Returns:
out – Array with the gathered elements. Quantity if input is a Quantity.
- Return type:
Array, Quantity
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> a = jnp.array([[1, 2], [3, 4]]) * u.mV >>> index = jnp.array([[0, 0], [1, 0]]) >>> u.math.gather(a, 1, index)