diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 61fa45330fe3..b70752910aa2 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -201,7 +201,7 @@ def arccosh(x: ArrayLike, /) -> Array: @jit def bitwise_count(x: ArrayLike, /) -> Array: # Following numpy we take the absolute value and return uint8. - return lax.population_count(lax.abs(x)).astype('uint8') + return lax.population_count(abs(x)).astype('uint8') @_wraps(np.right_shift, module='numpy') @partial(jit, inline=True) diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py index aad3318a098c..a90cefb1c2de 100644 --- a/tests/lax_numpy_operators_test.py +++ b/tests/lax_numpy_operators_test.py @@ -581,7 +581,7 @@ def testBitwiseOp(self, name, rng_factory, shapes, dtypes): @jtu.sample_product( shape=array_shapes, - dtype=int_dtypes, + dtype=int_dtypes + unsigned_dtypes, ) def testBitwiseCount(self, shape, dtype): # np.bitwise_count added after numpy 1.26, but