diff --git a/test/nn/conv/test_message_passing.py b/test/nn/conv/test_message_passing.py index 324c2d7acaff..b55f412a33d6 100644 --- a/test/nn/conv/test_message_passing.py +++ b/test/nn/conv/test_message_passing.py @@ -47,28 +47,30 @@ def message_and_aggregate(self, adj_t: SparseTensor, return spmm(adj_t, x[0], reduce=self.aggr) -def test_my_conv(): +def test_my_conv_basic(): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index value = torch.randn(row.size(0)) adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) - torch_adj = adj.to_torch_sparse_coo_tensor() + torch_adj_t = adj.to_torch_sparse_csr_tensor().t() + torch_adj_t = torch_adj_t.to_sparse(layout=torch.sparse_csr) conv = MyConv(8, 32) out = conv(x1, edge_index, value) assert out.size() == (4, 32) assert torch.allclose(conv(x1, edge_index, value, (4, 4)), out, atol=1e-6) assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6) - assert torch.allclose(conv(x1, torch_adj.t()), out, atol=1e-6) + assert torch.allclose(conv(x1, torch_adj_t), out, atol=1e-6) conv.fuse = False assert torch.allclose(conv(x1, adj.t()), out) - assert torch.allclose(conv(x1, torch_adj.t()), out, atol=1e-6) + assert torch.allclose(conv(x1, torch_adj_t), out, atol=1e-6) conv.fuse = True adj = adj.sparse_resize((4, 2)) - torch_adj = adj.to_torch_sparse_coo_tensor() + torch_adj_t = adj.to_torch_sparse_csr_tensor().t() + torch_adj_t = torch_adj_t.to_sparse(layout=torch.sparse_csr) conv = MyConv((8, 16), 32) out1 = conv((x1, x2), edge_index, value) @@ -77,21 +79,21 @@ def test_my_conv(): assert out2.size() == (2, 32) assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1) assert torch.allclose(conv((x1, x2), adj.t()), out1) - assert torch.allclose(conv((x1, x2), torch_adj.t()), out1, atol=1e-6) + assert torch.allclose(conv((x1, x2), torch_adj_t), out1, atol=1e-6) assert torch.allclose(conv((x1, None), adj.t()), out2) - assert torch.allclose(conv((x1, None), torch_adj.t()), out2, atol=1e-6) + assert torch.allclose(conv((x1, None), torch_adj_t), out2, atol=1e-6) conv.fuse = False assert torch.allclose(conv((x1, x2), adj.t()), out1) - assert torch.allclose(conv((x1, x2), torch_adj.t()), out1, atol=1e-6) + assert torch.allclose(conv((x1, x2), torch_adj_t), out1, atol=1e-6) assert torch.allclose(conv((x1, None), adj.t()), out2) - assert torch.allclose(conv((x1, None), torch_adj.t()), out2, atol=1e-6) + assert torch.allclose(conv((x1, None), torch_adj_t), out2, atol=1e-6) conv.fuse = True # Test backward compatibility for `torch.sparse` tensors: conv.fuse = True - torch_adj = torch_adj.requires_grad_() - conv((x1, x2), torch_adj.t()).sum().backward() - assert torch_adj.grad is not None + torch_adj_t = torch_adj_t.requires_grad_() + conv((x1, x2), torch_adj_t).sum().backward() + assert torch_adj_t.grad is not None def test_my_conv_out_of_bounds(): @@ -211,7 +213,7 @@ def test_my_multiple_aggr_conv(multi_aggr_tuple): edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) - torch_adj = adj.to_torch_sparse_coo_tensor() + torch_adj = adj.to_torch_sparse_csr_tensor() conv = MyMultipleAggrConv(aggr_kwargs=aggr_kwargs) out = conv(x, edge_index) @@ -280,7 +282,7 @@ def test_my_edge_conv(): edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) - torch_adj = adj.to_torch_sparse_coo_tensor() + torch_adj = adj.to_torch_sparse_csr_tensor() expected = scatter(x[row] - x[col], col, dim=0, dim_size=4, reduce='sum') @@ -443,7 +445,7 @@ def test_my_default_arg_conv(): edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) - torch_adj = adj.to_torch_sparse_coo_tensor() + torch_adj = adj.to_torch_sparse_csr_tensor() conv = MyDefaultArgConv() assert conv(x, edge_index).view(-1).tolist() == [0, 0, 0, 0] diff --git a/torch_geometric/nn/conv/message_passing.py b/torch_geometric/nn/conv/message_passing.py index dda25a5b0135..54c85ab00ced 100644 --- a/torch_geometric/nn/conv/message_passing.py +++ b/torch_geometric/nn/conv/message_passing.py @@ -42,6 +42,11 @@ FUSE_AGGRS = {'add', 'sum', 'mean', 'min', 'max'} +def ptr2ind(ptr: Tensor) -> Tensor: + ind = torch.arange(ptr.numel() - 1, device=ptr.device) + return ind.repeat_interleave(ptr[1:] - ptr[:-1]) + + class MessagePassing(torch.nn.Module): r"""Base class for creating message passing layers of the form @@ -240,7 +245,21 @@ def __set_size__(self, size: List[Optional[int]], dim: int, src: Tensor): def __lift__(self, src, edge_index, dim): if is_torch_sparse_tensor(edge_index): assert dim == 0 or dim == 1 - index = edge_index._indices()[1 - dim] + if edge_index.layout == torch.sparse_coo: + index = edge_index._indices()[1 - dim] + elif edge_index.layout == torch.sparse_csr: + if dim == 0: + index = edge_index.col_indices() + else: + index = ptr2ind(edge_index.crow_indices()) + elif edge_index.layout == torch.sparse_csc: + if dim == 0: + index = ptr2ind(edge_index.ccol_indices()) + else: + index = edge_index.row_indices() + else: + raise ValueError(f"Unsupported sparse tensor layout " + f"(got '{edge_index.layout}')") return src.index_select(self.node_dim, index) elif isinstance(edge_index, Tensor): diff --git a/torch_geometric/utils/sparse.py b/torch_geometric/utils/sparse.py index 0ffff33fda46..4fbf9e694f02 100644 --- a/torch_geometric/utils/sparse.py +++ b/torch_geometric/utils/sparse.py @@ -65,6 +65,8 @@ def is_torch_sparse_tensor(src: Any) -> bool: return True if src.layout == torch.sparse_csr: return True + if src.layout == torch.sparse_csc: + return True return False diff --git a/torch_geometric/utils/spmm.py b/torch_geometric/utils/spmm.py index 65c6ed33becd..28fc6a1013a5 100644 --- a/torch_geometric/utils/spmm.py +++ b/torch_geometric/utils/spmm.py @@ -55,10 +55,10 @@ def spmm(src: Adj, other: Tensor, reduce: str = "sum") -> Tensor: # This will currently throw on error for CUDA tensors. if torch_geometric.typing.WITH_PT2: if src.layout != torch.sparse_csr: - warnings.warn("Converting sparse tensor to CSR format for more " - "efficient processing. Consider converting your " - "sparse tensor to CSR format beforehand to avoid " - "repeated conversion") + warnings.warn(f"Converting sparse tensor to CSR format for more " + f"efficient processing. Consider converting your " + f"sparse tensor to CSR format beforehand to avoid " + f"repeated conversion (got '{src.layout}')") src = src.to_sparse_csr() if reduce == 'sum': return torch.sparse.mm(src, other)