diff --git a/pyttb/tensor.py b/pyttb/tensor.py index b7ed5c25..59b772b7 100644 --- a/pyttb/tensor.py +++ b/pyttb/tensor.py @@ -7,7 +7,6 @@ import warnings from itertools import permutations from math import factorial -from numbers import Real from typing import Any, Callable, List, Literal, Optional, Tuple, Union import numpy as np @@ -181,9 +180,9 @@ def collapse( self, dims: Optional[np.ndarray] = None, fun: Union[ - Literal["sum"], Callable[[np.ndarray], Union[Real, np.ndarray]] + Literal["sum"], Callable[[np.ndarray], Union[float, np.ndarray]] ] = "sum", - ) -> Union[Real, np.ndarray, tensor]: + ) -> Union[float, np.ndarray, tensor]: """ Collapse tensor along specified dimensions. @@ -389,7 +388,7 @@ def full(self) -> tensor: """ return ttb.tensor.from_data(self.data) - def innerprod(self, other: Union[tensor, ttb.sptensor, ttb.ktensor]) -> Real: + def innerprod(self, other: Union[tensor, ttb.sptensor, ttb.ktensor]) -> float: """ Efficient inner product with a tensor @@ -542,7 +541,7 @@ def issymmetric( return bool((all_diffs == 0).all()) return bool((all_diffs == 0).all()), all_diffs, all_perms - def logical_and(self, B: Union[Real, tensor]) -> tensor: + def logical_and(self, B: Union[float, tensor]) -> tensor: """ Logical and for tensors @@ -578,7 +577,7 @@ def logical_not(self) -> tensor: """ return ttb.tensor.from_data(np.logical_not(self.data)) - def logical_or(self, other: Union[Real, tensor]) -> tensor: + def logical_or(self, other: Union[float, tensor]) -> tensor: """ Logical or for tensors @@ -598,7 +597,7 @@ def tensor_or(x, y): return ttb.tt_tenfun(tensor_or, self, other) - def logical_xor(self, other: Union[Real, tensor]) -> tensor: + def logical_xor(self, other: Union[float, tensor]) -> tensor: """ Logical xor for tensors @@ -867,7 +866,7 @@ def reshape(self, shape: Tuple[int, ...]) -> tensor: return ttb.tensor.from_data(np.reshape(self.data, shape, order="F"), shape) - def squeeze(self) -> Union[tensor, np.ndarray, Real]: + def squeeze(self) -> Union[tensor, np.ndarray, float]: """ Removes singleton dimensions from a tensor @@ -1029,7 +1028,7 @@ def symmetrize( def ttm( self, matrix: Union[np.ndarray, List[np.ndarray]], - dims: Optional[Union[Real, np.ndarray]] = None, + dims: Optional[Union[float, np.ndarray]] = None, transpose: bool = False, ) -> tensor: """ @@ -1046,7 +1045,7 @@ def ttm( dims = np.arange(self.ndims) elif isinstance(dims, list): dims = np.array(dims) - elif isinstance(dims, Real): + elif isinstance(dims, (float, int, np.generic)): dims = np.array([dims]) if isinstance(matrix, list):