Skip to content

Commit

Permalink
[Sparse] Support converson to/from torch sparse tensor.
Browse files Browse the repository at this point in the history
  • Loading branch information
czkkkkkk committed Feb 27, 2023
1 parent c396942 commit 4f0b446
Show file tree
Hide file tree
Showing 7 changed files with 334 additions and 39 deletions.
7 changes: 4 additions & 3 deletions dgl_sparse/include/sparse/sparse_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,14 @@ class SparseMatrix : public torch::CustomClassHolder {

/**
* @brief Create a SparseMatrix from tensors in COO format.
* @param row Row indices of the COO.
* @param col Column indices of the COO.
* @param indices COO coordinates with shape (2, nnz).
* @param value Values of the sparse matrix.
* @param shape Shape of the sparse matrix.
*
* @return SparseMatrix
*/
static c10::intrusive_ptr<SparseMatrix> FromCOO(
torch::Tensor row, torch::Tensor col, torch::Tensor value,
torch::Tensor indices, torch::Tensor value,
const std::vector<int64_t>& shape);

/**
Expand Down Expand Up @@ -153,6 +152,8 @@ class SparseMatrix : public torch::CustomClassHolder {

/** @return {row, col} tensors in the COO format. */
std::tuple<torch::Tensor, torch::Tensor> COOTensors();
/** @return Stacked row and col tensors in the COO format. */
torch::Tensor Indices();
/** @return {row, col, value_indices} tensors in the CSR format. */
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
CSRTensors();
Expand Down
17 changes: 4 additions & 13 deletions dgl_sparse/src/elemenwise_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,10 @@ c10::intrusive_ptr<SparseMatrix> SpSpAdd(
const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B) {
ElementwiseOpSanityCheck(A, B);
torch::Tensor sum;
{
// TODO(#5145) This is a workaround to reduce peak memory usage. It is no
// longer needed after we address #5145.
auto torch_A = COOToTorchCOO(A->COOPtr(), A->value());
auto torch_B = COOToTorchCOO(B->COOPtr(), B->value());
sum = torch_A + torch_B;
}
sum = sum.coalesce();
auto indices = sum.indices();
auto row = indices[0];
auto col = indices[1];
return SparseMatrix::FromCOO(row, col, sum.values(), A->shape());
auto torch_A = COOToTorchCOO(A->COOPtr(), A->value());
auto torch_B = COOToTorchCOO(B->COOPtr(), B->value());
auto sum = (torch_A + torch_B).coalesce();
return SparseMatrix::FromCOO(sum.indices(), sum.values(), A->shape());
}

} // namespace sparse
Expand Down
1 change: 1 addition & 0 deletions dgl_sparse/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("device", &SparseMatrix::device)
.def("shape", &SparseMatrix::shape)
.def("coo", &SparseMatrix::COOTensors)
.def("indices", &SparseMatrix::Indices)
.def("csr", &SparseMatrix::CSRTensors)
.def("csc", &SparseMatrix::CSCTensors)
.def("transpose", &SparseMatrix::Transpose)
Expand Down
12 changes: 8 additions & 4 deletions dgl_sparse/src/sparse_matrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSCPointer(
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOO(
torch::Tensor row, torch::Tensor col, torch::Tensor value,
torch::Tensor indices, torch::Tensor value,
const std::vector<int64_t>& shape) {
auto coo = std::make_shared<COO>(
COO{shape[0], shape[1], torch::stack({row, col}), false, false});
auto coo =
std::make_shared<COO>(COO{shape[0], shape[1], indices, false, false});
return SparseMatrix::FromCOOPointer(coo, value, shape);
}

Expand Down Expand Up @@ -138,10 +138,14 @@ std::shared_ptr<CSR> SparseMatrix::CSCPtr() {

std::tuple<torch::Tensor, torch::Tensor> SparseMatrix::COOTensors() {
auto coo = COOPtr();
auto val = value();
return std::make_tuple(coo->indices.index({0}), coo->indices.index({1}));
}

torch::Tensor SparseMatrix::Indices() {
auto coo = COOPtr();
return coo->indices;
}

std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
SparseMatrix::CSRTensors() {
auto csr = CSRPtr();
Expand Down
6 changes: 2 additions & 4 deletions dgl_sparse/src/sparse_matrix_coalesce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@ namespace sparse {
c10::intrusive_ptr<SparseMatrix> SparseMatrix::Coalesce() {
auto torch_coo = COOToTorchCOO(this->COOPtr(), this->value());
auto coalesced_coo = torch_coo.coalesce();
torch::Tensor indices = coalesced_coo.indices();
torch::Tensor row = indices[0];
torch::Tensor col = indices[1];
return SparseMatrix::FromCOO(row, col, coalesced_coo.values(), this->shape());
return SparseMatrix::FromCOO(
coalesced_coo.indices(), coalesced_coo.values(), this->shape());
}

bool SparseMatrix::HasDuplicate() {
Expand Down
216 changes: 202 additions & 14 deletions python/dgl/sparse/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,35 @@ def coo(self) -> Tuple[torch.Tensor, torch.Tensor]:
--------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> A = from_coo(dst, src)
>>> A = dglsp.spmatrix(indices)
>>> A.coo()
(tensor([1, 2, 1]), tensor([2, 4, 3]))
"""
return self.c_sparse_matrix.coo()

def indices(self) -> torch.Tensor:
r"""Returns the coordinate list (COO) representation in one tensor with
shape ``(2, nnz)``.
See `COO in Wikipedia <https://en.wikipedia.org/wiki/
Sparse_matrix#Coordinate_list_(COO)>`_.
Returns
-------
torch.Tensor
Stacked COO tensor with shape ``(2, nnz)``.
Examples
--------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> A = dglsp.spmatrix(indices)
>>> A.indices()
tensor([[1, 2, 1],
[2, 4, 3]])
"""
return self.c_sparse_matrix.indices()

def csr(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Returns the compressed sparse row (CSR) representation of the sparse
matrix.
Expand All @@ -140,7 +163,7 @@ def csr(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
--------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> A = from_coo(dst, src)
>>> A = dglsp.spmatrix(indices)
>>> A.csr()
(tensor([0, 0, 2, 3]), tensor([2, 3, 4]), tensor([0, 2, 1]))
"""
Expand Down Expand Up @@ -171,7 +194,7 @@ def csc(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
--------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> A = from_coo(dst, src)
>>> A = dglsp.spmatrix(indices)
>>> A.csc()
(tensor([0, 0, 0, 1, 2, 3]), tensor([1, 1, 2]), tensor([0, 2, 1]))
"""
Expand Down Expand Up @@ -521,7 +544,18 @@ def spmatrix(
[3., 3.]]),
shape=(3, 5), nnz=3, val_size=(2,))
"""
return from_coo(indices[0], indices[1], val, shape)
if shape is None:
shape = (
torch.max(indices[0]).item() + 1,
torch.max(indices[1]).item() + 1,
)
if val is None:
val = torch.ones(indices.shape[1]).to(indices.device)

assert (
val.dim() <= 2
), "The values of a SparseMatrix can only be scalars or vectors."
return SparseMatrix(torch.ops.dgl_sparse.from_coo(indices, val, shape))


def from_coo(
Expand Down Expand Up @@ -599,16 +633,8 @@ def from_coo(
[3., 3.]]),
shape=(3, 5), nnz=3, val_size=(2,))
"""
if shape is None:
shape = (torch.max(row).item() + 1, torch.max(col).item() + 1)
if val is None:
val = torch.ones(row.shape[0]).to(row.device)

assert (
val.dim() <= 2
), "The values of a SparseMatrix can only be scalars or vectors."

return SparseMatrix(torch.ops.dgl_sparse.from_coo(row, col, val, shape))
assert row.shape[0] == col.shape[0]
return spmatrix(torch.stack([row, col]), val, shape)


def from_csr(
Expand Down Expand Up @@ -833,6 +859,168 @@ def val_like(mat: SparseMatrix, val: torch.Tensor) -> SparseMatrix:
return SparseMatrix(torch.ops.dgl_sparse.val_like(mat.c_sparse_matrix, val))


def from_torch_sparse(torch_sparse_tensor: torch.Tensor) -> SparseMatrix:
"""Creates a sparse matrix from a torch sparse tensor, which can have coo,
csr, or csc layout.
Parameters
----------
torch_sparse_tensor : torch.Tensor
Torch sparse tensor
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
>>> indices = torch.tensor([[1, 1, 2], [2, 4, 3]])
>>> val = torch.ones(3)
>>> torch_coo = torch.sparse_coo_tensor(indices, val)
>>> dglsp.from_torch_sparse(torch_coo)
SparseMatrix(indices=tensor([[1, 1, 2],
[2, 4, 3]]),
values=tensor([1., 1., 1.]),
shape=(3, 5), nnz=3)
"""
if torch_sparse_tensor.layout == torch.sparse_coo:
# Use ._indices() and ._values() to access uncoalesced indices and
# values.
return spmatrix(
torch_sparse_tensor._indices(),
torch_sparse_tensor._values(),
torch_sparse_tensor.shape[:2],
)
elif torch_sparse_tensor.layout == torch.sparse_csr:
return from_csr(
torch_sparse_tensor.crow_indices(),
torch_sparse_tensor.col_indices(),
torch_sparse_tensor.values(),
torch_sparse_tensor.shape[:2],
)
elif torch_sparse_tensor.layout == torch.sparse_csc:
return from_csc(
torch_sparse_tensor.ccol_indices(),
torch_sparse_tensor.row_indices(),
torch_sparse_tensor.values(),
torch_sparse_tensor.shape[:2],
)
else:
assert False, (
f"Cannot convert Pytorch sparse tensor"
f"{torch_sparse_tensor} to DGL sparse."
)


def to_torch_sparse_coo(spmat: SparseMatrix) -> torch.Tensor:
"""Creates a torch sparse coo tensor from a sparse matrix.
Parameters
----------
spmat : SparseMatrix
Sparse matrix
Returns
-------
torch.Tensor
torch tensor with torch.sparse_coo layout
Examples
--------
>>> indices = torch.tensor([[1, 1, 2], [2, 4, 3]])
>>> val = torch.ones(3)
>>> spmat = dglsp.spmatrix(indices, val)
>>> dglsp.to_torch_sparse_coo(spmat)
tensor(indices=tensor([[1, 1, 2],
[2, 4, 3]]),
values=tensor([1., 1., 1.]),
size=(3, 5), nnz=3, layout=torch.sparse_coo)
"""
shape = spmat.shape
if spmat.val.dim() > 1:
shape += spmat.val.shape[1:]
return torch.sparse_coo_tensor(spmat.indices(), spmat.val, shape)


def to_torch_sparse_csr(spmat: SparseMatrix) -> torch.Tensor:
"""Creates a torch sparse csr tensor from a sparse matrix.
Note that converting a sparse matrix to torch csr tensor could change the
order of non-zero values.
Parameters
----------
spmat : SparseMatrix
Sparse matrix
Returns
-------
torch.Tensor
Torch tensor with torch.sparse_csr layout
Examples
--------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> val = torch.arange(3)
>>> spmat = dglsp.spmatrix(indices, val)
>>> dglsp.to_torch_sparse_csr(spmat)
tensor(crow_indices=tensor([0, 0, 2, 3]),
col_indices=tensor([2, 3, 4]),
values=tensor([0, 2, 1]), size=(3, 5), nnz=3,
layout=torch.sparse_csr)
"""
shape = spmat.shape
if spmat.val.dim() > 1:
shape += spmat.val.shape[1:]
indptr, indices, value_indices = spmat.csr()
val = spmat.val
if value_indices is not None:
val = val[value_indices]
return torch.sparse_csr_tensor(indptr, indices, val, shape)


def to_torch_sparse_csc(spmat: SparseMatrix) -> torch.Tensor:
"""Creates a torch sparse csc tensor from a sparse matrix.
Note that converting a sparse matrix to torch csc tensor could change the
order of non-zero values.
Parameters
----------
spmat : SparseMatrix
Sparse matrix
Returns
-------
torch.Tensor
Torch tensor with torch.sparse_csc layout
Examples
--------
>>> indices = torch.tensor([[1, 2, 1], [2, 4, 3]])
>>> val = torch.arange(3)
>>> spmat = dglsp.spmatrix(indices, val)
>>> dglsp.to_torch_sparse_csc(spmat)
tensor(ccol_indices=tensor([0, 0, 0, 1, 2, 3]),
row_indices=tensor([1, 1, 2]),
values=tensor([0, 2, 1]), size=(3, 5), nnz=3,
layout=torch.sparse_csc)
"""
shape = spmat.shape
if spmat.val.dim() > 1:
shape += spmat.val.shape[1:]
indptr, indices, value_indices = spmat.csc()
val = spmat.val
if value_indices is not None:
val = val[value_indices]
return torch.sparse_csc_tensor(indptr, indices, val, shape)


def _sparse_matrix_str(spmat: SparseMatrix) -> str:
"""Internal function for converting a sparse matrix to string
representation.
Expand Down
Loading

0 comments on commit 4f0b446

Please sign in to comment.