Skip to content

Commit

Permalink
Merge pull request #17760 from superbobry:array-any
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570400629
  • Loading branch information
jax authors committed Oct 3, 2023
2 parents b2ac2de + 5ab05e4 commit c3e73c6
Show file tree
Hide file tree
Showing 12 changed files with 138 additions and 117 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ pytype_strict_library(
":core",
":effects",
":pretty_printer",
":typing",
":util",
],
)
Expand Down
10 changes: 6 additions & 4 deletions jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import (tree_unflatten, tree_flatten,
register_pytree_node)
from jax._src.typing import Array
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list,
canonicalize_axis, moveaxis, as_hashable_function,
curry, memoize, weakref_lru_cache)


Array = Any
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

Expand Down Expand Up @@ -116,7 +116,7 @@ class RaggedAxis:
# For each axis, we store its index and the corresponding segment lengths.
# For example, the jumble i:(Fin 3) => f32[lens1.i, 7, lens2.i]
# would be represented with ragged_axes = [(1, lens1), (3, lens2)]
ragged_axes: tuple[tuple[int, Array], ...]
ragged_axes: tuple[tuple[int, Any], ...]

@property
def size(self):
Expand Down Expand Up @@ -148,8 +148,10 @@ def _sorted_ragged_axis(stacked_axis, ragged_axes):
return RaggedAxis(stacked_axis, tuple(sorted(ragged_axes, key=lambda p: p[0])))

def make_batch_axis(
ndim: int, stacked_axis: int, ragged_axes: list[tuple[int, Array]]
) -> int | RaggedAxis:
ndim: int,
stacked_axis: int,
ragged_axes: list[tuple[int, Array | core.Var]],
) -> int | RaggedAxis:
if ragged_axes:
canonical = [(canonicalize_axis(ax, ndim), sz) for ax, sz in ragged_axes]
return _sorted_ragged_axis(canonicalize_axis(stacked_axis, ndim), canonical)
Expand Down
5 changes: 1 addition & 4 deletions jax/_src/lax/ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def pmap_mips(qy, db, db_offset, db_size, k, recall_target):
"""

from functools import partial
from typing import Any

import numpy as np

Expand All @@ -88,9 +87,7 @@ def pmap_mips(qy, db, db_offset, db_size, k, recall_target):
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import hlo


Array = Any
from jax._src.typing import Array


def approx_max_k(operand: Array,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/control_flow/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from jax._src.state import primitives as state_primitives
from jax._src.state import utils as state_utils
from jax._src.state import types as state_types
from jax._src.typing import Array
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
split_list, split_dict)
from jax._src.lax.control_flow import loops
Expand All @@ -53,7 +54,6 @@
S = TypeVar('S')
T = TypeVar('T')
class Ref(Generic[T]): pass
Array = Any

ref_set = state_primitives.ref_set
ref_get = state_primitives.ref_get
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from jax._src.state import discharge as state_discharge
from jax._src.numpy.ufuncs import logaddexp
from jax._src.traceback_util import api_boundary
from jax._src.typing import Array
from jax._src.util import (partition_list, safe_map, safe_zip, split_list,
unzip2, weakref_lru_cache, merge_lists)
import numpy as np
Expand All @@ -64,7 +65,6 @@
zip = safe_zip

T = TypeVar('T')
Array = Any
BooleanNumeric = Any # A bool, or a Boolean array.

### Helper functions
Expand Down
22 changes: 8 additions & 14 deletions jax/_src/lax/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import builtins
from collections.abc import Sequence
from functools import partial
import operator
from typing import Any, NamedTuple, Optional, Union
from typing import NamedTuple, Optional, Union

import numpy as np

Expand All @@ -28,14 +27,9 @@
from jax._src.interpreters import mlir
from jax._src.lax import lax
from jax._src.lib.mlir.dialects import hlo
from jax._src.typing import Array, DTypeLike


_max = builtins.max

Array = Any
DType = Any
Shape = core.Shape

class ConvDimensionNumbers(NamedTuple):
"""Describes batch, spatial, and feature dimensions of a convolution.
Expand All @@ -62,7 +56,7 @@ def conv_general_dilated(
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
feature_group_count: int = 1, batch_group_count: int = 1,
precision: lax.PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
preferred_element_type: Optional[DTypeLike] = None) -> Array:
"""General n-dimensional convolution operator, with optional dilation.
Wraps XLA's `Conv
Expand Down Expand Up @@ -174,7 +168,7 @@ def conv_general_dilated(

def conv(lhs: Array, rhs: Array, window_strides: Sequence[int],
padding: str, precision: lax.PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
preferred_element_type: Optional[DTypeLike] = None) -> Array:
"""Convenience wrapper around `conv_general_dilated`.
Args:
Expand Down Expand Up @@ -204,7 +198,7 @@ def conv_with_general_padding(lhs: Array, rhs: Array,
lhs_dilation: Optional[Sequence[int]],
rhs_dilation: Optional[Sequence[int]],
precision: lax.PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
preferred_element_type: Optional[DTypeLike] = None) -> Array:
"""Convenience wrapper around `conv_general_dilated`.
Args:
Expand Down Expand Up @@ -256,7 +250,7 @@ def _conv_transpose_padding(k, s, padding):
else:
pad_a = int(np.ceil(pad_len / 2))
elif padding == 'VALID':
pad_len = k + s - 2 + _max(k - s, 0)
pad_len = k + s - 2 + max(k - s, 0)
pad_a = k - 1
else:
raise ValueError('Padding mode must be `SAME` or `VALID`.')
Expand All @@ -277,7 +271,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
transpose_kernel: bool = False,
precision: lax.PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
preferred_element_type: Optional[DTypeLike] = None) -> Array:
"""Convenience wrapper for calculating the N-d convolution "transpose".
This function directly calculates a fractionally strided conv rather than
Expand Down Expand Up @@ -343,7 +337,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
if transpose_kernel:
# flip spatial dims and swap input / output channel axes
rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:])
rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1])
rhs = rhs.swapaxes(dn.rhs_spec[0], dn.rhs_spec[1])
return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn,
precision=precision,
preferred_element_type=preferred_element_type)
Expand Down
5 changes: 2 additions & 3 deletions jax/_src/lax/windowed_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from collections.abc import Sequence
from functools import partial
from typing import Any, Callable, Optional, Union
from typing import Callable, Optional, Union
import warnings

import numpy as np
Expand All @@ -36,12 +36,11 @@
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.ufuncs import logaddexp
from jax._src.typing import Array

map = util.safe_map
zip = util.safe_zip

Array = Any


def reduce_window(operand, init_value, computation: Callable,
window_dimensions: core.Shape, window_strides: Sequence[int],
Expand Down
Loading

0 comments on commit c3e73c6

Please sign in to comment.