Skip to content

Commit

Permalink
Merge pull request jax-ml#22191 from pkgoogle:better_log2_doc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 649038499
  • Loading branch information
jax authors committed Jul 3, 2024
2 parents 467c62c + 0e85ffe commit c00ac4f
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,9 +827,24 @@ def _logaddexp2_jvp(primals, tangents):
return primal_out, tangent_out


@implements(np.log2, module='numpy')
@partial(jit, inline=True)
def log2(x: ArrayLike, /) -> Array:
"""Calculates the base-2 logarithm of x element-wise
LAX-backend implementation of :func:`numpy.log2`.
Args:
x: Input array
Returns:
An array containing the base-2 logarithm of each element in ``x``, promotes
to inexact dtype.
Examples:
>>> x1 = jnp.array([0.25, 0.5, 1, 2, 4, 8])
>>> jnp.log2(x1)
Array([-2., -1., 0., 1., 2., 3.], dtype=float32)
"""
x, = promote_args_inexact("log2", x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))

Expand Down

0 comments on commit c00ac4f

Please sign in to comment.