Skip to content

Commit

Permalink
Fix: is_sorted argument in NeighborLoader (#4702)
Browse files Browse the repository at this point in the history
* fix: is_sorted

* changelog
  • Loading branch information
rusty1s authored May 24, 2022
1 parent 54103a5 commit f482cb7
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `HeteroData.node_items()` and `HeteroData.edge_items()` functionality ([#4644](https://github.com/pyg-team/pytorch_geometric/pull/4644))
- Added PyTorch Lightning support in GraphGym ([#4531](https://github.com/pyg-team/pytorch_geometric/pull/4531))
- Added support for returning embeddings in `MLP` models ([#4625](https://github.com/pyg-team/pytorch_geometric/pull/4625))
- Added faster initialization of `NeighborLoader` in case edge indices are already sorted (via `is_sorted=True`) ([#4620](https://github.com/pyg-team/pytorch_geometric/pull/4620))
- Added faster initialization of `NeighborLoader` in case edge indices are already sorted (via `is_sorted=True`) ([#4620](https://github.com/pyg-team/pytorch_geometric/pull/4620), [#4702](https://github.com/pyg-team/pytorch_geometric/pull/4702))
- Added `AddPositionalEncoding` transform ([#4521](https://github.com/pyg-team/pytorch_geometric/pull/4521))
- Added `HeteroData.is_undirected()` support ([#4604](https://github.com/pyg-team/pytorch_geometric/pull/4604))
- Added the `Genius` and `Wiki` datasets to `nn.datasets.LINKXDataset` ([#4570](https://github.com/pyg-team/pytorch_geometric/pull/4570), [#4600](https://github.com/pyg-team/pytorch_geometric/pull/4600))
Expand Down
5 changes: 2 additions & 3 deletions torch_geometric/loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,9 @@ def to_csc(
elif hasattr(data, 'edge_index'):
(row, col) = data.edge_index
if not is_sorted:
size = data.size()
perm = (col * size[0]).add_(row).argsort()
perm = (col * data.size(0)).add_(row).argsort()
row = row[perm]
colptr = torch.ops.torch_sparse.ind2ptr(col[perm], size[1])
colptr = torch.ops.torch_sparse.ind2ptr(col[perm], data.size(1))
else:
raise AttributeError("Data object does not contain attributes "
"'adj_t' or 'edge_index'")
Expand Down

0 comments on commit f482cb7

Please sign in to comment.