Skip to content

Commit

Permalink
Merge pull request #17962 from jakevdp:fix-bitwise-count
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571398043
  • Loading branch information
jax authors committed Oct 6, 2023
2 parents fae53d9 + cd18f8e commit 6681e64
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def arccosh(x: ArrayLike, /) -> Array:
@jit
def bitwise_count(x: ArrayLike, /) -> Array:
# Following numpy we take the absolute value and return uint8.
return lax.population_count(lax.abs(x)).astype('uint8')
return lax.population_count(abs(x)).astype('uint8')

@_wraps(np.right_shift, module='numpy')
@partial(jit, inline=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/lax_numpy_operators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def testBitwiseOp(self, name, rng_factory, shapes, dtypes):

@jtu.sample_product(
shape=array_shapes,
dtype=int_dtypes,
dtype=int_dtypes + unsigned_dtypes,
)
def testBitwiseCount(self, shape, dtype):
# np.bitwise_count added after numpy 1.26, but
Expand Down

0 comments on commit 6681e64

Please sign in to comment.