einrepeat

Contents

einrepeat#

class saiunit.math.einrepeat(x, pattern, **axes_lengths)#

Reorder elements and repeat them in arbitrary combinations.

This operation includes functionality of repeat, tile, and broadcast functions. When composing axes, C-order enumeration is used (consecutive elements have different last axis).

Parameters:
  • x (Union[Array, ndarray, number, bool, saiunit.Quantity, Sequence[Array | ndarray | number | bool], Sequence[saiunit.Quantity]]) – Input tensor(s). A list of tensors of the same type and shape is also accepted.

  • pattern (str) – Repeat pattern in 'input -> output' form. New axes in the output expression are repeated according to axes_lengths.

  • **axes_lengths (int) – Sizes of new or decomposed axes.

Returns:

out – Tensor of the same type as input.

Return type:

Union[Array, ndarray, number, bool, saiunit.Quantity]

Examples

>>> import jax.numpy as jnp
>>> import saiunit.math as sumath
>>> x = jnp.zeros((3, 4))
>>> sumath.einrepeat(x, 'h w -> h w c', c=3).shape
(3, 4, 3)
>>> sumath.einrepeat(x, 'h w -> (repeat h) w', repeat=2).shape
(6, 4)