Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LaxBackedNumpyTests.testClipStaticBounds18 failing #20664

Closed
vladbelit opened this issue Apr 9, 2024 · 1 comment · Fixed by #20665
Closed

LaxBackedNumpyTests.testClipStaticBounds18 failing #20664

vladbelit opened this issue Apr 9, 2024 · 1 comment · Fixed by #20665
Assignees
Labels
bug Something isn't working

Comments

@vladbelit
Copy link

vladbelit commented Apr 9, 2024

Description

It looks to have been caused by #20550

Test location:
https://github.com/google/jax/blob/f5cc272615ce2795f9133e63b7b535ec5ada7e52/tests/lax_numpy_test.py#L869

__________________ LaxBackedNumpyTests.testClipStaticBounds18 __________________
[gw6] linux -- Python 3.12.0 /home/kbuilder/.pyenv/versions/3.12.0/bin/python
tests/lax_numpy_test.py:877: in testClipStaticBounds
    self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
jax/_src/test_util.py:1180: in _CheckAgainstNumpy
    lax_ans = lax_op(*args)
tests/lax_numpy_test.py:875: in <lambda>
    jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max)
jax/_src/pjit.py:304: in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
jax/_src/pjit.py:171: in _python_pjit_helper
    attrs_tracked) = _infer_params(jit_info, args, kwargs)
jax/_src/pjit.py:605: in _infer_params
    jaxpr, consts, out_shardings_flat, out_layouts_flat, attrs_tracked = _pjit_jaxpr(
jax/_src/pjit.py:1222: in _pjit_jaxpr
    jaxpr, final_consts, out_type, attrs_tracked = _create_pjit_jaxpr(
jax/_src/linear_util.py:350: in memoized_fun
    ans = call(fun, *args)
jax/_src/pjit.py:1170: in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
jax/_src/profiler.py:335: in wrapper
    return func(*args, **kwargs)
jax/_src/interpreters/partial_eval.py:2326: in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic(
jax/_src/interpreters/partial_eval.py:2348: in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
jax/_src/linear_util.py:192: in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
jax/_src/numpy/lax_numpy.py:1343: in clip
    warnings.warn(
E   DeprecationWarning: Clip received a complex value either through the input or the min/max keywords. Complex values have no ordering and cannot be clipped. Attempting to clip using complex numbers is deprecated and will soon raise a ValueError. Please convert to a real value or array by taking the real or imaginary components via jax.numpy.real/imag respectively.

System info (python version, jaxlib version, accelerator, etc.)

Linux x86_64
Python 3.9-3.12
CPU

@vladbelit vladbelit added the bug Something isn't working label Apr 9, 2024
@vladbelit
Copy link
Author

@Micky774 Hi, please take a look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants