Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dispatch optimizations #55

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion cola/linalg/diag_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from cola.utils import export, dispatch
from cola.ops import LinearOperator, I_like, Diagonal, Identity
from cola.ops import BlockDiag, ScalarMul, Sum, Dense, Array
from cola.ops import Kronecker, KronSum
from cola.ops import Kronecker, KronSum, Product
from cola.algorithms import exact_diag, approx_diag


Expand Down Expand Up @@ -141,3 +141,8 @@ def trace(A: LinearOperator, **kwargs):
@dispatch
def trace(A: Kronecker, **kwargs):
return product([trace(M, **kwargs) for M in A.Ms])


@dispatch(cond=lambda A, **kwargs: A.Ms[0].shape[0] > A.Ms[0].shape[1])
def trace(A: Product[LinearOperator, LinearOperator]):
return trace(Product(*reversed(A.Ms)))
17 changes: 15 additions & 2 deletions cola/linalg/inv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from plum import dispatch
from cola.ops import LinearOperator
from cola.ops import Diagonal, Permutation
from cola.ops import Identity
from cola.ops import ScalarMul
from cola.ops import Identity, Dense
from cola.ops import ScalarMul, Sum
from cola.ops import BlockDiag, Triangular
from cola.ops import Kronecker, Product
from cola.algorithms.cg import cg
Expand Down Expand Up @@ -168,3 +168,16 @@ def inv(A: Diagonal, **kwargs):
@dispatch
def inv(A: Triangular, **kwargs):
return TriangularInv(A)


@dispatch
def inv(A: Sum[Product[Dense,Dense], Diagonal], **kwargs):
U, V = A.Ms[0].Ms
D_inv = inv(A.Ms[1])
I = Identity(shape=(V.shape[0], U.shape[1]), dtype=V.dtype)
return D_inv - D_inv @ U @ inv(V @ D_inv @ U + I) @ V @ D_inv


@dispatch
def inv(A: Sum[Diagonal, Product[Dense,Dense]], **kwargs):
return inv(Product(*A.Ms[1].Ms) + A.Ms[0])
17 changes: 16 additions & 1 deletion tests/linalg/operator_market.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from cola.ops import LinearOperator, Tridiagonal, Diagonal, Identity
from cola.ops import KronSum, Product
from cola.ops import Triangular, Kronecker, Permutation
from cola.ops import Dense, BlockDiag, Jacobian, Hessian
from cola.ops import Dense, BlockDiag, Jacobian, Hessian, Sparse
from cola.annotations import SelfAdjoint
from cola.annotations import PSD
from cola.utils_test import get_xnp
Expand Down Expand Up @@ -30,6 +30,8 @@
'square_product',
# 'square_sparse',
'square_tridiagonal',
'diagonal_plus_uv',
'uv_plus_diagonal',
}


Expand Down Expand Up @@ -73,6 +75,7 @@ def get_test_operator(backend: str, precision: str, op_name: str,
op = BlockDiag(M1, M2, multiplicities=[2, 3])
case 'prod':
op = M1 @ M1.T

case ('psd', 'kron'):
M1 = Dense(xnp.array([[6., 2], [2, 4]], dtype=dtype, device=device))
M2 = Dense(xnp.array([[7, 6], [6, 8]], dtype=dtype, device=device))
Expand Down Expand Up @@ -134,6 +137,18 @@ def get_test_operator(backend: str, precision: str, op_name: str,
shape = (3, 3)
sparse = Sparse(data, indices, indptr, shape)

case ('diagonal', 'plus_uv'):
D = Diagonal(xnp.array([.1, .5, .22, 8.], dtype=dtype, device=device))
U = Dense(xnp.array([[6., 2], [2, 4], [1.2, 2], [5, 10]], dtype=dtype, device=device))
V = Dense(xnp.array([[7, 6, 6, 8.7], [3., .2, 13, 4]], dtype=dtype, device=device))
op = D + U @ V

case ('uv', 'plus_diagonal'):
D = Diagonal(xnp.array([.1, .5, .22, 8.], dtype=dtype, device=device))
U = Dense(xnp.array([[6., 2], [2, 4], [1.2, 2], [5, 10]], dtype=dtype, device=device))
V = Dense(xnp.array([[7, 6, 6, 8.7], [3., .2, 13, 4]], dtype=dtype, device=device))
op = U @ V + D

# Check to sure that we hit a case statement
if op is None:
raise ValueError(op_name)
Expand Down
15 changes: 14 additions & 1 deletion tests/linalg/test_diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,17 @@ def test_diagonal_diag(backend):
d1 = diag(M, u)
d2 = xnp.diag(M.to_dense(), u)
assert d1.shape == d2.shape
assert relative_error(d1, d2) < 1e-5
assert relative_error(d1, d2) < 1e-5

@parametrize(['torch', 'jax'], ['exact', 'approx']).excluding[:,'approx']
def test_cyclic_trace(backend, method):
xnp = get_xnp(backend)
dtype = xnp.float32
array_fat = xnp.fixed_normal_samples((100, 200), dtype=dtype, device=None)
U = Dense(array_fat)
array_tall = xnp.fixed_normal_samples((200, 100), dtype=dtype, device=None)
V = Dense(array_tall)
A = U @ V
d1 = trace(A, method=method, tol=2e-2)
d2 = xnp.diag(array_fat @ array_tall).sum()
assert relative_error(d1, d2) < (1e-1 if method == 'approx' else 1e-5)