Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements on Sparse API. #98

Merged
merged 4 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions cola/backends/jax_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jax import grad, jit, vjp, vmap
from jax import numpy as jnp
from jax import tree_util as tu
from jax.experimental.sparse import CSR
from jax.experimental.sparse import BCSR
from jax.lax import conj as conj_lax
from jax.lax import dynamic_slice, expand_dims
from jax.lax import fori_loop as _for_loop
Expand Down Expand Up @@ -94,9 +94,12 @@
finfo = jnp.finfo


def sparse_csr(indptr, indices, data):
N = indptr.shape[0] - 1
out = CSR((data, indices, indptr), shape=(N, N))
def to_np(array):
return jax.device_get(array)


def sparse_csr(row_pointers, col_indices, data, shape):
out = BCSR((data, col_indices, row_pointers), shape=shape)
return out


Expand Down
19 changes: 14 additions & 5 deletions cola/backends/torch_fns.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import hashlib
import logging
import torch

import optree
from torch.nn import Parameter
from torch.func import vjp, jvp
from torch.func import vmap as _vmap
import torch
from torch.func import grad as _grad
from torch.func import jvp, vjp
from torch.func import vmap as _vmap
from torch.nn import Parameter

from cola.utils.torch_tqdm import while_loop_winfo

Parameter = Parameter
Expand Down Expand Up @@ -52,7 +54,6 @@
is_array = torch.is_tensor
autograd = torch.autograd
argsort = torch.argsort
sparse_csr = torch.sparse_csr_tensor
roll = torch.roll
maximum = torch.maximum
isreal = torch.isreal
Expand All @@ -67,6 +68,14 @@
iscomplexobj = torch.is_complex


def to_np(array):
return array.detach().cpu().numpy()


def sparse_csr(row_pointers, col_indices, data, shape):
return torch.sparse_csr_tensor(crow_indices=row_pointers, col_indices=col_indices, values=data, size=shape)


def norm(array, axis=None, keepdims=False, ord=None):
return torch.linalg.norm(array, dim=axis, keepdim=keepdims, ord=ord)

Expand Down
7 changes: 6 additions & 1 deletion cola/fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from cola.ops import LinearOperator, Array
from cola.ops import Dense
from cola.ops import Kronecker, Product, KronSum, Sum
from cola.ops import ScalarMul, Transpose, Adjoint
from cola.ops import ScalarMul, Transpose, Adjoint, Sparse
from cola.ops import BlockDiag, Diagonal, Triangular, Identity
from cola.utils import export
import cola
Expand Down Expand Up @@ -135,6 +135,11 @@ def transpose(A: Triangular):
return Triangular(A.A.T, lower=not A.lower)


@dispatch
def transpose(A: Sparse):
return Sparse(A.data, A.col_indices, A.row_indices, shape=(A.shape[1], A.shape[0]))


@dispatch
def adjoint(A: LinearOperator):
return Adjoint(A)
Expand Down
31 changes: 22 additions & 9 deletions cola/ops/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
from plum import parametric
from scipy.sparse import coo_array

import cola
from cola.backends import get_library_fns
Expand Down Expand Up @@ -45,28 +46,40 @@ def __init__(self, A: Array, lower=True):


class Sparse(LinearOperator):
""" Sparse CSR linear operator.
""" Sparse linear operator.

Args:
data (array_like): 1-D array representing the nonzero values of the sparse matrix.
indices (array_like): 1-D array representing the column indices of the nonzero values.
indptr (array_like): 1-D array representing the index pointers for the rows of the matrix.
row_indices (array_like): 1-D array representing the row indices of the nonzero values.
col_indices (array_like): 1-D array representing the column indices of the nonzero values.
shape (tuple): Shape of the sparse matrix.

Example:
>>> data = jnp.array([1, 2, 3, 4, 5, 6])
>>> indices = jnp.array([0, 2, 1, 0, 2, 1])
>>> indptr = jnp.array([0, 2, 4, 6])
>>> shape = (3, 3)
>>> op = Sparse(data, indices, indptr, shape)
>>> rol_indices = jnp.array([0, 0, 1, 2, 2, 2])
>>> col_indices = jnp.array([1, 3, 3, 0, 1, 2])
>>> shape = (3, 4)
>>> op = Sparse(data, row_indices, col_indices, shape)
"""
def __init__(self, data, indices, indptr, shape):
def __init__(self, data, row_indices, col_indices, shape):
super().__init__(dtype=data.dtype, shape=shape)
self.A = self.xnp.sparse_csr(indptr, indices, data)
xnp = self.xnp
indx = xnp.argsort(row_indices)
self.data = data[indx]
self.row_indices = row_indices[indx]
self.col_indices = col_indices[indx]
A = coo_array((xnp.to_np(self.data), (xnp.to_np(self.row_indices), xnp.to_np(self.col_indices))),
shape=shape).tocsr()
row_pointers = xnp.array(A.indptr, dtype=xnp.int32, device=data.device)
indices = xnp.array(A.indices, dtype=xnp.int32, device=data.device)
self.A = xnp.sparse_csr(row_pointers, indices, self.data, shape)

def _matmat(self, V):
return self.A @ V

def _rmatmat(self, V):
return (self.T @ V.T).T


class ScalarMul(LinearOperator):
""" Linear Operator representing scalar multiplication"""
Expand Down
33 changes: 24 additions & 9 deletions tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,19 +127,34 @@ def f(x1, x2):
assert diff < 1e-10


@parametrize(['torch'])
@parametrize(tracing_backends)
def test_sparse(backend):
xnp = get_xnp(backend)
dtype = xnp.float32
A = [[0., 1., 0., 0., 0.], [0., 2., -1., 0., 0.], [0., 0., 0., 0., 0.], [6.6, 0., 0., 0., 1.4]]
A = [[0., 1., 0., 2.], [0., 0., 0., 3.], [4., 5., 6., 0.]]
A = xnp.array(A, dtype=dtype, device=None)
data = xnp.array([1., 2., -1., 6.6, 1.4], dtype=dtype, device=None)
indices = xnp.array([1, 1, 2, 0, 4], dtype=xnp.int64, device=None)
indptr = xnp.array([0, 1, 3, 3, 5], dtype=xnp.int64, device=None)
shape = (4, 5)
As = Sparse(data, indices, indptr, shape)
x = xnp.array([0.29466099, 0.71853315, -0.06172857, -0.0432496, 0.44698924], dtype=dtype, device=None)
rel_error = relative_error(A @ x, As @ x)
x1 = xnp.array([1., 0., -1., 1.], dtype=dtype, device=None)
soln1 = xnp.array([2., 3., -2.], dtype=dtype, device=None)
x2 = xnp.array([1., 2., 3.], dtype=dtype, device=None)
soln2 = xnp.array([12., 16., 18., 8.], dtype=dtype, device=None)

data = xnp.array([4., 5., 6., 1., 2., 3.], dtype=dtype, device=None)
if backend == "torch":
data.requires_grad = True
row_indices = xnp.array([2., 2., 2., 0., 0., 1.], dtype=xnp.int64, device=None)
col_indices = xnp.array([0., 1., 2., 1., 3., 3.], dtype=xnp.int64, device=None)
Aop = Sparse(data, row_indices, col_indices, shape=(3, 4))

rel_error = relative_error(Aop @ x1, soln1)
assert rel_error < _tol

rel_error = relative_error(x2 @ Aop, soln2)
assert rel_error < _tol

rel_error = relative_error(Aop.to_dense(), A)
assert rel_error < _tol

rel_error = relative_error(Aop.T.to_dense(), A.T)
assert rel_error < _tol


Expand Down
Loading