Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update Jax-Metal CI runner with nightly jaxlib config #21070

Merged
merged 4 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/metal_plugin_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
jaxlib-version: ["plugin_latest"]
jaxlib-version: ["pypi_latest", "nightly"]
name: "Jax-Metal plugin test (jaxlib=${{ matrix.jaxlib-version }})"
runs-on: [self-hosted, macOS, ARM64]

Expand All @@ -32,13 +32,14 @@ jobs:
python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
pip install -U pip numpy wheel
pip install jax-metal absl-py pytest
pip install absl-py pytest
if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then
pip install --pre jaxlib \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
fi;
cd jax
pip install .
pip install jax-metal
- name: Run test
run: |
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
Expand Down
27 changes: 8 additions & 19 deletions tests/lax_metal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,11 @@ def testCountNonzero(self, shape, dtype, axis):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC why did you need to reduce coverage? Were new failures introduced in jaxlib nightly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metal_plugin_ci has been failing on tip since two weeks ago. I suspect there is a change in JAX? So the coverage changes here are synced from lax_numpy_test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. More reason to sync up the test files!

def testNonzero(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False)
self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False)

@jtu.sample_product(
[dict(shape=shape, fill_value=fill_value)
Expand Down Expand Up @@ -370,20 +368,16 @@ def np_fun(x):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
def testArgWhere(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False)
self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False)

# JIT compilation requires specifying a size statically. Full test of this
# behavior is in testNonzeroSize().
jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2)
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CompileAndCheck(jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
[dict(shape=shape, fill_value=fill_value)
Expand Down Expand Up @@ -2055,7 +2049,6 @@ def attempt_sideeffect(x):
self.assertAllClose(np_input, expected_np_input_after_call)
self.assertAllClose(jnp_input, expected_jnp_input_after_call)

@unittest.skip("Jax-metal fail to convert 1D convolution op.")
@jtu.sample_product(
mode=['full', 'same', 'valid'],
op=['convolve', 'correlate'],
Expand All @@ -2077,7 +2070,6 @@ def np_fun(x, y):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker)

@unittest.skip("Jax-metal fail to convert 1D convolution op.")
@jtu.sample_product(
mode=['full', 'same', 'valid'],
op=['convolve', 'correlate'],
Expand Down Expand Up @@ -4431,15 +4423,12 @@ def args_maker(): return []
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
shape=all_shapes, dtype=all_dtypes,
shape=nonzerodim_shapes, dtype=all_dtypes,
)
def testWhereOneArgument(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]

with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False)
self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False)

# JIT compilation requires specifying a size statically. Full test of
# this behavior is in testNonzeroSize().
Expand Down Expand Up @@ -5724,7 +5713,7 @@ def test_gather_ir(self):
#loc = loc(unknown)
module @jit_gather attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<3x2x3xf32> {mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<3x2xi32> {mhlo.sharding = "{replicated}"} loc(unknown)) -> tensor<3x2xf32> {
%0 = "stablehlo.gather"(%arg0, %arg1) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0, 2], start_index_map = [0, 2], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = dense<[1, 2, 1]> : tensor<3xi64>} : (tensor<3x2x3xf32>, tensor<3x2xi32>) -> tensor<3x2xf32> loc(#loc2)
%0 = "stablehlo.gather"(%arg0, %arg1) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0, 2], start_index_map = [0, 2], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 2, 1>} : (tensor<3x2x3xf32>, tensor<3x2xi32>) -> tensor<3x2xf32> loc(#loc2)
return %0 : tensor<3x2xf32> loc(#loc)
} loc(#loc)
} loc(#loc)
Expand Down