diff --git a/CHANGELOG.md b/CHANGELOG.md index f22c4893b3ae..156610162688 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added PyTorch 2.4 support ([#9594](https://github.com/pyg-team/pytorch_geometric/pull/9594)) +- Added `transforms.LineDiGraph` for applying line graph transformations on directed graphs ([#9592](https://github.com/pyg-team/pytorch_geometric/pull/9592)) - Added `utils.normalize_edge_index` for symmetric/asymmetric normalization of graph edges ([#9554](https://github.com/pyg-team/pytorch_geometric/pull/9554)) - Added the `RemoveSelfLoops` transformation ([#9562](https://github.com/pyg-team/pytorch_geometric/pull/9562)) - Added ONNX export for `scatter` with min/max reductions ([#9587](https://github.com/pyg-team/pytorch_geometric/pull/9587)) diff --git a/test/transforms/test_line_digraph.py b/test/transforms/test_line_digraph.py new file mode 100644 index 000000000000..4e86d1bdc4f0 --- /dev/null +++ b/test/transforms/test_line_digraph.py @@ -0,0 +1,33 @@ +import torch + +from torch_geometric.data import Data +from torch_geometric.transforms import LineDiGraph + + +def test_line_digraph(): + transform = LineDiGraph() + assert str(transform) == 'LineDiGraph()' + + edge_index = torch.tensor([ + [0, 1, 2, 2, 3], + [1, 2, 0, 3, 0], + ]) + data = Data(edge_index=edge_index, num_nodes=4) + data = transform(data) + assert data.edge_index.tolist() == [[0, 1, 1, 2, 3, 4], [1, 2, 3, 0, 4, 0]] + assert data.num_nodes == data.edge_index.max().item() + 1 + + edge_index = torch.tensor([[0, 0, 0, 1, 2, 2, 3, 3, 4], + [1, 2, 3, 4, 0, 3, 0, 4, 1]]) + edge_attr = torch.ones(edge_index.size(1)) + data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=5) + data = transform(data) + assert data.edge_index.max().item() + 1 == data.x.size(0) + assert data.edge_index.tolist() == [[ + 0, 1, 1, 2, 2, 3, 4, 4, 4, 5, 5, 6, 6, 6, 7, 8 + ], [3, 4, 5, 6, 7, 8, 0, 1, 2, 6, 7, 0, 1, 2, 8, 3]] + assert data.num_nodes <= data.edge_index.max().item() + 1 + + +if __name__ == "__main__": + test_line_digraph() diff --git a/torch_geometric/transforms/__init__.py b/torch_geometric/transforms/__init__.py index fa4098fcfec7..02175a2f0696 100644 --- a/torch_geometric/transforms/__init__.py +++ b/torch_geometric/transforms/__init__.py @@ -28,6 +28,7 @@ from .to_dense import ToDense from .two_hop import TwoHop from .line_graph import LineGraph +from .line_digraph import LineDiGraph from .laplacian_lambda_max import LaplacianLambdaMax from .gdc import GDC from .sign import SIGN diff --git a/torch_geometric/transforms/line_digraph.py b/torch_geometric/transforms/line_digraph.py new file mode 100644 index 000000000000..df94c58508c7 --- /dev/null +++ b/torch_geometric/transforms/line_digraph.py @@ -0,0 +1,42 @@ +import torch + +from torch_geometric.data import Data +from torch_geometric.data.datapipes import functional_transform +from torch_geometric.transforms import BaseTransform +from torch_geometric.utils import coalesce + + +@functional_transform('line_digraph') +class LineDiGraph(BaseTransform): + r"""Converts a graph to its corresponding line-digraph + (functional name: :obj:`line_digraph`). + + .. math:: + L(\mathcal{G}) &= (\mathcal{V}^{\prime}, \mathcal{E}^{\prime}) + + \mathcal{V}^{\prime} &= \mathcal{E} + + \mathcal{E}^{\prime} &= \{ ((u, v), (w, x)) : (u, v) \in \mathcal{E} + \land (w, x) \in \mathcal{E} \land v = w\} + + Line-digraph node indices are equal to indices in the original graph's + coalesced :obj:`edge_index`. + """ + def forward(self, data: Data) -> Data: + assert data.edge_index is not None + assert data.is_directed() + edge_index, edge_attr = data.edge_index, data.edge_attr + E = data.num_edges + + edge_index, edge_attr = coalesce(edge_index, edge_attr, data.num_nodes) + row, col = edge_index + + # Broadcast to create a mask for matching row and col elements + mask = row.unsqueeze(0) == col.unsqueeze(1) # (num_edges, num_edges) + new_edge_index = torch.nonzero(mask).T + + data.edge_index = new_edge_index + data.x = edge_attr + data.num_nodes = E + data.edge_attr = None + return data