Skip to content

Commit

Permalink
[JAX] Add end-to-end execution support in colocated Python API
Browse files Browse the repository at this point in the history
This change adds a capability to run colocated Python function calls through
`PyLoadedExecutable`. This capability is not yet used for McJAX, but is tested
with a prototype of a colocated Python backend. The overall behavior remains
the same for McJAX (running the user code inline when colocated Python is
called); the new logic will be used once we introduce a colocated Python
backend for McJAX.

Key highlights:

* Colocated Python is compiled into `PyLoadedExeutable` and uses the JAX C++
dispatch path.

* `CustomCallProgram` for a colocated Python compilation nows includes
specialization (input/output specs, devices). This information allows a
colocated Python backend to transform input/outputs and validate
PyTree/dtype/shape/sharding.

* `out_specs_fn` now receives `jax.ShapeDTypeStruct`s instead of concrete values.

* Deserialization of devices now prefers the default backend. This improves the
compatibility with an environment using both multi-platform backend as well as
the standard "cpu" backend at the same time.

* Several bugs have been fixed (e.g., correctly using `{}` for kwargs).

PiperOrigin-RevId: 703172997
  • Loading branch information
hyeontaek authored and Google-ML-Automation committed Dec 5, 2024
1 parent 3f5f3e1 commit e20a483
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 34 deletions.
77 changes: 59 additions & 18 deletions jax/experimental/colocated_python/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import jax
from jax._src import api
from jax._src import tree_util
from jax._src.interpreters import pxla
from jax._src.lib import xla_client as xc
from jax._src.traceback_util import api_boundary
from jax._src.util import wraps
Expand Down Expand Up @@ -137,23 +138,54 @@ def _infer_devices_from_args(args: Sequence[Any]) -> xc.DeviceList | None:
def _compile_to_executable(
name: str,
fun: Callable[..., Any],
in_specs_treedef: tree_util.PyTreeDef,
in_specs_leaves: tuple[api.ShapeDtypeStruct, ...],
out_specs_treedef: tree_util.PyTreeDef,
out_specs_leaves: tuple[api.ShapeDtypeStruct, ...],
devices: xc.DeviceList,
) -> Callable[..., Any]:
"""Compiles a Python function into a runtime executable."""
pickled_function = _serialize(fun)
fun_and_specialization = (
fun,
in_specs_treedef,
in_specs_leaves,
out_specs_treedef,
out_specs_leaves,
devices,
)
pickled_function = _serialize(fun_and_specialization)
program = ifrt_programs.make_colocated_python_program(
name, pickled_function, devices, in_specs_leaves, out_specs_leaves
)
# TODO(hyeontaek): Compile the program and use the executable.
del program
ifrt_client = devices[0].client
out_sdss = tuple(
jax.core.ShapedArray(sds.shape, sds.dtype) for sds in out_specs_leaves
)
out_shardings = tuple(sds.sharding for sds in out_specs_leaves)
try:
compile_options = ifrt_programs.make_colocated_python_compile_options()
loaded_executable = ifrt_client.compile_ifrt_program(
program, compile_options
)
out_handlers = pxla.global_avals_to_results_handler(
out_sdss, out_shardings, committed=True
).handlers

def call(*args, **kwargs):
args_leaves = tree_util.tree_leaves((args, kwargs))
execute_result = loaded_executable.execute_sharded(
args_leaves, with_tokens=False
)
results = execute_result.consume_with_handlers(out_handlers)
return tree_util.tree_unflatten(out_specs_treedef, results)

del name
del in_specs_leaves
del out_specs_leaves
del devices
return fun
return call
except jax.errors.JaxRuntimeError as e:
# TODO(hyeontaek): Implement colocated Python support in McJAX and remove
# this fallback path.
if "PjRtCompiler requires an HloProgram" in str(e):
return fun
raise


def _make_output_specs_and_push_result_fun(
Expand All @@ -170,20 +202,22 @@ def _make_output_specs_and_push_result_fun(

def lowered_fun(*args, **kwargs) -> Sequence[jax.Array]:
result = info.fun(*args, **kwargs)
out_leaves, out_treedef = tree_util.tree_flatten(result)
out_spec_leaves = tuple(_get_spec(x) for x in out_leaves)
func_backend.SINGLETON_RESULT_STORE.push(uid, out_leaves)
result_leaves, out_treedef = tree_util.tree_flatten(result)
out_spec_leaves = tuple(_get_spec(x) for x in result_leaves)
func_backend.SINGLETON_RESULT_STORE.push(uid, result_leaves)
return _serialize_specs(out_treedef, out_spec_leaves, devices)

out_specs_leaves, _ = tree_util.tree_flatten(
out_specs_leaves, out_specs_treedef = tree_util.tree_flatten(
_make_specs_for_serialized_specs(specialization.devices),
)
name = getattr(info.fun, "__name__", "unknown")
name = f"{name}_output_specs_and_push_result"
return _compile_to_executable(
name=name,
fun=lowered_fun,
in_specs_treedef=specialization.in_specs_treedef,
in_specs_leaves=specialization.in_specs_leaves,
out_specs_treedef=out_specs_treedef,
out_specs_leaves=tuple(out_specs_leaves),
devices=specialization.devices,
)
Expand All @@ -200,21 +234,23 @@ def _make_pop_result_fun(
out_specs_treedef = specialization.out_specs_treedef

def lowered_fun() -> Any:
flat_result = func_backend.SINGLETON_RESULT_STORE.pop(uid)
return tree_util.tree_unflatten(out_specs_treedef, flat_result)
result_leaves = func_backend.SINGLETON_RESULT_STORE.pop(uid)
return tree_util.tree_unflatten(out_specs_treedef, result_leaves)

in_specs, _ = tree_util.tree_flatten((
in_specs_leaves, in_specs_treedef = tree_util.tree_flatten((
# args
(),
# kwargs
(),
{},
))
name = getattr(info.fun, "__name__", "unknown")
name = f"{name}_pop_result"
return _compile_to_executable(
name=name,
fun=lowered_fun,
in_specs_leaves=tuple(in_specs),
in_specs_treedef=in_specs_treedef,
in_specs_leaves=tuple(in_specs_leaves),
out_specs_treedef=specialization.out_specs_treedef,
out_specs_leaves=specialization.out_specs_leaves,
devices=specialization.devices,
)
Expand All @@ -234,7 +270,9 @@ def _make_async_execution_fun(
return _compile_to_executable(
name=name,
fun=info.fun,
in_specs_treedef=specialization.in_specs_treedef,
in_specs_leaves=specialization.in_specs_leaves,
out_specs_treedef=specialization.out_specs_treedef,
out_specs_leaves=specialization.out_specs_leaves,
devices=specialization.devices,
)
Expand Down Expand Up @@ -283,7 +321,10 @@ def specialized_func(*args, **kwargs) -> Any:
return _make_pop_result_fun(info, specialization, uid)()
else:
# Compute out_specs using out_specs_fn and inputs.
out_specs = specialization.out_specs_fn(*args, **kwargs)
args_specs, kwargs_specs = tree_util.tree_map(
_get_spec, (args, kwargs)
)
out_specs = specialization.out_specs_fn(*args_specs, **kwargs_specs)
# Type checking is ignored to silence mypy error: Incompatible types
# in assignment (expression has type "list[Any]", variable has type
# "tuple[ShapeDtypeStruct, ...]") [assignment]
Expand Down
20 changes: 17 additions & 3 deletions jax/experimental/colocated_python/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,22 @@ def _get_cpu_device_map() -> dict[int, jax.Device]:
# associated with colocated_python. When deserializing on the colocated_python
# executor, it should be the CPU backend visible to the user function running
# under colocated_python.
for backed in xb.backends().values():
for d in backed._get_all_devices(): # pylint: disable=protected-access

# Look for CPU devices in the default backend.
for d in xb.local_devices()[0].client._get_all_devices(): # pylint: disable=protected-access
if d.device_kind == "cpu":
if d.id in cpu_device_map:
raise ValueError(
f"Multiple CPU devices with id {d.id} found:"
f" {cpu_device_map[d.id]} and {d}"
)
cpu_device_map[d.id] = d
if cpu_device_map:
return cpu_device_map

# Fall back to searching CPU devices in all backends.
for backend in xb.backends().values():
for d in backend._get_all_devices(): # pylint: disable=protected-access
if d.device_kind == "cpu":
if d.id in cpu_device_map:
raise ValueError(
Expand Down Expand Up @@ -87,7 +101,7 @@ def make_device_list(device_ids: Sequence[int]) -> DeviceList:
devices = np.vectorize(lambda device_id: cpu_device_map[device_id])(
device_ids
)
return DeviceList(devices)
return DeviceList(tuple(devices))

device_ids = [d.id for d in device_list]
return make_device_list, (device_ids,)
Expand Down
1 change: 1 addition & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,7 @@ exports_files(
"api_test.py",
"array_test.py",
"cache_key_test.py",
"colocated_python_test.py",
"compilation_cache_test.py",
"memories_test.py",
"pmap_test.py",
Expand Down
45 changes: 32 additions & 13 deletions tests/colocated_python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,20 @@ def _colocated_cpu_devices(
devices: Sequence[jax.Device],
) -> Sequence[jax.Device]:
"""Returns CPU devices colocated with the given devices."""
# TODO(hyeontaek): Use `colocated_python.colocated_cpu_devices(devices)` once
# PjRt-IFRT prepares CPU devices by its own.
cpu_backend_devices = jax.local_devices(backend="cpu")
device_index_map = {device.id: i for i, device in enumerate(jax.devices())}
try:
return colocated_python.colocated_cpu_devices(devices)
except (ValueError, AttributeError):
# PjRt-IFRT prepares CPU devices by its own.
# TODO(hyeontaek): Remove this fallback path once PjRt-IFRT prepares CPU
# devices by its own.
cpu_backend_devices = jax.local_devices(backend="cpu")
device_index_map = {device.id: i for i, device in enumerate(jax.devices())}

available_devices = devices[: min(len(cpu_backend_devices), len(devices))]
return [
cpu_backend_devices[device_index_map[d.id]] for d in available_devices
]

available_devices = devices[:min(len(cpu_backend_devices), len(devices))]
return [
cpu_backend_devices[device_index_map[d.id]] for d in available_devices
]

@contextlib.contextmanager
def _count_colocated_python_specialization_cache_miss() -> list[int]:
Expand Down Expand Up @@ -79,20 +84,20 @@ class ColocatedPythonTest(jtu.JaxTestCase):

def setUp(self):
super().setUp()
if xla_extension_version < 298:
self.skipTest("Requires xla_extension_version >= 298")
if xla_extension_version < 300:
self.skipTest("Requires xla_extension_version >= 300")

def testMakeColocatedPythonProgram(self):
def add_one(x):
return x + 1

cpu_devices = _colocated_cpu_devices(jax.local_devices())
sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0])
aval = jax.ShapeDtypeStruct((), jnp.int32, sharding=sharding)
sds = jax.ShapeDtypeStruct((), jnp.int32, sharding=sharding)

pickled_function = serialization._serialize(add_one)
program = ifrt_programs.make_colocated_python_program(
"add_one", pickled_function, [cpu_devices[0]], [aval], [aval]
"add_one", pickled_function, [cpu_devices[0]], [sds], [sds]
)
del program

Expand All @@ -107,10 +112,12 @@ def add_one(x):

with _count_colocated_python_specialization_cache_miss() as count:
out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)

out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)

Expand All @@ -125,10 +132,12 @@ def add_one(x):

with _count_colocated_python_specialization_cache_miss() as count:
out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 1)

out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 1)

Expand All @@ -154,10 +163,12 @@ def make_zero():
with _count_colocated_python_specialization_cache_miss() as count:
make_zero = make_zero.specialize(devices=cpu_devices[:1])
out = make_zero()
out = jax.device_get(out)
self.assertEqual(out, np.array(0))
self.assertEqual(count[0], 1)

out = make_zero()
out = jax.device_get(out)
self.assertEqual(out, np.array(0))
self.assertEqual(count[0], 1)

Expand All @@ -172,10 +183,12 @@ def add_one(x):

with _count_colocated_python_specialization_cache_miss() as count:
out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)

out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)

Expand All @@ -184,10 +197,12 @@ def add_one(x):
x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0]))

out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 2)

out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 2)

Expand All @@ -203,22 +218,26 @@ def add_one(x):
with _count_colocated_python_specialization_cache_miss() as count:
add_one = add_one.specialize(out_specs_fn=lambda x: x)
out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)

out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)

# Different input tree structure and dtype/shape.
x = [np.array(1), (np.array(2), {"v": jnp.array(3)})]
x = [np.array(1), (np.array(2), {"v": np.array(3)})]
x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0]))

out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 2)

out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 2)

Expand Down

0 comments on commit e20a483

Please sign in to comment.