Skip to content

Commit

Permalink
jax.lax: deprecate inadvertent exports & internal utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 6, 2023
1 parent 4c37c79 commit 4c862dc
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 10 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ Remember to align the itemized text with the first line of an item within a list
If using the recommended `cuda12_pip` installation, NCCL should be installed
automatically.

* Deprecations
* A number of internal utilities and inadvertent exports in {mod}`jax.lax` have
been deprecated, and will be removed in a future release.
* `jax.lax.dtypes`: use `jax.dtypes` instead.
* `jax.lax.itertools`: use `itertools` instead.
* `naryop`, `naryop_dtype_rule`, `standard_abstract_eval`, `standard_naryop`,
`standard_primitive`, `standard_unop`, `unop`, and `unop_dtype_rule` are
internal utilities, now deprecated without replacement.

# jax 0.4.17 (Oct 3, 2023)

* New features
Expand Down
85 changes: 75 additions & 10 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
dot_general as dot_general,
dot_general_p as dot_general_p,
dtype as dtype,
dtypes as dtypes,
dtypes as _deprecated_dtypes,
eq as eq,
eq_p as eq_p,
exp as exp,
Expand Down Expand Up @@ -116,7 +116,7 @@
iota_p as iota_p,
is_finite as is_finite,
is_finite_p as is_finite_p,
itertools as itertools,
itertools as _deprecated_itertools,
le as le,
le_p as le_p,
log as log,
Expand All @@ -133,8 +133,8 @@
min_p as min_p,
mul as mul,
mul_p as mul_p,
naryop as naryop,
naryop_dtype_rule as naryop_dtype_rule,
naryop as _deprecated_naryop,
naryop_dtype_rule as _deprecated_naryop_dtype_rule,
ne as ne,
ne_p as ne_p,
neg as neg,
Expand Down Expand Up @@ -203,10 +203,10 @@
square as square,
squeeze as squeeze,
squeeze_p as squeeze_p,
standard_abstract_eval as standard_abstract_eval,
standard_naryop as standard_naryop,
standard_primitive as standard_primitive,
standard_unop as standard_unop,
standard_abstract_eval as _deprecated_standard_abstract_eval,
standard_naryop as _deprecated_standard_naryop,
standard_primitive as _deprecated_standard_primitive,
standard_unop as _deprecated_standard_unop,
stop_gradient as stop_gradient,
sub as sub,
sub_p as sub_p,
Expand All @@ -219,8 +219,8 @@
top_k_p as top_k_p,
transpose as transpose,
transpose_p as transpose_p,
unop as unop,
unop_dtype_rule as unop_dtype_rule,
unop as _deprecated_unop,
unop_dtype_rule as _deprecated_unop_dtype_rule,
xor_p as xor_p,
zeros_like_array as zeros_like_array,
)
Expand Down Expand Up @@ -377,3 +377,68 @@
from jax._src.pjit import with_sharding_constraint as with_sharding_constraint
from jax._src.pjit import sharding_constraint_p as sharding_constraint_p
from jax._src.dispatch import device_put_p as device_put_p


_deprecations = {
# Added October 6 2023
"dtypes": (
"jax.lax.dtypes is deprecated: import jax.dtypes directly.",
_deprecated_dtypes,
),
"itertools": (
"jax.lax.itertools is deprecated: import itertools directly.",
_deprecated_itertools,
),
"naryop": (
"jax.lax.naryop is an internal API and has been deprecated.",
_deprecated_naryop,
),
"naryop_dtype_rule": (
"jax.lax.naryop_dtype_rule is an internal API and has been deprecated.",
_deprecated_naryop_dtype_rule,
),
"standard_abstract_eval": (
"jax.lax.standard_abstract_eval is an internal API and has been deprecated.",
_deprecated_standard_abstract_eval,
),
"standard_naryop": (
"jax.lax.standard_naryop is an internal API and has been deprecated.",
_deprecated_standard_naryop,
),
"standard_primitive": (
"jax.lax.standard_primitive is an internal API and has been deprecated.",
_deprecated_standard_primitive,
),
"standard_unop": (
"jax.lax.standard_unop is an internal API and has been deprecated.",
_deprecated_standard_unop,
),
"unop": (
"jax.lax.unop is an internal API and has been deprecated.",
_deprecated_unop,
),
"unop_dtype_rule": (
"jax.lax.unop_dtype_rule is an internal API and has been deprecated.",
_deprecated_unop_dtype_rule,
),
}

import typing as _typing
if _typing.TYPE_CHECKING:
from jax._src.lax import (
dtypes as dtypes,
itertools as itertools,
naryop as naryop,
naryop_dtype_rule as naryop_dtype_rule,
standard_abstract_eval as standard_abstract_eval,
standard_naryop as standard_naryop,
standard_primitive as standard_primitive,
standard_unop as standard_unop,
unop as unop,
unop_dtype_rule as unop_dtype_rule,
)
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing

0 comments on commit 4c862dc

Please sign in to comment.