Skip to content

Commit

Permalink
[shape_poly] Remove some deprecated kwargs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 703116755
  • Loading branch information
gnecula authored and Google-ML-Automation committed Dec 5, 2024
1 parent e510295 commit 5fe5206
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 23 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
use `uses_global_constants`.
* the `lowering_platforms` kwarg for {func}`jax.export.export`: use
`platforms` instead.
* The kwargs `symbolic_scope` and `symbolic_constraints` from
{func}`jax.export.symbolic_args_specs` have been removed. They were
deprecated in June 2024. Use `scope` and `constraints` instead.
* Hashing of tracers, which has been deprecated since version 0.4.30, now
results in a `TypeError`.
* Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
Expand Down
23 changes: 0 additions & 23 deletions jax/_src/export/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,12 +1198,6 @@ def is_symbolic_dim(p: DimSize) -> bool:
"""
return isinstance(p, _DimExpr)

def is_poly_dim(p: DimSize) -> bool:
# TODO: deprecated January 2024, remove June 2024.
warnings.warn("is_poly_dim is deprecated, use export.is_symbolic_dim",
DeprecationWarning, stacklevel=2)
return is_symbolic_dim(p)

dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int]

def _einsum_contract_path(*operands, **kwargs):
Expand Down Expand Up @@ -1413,8 +1407,6 @@ def symbolic_args_specs(
shapes_specs, # prefix pytree of strings
constraints: Sequence[str] = (),
scope: SymbolicScope | None = None,
symbolic_constraints: Sequence[str] = (), # DEPRECATED on 6/14/24
symbolic_scope: SymbolicScope | None = None, # DEPRECATED on 6/14/24
):
"""Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`.
Expand All @@ -1435,25 +1427,10 @@ def symbolic_args_specs(
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
constraints: as for :func:`jax.export.symbolic_shape`.
scope: as for :func:`jax.export.symbolic_shape`.
symbolic_constraints: DEPRECATED, use `constraints`.
symbolic_scope: DEPRECATED, use `scope`.
Returns: a pytree of jax.ShapeDTypeStruct matching the `args` with the shapes
replaced with symbolic dimensions as specified by `shapes_specs`.
"""
if symbolic_constraints:
warnings.warn("symbolic_constraints is deprecated, use constraints",
DeprecationWarning, stacklevel=2)
if constraints:
raise ValueError("Cannot use both symbolic_constraints and constraints")
constraints = symbolic_constraints
if symbolic_scope is not None:
warnings.warn("symbolic_scope is deprecated, use scope",
DeprecationWarning, stacklevel=2)
if scope is not None:
raise ValueError("Cannot use both symbolic_scope and scope")
scope = symbolic_scope

polymorphic_shapes = shapes_specs
args_flat, args_tree = tree_util.tree_flatten(args)

Expand Down

0 comments on commit 5fe5206

Please sign in to comment.