Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[callback] Improve caching effectiveness in presence of callbacks. #20514

Merged
merged 1 commit into from
Apr 2, 2024

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Apr 1, 2024

Previously, the user-provided Python callback function was first flattened and then the result passed as a primitive parameter to the callback primitives. This means that two separate io_callback invocations with the same Python callable will generate different Jaxprs. To prevent this we defer the flattening to lowering time.

I discovered this problem by trying to ensure that io_callback passes the host_callback_test.py.

@gnecula gnecula self-assigned this Apr 1, 2024
@gnecula gnecula added the pull ready Ready for copybara import and testing label Apr 1, 2024
@gnecula gnecula requested review from superbobry and sharadmv April 1, 2024 09:39
Copy link
Collaborator

@superbobry superbobry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

Previously, the user-provided Python callback function was first
flattened and then the result passed as a primitive parameter to
the callback primitives. This means that two separate io_callback
invocations with the same Python callable will generate different
Jaxprs. To prevent this we defer the flattening to lowering time.
@copybara-service copybara-service bot merged commit 4c41c12 into jax-ml:main Apr 2, 2024
13 checks passed
@gnecula gnecula deleted the callback_cache branch April 2, 2024 14:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants