Skip to content

Commit

Permalink
typehints
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Sep 26, 2023
1 parent cc8d0ed commit 5cd62cf
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 70 deletions.
30 changes: 15 additions & 15 deletions nncf/experimental/tensor/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def device(a: Tensor) -> TensorDeviceType:

@functools.singledispatch
@_tensor_guard
def squeeze(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor:
def squeeze(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor:
"""
Remove axes of length one from a.
Expand All @@ -75,7 +75,7 @@ def flatten(a: Tensor) -> Tensor:

@functools.singledispatch
@_tensor_guard
def max(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin
def max(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: # pylint: disable=redefined-builtin
"""
Return the maximum of an array or maximum along an axis.
Expand All @@ -88,7 +88,7 @@ def max(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: #

@functools.singledispatch
@_tensor_guard
def min(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin
def min(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: # pylint: disable=redefined-builtin
"""
Return the minimum of an array or minimum along an axis.
Expand Down Expand Up @@ -139,7 +139,7 @@ def dtype(a: Tensor) -> TensorDataType:

@functools.singledispatch
@_tensor_guard
def reshape(a: Tensor, shape: List[int]) -> Tensor:
def reshape(a: Tensor, shape: Tuple[int, ...]) -> Tensor:
"""
Gives a new shape to a tensor without changing its data.
Expand All @@ -152,7 +152,7 @@ def reshape(a: Tensor, shape: List[int]) -> Tensor:

@functools.singledispatch
@_tensor_guard
def all(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin
def all(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: # pylint: disable=redefined-builtin
"""
Test whether all tensor elements along a given axis evaluate to True.
Expand Down Expand Up @@ -191,7 +191,7 @@ def allclose(a: Tensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, eq

@functools.singledispatch
@_tensor_guard
def any(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: # pylint: disable=redefined-builtin
def any(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: # pylint: disable=redefined-builtin
"""
Test whether any tensor elements along a given axis evaluate to True.
Expand All @@ -204,7 +204,7 @@ def any(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor: #

@functools.singledispatch
@_tensor_guard
def count_nonzero(a: Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor:
def count_nonzero(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor:
"""
Counts the number of non-zero values in the tensor input.
Expand Down Expand Up @@ -357,7 +357,7 @@ def unstack(a: Tensor, axis: int = 0) -> List[TTensor]:

@functools.singledispatch
@_tensor_guard
def moveaxis(a: Tensor, source: Union[int, List[int]], destination: Union[int, List[int]]) -> Tensor:
def moveaxis(a: Tensor, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> Tensor:
"""
Move axes of an array to new positions.
Expand All @@ -371,7 +371,7 @@ def moveaxis(a: Tensor, source: Union[int, List[int]], destination: Union[int, L

@functools.singledispatch
@_tensor_guard
def mean(a: Tensor, axis: Union[int, List[int]] = None, keepdims: bool = False) -> Tensor:
def mean(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Tensor:
"""
Compute the arithmetic mean along the specified axis.
Expand Down Expand Up @@ -399,30 +399,30 @@ def round(a: Tensor, decimals=0) -> Tensor: # pylint: disable=redefined-builtin

@functools.singledispatch
@_tensor_guard
def binary_op_nowarn(a: Tensor, b: TTensor, operator_fn: Callable) -> Tensor:
def _binary_op_nowarn(a: Tensor, b: TTensor, operator_fn: Callable) -> Tensor:
"""
Applies a binary operation to two tensors with disable warnings.
Applies a binary operation with disable warnings.
:param a: The first tensor.
:param b: The second tensor.
:param operator_fn: The binary operation function.
:return: The result of the binary operation.
"""
return Tensor(binary_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn))
return Tensor(_binary_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn))


@functools.singledispatch
@_tensor_guard
def binary_reverse_op_nowarn(a: Tensor, b: TTensor, operator_fn: Callable) -> Tensor:
def _binary_reverse_op_nowarn(a: Tensor, b: TTensor, operator_fn: Callable) -> Tensor:
"""
Applies a binary reverse operation to two tensors with disable warnings.
Applies a binary reverse operation with disable warnings.
:param a: The first tensor.
:param b: The second tensor.
:param operator_fn: The binary operation function.
:return: The result of the binary operation.
"""
return Tensor(binary_reverse_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn))
return Tensor(_binary_reverse_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn))


def _initialize_backends():
Expand Down
73 changes: 39 additions & 34 deletions nncf/experimental/tensor/numpy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,63 +43,64 @@ def inner(func):
return inner


NUMPY_TYPES = Union[np.ndarray, np.generic]


@_register_numpy_types(fns.device)
def _(a: NUMPY_TYPES) -> TensorDeviceType:
def _(a: Union[np.ndarray, np.generic]) -> TensorDeviceType:
return TensorDeviceType.CPU


@_register_numpy_types(fns.squeeze)
def _(a: NUMPY_TYPES, axis: Optional[Union[int, Tuple[int]]] = None) -> NUMPY_TYPES:
def _(
a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None
) -> Union[np.ndarray, np.generic]:
return np.squeeze(a, axis=axis)


@_register_numpy_types(fns.flatten)
def _(a: NUMPY_TYPES) -> np.ndarray:
def _(a: Union[np.ndarray, np.generic]) -> np.ndarray:
return a.flatten()


@_register_numpy_types(fns.max)
def _(a: NUMPY_TYPES, axis: Optional[Union[int, Tuple[int]]] = None) -> np.ndarray:
def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> np.ndarray:
return np.max(a, axis=axis)


@_register_numpy_types(fns.min)
def _(a: NUMPY_TYPES, axis: Optional[Union[int, Tuple[int]]] = None) -> NUMPY_TYPES:
def _(
a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None
) -> Union[np.ndarray, np.generic]:
return np.min(a, axis=axis)


@_register_numpy_types(fns.abs)
def _(a: NUMPY_TYPES) -> NUMPY_TYPES:
def _(a: Union[np.ndarray, np.generic]) -> Union[np.ndarray, np.generic]:
return np.absolute(a)


@_register_numpy_types(fns.astype)
def _(a: NUMPY_TYPES, dtype: TensorDataType) -> NUMPY_TYPES:
def _(a: Union[np.ndarray, np.generic], dtype: TensorDataType) -> Union[np.ndarray, np.generic]:
return a.astype(DTYPE_MAP[dtype])


@_register_numpy_types(fns.dtype)
def _(a: NUMPY_TYPES) -> TensorDataType:
def _(a: Union[np.ndarray, np.generic]) -> TensorDataType:
return DTYPE_MAP_REV[np.dtype(a.dtype)]


@_register_numpy_types(fns.reshape)
def _(a: NUMPY_TYPES, shape: Union[int, Tuple[int]]) -> np.ndarray:
def _(a: Union[np.ndarray, np.generic], shape: Union[int, Tuple[int, ...]]) -> np.ndarray:
return a.reshape(shape)


@_register_numpy_types(fns.all)
def _(a: NUMPY_TYPES, axis: Optional[Union[int, Tuple[int]]] = None) -> Union[np.ndarray, bool]:
def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[np.ndarray, bool]:
return np.all(a, axis=axis)


@_register_numpy_types(fns.allclose)
def _(
a: NUMPY_TYPES,
b: NUMPY_TYPES,
a: Union[np.ndarray, np.generic],
b: Union[np.ndarray, np.generic],
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
Expand All @@ -108,24 +109,24 @@ def _(


@_register_numpy_types(fns.any)
def _(a: NUMPY_TYPES, axis: Optional[Union[int, Tuple[int]]] = None) -> Union[np.ndarray, bool]:
def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[np.ndarray, bool]:
return np.any(a, axis=axis)


@_register_numpy_types(fns.count_nonzero)
def _(a: NUMPY_TYPES, axis: Optional[Union[int, Tuple[int]]] = None) -> np.ndarray:
def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> np.ndarray:
return np.array(np.count_nonzero(a, axis=axis))


@_register_numpy_types(fns.isempty)
def _(a: NUMPY_TYPES) -> bool:
def _(a: Union[np.ndarray, np.generic]) -> bool:
return a.size == 0


@_register_numpy_types(fns.isclose)
def _(
a: NUMPY_TYPES,
b: NUMPY_TYPES,
a: Union[np.ndarray, np.generic],
b: Union[np.ndarray, np.generic],
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
Expand All @@ -134,68 +135,72 @@ def _(


@_register_numpy_types(fns.maximum)
def _(x1: NUMPY_TYPES, x2: NUMPY_TYPES) -> np.ndarray:
def _(x1: Union[np.ndarray, np.generic], x2: Union[np.ndarray, np.generic]) -> np.ndarray:
return np.maximum(x1, x2)


@_register_numpy_types(fns.minimum)
def _(x1: NUMPY_TYPES, x2: NUMPY_TYPES) -> np.ndarray:
def _(x1: Union[np.ndarray, np.generic], x2: Union[np.ndarray, np.generic]) -> np.ndarray:
return np.minimum(x1, x2)


@_register_numpy_types(fns.ones_like)
def _(a: NUMPY_TYPES) -> np.ndarray:
def _(a: Union[np.ndarray, np.generic]) -> np.ndarray:
return np.ones_like(a)


@_register_numpy_types(fns.where)
def _(
condition: NUMPY_TYPES,
condition: Union[np.ndarray, np.generic],
x: Union[np.ndarray, np.number, float, bool],
y: Union[np.ndarray, float, bool],
) -> np.ndarray:
return np.where(condition, x, y)


@_register_numpy_types(fns.zeros_like)
def _(a: NUMPY_TYPES) -> np.ndarray:
def _(a: Union[np.ndarray, np.generic]) -> np.ndarray:
return np.zeros_like(a)


@_register_numpy_types(fns.stack)
def _(x: NUMPY_TYPES, axis: int = 0) -> List[np.ndarray]:
def _(x: Union[np.ndarray, np.generic], axis: int = 0) -> List[np.ndarray]:
return np.stack(x, axis=axis)


@_register_numpy_types(fns.unstack)
def _(x: NUMPY_TYPES, axis: int = 0) -> List[np.ndarray]:
def _(x: Union[np.ndarray, np.generic], axis: int = 0) -> List[np.ndarray]:
return [np.squeeze(e, axis) for e in np.split(x, x.shape[axis], axis=axis)]


@_register_numpy_types(fns.moveaxis)
def _(a: np.ndarray, source: Union[int, List[int]], destination: Union[int, List[int]]) -> np.ndarray:
def _(a: np.ndarray, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> np.ndarray:
return np.moveaxis(a, source, destination)


@_register_numpy_types(fns.mean)
def _(a: NUMPY_TYPES, axis: Union[int, List[int]] = None, keepdims: bool = False) -> np.ndarray:
def _(a: Union[np.ndarray, np.generic], axis: Union[int, Tuple[int, ...]] = None, keepdims: bool = False) -> np.ndarray:
return np.mean(a, axis=axis, keepdims=keepdims)


@_register_numpy_types(fns.round)
def _(a: NUMPY_TYPES, decimals: int = 0) -> np.ndarray:
def _(a: Union[np.ndarray, np.generic], decimals: int = 0) -> np.ndarray:
return np.round(a, decimals=decimals)


@_register_numpy_types(fns.binary_op_nowarn)
def _(a: NUMPY_TYPES, b: NUMPY_TYPES, operator_fn: Callable) -> NUMPY_TYPES:
@_register_numpy_types(fns._binary_op_nowarn)
def _(
a: Union[np.ndarray, np.generic], b: Union[np.ndarray, np.generic], operator_fn: Callable
) -> Union[np.ndarray, np.generic]:
# Run operator with disabled warning
with np.errstate(invalid="ignore", divide="ignore"):
return operator_fn(a, b)


@_register_numpy_types(fns.binary_reverse_op_nowarn)
def _(a: NUMPY_TYPES, b: NUMPY_TYPES, operator_fn: Callable) -> NUMPY_TYPES:
@_register_numpy_types(fns._binary_reverse_op_nowarn)
def _(
a: Union[np.ndarray, np.generic], b: Union[np.ndarray, np.generic], operator_fn: Callable
) -> Union[np.ndarray, np.generic]:
# Run operator with disabled warning
with np.errstate(invalid="ignore", divide="ignore"):
return operator_fn(b, a)
14 changes: 7 additions & 7 deletions nncf/experimental/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def data(self) -> TTensor:
return self._data

@property
def shape(self) -> List[int]:
def shape(self) -> Tuple[int, ...]:
return tuple(self.data.shape)

@property
Expand Down Expand Up @@ -86,16 +86,16 @@ def __pow__(self, other: TTensor) -> Tensor:
return Tensor(self.data ** unwrap_tensor_data(other))

def __truediv__(self, other: TTensor) -> Tensor:
return _call_function("binary_op_nowarn", self, other, operator.truediv)
return _call_function("_binary_op_nowarn", self, other, operator.truediv)

def __rtruediv__(self, other: TTensor) -> Tensor:
return _call_function("binary_reverse_op_nowarn", self, other, operator.truediv)
return _call_function("_binary_reverse_op_nowarn", self, other, operator.truediv)

def __floordiv__(self, other: TTensor) -> Tensor:
return _call_function("binary_op_nowarn", self, other, operator.floordiv)
return _call_function("_binary_op_nowarn", self, other, operator.floordiv)

def __rfloordiv__(self, other: TTensor) -> Tensor:
return _call_function("binary_reverse_op_nowarn", self, other, operator.floordiv)
return _call_function("_binary_reverse_op_nowarn", self, other, operator.floordiv)

def __neg__(self) -> Tensor:
return Tensor(-self.data)
Expand All @@ -122,7 +122,7 @@ def __ge__(self, other: TTensor) -> Tensor:

# Tensor functions

def squeeze(self, axis: Optional[Union[int, Tuple[int]]] = None) -> Tensor:
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor:
return _call_function("squeeze", self, axis)

def flatten(self) -> Tensor:
Expand All @@ -143,7 +143,7 @@ def isempty(self) -> bool:
def astype(self, dtype: TensorDataType):
return _call_function("astype", self, dtype)

def reshape(self, shape: List) -> Tensor:
def reshape(self, shape: Tuple[int, ...]) -> Tensor:
return _call_function("reshape", self, shape)


Expand Down
Loading

0 comments on commit 5cd62cf

Please sign in to comment.