From ce6a0c43adf15f00fc1608910b823ab4290e2b55 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 6 Oct 2023 11:26:03 -0700 Subject: [PATCH] jax.lax: deprecate inadvertent exports & internal utilities --- CHANGELOG.md | 9 +++++ jax/lax/__init__.py | 83 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 82 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e297b784529a..27e37761deb9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,15 @@ Remember to align the itemized text with the first line of an item within a list If using the recommended `cuda12_pip` installation, NCCL should be installed automatically. +* Deprecations + * A number of internal utilities and inadvertent exports in {mod}`jax.lax` have + been deprecated, and will be removed in a future release. + * `jax.lax.dtypes`: use `jax.dtypes` instead. + * `jax.lax.itertools`: use `itertools` instead. + * `naryop`, `naryop_dtype_rule`, `standard_abstract_eval`, `standard_naryop`, + `standard_primitive`, `standard_unop`, `unop`, and `unop_dtype_rule` are + internal utilities, now deprecated without replacement. + # jax 0.4.17 (Oct 3, 2023) * New features diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 111233d721fa..04ab2f037dff 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -88,7 +88,7 @@ dot_general as dot_general, dot_general_p as dot_general_p, dtype as dtype, - dtypes as dtypes, + dtypes as _deprecated_dtypes, eq as eq, eq_p as eq_p, exp as exp, @@ -116,7 +116,7 @@ iota_p as iota_p, is_finite as is_finite, is_finite_p as is_finite_p, - itertools as itertools, + itertools as _deprecated_itertools, le as le, le_p as le_p, log as log, @@ -133,8 +133,8 @@ min_p as min_p, mul as mul, mul_p as mul_p, - naryop as naryop, - naryop_dtype_rule as naryop_dtype_rule, + naryop as _deprecated_naryop, + naryop_dtype_rule as _deprecated_naryop_dtype_rule, ne as ne, ne_p as ne_p, neg as neg, @@ -203,10 +203,10 @@ square as square, squeeze as squeeze, squeeze_p as squeeze_p, - standard_abstract_eval as standard_abstract_eval, - standard_naryop as standard_naryop, - standard_primitive as standard_primitive, - standard_unop as standard_unop, + standard_abstract_eval as _deprecated_standard_abstract_eval, + standard_naryop as _deprecated_standard_naryop, + standard_primitive as _deprecated_standard_primitive, + standard_unop as _deprecated_standard_unop, stop_gradient as stop_gradient, sub as sub, sub_p as sub_p, @@ -219,8 +219,8 @@ top_k_p as top_k_p, transpose as transpose, transpose_p as transpose_p, - unop as unop, - unop_dtype_rule as unop_dtype_rule, + unop as _deprecated_unop, + unop_dtype_rule as _deprecated_unop_dtype_rule, xor_p as xor_p, zeros_like_array as zeros_like_array, ) @@ -377,3 +377,66 @@ from jax._src.pjit import with_sharding_constraint as with_sharding_constraint from jax._src.pjit import sharding_constraint_p as sharding_constraint_p from jax._src.dispatch import device_put_p as device_put_p + + +_deprecations = { + # Added October 6 2023 + "dtypes": ( + "jax.lax.dtypes is deprecated: import jax.dtypes directly.", + _deprecated_dtypes, + ), + "itertools": ( + "jax.lax.itertools is deprecated: import itertools directly.", + _deprecated_itertools, + ), + "naryop": ( + "jax.lax.naryop is an internal API and has been deprecated.", + _deprecated_naryop, + ), + "naryop_dtype_rule": ( + "jax.lax.naryop_dtype_rule is an internal API and has been deprecated.", + _deprecated_naryop_dtype_rule, + ), + "standard_abstract_eval": ( + "jax.lax.standard_abstract_eval is an internal API and has been deprecated.", + _deprecated_standard_abstract_eval, + ), + "standard_naryop": ( + "jax.lax.standard_naryop is an internal API and has been deprecated.", + _deprecated_standard_naryop, + ), + "standard_primitive": ( + "jax.lax.standard_primitive is an internal API and has been deprecated.", + _deprecated_standard_primitive, + ), + "standard_unop": ( + "jax.lax.standard_unop is an internal API and has been deprecated.", + _deprecated_standard_unop, + ), + "unop": ( + "jax.lax.unop is an internal API and has been deprecated.", + _deprecated_unop, + ), + "unop_dtype_rule": ( + "jax.lax.unop_dtype_rule is an internal API and has been deprecated.", + _deprecated_unop_dtype_rule, + ), +} + +import typing as _typing +if _typing.TYPE_CHECKING: + dtypes = _deprecated_dtypes, + itertools = _deprecated_itertools, + naryop = _deprecated_naryop, + naryop_dtype_rule = _deprecated_naryop_dtype_rule, + standard_abstract_eval = _deprecated_standard_abstract_eval, + standard_naryop = _deprecated_standard_naryop, + standard_primitive = _deprecated_standard_primitive, + standard_unop = _deprecated_standard_unop, + unop = _deprecated_unop, + unop_dtype_rule = _deprecated_unop_dtype_rule, +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing