Skip to content

Commit

Permalink
Fix dimension in edge filter selection (#4629)
Browse files Browse the repository at this point in the history
* fix dimension in edge filter

* update changelog

* Update CHANGELOG.md

* update

Co-authored-by: Matthias Fey <[email protected]>
  • Loading branch information
Padarn and rusty1s authored May 12, 2022
1 parent bbff5b7 commit 6fd6f5b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed unnecessary colons and fixed typos in the documentation ([#4616](https://github.com/pyg-team/pytorch_geometric/pull/4616))
- The `bias` argument in `TAGConv` is now actually applied ([#4597](https://github.com/pyg-team/pytorch_geometric/pull/4597))
- Fixed subclass behaviour of `process` and `download` in `Datsaet` ([#4586](https://github.com/pyg-team/pytorch_geometric/pull/4586))
- Fixed filtering of attributes for loaders in case `__cat_dim__ != 0` ([#4629](https://github.com/pyg-team/pytorch_geometric/pull/4629))
### Removed
10 changes: 6 additions & 4 deletions torch_geometric/loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def index_select(value: Tensor, index: Tensor, dim: int = 0) -> Tensor:
numel = math.prod(size)
storage = value.storage()._new_shared(numel)
out = value.new(storage).view(size)
return torch.index_select(value, 0, index, out=out)
return torch.index_select(value, dim, index, out=out)


def edge_type_to_str(edge_type: Union[EdgeType, str]) -> str:
Expand Down Expand Up @@ -101,7 +101,8 @@ def filter_node_store_(store: NodeStorage, out_store: NodeStorage,

elif store.is_node_attr(key):
index = index.to(value.device)
out_store[key] = index_select(value, index, dim=0)
dim = store._parent().__cat_dim__(key, value, store)
out_store[key] = index_select(value, index, dim=dim)

return store

Expand Down Expand Up @@ -132,13 +133,14 @@ def filter_edge_store_(store: EdgeStorage, out_store: EdgeStorage, row: Tensor,
is_sorted=False, trust_data=True)

elif store.is_edge_attr(key):
dim = store._parent().__cat_dim__(key, value, store)
if perm is None:
index = index.to(value.device)
out_store[key] = index_select(value, index, dim=0)
out_store[key] = index_select(value, index, dim=dim)
else:
perm = perm.to(value.device)
index = index.to(value.device)
out_store[key] = index_select(value, perm[index], dim=0)
out_store[key] = index_select(value, perm[index], dim=dim)

return store

Expand Down

0 comments on commit 6fd6f5b

Please sign in to comment.