From bb410c8cc63804329104b91ea509ba8b0cdb16e0 Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Tue, 28 Jun 2022 22:08:52 +0000 Subject: [PATCH 1/3] init --- test/data/test_data.py | 2 ++ test/data/test_hetero_data.py | 3 +++ torch_geometric/data/data.py | 5 ++++- torch_geometric/data/hetero_data.py | 5 ++++- 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/test/data/test_data.py b/test/data/test_data.py index 5d14be740627..eac9fc8cea9f 100644 --- a/test/data/test_data.py +++ b/test/data/test_data.py @@ -248,6 +248,8 @@ def my_attr1(self, value): def test_basic_feature_store(): data = Data() x = torch.randn(20, 20) + data.not_a_tensor_attr = 10 # don't include, not a tensor attr + data.bad_attr = torch.randn(10, 20) # don't include, bad cat_dim # Put tensor: assert data.put_tensor(copy.deepcopy(x), attr_name='x', index=None) diff --git a/test/data/test_hetero_data.py b/test/data/test_hetero_data.py index 6b8d88735537..ae4ed8d0776c 100644 --- a/test/data/test_hetero_data.py +++ b/test/data/test_hetero_data.py @@ -428,6 +428,9 @@ def test_basic_feature_store(): assert data.get_tensor_size(group_name='paper', attr_name='x') == (20, 20) # Get tensor attrs: + data['paper'].num_nodes = 20 # don't include, not a tensor attr + data['paper'].bad_attr = torch.randn(10, 20) # don't include, bad cat_dim + tensor_attrs = data.get_all_tensor_attrs() assert len(tensor_attrs) == 1 assert tensor_attrs[0].group_name == 'paper' diff --git a/torch_geometric/data/data.py b/torch_geometric/data/data.py index 0737ee180f71..512ad581f53c 100644 --- a/torch_geometric/data/data.py +++ b/torch_geometric/data/data.py @@ -782,7 +782,10 @@ def _get_tensor_size(self, attr: TensorAttr) -> Tuple: def get_all_tensor_attrs(self) -> List[TensorAttr]: r"""Obtains all feature attributes stored in `Data`.""" - return [TensorAttr(attr_name=name) for name in self._store.keys()] + return [ + TensorAttr(attr_name=name) for name in self._store.keys() + if self._store.is_node_attr(name) + ] def __len__(self) -> int: return BaseData.__len__(self) diff --git a/torch_geometric/data/hetero_data.py b/torch_geometric/data/hetero_data.py index c343ae4330fb..242849c570a3 100644 --- a/torch_geometric/data/hetero_data.py +++ b/torch_geometric/data/hetero_data.py @@ -680,8 +680,11 @@ def _get_tensor_size(self, attr: TensorAttr) -> Tuple: def get_all_tensor_attrs(self) -> List[TensorAttr]: out = [] for group_name, group in self.node_items(): + print(group_name, group) for attr_name in group: - out.append(TensorAttr(group_name, attr_name)) + print(attr_name) + if group.is_node_attr(attr_name): + out.append(TensorAttr(group_name, attr_name)) return out def __len__(self) -> int: From 1bf110b8c1715bae1283fc4651efe1a9f934afb1 Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Tue, 28 Jun 2022 22:11:15 +0000 Subject: [PATCH 2/3] fix --- torch_geometric/data/hetero_data.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch_geometric/data/hetero_data.py b/torch_geometric/data/hetero_data.py index 242849c570a3..c7a26686ec63 100644 --- a/torch_geometric/data/hetero_data.py +++ b/torch_geometric/data/hetero_data.py @@ -680,9 +680,7 @@ def _get_tensor_size(self, attr: TensorAttr) -> Tuple: def get_all_tensor_attrs(self) -> List[TensorAttr]: out = [] for group_name, group in self.node_items(): - print(group_name, group) for attr_name in group: - print(attr_name) if group.is_node_attr(attr_name): out.append(TensorAttr(group_name, attr_name)) return out From 0b4894ff252f8b8f01a3f0dbb7998c30e41d8a4c Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Tue, 28 Jun 2022 22:13:45 +0000 Subject: [PATCH 3/3] CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0257f809d15..8e352af49689 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877)) - Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873)) - Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815)) -- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857)) +- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882)) - Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847)) - Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850)) - Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))