From 76d5938062360177a7246af8f9ffc6e0dd4ead15 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 7 Oct 2024 15:33:24 -0700 Subject: [PATCH] [pallas] Added `MemoryRef` and `run_scoped` to the API docs PiperOrigin-RevId: 683349061 --- docs/jax.experimental.pallas.rst | 4 ++++ jax/_src/pallas/primitives.py | 12 +++++------- jax/experimental/pallas/__init__.py | 4 +++- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/docs/jax.experimental.pallas.rst b/docs/jax.experimental.pallas.rst index d250d682d32a..1cddbc177e6f 100644 --- a/docs/jax.experimental.pallas.rst +++ b/docs/jax.experimental.pallas.rst @@ -13,6 +13,8 @@ Classes GridSpec Slice + MemoryRef + Functions --------- @@ -36,3 +38,5 @@ Functions atomic_xchg debug_print + + run_scoped diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 40caae76bd8f..3bf815cd3cdd 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -819,15 +819,13 @@ def debug_print_lowering_rule(ctx, *args, **params): run_scoped_p.multiple_results = True -def run_scoped(f: Callable[..., Any], *types, **kw_types) -> Any: - """Call the function with allocated references. +def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any: + """Calls the function with allocated references and returns the result. - Args: - f: The function that generates the jaxpr. - *types: The types of the function's positional arguments. - **kw_types: The types of the function's keyword arguments. + The positional and keyword arguments describe which reference types + to allocate for each argument. Each backend has its own set of reference + types in addition to :class:`jax.experimental.pallas.MemoryRef`. """ - flat_types, in_tree = tree_util.tree_flatten((types, kw_types)) flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree) avals = [t.get_ref_aval() for t in flat_types] diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 51751f7ec96d..0a82137f8dd6 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -24,10 +24,11 @@ from jax._src.pallas.core import CostEstimate from jax._src.pallas.core import GridSpec from jax._src.pallas.core import IndexingMode +from jax._src.pallas.core import MemorySpace +from jax._src.pallas.core import MemoryRef from jax._src.pallas.core import no_block_spec from jax._src.pallas.core import Unblocked from jax._src.pallas.core import unblocked -from jax._src.pallas.core import MemorySpace from jax._src.pallas.pallas_call import pallas_call from jax._src.pallas.pallas_call import pallas_call_p from jax._src.pallas.primitives import atomic_add @@ -57,4 +58,5 @@ from jax._src.state.indexing import Slice from jax._src.state.primitives import broadcast_to + ANY = MemorySpace.ANY