From e2f47748e3ac9fda652398c20314d4ae5f3ade9a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 3 Apr 2024 08:34:31 -0700 Subject: [PATCH] Fix tests that fail if enable_checks is true under NumPy 2.0.0rc1. np.vecdot is missing `__module__` under NumPy 2.0.0rc1. PiperOrigin-RevId: 621532796 --- jax/_src/numpy/lax_numpy.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 0c2640186d74..826251c82945 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3443,8 +3443,11 @@ def vdot( preferred_element_type=preferred_element_type) -@util.implements(getattr(np, "vecdot", None), lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) +@util.implements( + getattr(np, "vecdot", None), lax_description=_PRECISION_DOC, + extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION, + # TODO(phawkins): numpy.vecdot doesn't have a __module__ attribute. + module="numpy") def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: