diff --git a/.github/workflows/metal_plugin_ci.yml b/.github/workflows/metal_plugin_ci.yml index a62b0a1b6eb2..6e67841460e9 100644 --- a/.github/workflows/metal_plugin_ci.yml +++ b/.github/workflows/metal_plugin_ci.yml @@ -17,7 +17,7 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - jaxlib-version: ["plugin_latest"] + jaxlib-version: ["pypi_latest", "nightly"] name: "Jax-Metal plugin test (jaxlib=${{ matrix.jaxlib-version }})" runs-on: [self-hosted, macOS, ARM64] @@ -32,13 +32,14 @@ jobs: python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate pip install -U pip numpy wheel - pip install jax-metal absl-py pytest + pip install absl-py pytest if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then pip install --pre jaxlib \ -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html fi; cd jax pip install . + pip install jax-metal - name: Run test run: | source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index 87ca6eaee322..5069187d2334 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -302,13 +302,11 @@ def testCountNonzero(self, shape, dtype, axis): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) self._CompileAndCheck(jnp_fun, args_maker) - @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) + @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) def testNonzero(self, shape, dtype): rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng(shape, dtype)] - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False) + self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False) @jtu.sample_product( [dict(shape=shape, fill_value=fill_value) @@ -370,20 +368,16 @@ def np_fun(x): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) self._CompileAndCheck(jnp_fun, args_maker) - @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) + @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) def testArgWhere(self, shape, dtype): rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng(shape, dtype)] - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False) + self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False) # JIT compilation requires specifying a size statically. Full test of this # behavior is in testNonzeroSize(). jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2) - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CompileAndCheck(jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( [dict(shape=shape, fill_value=fill_value) @@ -2055,7 +2049,6 @@ def attempt_sideeffect(x): self.assertAllClose(np_input, expected_np_input_after_call) self.assertAllClose(jnp_input, expected_jnp_input_after_call) - @unittest.skip("Jax-metal fail to convert 1D convolution op.") @jtu.sample_product( mode=['full', 'same', 'valid'], op=['convolve', 'correlate'], @@ -2077,7 +2070,6 @@ def np_fun(x, y): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol) self._CompileAndCheck(jnp_fun, args_maker) - @unittest.skip("Jax-metal fail to convert 1D convolution op.") @jtu.sample_product( mode=['full', 'same', 'valid'], op=['convolve', 'correlate'], @@ -4431,15 +4423,12 @@ def args_maker(): return [] self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - shape=all_shapes, dtype=all_dtypes, + shape=nonzerodim_shapes, dtype=all_dtypes, ) def testWhereOneArgument(self, shape, dtype): rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng(shape, dtype)] - - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False) + self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False) # JIT compilation requires specifying a size statically. Full test of # this behavior is in testNonzeroSize(). @@ -5724,7 +5713,7 @@ def test_gather_ir(self): #loc = loc(unknown) module @jit_gather attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<3x2x3xf32> {mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<3x2xi32> {mhlo.sharding = "{replicated}"} loc(unknown)) -> tensor<3x2xf32> { - %0 = "stablehlo.gather"(%arg0, %arg1) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 2, 1]> : tensor<3xi64>} : (tensor<3x2x3xf32>, tensor<3x2xi32>) -> tensor<3x2xf32> loc(#loc2) + %0 = "stablehlo.gather"(%arg0, %arg1) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array} : (tensor<3x2x3xf32>, tensor<3x2xi32>) -> tensor<3x2xf32> loc(#loc2) return %0 : tensor<3x2xf32> loc(#loc) } loc(#loc) } loc(#loc)