Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Mar 28, 2023
1 parent 89b321b commit 94f1565
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/source/api/python/dgl.data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Datasets for node classification/regression tasks
PATTERNDataset
CLUSTERDataset
ChameleonDataset
SquirrelDataset

Edge Prediction Datasets
---------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from .utils import *
from .cluster import CLUSTERDataset
from .pattern import PATTERNDataset
from .wiki_network import ChameleonDataset
from .wiki_network import ChameleonDataset, SquirrelDataset
from .wikics import WikiCSDataset
from .yelp import YelpDataset
from .zinc import ZINCDataset
Expand Down
86 changes: 80 additions & 6 deletions python/dgl/data/wiki_network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Wikipedia page-page networks on the chameleon topic.
Wikipedia page-page networks on two topics: chameleons and squirrels.
"""
import os

Expand All @@ -23,8 +23,7 @@ class WikiNetworkDataset(DGLBuiltinDataset):
raw_dir : str
Raw file directory to store the processed data.
force_reload : bool
Whether to always generate the data from scratch rather than load a
cached version.
Whether to re-download the data source.
verbose : bool
Whether to print progress information.
transform : callable
Expand Down Expand Up @@ -123,7 +122,7 @@ class ChameleonDataset(WikiNetworkDataset):
- Nodes: 2277
- Edges: 36101
- Number of Classes: 5
- 10 splits with 60/20/20 train/val/test ratio
- 10 train/val/test splits
- Train: 1092
- Val: 729
Expand All @@ -134,8 +133,7 @@ class ChameleonDataset(WikiNetworkDataset):
raw_dir : str, optional
Raw file directory to store the processed data. Default: ~/.dgl/
force_reload : bool, optional
Whether to always generate the data from scratch rather than load a
cached version. Default: False
Whether to re-download the data source. Default: False
verbose : bool, optional
Whether to print progress information. Default: True
transform : callable, optional
Expand Down Expand Up @@ -182,3 +180,79 @@ def __init__(
verbose=verbose,
transform=transform,
)


class SquirrelDataset(WikiNetworkDataset):
r"""Wikipedia page-page network on squirrels from `Multi-scale Attributed
Node Embedding <https://arxiv.org/abs/1909.13021>`__ and later modified by
`Geom-GCN: Geometric Graph Convolutional Networks
<https://arxiv.org/abs/2002.05287>`
Nodes represent articles from the English Wikipedia, edges reflect mutual
links between them. Node features indicate the presence of particular nouns
in the articles. The nodes were classified into 5 classes in terms of their
average monthly traffic.
Statistics:
- Nodes: 5201
- Edges: 217073
- Number of Classes: 5
- 10 train/val/test splits
- Train: 2496
- Val: 1664
- Test: 1041
Parameters
----------
raw_dir : str, optional
Raw file directory to store the processed data. Default: ~/.dgl/
force_reload : bool, optional
Whether to re-download the data source. Default: False
verbose : bool, optional
Whether to print progress information. Default: True
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access. Default: None
Attributes
----------
num_classes : int
Number of node classes
Notes
-----
The graph does not come with edges for both directions.
Examples
--------
>>> from dgl.data import SquirrelDataset
>>> dataset = SquirrelDataset()
>>> g = dataset[0]
>>> num_classes = dataset.num_classes
>>> # get node features
>>> feat = g.ndata["feat"]
>>> # get data split
>>> train_mask = g.ndata["train_mask"]
>>> val_mask = g.ndata["val_mask"]
>>> test_mask = g.ndata["test_mask"]
>>> # get labels
>>> label = g.ndata['label']
"""

def __init__(
self, raw_dir=None, force_reload=False, verbose=True, transform=None
):
super(SquirrelDataset, self).__init__(
name="squirrel",
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
18 changes: 17 additions & 1 deletion tests/python/common/data/test_wiki_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,25 @@
def test_chameleon():
transform = dgl.AddSelfLoop(allow_duplicate=True)

# chameleon
g = dgl.data.ChameleonDataset(force_reload=True)[0]
assert g.num_nodes() == 2277
assert g.num_edges() == 36101
g2 = dgl.data.ChameleonDataset(force_reload=True, transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()


@unittest.skipIf(
F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="only supports pytorch"
)
def test_squirrel():
transform = dgl.AddSelfLoop(allow_duplicate=True)

g = dgl.data.SquirrelDataset(force_reload=True)[0]
assert g.num_nodes() == 5201
assert g.num_edges() == 217073
g2 = dgl.data.SquirrelDataset(force_reload=True, transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()

0 comments on commit 94f1565

Please sign in to comment.