diff --git a/CHANGELOG.md b/CHANGELOG.md index a0257f809d15..8e352af49689 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877)) - Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873)) - Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815)) -- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857)) +- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882)) - Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847)) - Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850)) - Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838)) diff --git a/test/data/test_data.py b/test/data/test_data.py index 5d14be740627..eac9fc8cea9f 100644 --- a/test/data/test_data.py +++ b/test/data/test_data.py @@ -248,6 +248,8 @@ def my_attr1(self, value): def test_basic_feature_store(): data = Data() x = torch.randn(20, 20) + data.not_a_tensor_attr = 10 # don't include, not a tensor attr + data.bad_attr = torch.randn(10, 20) # don't include, bad cat_dim # Put tensor: assert data.put_tensor(copy.deepcopy(x), attr_name='x', index=None) diff --git a/test/data/test_hetero_data.py b/test/data/test_hetero_data.py index 6b8d88735537..ae4ed8d0776c 100644 --- a/test/data/test_hetero_data.py +++ b/test/data/test_hetero_data.py @@ -428,6 +428,9 @@ def test_basic_feature_store(): assert data.get_tensor_size(group_name='paper', attr_name='x') == (20, 20) # Get tensor attrs: + data['paper'].num_nodes = 20 # don't include, not a tensor attr + data['paper'].bad_attr = torch.randn(10, 20) # don't include, bad cat_dim + tensor_attrs = data.get_all_tensor_attrs() assert len(tensor_attrs) == 1 assert tensor_attrs[0].group_name == 'paper' diff --git a/torch_geometric/data/data.py b/torch_geometric/data/data.py index 0737ee180f71..512ad581f53c 100644 --- a/torch_geometric/data/data.py +++ b/torch_geometric/data/data.py @@ -782,7 +782,10 @@ def _get_tensor_size(self, attr: TensorAttr) -> Tuple: def get_all_tensor_attrs(self) -> List[TensorAttr]: r"""Obtains all feature attributes stored in `Data`.""" - return [TensorAttr(attr_name=name) for name in self._store.keys()] + return [ + TensorAttr(attr_name=name) for name in self._store.keys() + if self._store.is_node_attr(name) + ] def __len__(self) -> int: return BaseData.__len__(self) diff --git a/torch_geometric/data/hetero_data.py b/torch_geometric/data/hetero_data.py index c343ae4330fb..c7a26686ec63 100644 --- a/torch_geometric/data/hetero_data.py +++ b/torch_geometric/data/hetero_data.py @@ -681,7 +681,8 @@ def get_all_tensor_attrs(self) -> List[TensorAttr]: out = [] for group_name, group in self.node_items(): for attr_name in group: - out.append(TensorAttr(group_name, attr_name)) + if group.is_node_attr(attr_name): + out.append(TensorAttr(group_name, attr_name)) return out def __len__(self) -> int: