Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quadrature within jitted functions can cause leaked trace #18

Open
wcxve opened this issue Dec 6, 2024 · 1 comment
Open

Quadrature within jitted functions can cause leaked trace #18

wcxve opened this issue Dec 6, 2024 · 1 comment

Comments

@wcxve
Copy link

wcxve commented Dec 6, 2024

Hi! This library is fantastic for performing numerical quadrature with JAX!

I encountered an issue when enabling JAX’s leak-checking mechanism for tracers. The following example raises an exception:

import jax
import jax.numpy as jnp
from quadax import quadgk

@jax.jit
def integral(interval):
    return quadgk(jnp.square, interval)

with jax.checking_leaks():
    jax.block_until_ready(integral([0.0, 1.0]))

The error message indicates a trace leak:

...... Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line .../python3.9/site-packages/quadax/utils.py:99 (map_interval)
<DynamicJaxprTracer 5063988464> is referred to by ......

@f0uriest
Copy link
Owner

f0uriest commented Dec 7, 2024

Hmm. I'm able to reproduce this but haven't been able to fully isolate it, so still working on a fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants