sort_key_val

Contents

sort_key_val#

class saiunit.lax.sort_key_val(keys, values, dimension=-1, is_stable=True, **kwargs)#

Sort keys along dimension and apply the same permutation to values.

Parameters:
  • keys (Union[saiunit.Quantity, Array, ndarray, number, bool]) – The array of keys to sort.

  • values (Union[saiunit.Quantity, Array, ndarray, number, bool]) – The array of values to permute according to the sorted order of keys.

  • dimension (int) – The dimension along which to sort. Default is -1.

  • is_stable (bool) – Whether to use a stable sort. Default is True.

Return type:

tuple[Union[saiunit.Quantity, Array], Union[saiunit.Quantity, Array]]

Returns:

  • sorted_keys (Array or Quantity) – The sorted keys. Preserves the unit of keys.

  • sorted_values (Array or Quantity) – The values permuted to match the sorted order of keys. Preserves the unit of values.

Examples

>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> keys = jnp.array([3.0, 1.0, 2.0]) * u.meter
>>> vals = jnp.array([30, 10, 20])
>>> sorted_keys, sorted_vals = sulax.sort_key_val(keys, vals)
>>> sorted_keys.mantissa
Array([1., 2., 3.], dtype=float32)
>>> sorted_vals
Array([10, 20, 30], dtype=int32)