-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JAX] Add end-to-end execution support in colocated Python API
This change adds a capability to run colocated Python function calls through `PyLoadedExecutable`. This capability is not yet used for McJAX, but is tested with a prototype of a colocated Python backend. The overall behavior remains the same for McJAX (running the user code inline when colocated Python is called); the new logic will be used once we introduce a colocated Python backend for McJAX. Key highlights: * Colocated Python is compiled into `PyLoadedExeutable` and uses the JAX C++ dispatch path. * `CustomCallProgram` for a colocated Python compilation nows includes specialization (input/output specs, devices). This information allows a colocated Python backend to transform input/outputs and validate PyTree/dtype/shape/sharding. * `out_specs_fn` now receives `jax.ShapeDTypeStruct`s instead of concrete values. * Deserialization of devices now prefers the default backend. This improves the compatibility with an environment using both multi-platform backend as well as the standard "cpu" backend at the same time. * Several bugs have been fixed (e.g., correctly using `{}` for kwargs). PiperOrigin-RevId: 703172997
- Loading branch information
1 parent
3f5f3e1
commit e20a483
Showing
4 changed files
with
109 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters