Skip to content
This repository has been archived by the owner on Jun 12, 2024. It is now read-only.

JAX: Cache intermediates to speed up guide vjp #6

Open
jiawen opened this issue Sep 25, 2020 · 0 comments
Open

JAX: Cache intermediates to speed up guide vjp #6

jiawen opened this issue Sep 25, 2020 · 0 comments
Assignees

Comments

@jiawen
Copy link
Contributor

jiawen commented Sep 25, 2020

bilateral_slice and bilateral_slice_guide_vjp are nearly identical. The intermediates from the former can be cached as the "residual" computation in _bilateral_slice_fwd and passed to _bilateral_slice_bwd.

@jiawen jiawen self-assigned this Sep 25, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant