diff --git a/jax/lax/lax.py b/jax/lax/lax.py index 7da7622f6b73..076847ad2b12 100644 --- a/jax/lax/lax.py +++ b/jax/lax/lax.py @@ -20,7 +20,8 @@ import itertools import operator import string -from typing import Any, Callable +from typing import (Any, Callable, List, NamedTuple, Optional, Sequence, Union, + Tuple, Type) import warnings import numpy as onp @@ -56,6 +57,9 @@ _min = builtins.max _reduce = functools.reduce +Array = Any +DType = Any +Shape = Sequence[int] @cache() def broadcast_shapes(*shapes): @@ -76,11 +80,11 @@ def _identity(x): return x ### traceables -def neg(x): +def neg(x: Array) -> Array: r"""Elementwise negation: :math:`-x`.""" return neg_p.bind(x) -def sign(x): +def sign(x: Array) -> Array: r"""Elementwise sign. For floating-point inputs, returns @@ -104,26 +108,26 @@ def sign(x): """ return sign_p.bind(x) -def nextafter(x1, x2): +def nextafter(x1: Array, x2: Array) -> Array: r"""Returns the next representable value after `x1` in the direction of `x2`.""" return nextafter_p.bind(_brcast(x1, x2), _brcast(x2, x1)) -def floor(x): +def floor(x: Array) -> Array: r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`.""" return floor_p.bind(x) -def ceil(x): +def ceil(x: Array) -> Array: r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`.""" return ceil_p.bind(x) -def round(x): +def round(x: Array) -> Array: r"""Elementwise round. Rounds values to the nearest integer. Halfway values (e.g., `0.5`) are rounded away from zero.""" return round_p.bind(x) -def is_finite(x): +def is_finite(x: Array) -> Array: r"""Elementwise :math:`\mathrm{isfinite}`. For each element x returns `True` if and only if x is not :math:`\pm\infty` or @@ -131,215 +135,215 @@ def is_finite(x): """ return is_finite_p.bind(x) -def exp(x): +def exp(x: Array) -> Array: r"""Elementwise exponential: :math:`e^x`.""" return exp_p.bind(x) -def expm1(x): +def expm1(x: Array) -> Array: r"""Elementwise :math:`e^{x - 1}`.""" return expm1_p.bind(x) -def log(x): +def log(x: Array) -> Array: r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`.""" return log_p.bind(x) -def log1p(x): +def log1p(x: Array) -> Array: r"""Elementwise :math:`\mathrm{log}(1 + x)`.""" return log1p_p.bind(x) -def tanh(x): +def tanh(x: Array) -> Array: r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`.""" return tanh_p.bind(x) -def sin(x): +def sin(x: Array) -> Array: r"""Elementwise sine: :math:`\mathrm{sin}(x)`.""" return sin_p.bind(x) -def cos(x): +def cos(x: Array) -> Array: r"""Elementwise cosine: :math:`\mathrm{cos}(x)`.""" return cos_p.bind(x) -def atan2(x, y): +def atan2(x: Array, y: Array) -> Array: r"""Elementwise arc tangent of two variables: :math:`\mathrm{atan}({x \over y})`.""" return atan2_p.bind(x, y) -def betainc(a, b, x): +def betainc(a: Array, b: Array, x: Array) -> Array: r"""Elementwise regularized incomplete beta integral.""" a = _brcast(_brcast(a, b), x) b = _brcast(b, a) x = _brcast(x, a) return regularized_incomplete_beta_p.bind(a, b, x) -def lgamma(x): +def lgamma(x: Array) -> Array: r"""Elementwise log gamma: :math:`\mathrm{log}(\Gamma(x))`.""" return lgamma_p.bind(x) -def digamma(x): +def digamma(x: Array) -> Array: r"""Elementwise digamma: :math:`\psi(x)`.""" return digamma_p.bind(x) -def igamma(a, x): +def igamma(a: Array, x: Array) -> Array: r"""Elementwise regularized incomplete gamma function.""" return igamma_p.bind(_brcast(a, x), _brcast(x, a)) -def igammac(a, x): +def igammac(a: Array, x: Array) -> Array: r"""Elementwise complementary regularized incomplete gamma function.""" return igammac_p.bind(_brcast(a, x), _brcast(x, a)) -def bessel_i0e(x): +def bessel_i0e(x: Array) -> Array: r"""Exponentially scaled modified Bessel function of order 0: :math:`\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)` """ return bessel_i0e_p.bind(x) -def bessel_i1e(x): +def bessel_i1e(x: Array) -> Array: r"""Exponentially scaled modified Bessel function of order 1: :math:`\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)` """ return bessel_i1e_p.bind(x) -def erf(x): +def erf(x: Array) -> Array: r"""Elementwise error function: :math:`\mathrm{erf}(x)`.""" return erf_p.bind(x) -def erfc(x): +def erfc(x: Array) -> Array: r"""Elementwise complementary error function: :math:`\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)`.""" return erfc_p.bind(x) -def erf_inv(x): +def erf_inv(x: Array) -> Array: r"""Elementwise inverse error function: :math:`\mathrm{erf}^{-1}(x)`.""" return erf_inv_p.bind(x) -def real(x): +def real(x: Array) -> Array: r"""Elementwise extract real part: :math:`\mathrm{Re}(x)`. Returns the real part of a complex number. """ return real_p.bind(x) -def imag(x): +def imag(x: Array) -> Array: r"""Elementwise extract imaginary part: :math:`\mathrm{Im}(x)`. Returns the imaginary part of a complex number. """ return imag_p.bind(x) -def complex(x, y): +def complex(x: Array, y: Array) -> Array: r"""Elementwise make complex number: :math:`x + jy`. Builds a complex number from real and imaginary parts. """ return complex_p.bind(_brcast(x, y), _brcast(y, x)) -def conj(x): +def conj(x: Array) -> Array: r"""Elementwise complex conjugate function: :math:`\overline{x}`.""" return conj_p.bind(x, input_dtype=_dtype(x)) -def abs(x): +def abs(x: Array) -> Array: r"""Elementwise absolute value: :math:`|x|`.""" return abs_p.bind(x) -def pow(x, y): +def pow(x: Array, y: Array) -> Array: r"""Elementwise power: :math:`x^y`.""" return pow_p.bind(x, y) -def sqrt(x): +def sqrt(x: Array) -> Array: r"""Elementwise square root: :math:`\sqrt{x}`.""" return sqrt_p.bind(x) -def rsqrt(x): +def rsqrt(x: Array) -> Array: r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}.""" return rsqrt_p.bind(x) -def bitwise_not(x): +def bitwise_not(x: Array) -> Array: r"""Elementwise NOT: :math:`\neg x`.""" return not_p.bind(x) -def bitwise_and(x, y): +def bitwise_and(x: Array, y: Array) -> Array: r"""Elementwise AND: :math:`x \wedge y`.""" return and_p.bind(x, y) -def bitwise_or(x, y): +def bitwise_or(x: Array, y: Array) -> Array: r"""Elementwise OR: :math:`x \vee y`.""" return or_p.bind(x, y) -def bitwise_xor(x, y): +def bitwise_xor(x: Array, y: Array) -> Array: r"""Elementwise exclusive OR: :math:`x \oplus y`.""" return xor_p.bind(x, y) -def add(x, y): +def add(x: Array, y: Array) -> Array: r"""Elementwise addition: :math:`x + y`.""" return add_p.bind(x, y) -def sub(x, y): +def sub(x: Array, y: Array) -> Array: r"""Elementwise subtraction: :math:`x - y`.""" return sub_p.bind(x, y) -def mul(x, y): +def mul(x: Array, y: Array) -> Array: r"""Elementwise multiplication: :math:`x \times y`.""" return mul_p.bind(x, y) -def div(x, y): +def div(x: Array, y: Array) -> Array: r"""Elementwise division: :math:`x \over y`.""" return div_p.bind(x, y) -def rem(x, y): +def rem(x: Array, y: Array) -> Array: r"""Elementwise remainder: :math:`x \bmod y`.""" return rem_p.bind(x, y) -def max(x, y): +def max(x: Array, y: Array) -> Array: r"""Elementwise maximum: :math:`\mathrm{max}(x, y)` For complex numbers, uses a lexicographic comparison on the `(real, imaginary)` pairs.""" return max_p.bind(x, y) -def min(x, y): +def min(x: Array, y: Array) -> Array: r"""Elementwise minimum: :math:`\mathrm{min}(x, y)` For complex numbers, uses a lexicographic comparison on the `(real, imaginary)` pairs.""" return min_p.bind(x, y) -def shift_left(x, y): +def shift_left(x: Array, y: Array) -> Array: r"""Elementwise left shift: :math:`x \ll y`.""" return shift_left_p.bind(x, y) -def shift_right_arithmetic(x, y): +def shift_right_arithmetic(x: Array, y: Array) -> Array: r"""Elementwise arithmetic right shift: :math:`x \gg y`.""" return shift_right_arithmetic_p.bind(x, y) -def shift_right_logical(x, y): +def shift_right_logical(x: Array, y: Array) -> Array: r"""Elementwise logical right shift: :math:`x \gg y`.""" return shift_right_logical_p.bind(x, y) -def eq(x, y): +def eq(x: Array, y: Array) -> Array: r"""Elementwise equals: :math:`x = y`.""" return eq_p.bind(x, y) -def ne(x, y): +def ne(x: Array, y: Array) -> Array: r"""Elementwise not-equals: :math:`x \neq y`.""" return ne_p.bind(x, y) -def ge(x, y): +def ge(x: Array, y: Array) -> Array: r"""Elementwise greater-than-or-equals: :math:`x \geq y`.""" return ge_p.bind(x, y) -def gt(x, y): +def gt(x: Array, y: Array) -> Array: r"""Elementwise greater-than: :math:`x > y`.""" return gt_p.bind(x, y) -def le(x, y): +def le(x: Array, y: Array) -> Array: r"""Elementwise less-than-or-equals: :math:`x \leq y`.""" return le_p.bind(x, y) -def lt(x, y): +def lt(x: Array, y: Array) -> Array: r"""Elementwise less-than: :math:`x < y`.""" return lt_p.bind(x, y) -def convert_element_type(operand, new_dtype): +def convert_element_type(operand: Array, new_dtype: DType) -> Array: """Elementwise cast. Wraps XLA's `ConvertElementType @@ -379,7 +383,7 @@ def convert_element_type(operand, new_dtype): return convert_element_type_p.bind( operand, new_dtype=new_dtype, old_dtype=old_dtype) -def bitcast_convert_type(operand, new_dtype): +def bitcast_convert_type(operand: Array, new_dtype: DType) -> Array: """Elementwise bitcast. Wraps XLA's `BitcastConvertType @@ -402,7 +406,7 @@ def bitcast_convert_type(operand, new_dtype): else: return operand -def clamp(min, x, max): +def clamp(min: Array, x: Array, max: Array) -> Array: r"""Elementwise clamp. Returns :math:`\mathrm{clamp}(x) = \begin{cases} @@ -413,7 +417,7 @@ def clamp(min, x, max): """ return clamp_p.bind(min, x, max) -def concatenate(operands, dimension): +def concatenate(operands: Sequence[Array], dimension: int) -> Array: """Concatenates a sequence of arrays along `dimension`. Wraps XLA's `Concatenate @@ -432,10 +436,33 @@ def concatenate(operands, dimension): Precision = xla_client.PrecisionConfig.Precision Precision.__str__ = lambda precision: precision.name +PrecisionType = Any -def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None, - rhs_dilation=None, dimension_numbers=None, - feature_group_count=1, precision=None): +class ConvDimensionNumbers(NamedTuple): + """Describes batch, spatial, and feature dimensions of a convolution. + + Args: + lhs_spec: a tuple of nonnegative integer dimension numbers containing + `(batch dimension, feature dimension, spatial dimensions...)`. + rhs_spec: a tuple of nonnegative integer dimension numbers containing + `(out feature dimension, in feature dimension, spatial dimensions...)`. + out_spec: a tuple of nonnegative integer dimension numbers containing + `(batch dimension, feature dimension, spatial dimensions...)`. + """ + lhs_spec: Sequence[int] + rhs_spec: Sequence[int] + out_spec: Sequence[int] + +ConvGeneralDilatedDimensionNumbers = Union[ + None, ConvDimensionNumbers, Tuple[str, str, str]] + +def conv_general_dilated( + lhs: Array, rhs: Array, window_strides: Sequence[int], + padding: Union[str, Sequence[Tuple[int, int]]], + lhs_dilation: Optional[Sequence[int]] = None, + rhs_dilation: Optional[Sequence[int]] = None, + dimension_numbers: ConvGeneralDilatedDimensionNumbers = None, + feature_group_count: int = 1, precision: Optional[PrecisionType] = None) -> Array: """General n-dimensional convolution operator, with optional dilation. Wraps XLA's `Conv @@ -490,9 +517,11 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None, If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')` (for a 2D convolution). """ - if type(dimension_numbers) is not ConvDimensionNumbers: - dimension_numbers = conv_dimension_numbers( - lhs.shape, rhs.shape, dimension_numbers) + dnums: ConvDimensionNumbers + if isinstance(dimension_numbers, ConvDimensionNumbers): + dnums = dimension_numbers + else: + dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers) if lhs_dilation is None: lhs_dilation = (1,) * (lhs.ndim - 2) elif isinstance(padding, str) and not len(lhs_dilation) == lhs_dilation.count(1): @@ -503,7 +532,7 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None, if rhs_dilation is None: rhs_dilation = (1,) * (rhs.ndim - 2) if isinstance(padding, str): - lhs_perm, rhs_perm, _ = dimension_numbers + lhs_perm, rhs_perm, _ = dnums rhs_shape = onp.take(rhs.shape, rhs_perm)[2:] effective_rhs_shape = [(k-1) * r + 1 for k, r in zip(rhs_shape, rhs_dilation)] padding = padtype_to_pads( @@ -512,12 +541,12 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None, return conv_general_dilated_p.bind( lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding), lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation), - dimension_numbers=dimension_numbers, + dimension_numbers=dnums, feature_group_count=feature_group_count, lhs_shape=lhs.shape, rhs_shape=rhs.shape, precision=_canonicalize_precision(precision)) -def dot(lhs, rhs, precision=None): +def dot(lhs: Array, rhs: Array, precision: Optional[PrecisionType] = None) -> Array: """Vector/vector, matrix/vector, and matrix/matrix multiplication. Wraps XLA's `Dot @@ -542,7 +571,12 @@ def dot(lhs, rhs, precision=None): raise TypeError("Incompatible shapes for dot: got {} and {}.".format( lhs.shape, rhs.shape)) -def dot_general(lhs, rhs, dimension_numbers, precision=None): + +DotDimensionNumbers = Tuple[Tuple[Sequence[int], Sequence[int]], + Tuple[Sequence[int], Sequence[int]]] + +def dot_general(lhs: Array, rhs: Array, dimension_numbers: DotDimensionNumbers, + precision: Optional[PrecisionType] = None) -> Array: """More general contraction operator. Wraps XLA's `DotGeneral @@ -561,9 +595,9 @@ def dot_general(lhs, rhs, dimension_numbers, precision=None): Returns: An array containing the result. """ - contract_dims, batch_dims = dimension_numbers - contract_dims = tuple(map(tuple, contract_dims)) - batch_dims = tuple(map(tuple, batch_dims)) + contract_dims_seq, batch_dims_seq = dimension_numbers + contract_dims = tuple(map(lambda x: tuple(x), contract_dims_seq)) + batch_dims = tuple(map(lambda x: tuple(x), batch_dims_seq)) if not dtypes.issubdtype(lhs.dtype, onp.inexact): # TODO(b/134526360): XLA doesn't support bool or integer dots, so we emit a # sum of products instead. @@ -596,7 +630,7 @@ def dot_general(lhs, rhs, dimension_numbers, precision=None): dimension_numbers=(contract_dims, batch_dims), precision=_canonicalize_precision(precision)) -def broadcast(operand, sizes): +def broadcast(operand: Array, sizes: Sequence[int]) -> Array: """Broadcasts an array, adding new major dimensions. Wraps XLA's `Broadcast @@ -614,7 +648,8 @@ def broadcast(operand, sizes): dims = tuple(range(len(sizes), len(sizes) + onp.ndim(operand))) return broadcast_in_dim(operand, tuple(sizes) + onp.shape(operand), dims) -def broadcast_in_dim(operand, shape, broadcast_dimensions): +def broadcast_in_dim(operand: Array, shape: Shape, + broadcast_dimensions: Sequence[int]) -> Array: """Wraps XLA's `BroadcastInDim `_ operator. @@ -627,7 +662,8 @@ def broadcast_in_dim(operand, shape, broadcast_dimensions): operand, shape=tuple(shape), broadcast_dimensions=tuple(broadcast_dimensions)) -def reshape(operand, new_sizes, dimensions=None): +def reshape(operand: Array, new_sizes: Shape, + dimensions: Optional[Sequence[int]] = None) -> Array: """Wraps XLA's `Reshape `_ operator. @@ -639,31 +675,35 @@ def reshape(operand, new_sizes, dimensions=None): if onp.shape(operand) and same_shape and same_dims: return operand else: - return reshape_p.bind(operand, new_sizes=new_sizes, - dimensions=None if same_dims else tuple(dimensions)) + return reshape_p.bind( + operand, new_sizes=new_sizes, + dimensions=None if dimensions is None or same_dims else tuple(dimensions)) -def pad(operand, padding_value, padding_config): +def pad(operand: Array, padding_value: Array, + padding_config: Sequence[Tuple[int, int, int]]) -> Array: """Wraps XLA's `Pad `_ operator. """ return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config)) -def rev(operand, dimensions): +def rev(operand: Array, dimensions: Sequence[int]) -> Array: """Wraps XLA's `Rev `_ operator. """ return rev_p.bind(operand, dimensions=tuple(dimensions)) -def select(pred, on_true, on_false): +def select(pred: Array, on_true: Array, on_false: Array) -> Array: """Wraps XLA's `Select `_ operator. """ return select_p.bind(pred, on_true, on_false) -def slice(operand: Any, start_indices, limit_indices, strides=None): +def slice(operand: Array, start_indices: Sequence[int], + limit_indices: Sequence[int], + strides: Optional[Sequence[int]] = None) -> Array: """Wraps XLA's `Slice `_ operator. @@ -677,7 +717,8 @@ def slice(operand: Any, start_indices, limit_indices, strides=None): limit_indices=tuple(limit_indices), strides=None if strides is None else tuple(strides)) -def dynamic_slice(operand, start_indices, slice_sizes): +def dynamic_slice(operand: Array, start_indices: Sequence[Array], + slice_sizes: Shape) -> Array: """Wraps XLA's `DynamicSlice `_ operator. @@ -695,7 +736,8 @@ def dynamic_slice(operand, start_indices, slice_sizes): return dynamic_slice_p.bind(operand, *start_indices, slice_sizes=tuple(slice_sizes)) -def dynamic_update_slice(operand, update, start_indices): +def dynamic_update_slice(operand: Array, update: Array, + start_indices: Array) -> Array: """Wraps XLA's `DynamicUpdateSlice `_ operator. @@ -711,7 +753,37 @@ def dynamic_update_slice(operand, update, start_indices): start_indices = _dynamic_slice_indices(operand, start_indices) return dynamic_update_slice_p.bind(operand, update, *start_indices) -def gather(operand, start_indices, dimension_numbers, slice_sizes): + +class GatherDimensionNumbers(NamedTuple): + """ + Describes the dimension number arguments to an `XLA's Gather operator + `_. See the XLA + documentation for more details of what the dimension numbers mean. + + Args: + offset_dims: the set of dimensions in the `gather` output that offset into + an array sliced from `operand`. Must be a tuple of integers in ascending + order, each representing a dimension number of the output. + collapsed_slice_dims: the set of dimensions `i` in `operand` that have + `slice_sizes[i] == 1` and that should not have a corresponding dimension + in the output of the gather. Must be a tuple of integers in ascending + order. + start_index_map: for each dimension in `start_indices`, gives the + corresponding dimension in `operand` that is to be sliced. Must be a + tuple of integers with size equal to `start_indices.shape[-1]`. + + Unlike XLA's `GatherDimensionNumbers` structure, `index_vector_dim` is + implicit; there is always an index vector dimension and it must always be the + last dimension. To gather scalar indices, add a trailing dimension of size 1. + """ + offset_dims: Sequence[int] + collapsed_slice_dims: Sequence[int] + start_index_map: Sequence[int] + + +def gather(operand: Array, start_indices: Array, + dimension_numbers: GatherDimensionNumbers, + slice_sizes: Shape) -> Array: """Gather operator. Wraps `XLA's Gather operator @@ -737,7 +809,35 @@ def gather(operand, start_indices, dimension_numbers, slice_sizes): operand, start_indices, dimension_numbers=dimension_numbers, slice_sizes=canonicalize_shape(slice_sizes)) -def scatter_add(operand, scatter_indices, updates, dimension_numbers): + +class ScatterDimensionNumbers(NamedTuple): + """ + Describes the dimension number arguments to an `XLA's Scatter operator + `_. See the XLA + documentation for more details of what the dimension numbers mean. + + Args: + update_window_dims: the set of dimensions in the `updates` that are window + dimensions. Must be a tuple of integers in ascending + order, each representing a dimension number. + inserted_window_dims: the set of size 1 window dimensions that must be inserted + into the shape of `updates`. Must be a tuple of integers in ascending + order, each representing a dimension number of the output. These are the + mirror image of `collapsed_slice_dims` in the case of `gather`. + scatter_dims_to_operand_dims: for each dimension in `scatter_indices`, gives + the corresponding dimension in `operand`. Must be a sequence of integers + with size equal to indices.shape[-1]. + + Unlike XLA's `ScatterDimensionNumbers` structure, `index_vector_dim` is + implicit; there is always an index vector dimension and it must always be the + last dimension. To scatter scalar indices, add a trailing dimension of size 1. + """ + update_window_dims: Sequence[int] + inserted_window_dims: Sequence[int] + scatter_dims_to_operand_dims: Sequence[int] + +def scatter_add(operand: Array, scatter_indices: Array, updates: Array, + dimension_numbers: ScatterDimensionNumbers) -> Array: """Scatter-add operator. Wraps `XLA's Scatter operator @@ -763,7 +863,8 @@ def scatter_add(operand, scatter_indices, updates, dimension_numbers): operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers) -def scatter_min(operand, scatter_indices, updates, dimension_numbers): +def scatter_min(operand: Array, scatter_indices: Array, updates: Array, + dimension_numbers: ScatterDimensionNumbers) -> Array: """Scatter-min operator. Wraps `XLA's Scatter operator @@ -789,7 +890,8 @@ def scatter_min(operand, scatter_indices, updates, dimension_numbers): operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers) -def scatter_max(operand, scatter_indices, updates, dimension_numbers): +def scatter_max(operand: Array, scatter_indices: Array, updates: Array, + dimension_numbers: ScatterDimensionNumbers) -> Array: """Scatter-max operator. Wraps `XLA's Scatter operator @@ -818,7 +920,8 @@ def scatter_max(operand, scatter_indices, updates, dimension_numbers): # Define this outside of scatter to ensure cache hits. _scatter_reduction_computation = lambda x, y: y -def scatter(operand, scatter_indices, updates, dimension_numbers): +def scatter(operand: Array, scatter_indices:Array, updates: Array, + dimension_numbers: ScatterDimensionNumbers) -> Array: """Scatter-update operator. Wraps `XLA's Scatter operator @@ -848,21 +951,21 @@ def scatter(operand, scatter_indices, updates, dimension_numbers): operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers) -def index_take(src, idxs, axes): +def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array: indices = concatenate([reshape(i, [i.shape[0], 1]) for i in idxs], 1) indices = indices % onp.array([src.shape[ax] for ax in axes]) slice_sizes = list(src.shape) for ax in axes: slice_sizes[ax] = 1 - slice_sizes = tuple(slice_sizes) offset_dims = tuple(range(1, src.ndim - indices.shape[1] + 1)) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=axes, start_index_map=axes) - return gather(src, indices, dimension_numbers=dnums, slice_sizes=slice_sizes) + return gather(src, indices, dimension_numbers=dnums, + slice_sizes=tuple(slice_sizes)) -def transpose(operand, permutation): +def transpose(operand: Array, permutation: Sequence[int]) -> Array: """Wraps XLA's `Transpose `_ operator. @@ -873,7 +976,8 @@ def transpose(operand, permutation): else: return transpose_p.bind(operand, permutation=permutation) -def reduce(operand, init_value, computation, dimensions): +def reduce(operand: Array, init_value: Array, computation: Callable, + dimensions: Sequence[int]) -> Array: """Wraps XLA's `Reduce `_ operator. @@ -893,7 +997,7 @@ def _reduction_jaxpr(computation, aval): jaxpr, _, consts = pe.trace_to_jaxpr(comp, (pval, pval), instantiate=False) return jaxpr, consts -def _get_monoid_reducer(monoid_op, x): +def _get_monoid_reducer(monoid_op: Callable, x: Array) -> Optional[Callable]: aval = core.get_aval(x) dtype = _dtype(x) if (type(aval) is ConcreteArray) and aval.shape == (): @@ -909,8 +1013,9 @@ def _get_monoid_reducer(monoid_op, x): return aval.val == _get_max_identity(dtype) and _reduce_max elif monoid_op is min: return aval.val == _get_min_identity(dtype) and _reduce_min + return None -def _get_max_identity(dtype): +def _get_max_identity(dtype: DType) -> Array: if dtypes.issubdtype(dtype, onp.inexact): return onp.array(-onp.inf, dtype) elif dtypes.issubdtype(dtype, onp.integer): @@ -918,7 +1023,7 @@ def _get_max_identity(dtype): elif dtypes.issubdtype(dtype, onp.bool_): return onp.array(False, onp.bool_) -def _get_min_identity(dtype): +def _get_min_identity(dtype: DType) -> Array: if dtypes.issubdtype(dtype, onp.inexact): return onp.array(onp.inf, dtype) elif dtypes.issubdtype(dtype, onp.integer): @@ -926,26 +1031,27 @@ def _get_min_identity(dtype): elif dtypes.issubdtype(dtype, onp.bool_): return onp.array(True, onp.bool_) -def _reduce_sum(operand, axes): +def _reduce_sum(operand: Array, axes: Sequence[int]) -> Array: return reduce_sum_p.bind(operand, axes=tuple(axes)) -def _reduce_prod(operand, axes): +def _reduce_prod(operand: Array, axes: Sequence[int]) -> Array: return reduce_prod_p.bind(operand, axes=tuple(axes)) -def _reduce_max(operand, axes): +def _reduce_max(operand: Array, axes: Sequence[int]) -> Array: return reduce_max_p.bind(operand, axes=tuple(axes)) -def _reduce_min(operand, axes): +def _reduce_min(operand: Array, axes: Sequence[int]) -> Array: return reduce_min_p.bind(operand, axes=tuple(axes)) -def _reduce_or(operand, axes): +def _reduce_or(operand: Array, axes: Sequence[int]) -> Array: return reduce_or_p.bind(operand, axes=tuple(axes)) -def _reduce_and(operand, axes): +def _reduce_and(operand: Array, axes: Sequence[int]) -> Array: return reduce_and_p.bind(operand, axes=tuple(axes)) -def reduce_window(operand, init_value, computation, window_dimensions, - window_strides, padding): +def reduce_window(operand: Array, init_value: Array, computation: Callable, + window_dimensions: Shape, window_strides: Sequence[int], + padding: str) -> Array: """Wraps XLA's `ReduceWindow `_ operator. @@ -960,7 +1066,7 @@ def reduce_window(operand, init_value, computation, window_dimensions, window_dimensions=tuple(window_dimensions), window_strides=tuple(window_strides), padding=padding) -def _get_monoid_window_reducer(monoid_op, x): +def _get_monoid_window_reducer(monoid_op: Callable, x: Array) -> Optional[Callable]: aval = core.get_aval(x) if (type(aval) is ConcreteArray) and aval.shape == (): if monoid_op is add: @@ -969,13 +1075,16 @@ def _get_monoid_window_reducer(monoid_op, x): return aval.val == _get_max_identity(aval.dtype) and _reduce_window_max elif monoid_op is min: return aval.val == _get_min_identity(aval.dtype) and _reduce_window_min + return None -def _reduce_window_sum(operand, window_dimensions, window_strides, padding): +def _reduce_window_sum(operand: Array, window_dimensions: Shape, + window_strides: Sequence[int], padding: str) -> Array: return reduce_window_sum_p.bind( operand, window_dimensions=tuple(window_dimensions), window_strides=tuple(window_strides), padding=padding) -def _reduce_window_prod(operand, window_dimensions, window_strides, padding): +def _reduce_window_prod(operand: Array, window_dimensions: Shape, + window_strides: Sequence[int], padding: str) -> Array: init_value = _const(operand, 1) jaxpr, consts = _reduction_jaxpr(mul, _abstractify(init_value)) return reduce_window_p.bind( @@ -983,18 +1092,22 @@ def _reduce_window_prod(operand, window_dimensions, window_strides, padding): window_dimensions=tuple(window_dimensions), window_strides=tuple(window_strides), padding=padding) -def _reduce_window_max(operand, window_dimensions, window_strides, padding): +def _reduce_window_max(operand: Array, window_dimensions: Shape, + window_strides: Sequence[int], padding: str) -> Array: return reduce_window_max_p.bind( operand, window_dimensions=tuple(window_dimensions), window_strides=tuple(window_strides), padding=padding) -def _reduce_window_min(operand, window_dimensions, window_strides, padding): +def _reduce_window_min(operand: Array, window_dimensions: Shape, + window_strides: Sequence[int], padding: str) -> Array: return reduce_window_min_p.bind( operand, window_dimensions=tuple(window_dimensions), window_strides=tuple(window_strides), padding=padding) -def _select_and_scatter(operand, select, window_dimensions, window_strides, - padding, source, init_value, scatter): +def _select_and_scatter(operand: Array, select: Callable, + window_dimensions: Shape, window_strides: Sequence[int], + padding: str, source: Array, init_value: Array, + scatter: Callable) -> Array: select_jaxpr, select_consts = _reduction_jaxpr(select, _abstractify(init_value)) scatter_jaxpr, scatter_consts = _reduction_jaxpr(scatter, _abstractify(init_value)) return select_and_scatter_p.bind( @@ -1003,48 +1116,54 @@ def _select_and_scatter(operand, select, window_dimensions, window_strides, scatter_consts=scatter_consts, window_dimensions=tuple(window_dimensions), window_strides=tuple(window_strides), padding=padding) -def _select_and_scatter_add(source, operand, select_prim, window_dimensions, - window_strides, padding): +def _select_and_scatter_add(source: Array, operand: Array, + select_prim: core.Primitive, + window_dimensions: Shape, + window_strides: Sequence[int], + padding: str) -> Array: return select_and_scatter_add_p.bind( source, operand, select_prim=select_prim, window_dimensions=tuple(window_dimensions), window_strides=tuple(window_strides), padding=padding) -def _select_and_gather_add(tangents, operand, select_prim, window_dimensions, - window_strides, padding): +def _select_and_gather_add(tangents: Array, operand: Array, + select_prim: core.Primitive, + window_dimensions: Shape, + window_strides: Sequence[int], + padding: str) -> Array: return select_and_gather_add_p.bind( tangents, operand, select_prim=select_prim, window_dimensions=tuple(window_dimensions), window_strides=tuple(window_strides), padding=padding) -def cumsum(operand, axis: int): +def cumsum(operand: Array, axis: int) -> Array: """Computes a cumulative sum along `axis`.""" return cumsum_p.bind(operand, axis=int(axis)) -def cumprod(operand, axis: int): +def cumprod(operand: Array, axis: int) -> Array: """Computes a cumulative product along `axis`.""" return cumprod_p.bind(operand, axis=int(axis)) -def sort(operand, dimension=-1): +def sort(operand: Array, dimension: int = -1) -> Array: """Wraps XLA's `Sort `_ operator. """ return sort_p.bind(operand, dimension=dimension) -def sort_key_val(keys, values, dimension=-1): +def sort_key_val(keys: Array, values: Array, dimension: int = -1) -> Array: # TODO(mattjj): new sort_key_val is variadic result = sort_key_val_p.bind(keys, values, dimension=dimension) sorted_keys, sorted_values = result return sorted_keys, sorted_values -def top_k(operand, k): +def top_k(operand: Array, k: int) -> Array: k = int(k) if k < 0: raise ValueError("k argument to top_k must be nonnegative, got {}".format(k)) return top_k_p.bind(operand, k=k) -def tie_in(x, y): +def tie_in(x: Array, y: Array) -> Array: """Gives ``y`` a fake data dependence on ``x``. When staging to XLA (e.g. running under jit or pmap), values that don't depend @@ -1059,7 +1178,7 @@ def tie_in(x, y): return tie_in_p.bind(x, y) -def full(shape, fill_value, dtype=None): +def full(shape: Shape, fill_value: Array, dtype: Optional[DType] = None) -> Array: """Returns an array of `shape` filled with `fill_value`. Arguments: @@ -1077,7 +1196,7 @@ def full(shape, fill_value, dtype=None): fill_value = xla.device_put_p.bind(convert_element_type(fill_value, dtype)) return broadcast(fill_value, shape) -def iota(dtype, size): +def iota(dtype: DType, size: int) -> Array: """Wraps XLA's `Iota `_ operator. @@ -1088,14 +1207,14 @@ def iota(dtype, size): aval = ShapedArray((size,), dtype) return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) -def broadcasted_iota(dtype, shape, dimension): +def broadcasted_iota(dtype: DType, shape: Shape, dimension: int) -> Array: """Convenience wrapper around ``iota``.""" dtype = dtypes.canonicalize_dtype(dtype) shape = canonicalize_shape(shape) dimension = int(dimension) return broadcast_in_dim(iota(dtype, shape[dimension]), shape, [dimension]) -def _eye(dtype, shape, offset): +def _eye(dtype: DType, shape: Shape, offset: int) -> Array: """Like numpy.eye, create a 2D array with ones on a diagonal. This function exists for creating lazy identity matrices; that is, @@ -1108,7 +1227,7 @@ def _eye(dtype, shape, offset): aval = ShapedArray((N, M), dtype) return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) -def _delta(dtype, shape, axes): +def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array: """This function exists for creating lazy Kronecker delta arrays, particularly for use in jax.numpy.einsum to express traces. It differs from ``eye`` in that it can create arrays of any rank, but doesn't allow offsets.""" @@ -1120,7 +1239,7 @@ def _delta(dtype, shape, axes): aval = ShapedArray(shape, dtype) return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) -def _tri(dtype, shape, offset): +def _tri(dtype: DType, shape: Shape, offset: int) -> Array: """Like numpy.tri, create a 2D array with ones below a diagonal. This function exists for creating lazy triangular matrices, particularly for use in jax.numpy.tri.""" @@ -1157,7 +1276,8 @@ def stop_gradient(x): ### convenience wrappers around traceables -def conv(lhs, rhs, window_strides, padding, precision=None): +def conv(lhs: Array, rhs: Array, window_strides: Sequence[int], + padding: str, precision: Optional[PrecisionType] = None) -> Array: """Convenience wrapper around `conv_general_dilated`. Args: @@ -1176,8 +1296,12 @@ def conv(lhs, rhs, window_strides, padding, precision=None): return conv_general_dilated(lhs, rhs, window_strides, padding, precision=precision) -def conv_with_general_padding(lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation, precision=None): +def conv_with_general_padding(lhs: Array, rhs: Array, + window_strides: Sequence[int], + padding: Union[str, Sequence[Tuple[int, int]]], + lhs_dilation: Optional[Sequence[int]], + rhs_dilation: Optional[Sequence[int]], + precision: Optional[PrecisionType] = None) -> Array: """Convenience wrapper around `conv_general_dilated`. Args: @@ -1238,8 +1362,12 @@ def _flip_axes(x, axes): return x -def conv_transpose(lhs, rhs, strides, padding, rhs_dilation=None, - dimension_numbers=None, transpose_kernel=False, precision=None): +def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], + padding: Union[str, Sequence[Tuple[int, int]]], + rhs_dilation: Optional[Sequence[int]] = None, + dimension_numbers: ConvGeneralDilatedDimensionNumbers = None, + transpose_kernel: bool = False, + precision: Optional[PrecisionType] = None) -> Array: """Convenience wrapper for calculating the N-d convolution "transpose". This function directly calculates a fractionally strided conv rather than @@ -1286,6 +1414,7 @@ def conv_transpose(lhs, rhs, strides, padding, rhs_dilation=None, k_shape = onp.take(rhs.shape, dn.rhs_spec) k_sdims = k_shape[2:] # Calculate correct output shape given padding and strides. + pads: Union[str, Sequence[Tuple[int, int]]] if padding in {'SAME', 'VALID'}: if rhs_dilation is None: rhs_dilation = (1,) * (rhs.ndim - 2) @@ -1302,7 +1431,8 @@ def conv_transpose(lhs, rhs, strides, padding, rhs_dilation=None, precision=precision) -def full_like(x, fill_value, dtype=None, shape=None): +def full_like(x: Array, fill_value: Array, dtype: Optional[DType] = None, + shape: Optional[Shape] = None) -> Array: """Create a full array like np.full based on the example array `x`. Args: @@ -1315,19 +1445,21 @@ def full_like(x, fill_value, dtype=None, shape=None): An ndarray with the same shape as `x` with its entries set equal to `fill_value`, similar to the output of np.full. """ - shape = onp.shape(x) if shape is None else canonicalize_shape(shape) + fill_shape = onp.shape(x) if shape is None else canonicalize_shape(shape) fill_value = tie_in(x, fill_value) - return full(shape, fill_value, dtype or _dtype(x)) + return full(fill_shape, fill_value, dtype or _dtype(x)) -def collapse(operand, start_dimension, stop_dimension): +def collapse(operand: Array, start_dimension: int, stop_dimension: int) -> Array: lo, hi = start_dimension, stop_dimension size = prod(operand.shape[lo:hi]) new_shape = operand.shape[:lo] + (size,) + operand.shape[hi:] return reshape(operand, new_shape) -def slice_in_dim(operand, start_index, limit_index, stride=1, axis=0): +def slice_in_dim(operand: Array, start_index: Optional[int], + limit_index: Optional[int], + stride: int = 1, axis: int = 0)-> Array: """Convenience wrapper around slice applying to only one dimension.""" start_indices = [0] * operand.ndim limit_indices = list(operand.shape) @@ -1335,24 +1467,25 @@ def slice_in_dim(operand, start_index, limit_index, stride=1, axis=0): # translate `None` len_axis = operand.shape[axis] - start_index = start_index if start_index is not None else 0 - limit_index = limit_index if limit_index is not None else len_axis + start_index_int = int(start_index) if start_index is not None else 0 + limit_index_int = int(limit_index) if limit_index is not None else len_axis # translate negative indices - if start_index < 0: - start_index = start_index + len_axis - if limit_index < 0: - limit_index = limit_index + len_axis + if start_index_int < 0: + start_index_int = start_index_int + len_axis + if limit_index_int < 0: + limit_index_int = limit_index_int + len_axis axis = int(axis) - start_indices[axis] = int(start_index) - limit_indices[axis] = int(limit_index) + start_indices[axis] = start_index_int + limit_indices[axis] = limit_index_int strides[axis] = int(stride) return slice(operand, start_indices, limit_indices, strides) -def index_in_dim(operand, index, axis=0, keepdims=True): +def index_in_dim(operand: Array, index: int, axis: int = 0, + keepdims: bool = True) -> Array: """Convenience wrapper around slice to perform int indexing.""" index, axis = int(index), int(axis) axis_size = operand.shape[axis] @@ -1367,7 +1500,8 @@ def index_in_dim(operand, index, axis=0, keepdims=True): return reshape(result, onp.delete(operand.shape, axis)) -def dynamic_slice_in_dim(operand, start_index, slice_size, axis=0): +def dynamic_slice_in_dim(operand: Array, start_index: Array, + slice_size: int, axis: int = 0) -> Array: """Convenience wrapper around dynamic_slice applying to one dimension.""" start_indices = [_zero(start_index)] * operand.ndim slice_sizes = list(operand.shape) @@ -1378,7 +1512,8 @@ def dynamic_slice_in_dim(operand, start_index, slice_size, axis=0): return dynamic_slice(operand, start_indices, slice_sizes) -def dynamic_index_in_dim(operand, index, axis=0, keepdims=True): +def dynamic_index_in_dim(operand: Array, index: Array, axis: int = 0, + keepdims: bool = True) -> Array: """Convenience wrapper around dynamic_slice to perform int indexing.""" result = dynamic_slice_in_dim(operand, index, 1, axis) if keepdims: @@ -1387,14 +1522,16 @@ def dynamic_index_in_dim(operand, index, axis=0, keepdims=True): return reshape(result, onp.delete(operand.shape, axis)) -def dynamic_update_slice_in_dim(operand, update, start_index, axis): +def dynamic_update_slice_in_dim(operand: Array, update: Array, + start_index: Array, axis: int) -> Array: axis = int(axis) start_indices = [_zero(start_index)] * _ndim(operand) start_indices[axis] = start_index return dynamic_update_slice(operand, update, start_indices) -def dynamic_update_index_in_dim(operand, update, index, axis): +def dynamic_update_index_in_dim(operand: Array, update: Array, index: Array, + axis: int) -> Array: axis = int(axis) if _ndim(update) != _ndim(operand): assert _ndim(update) + 1 == _ndim(operand) @@ -1403,7 +1540,8 @@ def dynamic_update_index_in_dim(operand, update, index, axis): return dynamic_update_slice_in_dim(operand, update, index, axis) -def batch_matmul(lhs, rhs, precision=None): +def batch_matmul(lhs: Array, rhs: Array, + precision: Optional[PrecisionType] = None) -> Array: """Batch matrix multiplication.""" if _min(lhs.ndim, rhs.ndim) < 2: raise ValueError('Arguments to batch_matmul must be at least 2D, got {}, {}' @@ -1414,18 +1552,18 @@ def batch_matmul(lhs, rhs, precision=None): lhs_contract = (lhs.ndim - 1,) rhs_contract = (rhs.ndim - 2,) batch = tuple(range(lhs.ndim - 2)) - return dot_general(lhs, rhs, [(lhs_contract, rhs_contract), (batch, batch)], + return dot_general(lhs, rhs, ((lhs_contract, rhs_contract), (batch, batch)), precision=precision) # These functions also exist in the XLA client library, but we treat them # as non-primitive to maintain a smaller set of autodiff primitives. -def square(x): +def square(x: Array) -> Array: r"""Elementwise square: :math:`x^2`.""" return mul(x, x) -def reciprocal(x): +def reciprocal(x: Array) -> Array: r"""Elementwise reciprocal: :math:`1 \over x`.""" return div(_const(x, 1), x) @@ -1442,18 +1580,18 @@ def f_wrapped(x): @api.jit @_upcast_fp16_for_computation -def tan(x): +def tan(x: Array) -> Array: r"""Elementwise tangent: :math:`\mathrm{tan}(x)`.""" return div(sin(x), cos(x)) @api.jit -def asin(x): +def asin(x: Array) -> Array: r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`.""" return mul(_const(x, 2), atan2(x, add(_const(x, 1), sqrt(sub(_const(x, 1), square(x)))))) @api.jit -def acos(x): +def acos(x: Array) -> Array: r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`.""" return select( ne(x, _const(x, -1.0)), @@ -1461,27 +1599,27 @@ def acos(x): atan2(sqrt(sub(_const(x, 1), square(x))), add(_const(x, 1), x))), full_like(x, onp.pi)) -def atan(x): +def atan(x: Array) -> Array: r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`.""" return atan2(x, _const(x, 1)) -def sinh(x): +def sinh(x: Array) -> Array: r"""Elementwise hyperbolic sine: :math:`\mathrm{sinh}(x)`.""" return sinh_p.bind(x) -def cosh(x): +def cosh(x: Array) -> Array: r"""Elementwise hyperbolic cosine: :math:`\mathrm{cosh}(x)`.""" return cosh_p.bind(x) -def asinh(x): +def asinh(x: Array) -> Array: r"""Elementwise inverse hyperbolic sine: :math:`\mathrm{asinh}(x)`.""" return asinh_p.bind(x) -def acosh(x): +def acosh(x: Array) -> Array: r"""Elementwise inverse hyperbolic cosine: :math:`\mathrm{acosh}(x)`.""" return acosh_p.bind(x) -def atanh(x): +def atanh(x: Array) -> Array: r"""Elementwise inverse hyperbolic tangent: :math:`\mathrm{atanh}(x)`.""" return atanh_p.bind(x) @@ -3029,31 +3167,6 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): _dynamic_update_slice_batching_rule -class GatherDimensionNumbers(collections.namedtuple( - "GatherDimensionNumbers", - ["offset_dims", "collapsed_slice_dims", "start_index_map"])): - """ - Describes the dimension number arguments to an `XLA's Gather operator - `_. See the XLA - documentation for more details of what the dimension numbers mean. - - Args: - offset_dims: the set of dimensions in the `gather` output that offset into - an array sliced from `operand`. Must be a tuple of integers in ascending - order, each representing a dimension number of the output. - collapsed_slice_dims: the set of dimensions `i` in `operand` that have - `slice_sizes[i] == 1` and that should not have a corresponding dimension - in the output of the gather. Must be a tuple of integers in ascending - order. - start_index_map: for each dimension in `start_indices`, gives the - corresponding dimension in `operand` that is to be sliced. Must be a - tuple of integers with size equal to `start_indices.shape[-1]`. - - Unlike XLA's `GatherDimensionNumbers` structure, `index_vector_dim` is - implicit; there is always an index vector dimension and it must always be the - last dimension. To gather scalar indices, add a trailing dimension of size 1. - """ - def _gather_dimensions_proto(indices_shape, dimension_numbers): assert type(dimension_numbers) is GatherDimensionNumbers proto = xla_client.GatherDimensionNumbers() @@ -3168,32 +3281,6 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, batching.primitive_batchers[gather_p] = _gather_batching_rule -class ScatterDimensionNumbers(collections.namedtuple( - "ScatterDimensionNumbers", - ["update_window_dims", "inserted_window_dims", - "scatter_dims_to_operand_dims"])): - """ - Describes the dimension number arguments to an `XLA's Scatter operator - `_. See the XLA - documentation for more details of what the dimension numbers mean. - - Args: - update_window_dims: the set of dimensions in the `updates` that are window - dimensions. Must be a tuple of integers in ascending - order, each representing a dimension number. - inserted_window_dims: the set of size 1 window dimensions that must be inserted - into the shape of `updates`. Must be a tuple of integers in ascending - order, each representing a dimension number of the output. These are the - mirror image of `collapsed_slice_dims` in the case of `gather`. - scatter_dims_to_operand_dims: for each dimension in `scatter_indices`, gives - the corresponding dimension in `operand`. Must be a sequence of integers - with size equal to indices.shape[-1]. - - Unlike XLA's `ScatterDimensionNumbers` structure, `index_vector_dim` is - implicit; there is always an index vector dimension and it must always be the - last dimension. To scatter scalar indices, add a trailing dimension of size 1. - """ - def _scatter_dimensions_proto(indices_shape, dimension_numbers): assert type(dimension_numbers) is ScatterDimensionNumbers proto = xla_client.ScatterDimensionNumbers() @@ -4653,23 +4740,6 @@ def _canonicalize_precision(precision): raise ValueError(msg.format(precision)) -# lhs_spec and out_spec are lists containing -# [batch dim, feature dim, spatial dims ...] -# rhs_spec is a list containing: -# [out feature dim, in feature dim, spatial dims ...] -class ConvDimensionNumbers(collections.namedtuple( - "ConvDimensionNumbers", ["lhs_spec", "rhs_spec", "out_spec"])): - """Describes batch, spatial, and feature dimensions of a convolution. - - Args: - lhs_spec: a tuple of nonnegative integer dimension numbers containing - `(batch dimension, feature dimension, spatial dimensions...)`. - rhs_spec: a tuple of nonnegative integer dimension numbers containing - `(out feature dimension, in feature dimension, spatial dimensions...)`. - out_spec: a tuple of nonnegative integer dimension numbers containing - `(batch dimension, feature dimension, spatial dimensions...)`. - """ - def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers): """Converts convolution `dimension_numbers` to a `ConvDimensionNumbers`.