Skip to content

Commit

Permalink
Pass flags from kernel into HLO backend config.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 578390868
  • Loading branch information
jax authors committed Nov 1, 2023
1 parent 9f28512 commit a009f8d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax/_src/pallas/mosaic/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def _lower_fun(*args):
backend=ctx.module_context.backend,
kernel_name=name,
kernel_regeneration_metadata=kernel_regeneration_metadata,
cost_estimate=mosaic_params.get('cost_estimate', None),
cost_estimate=mosaic_params.get("cost_estimate", None),
flags=mosaic_params.get("flags", None),
)(*extra_args, *args)
return mlir.lower_fun(_lower_fun, multiple_results=True)(
ctx, *in_nodes)
Expand Down
26 changes: 26 additions & 0 deletions jax/_src/tpu_custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class CustomCallBackendConfig:
collective_id: int | None
device_type: str | None
cost_estimate: CostEstimate | None
flags: dict[str, bool | int | float] | None

# We omit the body while printing, because primitive params get embedded
# in HLO metadata, and the body blows up its size.
Expand Down Expand Up @@ -122,6 +123,27 @@ def to_json(self) -> bytes:
config.write(
('"DEVICE_TYPE_' + self.device_type.upper() + '"').encode("ascii")
)
if self.flags is not None:
config.write(b', "flag_configs": [')
for i, (flag, value) in enumerate(self.flags.items()):
config.write(b'{"flag_type": "')
config.write(flag.encode("ascii"))
config.write(b'", value: {')
if isinstance(value, bool):
config.write(b'"boolean_value": ')
config.write(b"true" if value else b"false")
elif isinstance(value, int):
config.write(b'"integer_value": ')
config.write(str(value).encode("ascii"))
elif isinstance(value, float):
config.write(b'"double_value": ')
config.write(str(value).encode("ascii"))
else:
raise ValueError("invalid flag value: " + str(value))
config.write(b"}}")
if i + 1 != len(self.flags):
config.write(b",")
config.write(b"]")
config.write(b"}")
return config.getvalue()

Expand Down Expand Up @@ -371,6 +393,7 @@ def as_tpu_kernel(
device_type: str | None = None,
kernel_name: str | None = None,
kernel_regeneration_metadata: bytes | None = None,
flags: dict[str, bool | int | float] | None = None,
) -> Callable[..., Any]:
"""Turns an MLIR Mosaic kernel into a JAX-compatible function."""
# We use jax.jit to make sure we hit the fast compilation cache.
Expand Down Expand Up @@ -398,6 +421,7 @@ def as_tpu_kernel(
kernel_name=kernel_name,
kernel_regeneration_metadata=kernel_regeneration_metadata,
cost_estimate=cost_estimate,
flags=flags,
)


Expand All @@ -412,6 +436,7 @@ def _lowered_as_tpu_kernel(
has_custom_barrier: bool = False,
kernel_name: str | None = None,
kernel_regeneration_metadata: bytes | None = None,
flags: dict[str, bool | int | float] | None = None,
):
"""Turns a low-level MLIR Mosaic kernel into a JAX-compatible function."""
unpack = False
Expand All @@ -436,6 +461,7 @@ def apply_kernel(*args, collective_id: int | None = None):
collective_id,
device_type,
cost_estimate,
flags,
)
result = tpu_custom_call_p.bind(
*args,
Expand Down

0 comments on commit a009f8d

Please sign in to comment.