Skip to content

Commit

Permalink
Add a linalg.pinv wrapper to common tensor (#2798)
Browse files Browse the repository at this point in the history
### Changes
The following functions were added:

- linalg.py: pinv

### Reason for changes
Part of int4 with LoRA adaptation implementation.

### Related tickets
CVS-135863

### Tests
tests/shared/test_templates/template_test_nncf_tensor.py
  • Loading branch information
ljaljushkin authored Jul 10, 2024
1 parent 42ae1f8 commit e9ae8f5
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 0 deletions.
12 changes: 12 additions & 0 deletions nncf/tensor/functions/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,18 @@ def inv(a: Tensor) -> Tensor:
return Tensor(inv(a.data))


@functools.singledispatch
@tensor_guard
def pinv(a: Tensor) -> Tensor:
"""
Compute the (Moore-Penrose) pseudo-inverse of a matrix.
:param a: The input tensor of shape (*, M, N) where * is zero or more batch dimensions.
:return: The pseudo-inverse of input tensor.
"""
return Tensor(pinv(a.data))


@functools.singledispatch
@tensor_guard
def lstsq(a: Tensor, b: Tensor, driver: Optional[str] = None) -> Tensor:
Expand Down
5 changes: 5 additions & 0 deletions nncf/tensor/functions/numpy_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def _(a: Union[np.ndarray, np.generic]) -> np.ndarray:
return np.linalg.inv(a)


@register_numpy_types(linalg.pinv)
def _(a: Union[np.ndarray, np.generic]) -> np.ndarray:
return np.linalg.pinv(a)


@register_numpy_types(linalg.lstsq)
def _(a: Union[np.ndarray, np.generic], b: Union[np.ndarray, np.generic], driver: Optional[str] = None) -> np.ndarray:
return lstsq(a, b, lapack_driver=driver)[0]
Expand Down
9 changes: 9 additions & 0 deletions nncf/tensor/functions/torch_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ def _(a: torch.Tensor) -> torch.Tensor:
return torch.linalg.inv(a)


@linalg.pinv.register(torch.Tensor)
def _(a: torch.Tensor) -> torch.Tensor:
# Consider using torch.linalg.lstsq() if possible for multiplying a matrix on the left by the pseudo-inverse, as:
# torch.linalg.lstsq(A, B).solution == A.pinv() @ B
# It is always preferred to use lstsq() when possible, as it is faster and more numerically stable than computing
# the pseudo-inverse explicitly.
return torch.linalg.pinv(a)


@linalg.lstsq.register(torch.Tensor)
def _(a: torch.Tensor, b: torch.Tensor, driver: Optional[str] = None) -> torch.Tensor:
return torch.linalg.lstsq(a, b, driver=driver).solution
Expand Down
9 changes: 9 additions & 0 deletions tests/shared/test_templates/template_test_nncf_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,15 @@ def test_fn_linalg_inv(self, a, ref):
assert fns.allclose(res.data, ref_tensor)
assert res.device == tensor_a.device

def test_fn_linalg_pinv(self):
a = [[1.0], [2.0]]
A = Tensor(self.to_tensor(a))
B = fns.linalg.pinv(A)
assert isinstance(B, Tensor)
assert B.device == A.device
assert fns.allclose(A, A @ B @ A)
assert fns.allclose(B, B @ A @ B)

@pytest.mark.parametrize(
"a, k, ref",
(
Expand Down

0 comments on commit e9ae8f5

Please sign in to comment.