Skip to content

Commit

Permalink
typehints
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Sep 26, 2023
1 parent 55a0d86 commit cc8d0ed
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 57 deletions.
54 changes: 27 additions & 27 deletions nncf/experimental/tensor/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def wrapper(*args, **kwargs):

@functools.singledispatch
@_tensor_guard
def device(a: TTensor) -> TensorDeviceType:
def device(a: Tensor) -> TensorDeviceType:
"""
Return the device of the tensor.
Expand All @@ -48,7 +48,7 @@ def device(a: TTensor) -> TensorDeviceType:

@functools.singledispatch
@_tensor_guard
def squeeze(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor:
def squeeze(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor:
"""
Remove axes of length one from a.
Expand All @@ -63,7 +63,7 @@ def squeeze(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTenso

@functools.singledispatch
@_tensor_guard
def flatten(a: TTensor) -> TTensor:
def flatten(a: Tensor) -> Tensor:
"""
Return a copy of the tensor collapsed into one dimension.
Expand All @@ -75,7 +75,7 @@ def flatten(a: TTensor) -> TTensor:

@functools.singledispatch
@_tensor_guard
def max(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin
def max(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin
"""
Return the maximum of an array or maximum along an axis.
Expand All @@ -88,7 +88,7 @@ def max(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor:

@functools.singledispatch
@_tensor_guard
def min(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin
def min(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin
"""
Return the minimum of an array or minimum along an axis.
Expand All @@ -101,7 +101,7 @@ def min(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor:

@functools.singledispatch
@_tensor_guard
def abs(a: TTensor) -> Tensor: # pylint: disable=redefined-builtin
def abs(a: Tensor) -> Tensor: # pylint: disable=redefined-builtin
"""
Calculate the absolute value element-wise.
Expand All @@ -113,7 +113,7 @@ def abs(a: TTensor) -> Tensor: # pylint: disable=redefined-builtin

@functools.singledispatch
@_tensor_guard
def astype(a: TTensor, data_type: TensorDataType) -> TTensor:
def astype(a: Tensor, data_type: TensorDataType) -> Tensor:
"""
Copy of the tensor, cast to a specified type.
Expand All @@ -127,7 +127,7 @@ def astype(a: TTensor, data_type: TensorDataType) -> TTensor:

@functools.singledispatch
@_tensor_guard
def dtype(a: TTensor) -> TensorDataType:
def dtype(a: Tensor) -> TensorDataType:
"""
Return data type of the tensor.
Expand All @@ -139,7 +139,7 @@ def dtype(a: TTensor) -> TensorDataType:

@functools.singledispatch
@_tensor_guard
def reshape(a: TTensor, shape: List[int]) -> TTensor:
def reshape(a: Tensor, shape: List[int]) -> Tensor:
"""
Gives a new shape to a tensor without changing its data.
Expand All @@ -152,7 +152,7 @@ def reshape(a: TTensor, shape: List[int]) -> TTensor:

@functools.singledispatch
@_tensor_guard
def all(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin
def all(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin
"""
Test whether all tensor elements along a given axis evaluate to True.
Expand All @@ -165,7 +165,7 @@ def all(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor:

@functools.singledispatch
@_tensor_guard
def allclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> TTensor:
def allclose(a: Tensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Tensor:
"""
Returns True if two arrays are element-wise equal within a tolerance.
Expand All @@ -191,7 +191,7 @@ def allclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, e

@functools.singledispatch
@_tensor_guard
def any(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin
def any(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin
"""
Test whether any tensor elements along a given axis evaluate to True.
Expand All @@ -204,7 +204,7 @@ def any(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor:

@functools.singledispatch
@_tensor_guard
def count_nonzero(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor:
def count_nonzero(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor:
"""
Counts the number of non-zero values in the tensor input.
Expand All @@ -218,7 +218,7 @@ def count_nonzero(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) ->

@functools.singledispatch
@_tensor_guard
def isempty(a: TTensor) -> bool:
def isempty(a: Tensor) -> bool:
"""
Return True if input tensor is empty.
Expand All @@ -230,7 +230,7 @@ def isempty(a: TTensor) -> bool:

@functools.singledispatch
@_tensor_guard
def isclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> TTensor:
def isclose(a: Tensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Tensor:
"""
Returns a boolean array where two arrays are element-wise equal within a tolerance.
Expand All @@ -256,7 +256,7 @@ def isclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, eq

@functools.singledispatch
@_tensor_guard
def maximum(x1: TTensor, x2: TTensor) -> TTensor:
def maximum(x1: Tensor, x2: TTensor) -> Tensor:
"""
Element-wise maximum of tensor elements.
Expand All @@ -269,7 +269,7 @@ def maximum(x1: TTensor, x2: TTensor) -> TTensor:

@functools.singledispatch
@_tensor_guard
def minimum(x1: TTensor, x2: TTensor) -> TTensor:
def minimum(x1: Tensor, x2: TTensor) -> Tensor:
"""
Element-wise minimum of tensor elements.
Expand All @@ -282,7 +282,7 @@ def minimum(x1: TTensor, x2: TTensor) -> TTensor:

@functools.singledispatch
@_tensor_guard
def ones_like(a: TTensor) -> TTensor:
def ones_like(a: Tensor) -> Tensor:
"""
Return a tensor of ones with the same shape and type as a given tensor.
Expand All @@ -294,7 +294,7 @@ def ones_like(a: TTensor) -> TTensor:

@functools.singledispatch
@_tensor_guard
def where(condition: TTensor, x: TTensor, y: TTensor) -> TTensor:
def where(condition: Tensor, x: TTensor, y: TTensor) -> Tensor:
"""
Return elements chosen from x or y depending on condition.
Expand All @@ -314,7 +314,7 @@ def where(condition: TTensor, x: TTensor, y: TTensor) -> TTensor:

@functools.singledispatch
@_tensor_guard
def zeros_like(a: TTensor) -> TTensor:
def zeros_like(a: Tensor) -> Tensor:
"""
Return an tensor of zeros with the same shape and type as a given tensor.
Expand All @@ -325,7 +325,7 @@ def zeros_like(a: TTensor) -> TTensor:


@functools.singledispatch
def stack(x: List[TTensor], axis: int = 0) -> TTensor:
def stack(x: List[Tensor], axis: int = 0) -> Tensor:
"""
Stacks a list or deque of Tensors rank-R tensors into one Tensor rank-(R+1) tensor.
Expand All @@ -343,7 +343,7 @@ def stack(x: List[TTensor], axis: int = 0) -> TTensor:

@functools.singledispatch
@_tensor_guard
def unstack(a: TTensor, axis: int = 0) -> List[TTensor]:
def unstack(a: Tensor, axis: int = 0) -> List[TTensor]:
"""
Unstack a Tensor into list.
Expand All @@ -357,7 +357,7 @@ def unstack(a: TTensor, axis: int = 0) -> List[TTensor]:

@functools.singledispatch
@_tensor_guard
def moveaxis(a: TTensor, source: Union[int, List[int]], destination: Union[int, List[int]]) -> TTensor:
def moveaxis(a: Tensor, source: Union[int, List[int]], destination: Union[int, List[int]]) -> Tensor:
"""
Move axes of an array to new positions.
Expand All @@ -371,7 +371,7 @@ def moveaxis(a: TTensor, source: Union[int, List[int]], destination: Union[int,

@functools.singledispatch
@_tensor_guard
def mean(a: TTensor, axis: Union[int, List[int]] = None, keepdims: bool = False) -> TTensor:
def mean(a: Tensor, axis: Union[int, List[int]] = None, keepdims: bool = False) -> Tensor:
"""
Compute the arithmetic mean along the specified axis.
Expand All @@ -385,7 +385,7 @@ def mean(a: TTensor, axis: Union[int, List[int]] = None, keepdims: bool = False)

@functools.singledispatch
@_tensor_guard
def round(a: TTensor, decimals=0) -> TTensor: # pylint: disable=redefined-builtin
def round(a: Tensor, decimals=0) -> Tensor: # pylint: disable=redefined-builtin
"""
Evenly round to the given number of decimals.
Expand All @@ -399,7 +399,7 @@ def round(a: TTensor, decimals=0) -> TTensor: # pylint: disable=redefined-built

@functools.singledispatch
@_tensor_guard
def binary_op_nowarn(a: TTensor, b: TTensor, operator_fn: Callable) -> TTensor:
def binary_op_nowarn(a: Tensor, b: TTensor, operator_fn: Callable) -> Tensor:
"""
Applies a binary operation to two tensors with disable warnings.
Expand All @@ -413,7 +413,7 @@ def binary_op_nowarn(a: TTensor, b: TTensor, operator_fn: Callable) -> TTensor:

@functools.singledispatch
@_tensor_guard
def binary_reverse_op_nowarn(a: TTensor, b: TTensor, operator_fn: Callable) -> TTensor:
def binary_reverse_op_nowarn(a: Tensor, b: TTensor, operator_fn: Callable) -> Tensor:
"""
Applies a binary reverse operation to two tensors with disable warnings.
Expand Down
60 changes: 30 additions & 30 deletions nncf/experimental/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import operator
from typing import Any, List, Optional, Tuple, TypeVar, Union
Expand Down Expand Up @@ -53,7 +53,7 @@ def __bool__(self) -> bool:
def __iter__(self):
return TensorIterator(self.data)

def __getitem__(self, index: int) -> "Tensor":
def __getitem__(self, index: int) -> Tensor:
return Tensor(self.data[index])

def __str__(self) -> str:
Expand All @@ -64,77 +64,77 @@ def __repr__(self) -> str:

# built-in operations

def __add__(self, other: TTensor) -> "Tensor":
def __add__(self, other: TTensor) -> Tensor:
return Tensor(self.data + unwrap_tensor_data(other))

def __radd__(self, other: TTensor) -> "Tensor":
def __radd__(self, other: TTensor) -> Tensor:
return Tensor(unwrap_tensor_data(other) + self.data)

def __sub__(self, other: TTensor) -> "Tensor":
def __sub__(self, other: TTensor) -> Tensor:
return Tensor(self.data - unwrap_tensor_data(other))

def __rsub__(self, other: TTensor) -> "Tensor":
def __rsub__(self, other: TTensor) -> Tensor:
return Tensor(unwrap_tensor_data(other) - self.data)

def __mul__(self, other: TTensor) -> "Tensor":
def __mul__(self, other: TTensor) -> Tensor:
return Tensor(self.data * unwrap_tensor_data(other))

def __rmul__(self, other: TTensor) -> "Tensor":
def __rmul__(self, other: TTensor) -> Tensor:
return Tensor(unwrap_tensor_data(other) * self.data)

def __pow__(self, other: TTensor) -> "Tensor":
def __pow__(self, other: TTensor) -> Tensor:
return Tensor(self.data ** unwrap_tensor_data(other))

def __truediv__(self, other: TTensor) -> "Tensor":
return _call_function("_binary_op_nowarn", self, other, operator.truediv)
def __truediv__(self, other: TTensor) -> Tensor:
return _call_function("binary_op_nowarn", self, other, operator.truediv)

def __rtruediv__(self, other: TTensor) -> "Tensor":
return _call_function("_binary_reverse_op_nowarn", self, other, operator.truediv)
def __rtruediv__(self, other: TTensor) -> Tensor:
return _call_function("binary_reverse_op_nowarn", self, other, operator.truediv)

def __floordiv__(self, other: TTensor) -> "Tensor":
return _call_function("_binary_op_nowarn", self, other, operator.floordiv)
def __floordiv__(self, other: TTensor) -> Tensor:
return _call_function("binary_op_nowarn", self, other, operator.floordiv)

def __rfloordiv__(self, other: TTensor) -> "Tensor":
return _call_function("_binary_reverse_op_nowarn", self, other, operator.floordiv)
def __rfloordiv__(self, other: TTensor) -> Tensor:
return _call_function("binary_reverse_op_nowarn", self, other, operator.floordiv)

def __neg__(self) -> "Tensor":
def __neg__(self) -> Tensor:
return Tensor(-self.data)

# Comparison operators

def __lt__(self, other: TTensor) -> "Tensor":
def __lt__(self, other: TTensor) -> Tensor:
return Tensor(self.data < unwrap_tensor_data(other))

def __le__(self, other: TTensor) -> "Tensor":
def __le__(self, other: TTensor) -> Tensor:
return Tensor(self.data <= unwrap_tensor_data(other))

def __eq__(self, other: TTensor) -> "Tensor":
def __eq__(self, other: TTensor) -> Tensor:
return Tensor(self.data == unwrap_tensor_data(other))

def __ne__(self, other: TTensor) -> "Tensor":
def __ne__(self, other: TTensor) -> Tensor:
return Tensor(self.data != unwrap_tensor_data(other))

def __gt__(self, other: TTensor) -> "Tensor":
def __gt__(self, other: TTensor) -> Tensor:
return Tensor(self.data > unwrap_tensor_data(other))

def __ge__(self, other: TTensor) -> "Tensor":
def __ge__(self, other: TTensor) -> Tensor:
return Tensor(self.data >= unwrap_tensor_data(other))

# Tensor functions

def squeeze(self, axis: Optional[Union[int, Tuple[int]]] = None) -> "Tensor":
def squeeze(self, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor:
return _call_function("squeeze", self, axis)

def flatten(self) -> "Tensor":
def flatten(self) -> Tensor:
return _call_function("flatten", self)

def max(self, axis: Optional[TTensor] = None) -> "Tensor":
def max(self, axis: Optional[TTensor] = None) -> Tensor:
return _call_function("max", self, axis)

def min(self, axis: Optional[TTensor] = None) -> "Tensor":
def min(self, axis: Optional[TTensor] = None) -> Tensor:
return _call_function("min", self, axis)

def abs(self) -> "Tensor":
def abs(self) -> Tensor:
return _call_function("abs", self)

def isempty(self) -> bool:
Expand All @@ -143,7 +143,7 @@ def isempty(self) -> bool:
def astype(self, dtype: TensorDataType):
return _call_function("astype", self, dtype)

def reshape(self, shape: List) -> "Tensor":
def reshape(self, shape: List) -> Tensor:
return _call_function("reshape", self, shape)


Expand Down

0 comments on commit cc8d0ed

Please sign in to comment.