From 0e85ffee405abf0577ca297398a5fe4f01d684e4 Mon Sep 17 00:00:00 2001 From: Piseth Ky Date: Fri, 28 Jun 2024 15:27:33 -0700 Subject: [PATCH] better log2 doc removing complex example removing negative example --- jax/_src/numpy/ufuncs.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 707a09fc6930..58a52e52ea4f 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -827,9 +827,24 @@ def _logaddexp2_jvp(primals, tangents): return primal_out, tangent_out -@implements(np.log2, module='numpy') @partial(jit, inline=True) def log2(x: ArrayLike, /) -> Array: + """Calculates the base-2 logarithm of x element-wise + + LAX-backend implementation of :func:`numpy.log2`. + + Args: + x: Input array + + Returns: + An array containing the base-2 logarithm of each element in ``x``, promotes + to inexact dtype. + + Examples: + >>> x1 = jnp.array([0.25, 0.5, 1, 2, 4, 8]) + >>> jnp.log2(x1) + Array([-2., -1., 0., 1., 2., 3.], dtype=float32) + """ x, = promote_args_inexact("log2", x) return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))