Skip to content

Commit

Permalink
Fix tests that fail if enable_checks is true under NumPy 2.0.0rc1.
Browse files Browse the repository at this point in the history
np.vecdot is missing `__module__` under NumPy 2.0.0rc1.

PiperOrigin-RevId: 621532796
  • Loading branch information
hawkinsp authored and jax authors committed Apr 3, 2024
1 parent d89f0d6 commit e2f4774
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3443,8 +3443,11 @@ def vdot(
preferred_element_type=preferred_element_type)


@util.implements(getattr(np, "vecdot", None), lax_description=_PRECISION_DOC,
extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
@util.implements(
getattr(np, "vecdot", None), lax_description=_PRECISION_DOC,
extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION,
# TODO(phawkins): numpy.vecdot doesn't have a __module__ attribute.
module="numpy")
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
Expand Down

0 comments on commit e2f4774

Please sign in to comment.