Skip to content

Commit

Permalink
Raise an error when registering a lowering for an unknown platform
Browse files Browse the repository at this point in the history
  • Loading branch information
andportnoy committed Oct 22, 2024
1 parent 801fe87 commit 2aaa108
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 2 deletions.
6 changes: 6 additions & 0 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,12 @@ def register_lowering(prim: core.Primitive, rule: LoweringRule,
if platform is None:
_lowerings[prim] = rule
else:
if not xb.is_known_platform(platform):
known_platforms = sorted(xb.known_platforms())
raise NotImplementedError(
f"Registering an MLIR lowering rule for primitive {prim}"
f" for an unknown platform {platform}. Known platforms are:"
f" {', '.join(known_platforms)}.")
# For backward compatibility reasons, we allow rules to be registered
# under "gpu" even though the platforms are now called "cuda" and "rocm".
# TODO(phawkins): fix up users to specify either "cuda" or "rocm" and remove
Expand Down
15 changes: 13 additions & 2 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ class BackendRegistration:
# for unimplemented features. Wrong outputs are not acceptable.
_nonexperimental_plugins: set[str] = {'cuda', 'rocm'}

# The set of known experimental plugins that have registrations in JAX codebase.
_experimental_plugins: set[str] = {"METAL"}

def register_backend_factory(name: str, factory: BackendFactory, *,
priority: int = 0,
fail_quietly: bool = True,
Expand Down Expand Up @@ -774,12 +777,20 @@ def _discover_and_register_pjrt_plugins():
_alias_to_platforms.setdefault(_alias, []).append(_platform)


def known_platforms() -> set[str]:
platforms = set()
platforms |= set(_nonexperimental_plugins)
platforms |= set(_experimental_plugins)
platforms |= set(_backend_factories.keys())
platforms |= set(_platform_aliases.values())
return platforms


def is_known_platform(platform: str) -> bool:
# A platform is valid if there is a registered factory for it. It does not
# matter if we were unable to initialize that platform; we only care that
# we've heard of it and it isn't, e.g., a typo.
return (platform in _backend_factories.keys() or
platform in _platform_aliases.keys())
return platform in known_platforms()


def canonicalize_platform(platform: str) -> str:
Expand Down
10 changes: 10 additions & 0 deletions tests/extend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,5 +291,15 @@ def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, **kwargs):
)


class MlirRegisterLoweringTest(jtu.JaxTestCase):

def test_unknown_platform_error(self):
with self.assertRaisesRegex(
NotImplementedError,
"Registering an MLIR lowering rule for primitive .+ for an unknown "
"platform foo. Known platforms are: .+."):
mlir.register_lowering(prim=None, rule=None, platform="foo")


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 2aaa108

Please sign in to comment.