Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: jax-ml/jax
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 8ca39a1664a9eab11c13aefe83146f594532c1ef
Choose a base ref
..
head repository: jax-ml/jax
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: 86f6f07b892fce53a50732c49c203f92fd736f72
Choose a head ref
Showing with 20 additions and 8 deletions.
  1. +20 −8 jax/_src/nn/functions.py
28 changes: 20 additions & 8 deletions jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
@@ -141,7 +141,9 @@ def silu(x: ArrayLike) -> Array:
See also:
:func:`sigmoid`
"""
return jnp.multiply(x, sigmoid(x))
numpy_util.check_arraylike("silu", x)
x_arr = jnp.asarray(x)
return x_arr * sigmoid(x_arr)

swish = silu

@@ -163,7 +165,9 @@ def log_sigmoid(x: ArrayLike) -> Array:
See also:
:func:`sigmoid`
"""
return -softplus(jnp.negative(x))
numpy_util.check_arraylike("log_sigmoid", x)
x_arr = jnp.asarray(x)
return -softplus(-x_arr)

@jax.jit
def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array:
@@ -187,9 +191,11 @@ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array:
See also:
:func:`selu`
"""
x = jnp.asarray(x)
safe_x = jnp.where(jnp.greater(x, 0), 0., x)
return jnp.where(jnp.greater(x, 0), x, alpha * jnp.expm1(safe_x))
numpy_util.check_arraylike("elu", x)
x_arr = jnp.asarray(x)
return jnp.where(x_arr > 0,
x_arr,
alpha * jnp.expm1(jnp.where(x_arr > 0, 0., x_arr)))

@jax.jit
def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Array:
@@ -215,7 +221,9 @@ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Array:
See also:
:func:`relu`
"""
return jnp.where(jnp.greater_equal(x, 0), x, negative_slope * x)
numpy_util.check_arraylike("leaky_relu", x)
x_arr = jnp.asarray(x)
return jnp.where(x_arr >= 0, x_arr, negative_slope * x_arr)

@jax.jit
def hard_tanh(x: ArrayLike) -> Array:
@@ -236,7 +244,9 @@ def hard_tanh(x: ArrayLike) -> Array:
Returns:
An array.
"""
return jnp.where(jnp.greater(x, 1), 1, jnp.where(jnp.less(x, -1), -1, x))
numpy_util.check_arraylike("hard_tanh", x)
x_arr = jnp.asarray(x)
return jnp.where(x_arr > 1, 1, jnp.where(x_arr < -1, -1, x_arr))

@jax.jit
def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array:
@@ -482,6 +492,8 @@ def standardize(x: ArrayLike,
epsilon: ArrayLike = 1e-5,
where: Optional[ArrayLike] = None) -> Array:
r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
numpy_util.check_arraylike("standardize", x)
numpy_util.check_arraylike_or_none("standardize", mean, variance, where)
if mean is None:
mean = jnp.mean(x, axis, keepdims=True, where=where)
if variance is None:
@@ -491,7 +503,7 @@ def standardize(x: ArrayLike,
# when used in neural network normalization layers
variance = jnp.mean(
jnp.square(x), axis, keepdims=True, where=where) - jnp.square(mean)
return jnp.subtract(x, mean) * lax.rsqrt(jnp.add(variance, epsilon))
return jnp.subtract(x, jnp.asarray(mean)) * lax.rsqrt(jnp.asarray(variance) + epsilon)

def normalize(x: ArrayLike,
axis: Optional[Union[int, tuple[int, ...]]] = -1,