diff --git a/cola/linalg/diag_trace.py b/cola/linalg/diag_trace.py index 195cac48..08296bba 100644 --- a/cola/linalg/diag_trace.py +++ b/cola/linalg/diag_trace.py @@ -3,7 +3,7 @@ from cola.utils import export, dispatch from cola.ops import LinearOperator, I_like, Diagonal, Identity from cola.ops import BlockDiag, ScalarMul, Sum, Dense, Array -from cola.ops import Kronecker, KronSum +from cola.ops import Kronecker, KronSum, Product from cola.algorithms import exact_diag, approx_diag @@ -141,3 +141,8 @@ def trace(A: LinearOperator, **kwargs): @dispatch def trace(A: Kronecker, **kwargs): return product([trace(M, **kwargs) for M in A.Ms]) + + +@dispatch(cond=lambda A, **kwargs: A.Ms[0].shape[0] > A.Ms[0].shape[1]) +def trace(A: Product[LinearOperator, LinearOperator]): + return trace(Product(*reversed(A.Ms))) diff --git a/cola/linalg/inv.py b/cola/linalg/inv.py index 4a72365d..6e4243d0 100644 --- a/cola/linalg/inv.py +++ b/cola/linalg/inv.py @@ -2,8 +2,8 @@ from plum import dispatch from cola.ops import LinearOperator from cola.ops import Diagonal, Permutation -from cola.ops import Identity -from cola.ops import ScalarMul +from cola.ops import Identity, Dense +from cola.ops import ScalarMul, Sum from cola.ops import BlockDiag, Triangular from cola.ops import Kronecker, Product from cola.algorithms.cg import cg @@ -168,3 +168,16 @@ def inv(A: Diagonal, **kwargs): @dispatch def inv(A: Triangular, **kwargs): return TriangularInv(A) + + +@dispatch +def inv(A: Sum[Product[Dense,Dense], Diagonal], **kwargs): + U, V = A.Ms[0].Ms + D_inv = inv(A.Ms[1]) + I = Identity(shape=(V.shape[0], U.shape[1]), dtype=V.dtype) + return D_inv - D_inv @ U @ inv(V @ D_inv @ U + I) @ V @ D_inv + + +@dispatch +def inv(A: Sum[Diagonal, Product[Dense,Dense]], **kwargs): + return inv(Product(*A.Ms[1].Ms) + A.Ms[0]) \ No newline at end of file diff --git a/tests/linalg/operator_market.py b/tests/linalg/operator_market.py index 534270f4..d779186a 100644 --- a/tests/linalg/operator_market.py +++ b/tests/linalg/operator_market.py @@ -2,7 +2,7 @@ from cola.ops import LinearOperator, Tridiagonal, Diagonal, Identity from cola.ops import KronSum, Product from cola.ops import Triangular, Kronecker, Permutation -from cola.ops import Dense, BlockDiag, Jacobian, Hessian +from cola.ops import Dense, BlockDiag, Jacobian, Hessian, Sparse from cola.annotations import SelfAdjoint from cola.annotations import PSD from cola.utils_test import get_xnp @@ -30,6 +30,8 @@ 'square_product', # 'square_sparse', 'square_tridiagonal', + 'diagonal_plus_uv', + 'uv_plus_diagonal', } @@ -73,6 +75,7 @@ def get_test_operator(backend: str, precision: str, op_name: str, op = BlockDiag(M1, M2, multiplicities=[2, 3]) case 'prod': op = M1 @ M1.T + case ('psd', 'kron'): M1 = Dense(xnp.array([[6., 2], [2, 4]], dtype=dtype, device=device)) M2 = Dense(xnp.array([[7, 6], [6, 8]], dtype=dtype, device=device)) @@ -134,6 +137,18 @@ def get_test_operator(backend: str, precision: str, op_name: str, shape = (3, 3) sparse = Sparse(data, indices, indptr, shape) + case ('diagonal', 'plus_uv'): + D = Diagonal(xnp.array([.1, .5, .22, 8.], dtype=dtype, device=device)) + U = Dense(xnp.array([[6., 2], [2, 4], [1.2, 2], [5, 10]], dtype=dtype, device=device)) + V = Dense(xnp.array([[7, 6, 6, 8.7], [3., .2, 13, 4]], dtype=dtype, device=device)) + op = D + U @ V + + case ('uv', 'plus_diagonal'): + D = Diagonal(xnp.array([.1, .5, .22, 8.], dtype=dtype, device=device)) + U = Dense(xnp.array([[6., 2], [2, 4], [1.2, 2], [5, 10]], dtype=dtype, device=device)) + V = Dense(xnp.array([[7, 6, 6, 8.7], [3., .2, 13, 4]], dtype=dtype, device=device)) + op = U @ V + D + # Check to sure that we hit a case statement if op is None: raise ValueError(op_name) diff --git a/tests/linalg/test_diagonal.py b/tests/linalg/test_diagonal.py index e2511f55..4b4d8f83 100644 --- a/tests/linalg/test_diagonal.py +++ b/tests/linalg/test_diagonal.py @@ -63,4 +63,17 @@ def test_diagonal_diag(backend): d1 = diag(M, u) d2 = xnp.diag(M.to_dense(), u) assert d1.shape == d2.shape - assert relative_error(d1, d2) < 1e-5 \ No newline at end of file + assert relative_error(d1, d2) < 1e-5 + +@parametrize(['torch', 'jax'], ['exact', 'approx']).excluding[:,'approx'] +def test_cyclic_trace(backend, method): + xnp = get_xnp(backend) + dtype = xnp.float32 + array_fat = xnp.fixed_normal_samples((100, 200), dtype=dtype, device=None) + U = Dense(array_fat) + array_tall = xnp.fixed_normal_samples((200, 100), dtype=dtype, device=None) + V = Dense(array_tall) + A = U @ V + d1 = trace(A, method=method, tol=2e-2) + d2 = xnp.diag(array_fat @ array_tall).sum() + assert relative_error(d1, d2) < (1e-1 if method == 'approx' else 1e-5) \ No newline at end of file