Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LineDigraph transformation #9592

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added PyTorch 2.4 support ([#9594](https://github.com/pyg-team/pytorch_geometric/pull/9594))
- Added `transforms.LineDiGraph` for applying line graph transformations on directed graphs ([#9592](https://github.com/pyg-team/pytorch_geometric/pull/9592))
- Added `utils.normalize_edge_index` for symmetric/asymmetric normalization of graph edges ([#9554](https://github.com/pyg-team/pytorch_geometric/pull/9554))
- Added the `RemoveSelfLoops` transformation ([#9562](https://github.com/pyg-team/pytorch_geometric/pull/9562))
- Added ONNX export for `scatter` with min/max reductions ([#9587](https://github.com/pyg-team/pytorch_geometric/pull/9587))
Expand Down
33 changes: 33 additions & 0 deletions test/transforms/test_line_digraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch

from torch_geometric.data import Data
from torch_geometric.transforms import LineDiGraph


def test_line_digraph():
transform = LineDiGraph()
assert str(transform) == 'LineDiGraph()'

edge_index = torch.tensor([
[0, 1, 2, 2, 3],
[1, 2, 0, 3, 0],
])
data = Data(edge_index=edge_index, num_nodes=4)
data = transform(data)
assert data.edge_index.tolist() == [[0, 1, 1, 2, 3, 4], [1, 2, 3, 0, 4, 0]]
assert data.num_nodes == data.edge_index.max().item() + 1

edge_index = torch.tensor([[0, 0, 0, 1, 2, 2, 3, 3, 4],
[1, 2, 3, 4, 0, 3, 0, 4, 1]])
edge_attr = torch.ones(edge_index.size(1))
data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=5)
data = transform(data)
assert data.edge_index.max().item() + 1 == data.x.size(0)
assert data.edge_index.tolist() == [[
0, 1, 1, 2, 2, 3, 4, 4, 4, 5, 5, 6, 6, 6, 7, 8
], [3, 4, 5, 6, 7, 8, 0, 1, 2, 6, 7, 0, 1, 2, 8, 3]]
assert data.num_nodes <= data.edge_index.max().item() + 1


if __name__ == "__main__":
test_line_digraph()
1 change: 1 addition & 0 deletions torch_geometric/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .to_dense import ToDense
from .two_hop import TwoHop
from .line_graph import LineGraph
from .line_digraph import LineDiGraph
from .laplacian_lambda_max import LaplacianLambdaMax
from .gdc import GDC
from .sign import SIGN
Expand Down
42 changes: 42 additions & 0 deletions torch_geometric/transforms/line_digraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import coalesce


@functional_transform('line_digraph')
class LineDiGraph(BaseTransform):
r"""Converts a graph to its corresponding line-digraph
(functional name: :obj:`line_digraph`).

.. math::
L(\mathcal{G}) &= (\mathcal{V}^{\prime}, \mathcal{E}^{\prime})

\mathcal{V}^{\prime} &= \mathcal{E}

\mathcal{E}^{\prime} &= \{ ((u, v), (w, x)) : (u, v) \in \mathcal{E}
\land (w, x) \in \mathcal{E} \land v = w\}

Line-digraph node indices are equal to indices in the original graph's
coalesced :obj:`edge_index`.
"""
def forward(self, data: Data) -> Data:
assert data.edge_index is not None
assert data.is_directed()
edge_index, edge_attr = data.edge_index, data.edge_attr
E = data.num_edges

edge_index, edge_attr = coalesce(edge_index, edge_attr, data.num_nodes)
row, col = edge_index

# Broadcast to create a mask for matching row and col elements
mask = row.unsqueeze(0) == col.unsqueeze(1) # (num_edges, num_edges)
new_edge_index = torch.nonzero(mask).T

data.edge_index = new_edge_index
data.x = edge_attr
data.num_nodes = E
data.edge_attr = None
return data
Loading