Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.lax: deprecate inadvertent exports & internal utilities #17987

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ Remember to align the itemized text with the first line of an item within a list
If using the recommended `cuda12_pip` installation, NCCL should be installed
automatically.

* Deprecations
* A number of internal utilities and inadvertent exports in {mod}`jax.lax` have
been deprecated, and will be removed in a future release.
* `jax.lax.dtypes`: use `jax.dtypes` instead.
* `jax.lax.itertools`: use `itertools` instead.
* `naryop`, `naryop_dtype_rule`, `standard_abstract_eval`, `standard_naryop`,
`standard_primitive`, `standard_unop`, `unop`, and `unop_dtype_rule` are
internal utilities, now deprecated without replacement.

# jax 0.4.17 (Oct 3, 2023)

* New features
Expand Down
83 changes: 73 additions & 10 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
dot_general as dot_general,
dot_general_p as dot_general_p,
dtype as dtype,
dtypes as dtypes,
dtypes as _deprecated_dtypes,
eq as eq,
eq_p as eq_p,
exp as exp,
Expand Down Expand Up @@ -116,7 +116,7 @@
iota_p as iota_p,
is_finite as is_finite,
is_finite_p as is_finite_p,
itertools as itertools,
itertools as _deprecated_itertools,
le as le,
le_p as le_p,
log as log,
Expand All @@ -133,8 +133,8 @@
min_p as min_p,
mul as mul,
mul_p as mul_p,
naryop as naryop,
naryop_dtype_rule as naryop_dtype_rule,
naryop as _deprecated_naryop,
naryop_dtype_rule as _deprecated_naryop_dtype_rule,
ne as ne,
ne_p as ne_p,
neg as neg,
Expand Down Expand Up @@ -203,10 +203,10 @@
square as square,
squeeze as squeeze,
squeeze_p as squeeze_p,
standard_abstract_eval as standard_abstract_eval,
standard_naryop as standard_naryop,
standard_primitive as standard_primitive,
standard_unop as standard_unop,
standard_abstract_eval as _deprecated_standard_abstract_eval,
standard_naryop as _deprecated_standard_naryop,
standard_primitive as _deprecated_standard_primitive,
standard_unop as _deprecated_standard_unop,
stop_gradient as stop_gradient,
sub as sub,
sub_p as sub_p,
Expand All @@ -219,8 +219,8 @@
top_k_p as top_k_p,
transpose as transpose,
transpose_p as transpose_p,
unop as unop,
unop_dtype_rule as unop_dtype_rule,
unop as _deprecated_unop,
unop_dtype_rule as _deprecated_unop_dtype_rule,
xor_p as xor_p,
zeros_like_array as zeros_like_array,
)
Expand Down Expand Up @@ -377,3 +377,66 @@
from jax._src.pjit import with_sharding_constraint as with_sharding_constraint
from jax._src.pjit import sharding_constraint_p as sharding_constraint_p
from jax._src.dispatch import device_put_p as device_put_p


_deprecations = {
# Added October 6 2023
"dtypes": (
"jax.lax.dtypes is deprecated: import jax.dtypes directly.",
_deprecated_dtypes,
),
"itertools": (
"jax.lax.itertools is deprecated: import itertools directly.",
_deprecated_itertools,
),
"naryop": (
"jax.lax.naryop is an internal API and has been deprecated.",
_deprecated_naryop,
),
"naryop_dtype_rule": (
"jax.lax.naryop_dtype_rule is an internal API and has been deprecated.",
_deprecated_naryop_dtype_rule,
),
"standard_abstract_eval": (
"jax.lax.standard_abstract_eval is an internal API and has been deprecated.",
_deprecated_standard_abstract_eval,
),
"standard_naryop": (
"jax.lax.standard_naryop is an internal API and has been deprecated.",
_deprecated_standard_naryop,
),
"standard_primitive": (
"jax.lax.standard_primitive is an internal API and has been deprecated.",
_deprecated_standard_primitive,
),
"standard_unop": (
"jax.lax.standard_unop is an internal API and has been deprecated.",
_deprecated_standard_unop,
),
"unop": (
"jax.lax.unop is an internal API and has been deprecated.",
_deprecated_unop,
),
"unop_dtype_rule": (
"jax.lax.unop_dtype_rule is an internal API and has been deprecated.",
_deprecated_unop_dtype_rule,
),
}

import typing as _typing
if _typing.TYPE_CHECKING:
dtypes = _deprecated_dtypes,
itertools = _deprecated_itertools,
naryop = _deprecated_naryop,
naryop_dtype_rule = _deprecated_naryop_dtype_rule,
standard_abstract_eval = _deprecated_standard_abstract_eval,
standard_naryop = _deprecated_standard_naryop,
standard_primitive = _deprecated_standard_primitive,
standard_unop = _deprecated_standard_unop,
unop = _deprecated_unop,
unop_dtype_rule = _deprecated_unop_dtype_rule,
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing