diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index f9458caa3980..2adeb4b16cd9 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -834,6 +834,12 @@ def register_lowering(prim: core.Primitive, rule: LoweringRule, if platform is None: _lowerings[prim] = rule else: + if not xb.is_known_platform(platform): + known_platforms = sorted(xb.known_platforms()) + raise NotImplementedError( + f"Registering an MLIR lowering rule for primitive {prim}" + f" for an unknown platform {platform}. Known platforms are:" + f" {', '.join(known_platforms)}.") # For backward compatibility reasons, we allow rules to be registered # under "gpu" even though the platforms are now called "cuda" and "rocm". # TODO(phawkins): fix up users to specify either "cuda" or "rocm" and remove diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 796093b6225f..23b255ef1750 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -202,6 +202,9 @@ class BackendRegistration: # for unimplemented features. Wrong outputs are not acceptable. _nonexperimental_plugins: set[str] = {'cuda', 'rocm'} +# The set of known experimental plugins that have registrations in JAX codebase. +_experimental_plugins: set[str] = {"METAL"} + def register_backend_factory(name: str, factory: BackendFactory, *, priority: int = 0, fail_quietly: bool = True, @@ -774,12 +777,20 @@ def _discover_and_register_pjrt_plugins(): _alias_to_platforms.setdefault(_alias, []).append(_platform) +def known_platforms() -> set[str]: + platforms = set() + platforms |= set(_nonexperimental_plugins) + platforms |= set(_experimental_plugins) + platforms |= set(_backend_factories.keys()) + platforms |= set(_platform_aliases.values()) + return platforms + + def is_known_platform(platform: str) -> bool: # A platform is valid if there is a registered factory for it. It does not # matter if we were unable to initialize that platform; we only care that # we've heard of it and it isn't, e.g., a typo. - return (platform in _backend_factories.keys() or - platform in _platform_aliases.keys()) + return platform in known_platforms() def canonicalize_platform(platform: str) -> str: diff --git a/tests/extend_test.py b/tests/extend_test.py index ce38091a618f..0fc8821f1984 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -296,5 +296,15 @@ def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, **kwargs): )(pivots) +class MlirRegisterLoweringTest(jtu.JaxTestCase): + + def test_unknown_platform_error(self): + with self.assertRaisesRegex( + NotImplementedError, + "Registering an MLIR lowering rule for primitive .+ for an unknown " + "platform foo. Known platforms are: .+."): + mlir.register_lowering(prim=None, rule=None, platform="foo") + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())