einreduce

Contents

einreduce#

class saiunit.math.einreduce(x, pattern, reduction, **axes_lengths)#

Combine reordering and reduction using reader-friendly notation.

einreduce provides combination of reordering and reduction using reader-friendly notation similar to einops.

Parameters:
  • x (Union[Array, ndarray, number, bool, saiunit.Quantity, Sequence[Array | ndarray | number | bool], Sequence[saiunit.Quantity]]) – Input tensor(s). A list of tensors of the same type and shape is also accepted.

  • pattern (str) – Reduction pattern in 'input -> output' form. Axes that appear on the left but not on the right are reduced.

  • reduction (Union[str, Callable[[Array | ndarray | number | bool, Tuple[int, ...]], Array | ndarray | number | bool]]) – Reduction operation to apply. A callable with signature f(tensor, reduced_axes) -> tensor may also be provided.

  • **axes_lengths (int) – Additional specifications for dimension sizes.

Returns:

out – The reduced tensor with the same type as the input.

Return type:

Union[Array, ndarray, number, bool, saiunit.Quantity]

Examples

>>> import jax.numpy as jnp
>>> import saiunit.math as sumath
>>> x = jnp.ones((4, 3, 5))
>>> sumath.einreduce(x, 'b c h -> b c', 'sum').shape
(4, 3)
>>> sumath.einreduce(x, 'b c h -> b c', 'mean').shape
(4, 3)