diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 707a09fc6930..58a52e52ea4f 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -827,9 +827,24 @@ def _logaddexp2_jvp(primals, tangents): return primal_out, tangent_out -@implements(np.log2, module='numpy') @partial(jit, inline=True) def log2(x: ArrayLike, /) -> Array: + """Calculates the base-2 logarithm of x element-wise + + LAX-backend implementation of :func:`numpy.log2`. + + Args: + x: Input array + + Returns: + An array containing the base-2 logarithm of each element in ``x``, promotes + to inexact dtype. + + Examples: + >>> x1 = jnp.array([0.25, 0.5, 1, 2, 4, 8]) + >>> jnp.log2(x1) + Array([-2., -1., 0., 1., 2., 3.], dtype=float32) + """ x, = promote_args_inexact("log2", x) return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))