Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[array API] clean up some unused/unnecessary code #22634

Merged
merged 1 commit into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
square as square,
squeeze as squeeze,
stack as stack,
std as std,
subtract as subtract,
sum as sum,
take as take,
Expand All @@ -174,6 +175,7 @@
unique_inverse as unique_inverse,
unique_values as unique_values,
unstack as unstack,
var as var,
vecdot as vecdot,
where as where,
zeros as zeros,
Expand All @@ -199,11 +201,6 @@
hypot as hypot,
)

from jax.experimental.array_api._statistical_functions import (
std as std,
var as var,
)

from jax.experimental.array_api._utility_functions import (
__array_namespace_info__ as __array_namespace_info__,
)
Expand Down
5 changes: 0 additions & 5 deletions jax/experimental/array_api/_array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,9 @@

from __future__ import annotations

from typing import Any

import jax
from jax._src.array import ArrayImpl
from jax.experimental.array_api._version import __array_api_version__
from jax.sharding import Sharding

from jax._src.lib import xla_extension as xe


def _array_namespace(self, /, *, api_version: None | str = None):
Expand Down
28 changes: 0 additions & 28 deletions jax/experimental/array_api/_linear_algebra_functions.py

This file was deleted.

25 changes: 0 additions & 25 deletions jax/experimental/array_api/_statistical_functions.py

This file was deleted.

10 changes: 3 additions & 7 deletions jax/experimental/array_api/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,17 @@
matmul as matmul,
matrix_norm as matrix_norm,
matrix_power as matrix_power,
matrix_rank as matrix_rank,
matrix_transpose as matrix_transpose,
outer as outer,
pinv as pinv,
qr as qr,
slogdet as slogdet,
solve as solve,
svd as svd,
svdvals as svdvals,
tensordot as tensordot,
trace as trace,
vecdot as vecdot,
vector_norm as vector_norm,
)

from jax.numpy.linalg import trace as trace

from jax.experimental.array_api._linear_algebra_functions import (
matrix_rank as matrix_rank,
pinv as pinv,
)