-
Notifications
You must be signed in to change notification settings - Fork 247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
filter out tests waiting for next tfp release #1817
Conversation
There is a new test failing probably because the new jax and/or numpy releases __________________ test_discrete_site_without_infer_enumerate __________________
def test_discrete_site_without_infer_enumerate():
def model():
numpyro.sample("x", dist.Bernoulli(0.5))
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
with pytest.warns(FutureWarning, match="enumerated sites"):
> mcmc.run(random.PRNGKey(0))
E FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
test/infer/test_mcmc.py:1104: FutureWarning I added a different match group in aa30c69 but I think it is essential to address these warnings. Especially because we are also getting DeprecationWarning: numpy.core.numeric is deprecated and has been renamed to numpy._core.numeric. The numpy._core namespace contains private NumPy internals and its use is discouraged, as NumPy internals can change without warning in any release. In practice, most real-world usage of numpy.core is to access functionality in the public NumPy API. If that is the case, use the public NumPy API. If not, you are using NumPy internals. If you would still like to access an internal attribute, use numpy._core.numeric.normalize_axis_tuple.
from numpy.core.numeric import normalize_axis_tuple
test/infer/test_mcmc.py::test_discrete_site_without_infer_enumerate
/Users/juanitorduz/Documents/envs/numpyro-env/lib/python3.12/site-packages/jax/_src/linear_util.py:192: DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy.clip is deprecated. Please use 'x', 'min', and 'max' respectively instead. |
It seems
is all over the place now 😓 |
Thanks @juanitorduz. I'm looking at them. |
ok! You can either push to this branch or create a new one if needed |
It turns out that in funsor, we have some checks for tracers to be Hashable. I don't think that the new behavior will cause issues: it is fine to let arrays to be either hashable or unhashable. So I think we can simply filter out these warnings:
|
ok! We are making progress 😅 ! Now we have FAILED test/contrib/test_control_flow.py::test_scan - TypeError: body_fun output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', ['ShapedArray(float32[10])', 'ShapedArray(float32[10])', 'DIFFERENT ShapedArray(int32[], weak_type=True) vs. ShapedArray(float0[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])'], []).
FAILED test/contrib/test_control_flow.py::test_scan_svi - TypeError: body_fun output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', ['ShapedArray(float32[3,5])', 'DIFFERENT ShapedArray(int32[], weak_type=True) vs. ShapedArray(float0[])', 'ShapedArray(float32[5])'], []). and TypeError: body_fun output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', ['ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'DIFFERENT ShapedArray(int32[], weak_type=True) vs. ShapedArray(float0[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[10])'], []).
test/test_examples.py::test_cpu[holt_winters.py --T 4 --num-samples 10 --num-warmup 10 --num-chains 2] Running:
python examples/holt_winters.py --T 4 --num-samples 10 --num-warmup 10 --num-chains 2 I am so seeing The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. 🤔 |
It seems like a type problem ... could this be again a numpy or jax recent change? |
Hi @juanitorduz, I raised the upstream error in jax-ml/jax#22045. For a fix, could you help me change every tree_map(device_put, foo) |
In 19f8232 I still saw other test failing: FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[1-MetropolisAdjustedLangevinAlgorithm-kwargs0] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[1-SliceSampler-kwargs2] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[1-UncalibratedLangevin-kwargs3] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[2-MetropolisAdjustedLangevinAlgorithm-kwargs0] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[2-SliceSampler-kwargs2] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[2-UncalibratedLangevin-kwargs3] - OverflowError: Python int too large to convert to C long Hence: 038dd84 |
Ok! Finally is 🟢! Shall we revert these changes once the JAX issue is fixed and released? Similarly for tfp next release? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing the issues, @juanitorduz!!
Partially addresses #1814 . We must keep in mind removing these
skip
statements once we see a new release.