diff --git a/nncf/tensor/functions/linalg.py b/nncf/tensor/functions/linalg.py index 441287f7d79..43d32d65c9b 100644 --- a/nncf/tensor/functions/linalg.py +++ b/nncf/tensor/functions/linalg.py @@ -113,6 +113,18 @@ def inv(a: Tensor) -> Tensor: return Tensor(inv(a.data)) +@functools.singledispatch +@tensor_guard +def pinv(a: Tensor) -> Tensor: + """ + Compute the (Moore-Penrose) pseudo-inverse of a matrix. + + :param a: The input tensor of shape (*, M, N) where * is zero or more batch dimensions. + :return: The pseudo-inverse of input tensor. + """ + return Tensor(pinv(a.data)) + + @functools.singledispatch @tensor_guard def lstsq(a: Tensor, b: Tensor, driver: Optional[str] = None) -> Tensor: diff --git a/nncf/tensor/functions/numpy_linalg.py b/nncf/tensor/functions/numpy_linalg.py index 4ecf4633725..4dfea30267b 100644 --- a/nncf/tensor/functions/numpy_linalg.py +++ b/nncf/tensor/functions/numpy_linalg.py @@ -50,6 +50,11 @@ def _(a: Union[np.ndarray, np.generic]) -> np.ndarray: return np.linalg.inv(a) +@register_numpy_types(linalg.pinv) +def _(a: Union[np.ndarray, np.generic]) -> np.ndarray: + return np.linalg.pinv(a) + + @register_numpy_types(linalg.lstsq) def _(a: Union[np.ndarray, np.generic], b: Union[np.ndarray, np.generic], driver: Optional[str] = None) -> np.ndarray: return lstsq(a, b, lapack_driver=driver)[0] diff --git a/nncf/tensor/functions/torch_linalg.py b/nncf/tensor/functions/torch_linalg.py index 19c88209d2f..d6a079d5a2d 100644 --- a/nncf/tensor/functions/torch_linalg.py +++ b/nncf/tensor/functions/torch_linalg.py @@ -40,6 +40,15 @@ def _(a: torch.Tensor) -> torch.Tensor: return torch.linalg.inv(a) +@linalg.pinv.register(torch.Tensor) +def _(a: torch.Tensor) -> torch.Tensor: + # Consider using torch.linalg.lstsq() if possible for multiplying a matrix on the left by the pseudo-inverse, as: + # torch.linalg.lstsq(A, B).solution == A.pinv() @ B + # It is always preferred to use lstsq() when possible, as it is faster and more numerically stable than computing + # the pseudo-inverse explicitly. + return torch.linalg.pinv(a) + + @linalg.lstsq.register(torch.Tensor) def _(a: torch.Tensor, b: torch.Tensor, driver: Optional[str] = None) -> torch.Tensor: return torch.linalg.lstsq(a, b, driver=driver).solution diff --git a/tests/shared/test_templates/template_test_nncf_tensor.py b/tests/shared/test_templates/template_test_nncf_tensor.py index 29b558b240b..13f2d6bc976 100644 --- a/tests/shared/test_templates/template_test_nncf_tensor.py +++ b/tests/shared/test_templates/template_test_nncf_tensor.py @@ -1349,6 +1349,15 @@ def test_fn_linalg_inv(self, a, ref): assert fns.allclose(res.data, ref_tensor) assert res.device == tensor_a.device + def test_fn_linalg_pinv(self): + a = [[1.0], [2.0]] + A = Tensor(self.to_tensor(a)) + B = fns.linalg.pinv(A) + assert isinstance(B, Tensor) + assert B.device == A.device + assert fns.allclose(A, A @ B @ A) + assert fns.allclose(B, B @ A @ B) + @pytest.mark.parametrize( "a, k, ref", (