Skip to content

Commit

Permalink
Improvements on Sparse API. (#98)
Browse files Browse the repository at this point in the history
Now the sparse API requires `row_indices`, `col_indices` and `data`
which makes it easier for the user to construct the `Sparse` operator.
This is the COO "ijv" format. Since this format is not the most
efficient for matrix-multiplication, it is then internally transformed
to the CSR format which is what the API was asking the user before.
  • Loading branch information
AndPotap authored Aug 29, 2024
1 parent d9e12c0 commit 50dd4cc
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 28 deletions.
11 changes: 7 additions & 4 deletions cola/backends/jax_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jax import grad, jit, vjp, vmap
from jax import numpy as jnp
from jax import tree_util as tu
from jax.experimental.sparse import CSR
from jax.experimental.sparse import BCSR
from jax.lax import conj as conj_lax
from jax.lax import dynamic_slice, expand_dims
from jax.lax import fori_loop as _for_loop
Expand Down Expand Up @@ -94,9 +94,12 @@
finfo = jnp.finfo


def sparse_csr(indptr, indices, data):
N = indptr.shape[0] - 1
out = CSR((data, indices, indptr), shape=(N, N))
def to_np(array):
return jax.device_get(array)


def sparse_csr(row_pointers, col_indices, data, shape):
out = BCSR((data, col_indices, row_pointers), shape=shape)
return out


Expand Down
19 changes: 14 additions & 5 deletions cola/backends/torch_fns.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import hashlib
import logging
import torch

import optree
from torch.nn import Parameter
from torch.func import vjp, jvp
from torch.func import vmap as _vmap
import torch
from torch.func import grad as _grad
from torch.func import jvp, vjp
from torch.func import vmap as _vmap
from torch.nn import Parameter

from cola.utils.torch_tqdm import while_loop_winfo

Parameter = Parameter
Expand Down Expand Up @@ -52,7 +54,6 @@
is_array = torch.is_tensor
autograd = torch.autograd
argsort = torch.argsort
sparse_csr = torch.sparse_csr_tensor
roll = torch.roll
maximum = torch.maximum
isreal = torch.isreal
Expand All @@ -67,6 +68,14 @@
iscomplexobj = torch.is_complex


def to_np(array):
return array.detach().cpu().numpy()


def sparse_csr(row_pointers, col_indices, data, shape):
return torch.sparse_csr_tensor(crow_indices=row_pointers, col_indices=col_indices, values=data, size=shape)


def norm(array, axis=None, keepdims=False, ord=None):
return torch.linalg.norm(array, dim=axis, keepdim=keepdims, ord=ord)

Expand Down
7 changes: 6 additions & 1 deletion cola/fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from cola.ops import LinearOperator, Array
from cola.ops import Dense
from cola.ops import Kronecker, Product, KronSum, Sum
from cola.ops import ScalarMul, Transpose, Adjoint
from cola.ops import ScalarMul, Transpose, Adjoint, Sparse
from cola.ops import BlockDiag, Diagonal, Triangular, Identity
from cola.utils import export
import cola
Expand Down Expand Up @@ -135,6 +135,11 @@ def transpose(A: Triangular):
return Triangular(A.A.T, lower=not A.lower)


@dispatch
def transpose(A: Sparse):
return Sparse(A.data, A.col_indices, A.row_indices, shape=(A.shape[1], A.shape[0]))


@dispatch
def adjoint(A: LinearOperator):
return Adjoint(A)
Expand Down
31 changes: 22 additions & 9 deletions cola/ops/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
from plum import parametric
from scipy.sparse import coo_array

import cola
from cola.backends import get_library_fns
Expand Down Expand Up @@ -45,28 +46,40 @@ def __init__(self, A: Array, lower=True):


class Sparse(LinearOperator):
""" Sparse CSR linear operator.
""" Sparse linear operator.
Args:
data (array_like): 1-D array representing the nonzero values of the sparse matrix.
indices (array_like): 1-D array representing the column indices of the nonzero values.
indptr (array_like): 1-D array representing the index pointers for the rows of the matrix.
row_indices (array_like): 1-D array representing the row indices of the nonzero values.
col_indices (array_like): 1-D array representing the column indices of the nonzero values.
shape (tuple): Shape of the sparse matrix.
Example:
>>> data = jnp.array([1, 2, 3, 4, 5, 6])
>>> indices = jnp.array([0, 2, 1, 0, 2, 1])
>>> indptr = jnp.array([0, 2, 4, 6])
>>> shape = (3, 3)
>>> op = Sparse(data, indices, indptr, shape)
>>> rol_indices = jnp.array([0, 0, 1, 2, 2, 2])
>>> col_indices = jnp.array([1, 3, 3, 0, 1, 2])
>>> shape = (3, 4)
>>> op = Sparse(data, row_indices, col_indices, shape)
"""
def __init__(self, data, indices, indptr, shape):
def __init__(self, data, row_indices, col_indices, shape):
super().__init__(dtype=data.dtype, shape=shape)
self.A = self.xnp.sparse_csr(indptr, indices, data)
xnp = self.xnp
indx = xnp.argsort(row_indices)
self.data = data[indx]
self.row_indices = row_indices[indx]
self.col_indices = col_indices[indx]
A = coo_array((xnp.to_np(self.data), (xnp.to_np(self.row_indices), xnp.to_np(self.col_indices))),
shape=shape).tocsr()
row_pointers = xnp.array(A.indptr, dtype=xnp.int32, device=data.device)
indices = xnp.array(A.indices, dtype=xnp.int32, device=data.device)
self.A = xnp.sparse_csr(row_pointers, indices, self.data, shape)

def _matmat(self, V):
return self.A @ V

def _rmatmat(self, V):
return (self.T @ V.T).T


class ScalarMul(LinearOperator):
""" Linear Operator representing scalar multiplication"""
Expand Down
33 changes: 24 additions & 9 deletions tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,19 +127,34 @@ def f(x1, x2):
assert diff < 1e-10


@parametrize(['torch'])
@parametrize(tracing_backends)
def test_sparse(backend):
xnp = get_xnp(backend)
dtype = xnp.float32
A = [[0., 1., 0., 0., 0.], [0., 2., -1., 0., 0.], [0., 0., 0., 0., 0.], [6.6, 0., 0., 0., 1.4]]
A = [[0., 1., 0., 2.], [0., 0., 0., 3.], [4., 5., 6., 0.]]
A = xnp.array(A, dtype=dtype, device=None)
data = xnp.array([1., 2., -1., 6.6, 1.4], dtype=dtype, device=None)
indices = xnp.array([1, 1, 2, 0, 4], dtype=xnp.int64, device=None)
indptr = xnp.array([0, 1, 3, 3, 5], dtype=xnp.int64, device=None)
shape = (4, 5)
As = Sparse(data, indices, indptr, shape)
x = xnp.array([0.29466099, 0.71853315, -0.06172857, -0.0432496, 0.44698924], dtype=dtype, device=None)
rel_error = relative_error(A @ x, As @ x)
x1 = xnp.array([1., 0., -1., 1.], dtype=dtype, device=None)
soln1 = xnp.array([2., 3., -2.], dtype=dtype, device=None)
x2 = xnp.array([1., 2., 3.], dtype=dtype, device=None)
soln2 = xnp.array([12., 16., 18., 8.], dtype=dtype, device=None)

data = xnp.array([4., 5., 6., 1., 2., 3.], dtype=dtype, device=None)
if backend == "torch":
data.requires_grad = True
row_indices = xnp.array([2., 2., 2., 0., 0., 1.], dtype=xnp.int64, device=None)
col_indices = xnp.array([0., 1., 2., 1., 3., 3.], dtype=xnp.int64, device=None)
Aop = Sparse(data, row_indices, col_indices, shape=(3, 4))

rel_error = relative_error(Aop @ x1, soln1)
assert rel_error < _tol

rel_error = relative_error(x2 @ Aop, soln2)
assert rel_error < _tol

rel_error = relative_error(Aop.to_dense(), A)
assert rel_error < _tol

rel_error = relative_error(Aop.T.to_dense(), A.T)
assert rel_error < _tol


Expand Down

0 comments on commit 50dd4cc

Please sign in to comment.