squareplus#
- class saiunit.math.squareplus(x, b=4, unit_to_scale=None)#
Squareplus activation function.
Computes the element-wise function
\[\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}\]as described in https://arxiv.org/abs/2112.11687.
- Parameters:
x (
Union[saiunit.Quantity,Array,ndarray,number,bool]) – Input array. Must be unitless if aQuantity.b (
Array|ndarray|number|bool) – Smoothness parameter. Default is 4.unit_to_scale (
Optional[saiunit.Unit]) – Unit used to convertxto a dimensionless number before applying the activation.
- Returns:
out – An array with non-negative values.
- Return type:
Array
Examples
>>> import jax.numpy as jnp >>> import saiunit.math as sumath >>> sumath.squareplus(jnp.array([-2., 0., 2.])) Array([0.23606798, 1. , 2.2360680 ], dtype=float32)