diff --git a/nncf/experimental/tensor/functions.py b/nncf/experimental/tensor/functions.py index 7bc5e083d1b..c29fd0e50a1 100644 --- a/nncf/experimental/tensor/functions.py +++ b/nncf/experimental/tensor/functions.py @@ -48,7 +48,7 @@ def device(a: Tensor) -> TensorDeviceType: @functools.singledispatch @_tensor_guard -def squeeze(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: +def squeeze(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: """ Remove axes of length one from a. @@ -75,7 +75,7 @@ def flatten(a: Tensor) -> Tensor: @functools.singledispatch @_tensor_guard -def max(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin +def max(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: # pylint: disable=redefined-builtin """ Return the maximum of an array or maximum along an axis. @@ -88,7 +88,7 @@ def max(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # @functools.singledispatch @_tensor_guard -def min(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin +def min(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: # pylint: disable=redefined-builtin """ Return the minimum of an array or minimum along an axis. @@ -139,7 +139,7 @@ def dtype(a: Tensor) -> TensorDataType: @functools.singledispatch @_tensor_guard -def reshape(a: Tensor, shape: List[int]) -> Tensor: +def reshape(a: Tensor, shape: Tuple[int, ...]) -> Tensor: """ Gives a new shape to a tensor without changing its data. @@ -152,7 +152,7 @@ def reshape(a: Tensor, shape: List[int]) -> Tensor: @functools.singledispatch @_tensor_guard -def all(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin +def all(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: # pylint: disable=redefined-builtin """ Test whether all tensor elements along a given axis evaluate to True. @@ -191,7 +191,7 @@ def allclose(a: Tensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, eq @functools.singledispatch @_tensor_guard -def any(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin +def any(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: # pylint: disable=redefined-builtin """ Test whether any tensor elements along a given axis evaluate to True. @@ -204,7 +204,7 @@ def any(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # @functools.singledispatch @_tensor_guard -def count_nonzero(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: +def count_nonzero(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: """ Counts the number of non-zero values in the tensor input. @@ -357,7 +357,7 @@ def unstack(a: Tensor, axis: int = 0) -> List[TTensor]: @functools.singledispatch @_tensor_guard -def moveaxis(a: Tensor, source: Union[int, List[int]], destination: Union[int, List[int]]) -> Tensor: +def moveaxis(a: Tensor, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> Tensor: """ Move axes of an array to new positions. @@ -371,7 +371,7 @@ def moveaxis(a: Tensor, source: Union[int, List[int]], destination: Union[int, L @functools.singledispatch @_tensor_guard -def mean(a: Tensor, axis: Union[int, List[int]] = None, keepdims: bool = False) -> Tensor: +def mean(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Tensor: """ Compute the arithmetic mean along the specified axis. @@ -399,30 +399,30 @@ def round(a: Tensor, decimals=0) -> Tensor: # pylint: disable=redefined-builtin @functools.singledispatch @_tensor_guard -def binary_op_nowarn(a: Tensor, b: TTensor, operator_fn: Callable) -> Tensor: +def _binary_op_nowarn(a: Tensor, b: TTensor, operator_fn: Callable) -> Tensor: """ - Applies a binary operation to two tensors with disable warnings. + Applies a binary operation with disable warnings. :param a: The first tensor. :param b: The second tensor. :param operator_fn: The binary operation function. :return: The result of the binary operation. """ - return Tensor(binary_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn)) + return Tensor(_binary_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn)) @functools.singledispatch @_tensor_guard -def binary_reverse_op_nowarn(a: Tensor, b: TTensor, operator_fn: Callable) -> Tensor: +def _binary_reverse_op_nowarn(a: Tensor, b: TTensor, operator_fn: Callable) -> Tensor: """ - Applies a binary reverse operation to two tensors with disable warnings. + Applies a binary reverse operation with disable warnings. :param a: The first tensor. :param b: The second tensor. :param operator_fn: The binary operation function. :return: The result of the binary operation. """ - return Tensor(binary_reverse_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn)) + return Tensor(_binary_reverse_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn)) def _initialize_backends(): diff --git a/nncf/experimental/tensor/numpy_functions.py b/nncf/experimental/tensor/numpy_functions.py index 9959ab5dcd4..0abac90d413 100644 --- a/nncf/experimental/tensor/numpy_functions.py +++ b/nncf/experimental/tensor/numpy_functions.py @@ -43,63 +43,64 @@ def inner(func): return inner -NUMPY_TYPES = Union[np.ndarray, np.generic] - - @_register_numpy_types(fns.device) -def _(a: NUMPY_TYPES) -> TensorDeviceType: +def _(a: Union[np.ndarray, np.generic]) -> TensorDeviceType: return TensorDeviceType.CPU @_register_numpy_types(fns.squeeze) -def _(a: NUMPY_TYPES, axis: Optional[Union[int, Tuple[int]]] = None) -> NUMPY_TYPES: +def _( + a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None +) -> Union[np.ndarray, np.generic]: return np.squeeze(a, axis=axis) @_register_numpy_types(fns.flatten) -def _(a: NUMPY_TYPES) -> np.ndarray: +def _(a: Union[np.ndarray, np.generic]) -> np.ndarray: return a.flatten() @_register_numpy_types(fns.max) -def _(a: NUMPY_TYPES, axis: Optional[Union[int, Tuple[int]]] = None) -> np.ndarray: +def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> np.ndarray: return np.max(a, axis=axis) @_register_numpy_types(fns.min) -def _(a: NUMPY_TYPES, axis: Optional[Union[int, Tuple[int]]] = None) -> NUMPY_TYPES: +def _( + a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None +) -> Union[np.ndarray, np.generic]: return np.min(a, axis=axis) @_register_numpy_types(fns.abs) -def _(a: NUMPY_TYPES) -> NUMPY_TYPES: +def _(a: Union[np.ndarray, np.generic]) -> Union[np.ndarray, np.generic]: return np.absolute(a) @_register_numpy_types(fns.astype) -def _(a: NUMPY_TYPES, dtype: TensorDataType) -> NUMPY_TYPES: +def _(a: Union[np.ndarray, np.generic], dtype: TensorDataType) -> Union[np.ndarray, np.generic]: return a.astype(DTYPE_MAP[dtype]) @_register_numpy_types(fns.dtype) -def _(a: NUMPY_TYPES) -> TensorDataType: +def _(a: Union[np.ndarray, np.generic]) -> TensorDataType: return DTYPE_MAP_REV[np.dtype(a.dtype)] @_register_numpy_types(fns.reshape) -def _(a: NUMPY_TYPES, shape: Union[int, Tuple[int]]) -> np.ndarray: +def _(a: Union[np.ndarray, np.generic], shape: Union[int, Tuple[int, ...]]) -> np.ndarray: return a.reshape(shape) @_register_numpy_types(fns.all) -def _(a: NUMPY_TYPES, axis: Optional[Union[int, Tuple[int]]] = None) -> Union[np.ndarray, bool]: +def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[np.ndarray, bool]: return np.all(a, axis=axis) @_register_numpy_types(fns.allclose) def _( - a: NUMPY_TYPES, - b: NUMPY_TYPES, + a: Union[np.ndarray, np.generic], + b: Union[np.ndarray, np.generic], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False, @@ -108,24 +109,24 @@ def _( @_register_numpy_types(fns.any) -def _(a: NUMPY_TYPES, axis: Optional[Union[int, Tuple[int]]] = None) -> Union[np.ndarray, bool]: +def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[np.ndarray, bool]: return np.any(a, axis=axis) @_register_numpy_types(fns.count_nonzero) -def _(a: NUMPY_TYPES, axis: Optional[Union[int, Tuple[int]]] = None) -> np.ndarray: +def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> np.ndarray: return np.array(np.count_nonzero(a, axis=axis)) @_register_numpy_types(fns.isempty) -def _(a: NUMPY_TYPES) -> bool: +def _(a: Union[np.ndarray, np.generic]) -> bool: return a.size == 0 @_register_numpy_types(fns.isclose) def _( - a: NUMPY_TYPES, - b: NUMPY_TYPES, + a: Union[np.ndarray, np.generic], + b: Union[np.ndarray, np.generic], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False, @@ -134,23 +135,23 @@ def _( @_register_numpy_types(fns.maximum) -def _(x1: NUMPY_TYPES, x2: NUMPY_TYPES) -> np.ndarray: +def _(x1: Union[np.ndarray, np.generic], x2: Union[np.ndarray, np.generic]) -> np.ndarray: return np.maximum(x1, x2) @_register_numpy_types(fns.minimum) -def _(x1: NUMPY_TYPES, x2: NUMPY_TYPES) -> np.ndarray: +def _(x1: Union[np.ndarray, np.generic], x2: Union[np.ndarray, np.generic]) -> np.ndarray: return np.minimum(x1, x2) @_register_numpy_types(fns.ones_like) -def _(a: NUMPY_TYPES) -> np.ndarray: +def _(a: Union[np.ndarray, np.generic]) -> np.ndarray: return np.ones_like(a) @_register_numpy_types(fns.where) def _( - condition: NUMPY_TYPES, + condition: Union[np.ndarray, np.generic], x: Union[np.ndarray, np.number, float, bool], y: Union[np.ndarray, float, bool], ) -> np.ndarray: @@ -158,44 +159,48 @@ def _( @_register_numpy_types(fns.zeros_like) -def _(a: NUMPY_TYPES) -> np.ndarray: +def _(a: Union[np.ndarray, np.generic]) -> np.ndarray: return np.zeros_like(a) @_register_numpy_types(fns.stack) -def _(x: NUMPY_TYPES, axis: int = 0) -> List[np.ndarray]: +def _(x: Union[np.ndarray, np.generic], axis: int = 0) -> List[np.ndarray]: return np.stack(x, axis=axis) @_register_numpy_types(fns.unstack) -def _(x: NUMPY_TYPES, axis: int = 0) -> List[np.ndarray]: +def _(x: Union[np.ndarray, np.generic], axis: int = 0) -> List[np.ndarray]: return [np.squeeze(e, axis) for e in np.split(x, x.shape[axis], axis=axis)] @_register_numpy_types(fns.moveaxis) -def _(a: np.ndarray, source: Union[int, List[int]], destination: Union[int, List[int]]) -> np.ndarray: +def _(a: np.ndarray, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> np.ndarray: return np.moveaxis(a, source, destination) @_register_numpy_types(fns.mean) -def _(a: NUMPY_TYPES, axis: Union[int, List[int]] = None, keepdims: bool = False) -> np.ndarray: +def _(a: Union[np.ndarray, np.generic], axis: Union[int, Tuple[int, ...]] = None, keepdims: bool = False) -> np.ndarray: return np.mean(a, axis=axis, keepdims=keepdims) @_register_numpy_types(fns.round) -def _(a: NUMPY_TYPES, decimals: int = 0) -> np.ndarray: +def _(a: Union[np.ndarray, np.generic], decimals: int = 0) -> np.ndarray: return np.round(a, decimals=decimals) -@_register_numpy_types(fns.binary_op_nowarn) -def _(a: NUMPY_TYPES, b: NUMPY_TYPES, operator_fn: Callable) -> NUMPY_TYPES: +@_register_numpy_types(fns._binary_op_nowarn) +def _( + a: Union[np.ndarray, np.generic], b: Union[np.ndarray, np.generic], operator_fn: Callable +) -> Union[np.ndarray, np.generic]: # Run operator with disabled warning with np.errstate(invalid="ignore", divide="ignore"): return operator_fn(a, b) -@_register_numpy_types(fns.binary_reverse_op_nowarn) -def _(a: NUMPY_TYPES, b: NUMPY_TYPES, operator_fn: Callable) -> NUMPY_TYPES: +@_register_numpy_types(fns._binary_reverse_op_nowarn) +def _( + a: Union[np.ndarray, np.generic], b: Union[np.ndarray, np.generic], operator_fn: Callable +) -> Union[np.ndarray, np.generic]: # Run operator with disabled warning with np.errstate(invalid="ignore", divide="ignore"): return operator_fn(b, a) diff --git a/nncf/experimental/tensor/tensor.py b/nncf/experimental/tensor/tensor.py index b199663e86e..e4df5c685c7 100644 --- a/nncf/experimental/tensor/tensor.py +++ b/nncf/experimental/tensor/tensor.py @@ -32,7 +32,7 @@ def data(self) -> TTensor: return self._data @property - def shape(self) -> List[int]: + def shape(self) -> Tuple[int, ...]: return tuple(self.data.shape) @property @@ -86,16 +86,16 @@ def __pow__(self, other: TTensor) -> Tensor: return Tensor(self.data ** unwrap_tensor_data(other)) def __truediv__(self, other: TTensor) -> Tensor: - return _call_function("binary_op_nowarn", self, other, operator.truediv) + return _call_function("_binary_op_nowarn", self, other, operator.truediv) def __rtruediv__(self, other: TTensor) -> Tensor: - return _call_function("binary_reverse_op_nowarn", self, other, operator.truediv) + return _call_function("_binary_reverse_op_nowarn", self, other, operator.truediv) def __floordiv__(self, other: TTensor) -> Tensor: - return _call_function("binary_op_nowarn", self, other, operator.floordiv) + return _call_function("_binary_op_nowarn", self, other, operator.floordiv) def __rfloordiv__(self, other: TTensor) -> Tensor: - return _call_function("binary_reverse_op_nowarn", self, other, operator.floordiv) + return _call_function("_binary_reverse_op_nowarn", self, other, operator.floordiv) def __neg__(self) -> Tensor: return Tensor(-self.data) @@ -122,7 +122,7 @@ def __ge__(self, other: TTensor) -> Tensor: # Tensor functions - def squeeze(self, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: + def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: return _call_function("squeeze", self, axis) def flatten(self) -> Tensor: @@ -143,7 +143,7 @@ def isempty(self) -> bool: def astype(self, dtype: TensorDataType): return _call_function("astype", self, dtype) - def reshape(self, shape: List) -> Tensor: + def reshape(self, shape: Tuple[int, ...]) -> Tensor: return _call_function("reshape", self, shape) diff --git a/nncf/experimental/tensor/torch_functions.py b/nncf/experimental/tensor/torch_functions.py index 0bd7b64a73a..91a5b0b9a40 100644 --- a/nncf/experimental/tensor/torch_functions.py +++ b/nncf/experimental/tensor/torch_functions.py @@ -38,9 +38,12 @@ def _(a: torch.Tensor) -> TensorDeviceType: @fns.squeeze.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> torch.Tensor: +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: if axis is None: return a.squeeze() + if isinstance(axis, Tuple) and any([1 != a.shape[i] for i in axis]): + # Make Numpy behavior, torch.squeeze skips axes that are not equal to one.. + raise ValueError("Cannot select an axis to squeeze out which has size not equal to one") return a.squeeze(axis) @@ -50,7 +53,7 @@ def _(a: torch.Tensor) -> torch.Tensor: @fns.max.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> torch.Tensor: +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: # Analog of numpy.max is torch.amax if axis is None: return torch.amax(a) @@ -58,7 +61,7 @@ def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> torch.T @fns.min.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> torch.Tensor: +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: # Analog of numpy.min is torch.amin if axis is None: return torch.amin(a) @@ -81,12 +84,12 @@ def _(a: torch.Tensor) -> TensorDataType: @fns.reshape.register(torch.Tensor) -def _(a: torch.Tensor, shape: List[int]) -> torch.Tensor: +def _(a: torch.Tensor, shape: Tuple[int, ...]) -> torch.Tensor: return a.reshape(shape) @fns.all.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Union[torch.Tensor, bool]: +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[torch.Tensor, bool]: if axis is None: return torch.all(a) return torch.all(a, dim=axis) @@ -100,14 +103,14 @@ def _(a: torch.Tensor, b: torch.Tensor, rtol: float = 1e-05, atol: float = 1e-08 @fns.any.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Union[torch.Tensor, bool]: +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[torch.Tensor, bool]: if axis is None: return torch.any(a) return torch.any(a, dim=axis) @fns.count_nonzero.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> torch.Tensor: +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: return torch.count_nonzero(a, dim=axis) @@ -167,12 +170,12 @@ def _(x: torch.Tensor, axis: int = 0) -> List[torch.Tensor]: @fns.moveaxis.register(torch.Tensor) -def _(a: torch.Tensor, source: Union[int, List[int]], destination: Union[int, List[int]]) -> torch.Tensor: +def _(a: torch.Tensor, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> torch.Tensor: return torch.moveaxis(a, source, destination) @fns.mean.register(torch.Tensor) -def _(a: torch.Tensor, axis: Union[int, List[int]] = None, keepdims: bool = False) -> torch.Tensor: +def _(a: torch.Tensor, axis: Union[int, Tuple[int, ...]] = None, keepdims: bool = False) -> torch.Tensor: return torch.mean(a, axis=axis, keepdims=keepdims) @@ -181,11 +184,11 @@ def _(a: torch.Tensor, decimals=0) -> torch.Tensor: return torch.round(a, decimals=decimals) -@fns.binary_op_nowarn.register(torch.Tensor) +@fns._binary_op_nowarn.register(torch.Tensor) def _(a: torch.Tensor, b: torch.Tensor, operator_fn: Callable) -> torch.Tensor: return operator_fn(a, b) -@fns.binary_reverse_op_nowarn.register(torch.Tensor) +@fns._binary_reverse_op_nowarn.register(torch.Tensor) def _(a: torch.Tensor, b: torch.Tensor, operator_fn: Callable) -> torch.Tensor: return operator_fn(b, a) diff --git a/tests/shared/test_templates/template_test_nncf_tensor.py b/tests/shared/test_templates/template_test_nncf_tensor.py index 2212daf4081..2b982380c26 100644 --- a/tests/shared/test_templates/template_test_nncf_tensor.py +++ b/tests/shared/test_templates/template_test_nncf_tensor.py @@ -151,6 +151,7 @@ def test_comparison_int_rev(self, op_name): ([[[[1], [2]], [[1], [2]]]], None, [[1, 2], [1, 2]]), ([[[[1], [2]], [[1], [2]]]], 0, [[[1], [2]], [[1], [2]]]), ([[[[1], [2]], [[1], [2]]]], -1, [[[1, 2], [1, 2]]]), + ([[[[1], [2]], [[1], [2]]]], (0, 3), [[1, 2], [1, 2]]), ), ) def test_squeeze(self, val, axis, ref): @@ -161,6 +162,20 @@ def test_squeeze(self, val, axis, ref): assert isinstance(res, Tensor) assert fns.allclose(res, ref_tensor) + @pytest.mark.parametrize( + "val, axis, exception_type, exception_match", + ( + ([[[[1], [2]], [[1], [2]]]], (0, 1), ValueError, "not equal to one"), + ([[[[1], [2]], [[1], [2]]]], 42, IndexError, "out of"), + ([[[[1], [2]], [[1], [2]]]], (0, 42), IndexError, "out of"), + ), + ) + def test_squeeze_axis_error(self, val, axis, exception_type, exception_match): + tensor = self.to_tensor(val) + nncf_tensor = Tensor(tensor) + with pytest.raises(exception_type, match=exception_match): + nncf_tensor.squeeze(axis=axis) + @pytest.mark.parametrize( "val, axis, ref", ( @@ -471,12 +486,12 @@ def test_fn_astype(self): def test_reshape(self): tensor = Tensor(self.to_tensor([1, 1])) assert tensor.shape == (2,) - assert tensor.reshape([1, 2]).shape == (1, 2) + assert tensor.reshape((1, 2)).shape == (1, 2) def test_fn_reshape(self): tensor = Tensor(self.to_tensor([1, 1])) assert tensor.shape == (2,) - assert fns.reshape(tensor, [1, 2]).shape == (1, 2) + assert fns.reshape(tensor, (1, 2)).shape == (1, 2) def test_not_implemented(self): with pytest.raises(NotImplementedError, match="is not implemented for"):