diff --git a/nncf/experimental/tensor/functions.py b/nncf/experimental/tensor/functions.py index 8697966b645..7bc5e083d1b 100644 --- a/nncf/experimental/tensor/functions.py +++ b/nncf/experimental/tensor/functions.py @@ -36,7 +36,7 @@ def wrapper(*args, **kwargs): @functools.singledispatch @_tensor_guard -def device(a: TTensor) -> TensorDeviceType: +def device(a: Tensor) -> TensorDeviceType: """ Return the device of the tensor. @@ -48,7 +48,7 @@ def device(a: TTensor) -> TensorDeviceType: @functools.singledispatch @_tensor_guard -def squeeze(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: +def squeeze(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: """ Remove axes of length one from a. @@ -63,7 +63,7 @@ def squeeze(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTenso @functools.singledispatch @_tensor_guard -def flatten(a: TTensor) -> TTensor: +def flatten(a: Tensor) -> Tensor: """ Return a copy of the tensor collapsed into one dimension. @@ -75,7 +75,7 @@ def flatten(a: TTensor) -> TTensor: @functools.singledispatch @_tensor_guard -def max(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin +def max(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin """ Return the maximum of an array or maximum along an axis. @@ -88,7 +88,7 @@ def max(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: @functools.singledispatch @_tensor_guard -def min(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin +def min(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin """ Return the minimum of an array or minimum along an axis. @@ -101,7 +101,7 @@ def min(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: @functools.singledispatch @_tensor_guard -def abs(a: TTensor) -> Tensor: # pylint: disable=redefined-builtin +def abs(a: Tensor) -> Tensor: # pylint: disable=redefined-builtin """ Calculate the absolute value element-wise. @@ -113,7 +113,7 @@ def abs(a: TTensor) -> Tensor: # pylint: disable=redefined-builtin @functools.singledispatch @_tensor_guard -def astype(a: TTensor, data_type: TensorDataType) -> TTensor: +def astype(a: Tensor, data_type: TensorDataType) -> Tensor: """ Copy of the tensor, cast to a specified type. @@ -127,7 +127,7 @@ def astype(a: TTensor, data_type: TensorDataType) -> TTensor: @functools.singledispatch @_tensor_guard -def dtype(a: TTensor) -> TensorDataType: +def dtype(a: Tensor) -> TensorDataType: """ Return data type of the tensor. @@ -139,7 +139,7 @@ def dtype(a: TTensor) -> TensorDataType: @functools.singledispatch @_tensor_guard -def reshape(a: TTensor, shape: List[int]) -> TTensor: +def reshape(a: Tensor, shape: List[int]) -> Tensor: """ Gives a new shape to a tensor without changing its data. @@ -152,7 +152,7 @@ def reshape(a: TTensor, shape: List[int]) -> TTensor: @functools.singledispatch @_tensor_guard -def all(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin +def all(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin """ Test whether all tensor elements along a given axis evaluate to True. @@ -165,7 +165,7 @@ def all(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: @functools.singledispatch @_tensor_guard -def allclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> TTensor: +def allclose(a: Tensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Tensor: """ Returns True if two arrays are element-wise equal within a tolerance. @@ -191,7 +191,7 @@ def allclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, e @functools.singledispatch @_tensor_guard -def any(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin +def any(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin """ Test whether any tensor elements along a given axis evaluate to True. @@ -204,7 +204,7 @@ def any(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: @functools.singledispatch @_tensor_guard -def count_nonzero(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: +def count_nonzero(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: """ Counts the number of non-zero values in the tensor input. @@ -218,7 +218,7 @@ def count_nonzero(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> @functools.singledispatch @_tensor_guard -def isempty(a: TTensor) -> bool: +def isempty(a: Tensor) -> bool: """ Return True if input tensor is empty. @@ -230,7 +230,7 @@ def isempty(a: TTensor) -> bool: @functools.singledispatch @_tensor_guard -def isclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> TTensor: +def isclose(a: Tensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Tensor: """ Returns a boolean array where two arrays are element-wise equal within a tolerance. @@ -256,7 +256,7 @@ def isclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, eq @functools.singledispatch @_tensor_guard -def maximum(x1: TTensor, x2: TTensor) -> TTensor: +def maximum(x1: Tensor, x2: TTensor) -> Tensor: """ Element-wise maximum of tensor elements. @@ -269,7 +269,7 @@ def maximum(x1: TTensor, x2: TTensor) -> TTensor: @functools.singledispatch @_tensor_guard -def minimum(x1: TTensor, x2: TTensor) -> TTensor: +def minimum(x1: Tensor, x2: TTensor) -> Tensor: """ Element-wise minimum of tensor elements. @@ -282,7 +282,7 @@ def minimum(x1: TTensor, x2: TTensor) -> TTensor: @functools.singledispatch @_tensor_guard -def ones_like(a: TTensor) -> TTensor: +def ones_like(a: Tensor) -> Tensor: """ Return a tensor of ones with the same shape and type as a given tensor. @@ -294,7 +294,7 @@ def ones_like(a: TTensor) -> TTensor: @functools.singledispatch @_tensor_guard -def where(condition: TTensor, x: TTensor, y: TTensor) -> TTensor: +def where(condition: Tensor, x: TTensor, y: TTensor) -> Tensor: """ Return elements chosen from x or y depending on condition. @@ -314,7 +314,7 @@ def where(condition: TTensor, x: TTensor, y: TTensor) -> TTensor: @functools.singledispatch @_tensor_guard -def zeros_like(a: TTensor) -> TTensor: +def zeros_like(a: Tensor) -> Tensor: """ Return an tensor of zeros with the same shape and type as a given tensor. @@ -325,7 +325,7 @@ def zeros_like(a: TTensor) -> TTensor: @functools.singledispatch -def stack(x: List[TTensor], axis: int = 0) -> TTensor: +def stack(x: List[Tensor], axis: int = 0) -> Tensor: """ Stacks a list or deque of Tensors rank-R tensors into one Tensor rank-(R+1) tensor. @@ -343,7 +343,7 @@ def stack(x: List[TTensor], axis: int = 0) -> TTensor: @functools.singledispatch @_tensor_guard -def unstack(a: TTensor, axis: int = 0) -> List[TTensor]: +def unstack(a: Tensor, axis: int = 0) -> List[TTensor]: """ Unstack a Tensor into list. @@ -357,7 +357,7 @@ def unstack(a: TTensor, axis: int = 0) -> List[TTensor]: @functools.singledispatch @_tensor_guard -def moveaxis(a: TTensor, source: Union[int, List[int]], destination: Union[int, List[int]]) -> TTensor: +def moveaxis(a: Tensor, source: Union[int, List[int]], destination: Union[int, List[int]]) -> Tensor: """ Move axes of an array to new positions. @@ -371,7 +371,7 @@ def moveaxis(a: TTensor, source: Union[int, List[int]], destination: Union[int, @functools.singledispatch @_tensor_guard -def mean(a: TTensor, axis: Union[int, List[int]] = None, keepdims: bool = False) -> TTensor: +def mean(a: Tensor, axis: Union[int, List[int]] = None, keepdims: bool = False) -> Tensor: """ Compute the arithmetic mean along the specified axis. @@ -385,7 +385,7 @@ def mean(a: TTensor, axis: Union[int, List[int]] = None, keepdims: bool = False) @functools.singledispatch @_tensor_guard -def round(a: TTensor, decimals=0) -> TTensor: # pylint: disable=redefined-builtin +def round(a: Tensor, decimals=0) -> Tensor: # pylint: disable=redefined-builtin """ Evenly round to the given number of decimals. @@ -399,7 +399,7 @@ def round(a: TTensor, decimals=0) -> TTensor: # pylint: disable=redefined-built @functools.singledispatch @_tensor_guard -def binary_op_nowarn(a: TTensor, b: TTensor, operator_fn: Callable) -> TTensor: +def binary_op_nowarn(a: Tensor, b: TTensor, operator_fn: Callable) -> Tensor: """ Applies a binary operation to two tensors with disable warnings. @@ -413,7 +413,7 @@ def binary_op_nowarn(a: TTensor, b: TTensor, operator_fn: Callable) -> TTensor: @functools.singledispatch @_tensor_guard -def binary_reverse_op_nowarn(a: TTensor, b: TTensor, operator_fn: Callable) -> TTensor: +def binary_reverse_op_nowarn(a: Tensor, b: TTensor, operator_fn: Callable) -> Tensor: """ Applies a binary reverse operation to two tensors with disable warnings. diff --git a/nncf/experimental/tensor/tensor.py b/nncf/experimental/tensor/tensor.py index 8e9023de66e..b199663e86e 100644 --- a/nncf/experimental/tensor/tensor.py +++ b/nncf/experimental/tensor/tensor.py @@ -8,7 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from __future__ import annotations import operator from typing import Any, List, Optional, Tuple, TypeVar, Union @@ -53,7 +53,7 @@ def __bool__(self) -> bool: def __iter__(self): return TensorIterator(self.data) - def __getitem__(self, index: int) -> "Tensor": + def __getitem__(self, index: int) -> Tensor: return Tensor(self.data[index]) def __str__(self) -> str: @@ -64,77 +64,77 @@ def __repr__(self) -> str: # built-in operations - def __add__(self, other: TTensor) -> "Tensor": + def __add__(self, other: TTensor) -> Tensor: return Tensor(self.data + unwrap_tensor_data(other)) - def __radd__(self, other: TTensor) -> "Tensor": + def __radd__(self, other: TTensor) -> Tensor: return Tensor(unwrap_tensor_data(other) + self.data) - def __sub__(self, other: TTensor) -> "Tensor": + def __sub__(self, other: TTensor) -> Tensor: return Tensor(self.data - unwrap_tensor_data(other)) - def __rsub__(self, other: TTensor) -> "Tensor": + def __rsub__(self, other: TTensor) -> Tensor: return Tensor(unwrap_tensor_data(other) - self.data) - def __mul__(self, other: TTensor) -> "Tensor": + def __mul__(self, other: TTensor) -> Tensor: return Tensor(self.data * unwrap_tensor_data(other)) - def __rmul__(self, other: TTensor) -> "Tensor": + def __rmul__(self, other: TTensor) -> Tensor: return Tensor(unwrap_tensor_data(other) * self.data) - def __pow__(self, other: TTensor) -> "Tensor": + def __pow__(self, other: TTensor) -> Tensor: return Tensor(self.data ** unwrap_tensor_data(other)) - def __truediv__(self, other: TTensor) -> "Tensor": - return _call_function("_binary_op_nowarn", self, other, operator.truediv) + def __truediv__(self, other: TTensor) -> Tensor: + return _call_function("binary_op_nowarn", self, other, operator.truediv) - def __rtruediv__(self, other: TTensor) -> "Tensor": - return _call_function("_binary_reverse_op_nowarn", self, other, operator.truediv) + def __rtruediv__(self, other: TTensor) -> Tensor: + return _call_function("binary_reverse_op_nowarn", self, other, operator.truediv) - def __floordiv__(self, other: TTensor) -> "Tensor": - return _call_function("_binary_op_nowarn", self, other, operator.floordiv) + def __floordiv__(self, other: TTensor) -> Tensor: + return _call_function("binary_op_nowarn", self, other, operator.floordiv) - def __rfloordiv__(self, other: TTensor) -> "Tensor": - return _call_function("_binary_reverse_op_nowarn", self, other, operator.floordiv) + def __rfloordiv__(self, other: TTensor) -> Tensor: + return _call_function("binary_reverse_op_nowarn", self, other, operator.floordiv) - def __neg__(self) -> "Tensor": + def __neg__(self) -> Tensor: return Tensor(-self.data) # Comparison operators - def __lt__(self, other: TTensor) -> "Tensor": + def __lt__(self, other: TTensor) -> Tensor: return Tensor(self.data < unwrap_tensor_data(other)) - def __le__(self, other: TTensor) -> "Tensor": + def __le__(self, other: TTensor) -> Tensor: return Tensor(self.data <= unwrap_tensor_data(other)) - def __eq__(self, other: TTensor) -> "Tensor": + def __eq__(self, other: TTensor) -> Tensor: return Tensor(self.data == unwrap_tensor_data(other)) - def __ne__(self, other: TTensor) -> "Tensor": + def __ne__(self, other: TTensor) -> Tensor: return Tensor(self.data != unwrap_tensor_data(other)) - def __gt__(self, other: TTensor) -> "Tensor": + def __gt__(self, other: TTensor) -> Tensor: return Tensor(self.data > unwrap_tensor_data(other)) - def __ge__(self, other: TTensor) -> "Tensor": + def __ge__(self, other: TTensor) -> Tensor: return Tensor(self.data >= unwrap_tensor_data(other)) # Tensor functions - def squeeze(self, axis: Optional[Union[int, Tuple[int]]] = None) -> "Tensor": + def squeeze(self, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: return _call_function("squeeze", self, axis) - def flatten(self) -> "Tensor": + def flatten(self) -> Tensor: return _call_function("flatten", self) - def max(self, axis: Optional[TTensor] = None) -> "Tensor": + def max(self, axis: Optional[TTensor] = None) -> Tensor: return _call_function("max", self, axis) - def min(self, axis: Optional[TTensor] = None) -> "Tensor": + def min(self, axis: Optional[TTensor] = None) -> Tensor: return _call_function("min", self, axis) - def abs(self) -> "Tensor": + def abs(self) -> Tensor: return _call_function("abs", self) def isempty(self) -> bool: @@ -143,7 +143,7 @@ def isempty(self) -> bool: def astype(self, dtype: TensorDataType): return _call_function("astype", self, dtype) - def reshape(self, shape: List) -> "Tensor": + def reshape(self, shape: List) -> Tensor: return _call_function("reshape", self, shape)