Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Mar 17, 2023
1 parent d4b1297 commit 6914fb3
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 20 deletions.
32 changes: 17 additions & 15 deletions test/nn/conv/test_message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,28 +47,30 @@ def message_and_aggregate(self, adj_t: SparseTensor,
return spmm(adj_t, x[0], reduce=self.aggr)


def test_my_conv():
def test_my_conv_basic():
x1 = torch.randn(4, 8)
x2 = torch.randn(2, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
value = torch.randn(row.size(0))
adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
torch_adj = adj.to_torch_sparse_coo_tensor()
torch_adj_t = adj.to_torch_sparse_csr_tensor().t()
torch_adj_t = torch_adj_t.to_sparse(layout=torch.sparse_csr)

conv = MyConv(8, 32)
out = conv(x1, edge_index, value)
assert out.size() == (4, 32)
assert torch.allclose(conv(x1, edge_index, value, (4, 4)), out, atol=1e-6)
assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, torch_adj.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, torch_adj_t), out, atol=1e-6)
conv.fuse = False
assert torch.allclose(conv(x1, adj.t()), out)
assert torch.allclose(conv(x1, torch_adj.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, torch_adj_t), out, atol=1e-6)
conv.fuse = True

adj = adj.sparse_resize((4, 2))
torch_adj = adj.to_torch_sparse_coo_tensor()
torch_adj_t = adj.to_torch_sparse_csr_tensor().t()
torch_adj_t = torch_adj_t.to_sparse(layout=torch.sparse_csr)

conv = MyConv((8, 16), 32)
out1 = conv((x1, x2), edge_index, value)
Expand All @@ -77,21 +79,21 @@ def test_my_conv():
assert out2.size() == (2, 32)
assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1)
assert torch.allclose(conv((x1, x2), adj.t()), out1)
assert torch.allclose(conv((x1, x2), torch_adj.t()), out1, atol=1e-6)
assert torch.allclose(conv((x1, x2), torch_adj_t), out1, atol=1e-6)
assert torch.allclose(conv((x1, None), adj.t()), out2)
assert torch.allclose(conv((x1, None), torch_adj.t()), out2, atol=1e-6)
assert torch.allclose(conv((x1, None), torch_adj_t), out2, atol=1e-6)
conv.fuse = False
assert torch.allclose(conv((x1, x2), adj.t()), out1)
assert torch.allclose(conv((x1, x2), torch_adj.t()), out1, atol=1e-6)
assert torch.allclose(conv((x1, x2), torch_adj_t), out1, atol=1e-6)
assert torch.allclose(conv((x1, None), adj.t()), out2)
assert torch.allclose(conv((x1, None), torch_adj.t()), out2, atol=1e-6)
assert torch.allclose(conv((x1, None), torch_adj_t), out2, atol=1e-6)
conv.fuse = True

# Test backward compatibility for `torch.sparse` tensors:
conv.fuse = True
torch_adj = torch_adj.requires_grad_()
conv((x1, x2), torch_adj.t()).sum().backward()
assert torch_adj.grad is not None
torch_adj_t = torch_adj_t.requires_grad_()
conv((x1, x2), torch_adj_t).sum().backward()
assert torch_adj_t.grad is not None


def test_my_conv_out_of_bounds():
Expand Down Expand Up @@ -211,7 +213,7 @@ def test_my_multiple_aggr_conv(multi_aggr_tuple):
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
torch_adj = adj.to_torch_sparse_coo_tensor()
torch_adj = adj.to_torch_sparse_csr_tensor()

conv = MyMultipleAggrConv(aggr_kwargs=aggr_kwargs)
out = conv(x, edge_index)
Expand Down Expand Up @@ -280,7 +282,7 @@ def test_my_edge_conv():
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
torch_adj = adj.to_torch_sparse_coo_tensor()
torch_adj = adj.to_torch_sparse_csr_tensor()

expected = scatter(x[row] - x[col], col, dim=0, dim_size=4, reduce='sum')

Expand Down Expand Up @@ -443,7 +445,7 @@ def test_my_default_arg_conv():
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
torch_adj = adj.to_torch_sparse_coo_tensor()
torch_adj = adj.to_torch_sparse_csr_tensor()

conv = MyDefaultArgConv()
assert conv(x, edge_index).view(-1).tolist() == [0, 0, 0, 0]
Expand Down
21 changes: 20 additions & 1 deletion torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@
FUSE_AGGRS = {'add', 'sum', 'mean', 'min', 'max'}


def ptr2ind(ptr: Tensor) -> Tensor:
ind = torch.arange(ptr.numel() - 1, device=ptr.device)
return ind.repeat_interleave(ptr[1:] - ptr[:-1])


class MessagePassing(torch.nn.Module):
r"""Base class for creating message passing layers of the form
Expand Down Expand Up @@ -240,7 +245,21 @@ def __set_size__(self, size: List[Optional[int]], dim: int, src: Tensor):
def __lift__(self, src, edge_index, dim):
if is_torch_sparse_tensor(edge_index):
assert dim == 0 or dim == 1
index = edge_index._indices()[1 - dim]
if edge_index.layout == torch.sparse_coo:
index = edge_index._indices()[1 - dim]
elif edge_index.layout == torch.sparse_csr:
if dim == 0:
index = edge_index.col_indices()
else:
index = ptr2ind(edge_index.crow_indices())
elif edge_index.layout == torch.sparse_csc:
if dim == 0:
index = ptr2ind(edge_index.ccol_indices())
else:
index = edge_index.row_indices()
else:
raise ValueError(f"Unsupported sparse tensor layout "
f"(got '{edge_index.layout}')")
return src.index_select(self.node_dim, index)

elif isinstance(edge_index, Tensor):
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/utils/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def is_torch_sparse_tensor(src: Any) -> bool:
return True
if src.layout == torch.sparse_csr:
return True
if src.layout == torch.sparse_csc:
return True
return False


Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/utils/spmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def spmm(src: Adj, other: Tensor, reduce: str = "sum") -> Tensor:
# This will currently throw on error for CUDA tensors.
if torch_geometric.typing.WITH_PT2:
if src.layout != torch.sparse_csr:
warnings.warn("Converting sparse tensor to CSR format for more "
"efficient processing. Consider converting your "
"sparse tensor to CSR format beforehand to avoid "
"repeated conversion")
warnings.warn(f"Converting sparse tensor to CSR format for more "
f"efficient processing. Consider converting your "
f"sparse tensor to CSR format beforehand to avoid "
f"repeated conversion (got '{src.layout}')")
src = src.to_sparse_csr()
if reduce == 'sum':
return torch.sparse.mm(src, other)
Expand Down

0 comments on commit 6914fb3

Please sign in to comment.