Skip to content

Commit

Permalink
restore access pattern from master
Browse files Browse the repository at this point in the history
  • Loading branch information
zacool64 committed Oct 25, 2024
1 parent 374ba5c commit a43ce62
Showing 1 changed file with 19 additions and 39 deletions.
58 changes: 19 additions & 39 deletions torch_geometric/datasets/web_qsp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class WebQSPDataset(InMemoryDataset):
Args:
root (str): Root directory where the dataset should be saved.
split (str, optional): If :obj:`"train"`, loads the training dataset.
If :obj:`"val"`, loads the validation dataset.
If :obj:`"validation"`, loads the validation dataset.
If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
Expand Down Expand Up @@ -187,11 +187,11 @@ def __init__(
self.force_reload = force_reload
super().__init__(root, force_reload=force_reload)

if split not in {'train', 'val', 'test'} and limit < 0:
if split not in set(self.raw_file_names):
raise ValueError(f"Invalid 'split' argument (got {split})")

self._load_raw_data()
self.load(self.processed_paths[0])
self.load(self.processed_paths[0] + "" * (self.limit >= 0))

'''
def _check_dependencies(self) -> None:
Expand All @@ -210,63 +210,42 @@ def _check_dependencies(self) -> None:

@property
def raw_file_names(self) -> List[str]:
return ["raw_data", "split_idxs"]
return ["train", "validation", "test"]

@property
def processed_file_names(self) -> List[str]:
file_lst = [
"train_data.pt",
"val_data.pt",
"validation_data.pt",
"test_data.pt",
"pre_filter.pt",
"pre_transform.pt",
"large_graph_indexer",
]
split_file = file_lst.pop(['train', 'val', 'test'].index(self.split))
split_file = file_lst.pop(self.raw_file_names.index(self.split))
file_lst.insert(0, split_file)
return file_lst

def _save_raw_data(self) -> None:
self.raw_dataset.save_to_disk(self.raw_paths[0])
torch.save(self.split_idxs, self.raw_paths[1])
def _save_raw_data(self, dataset) -> None:
for i, split in enumerate(self.raw_file_names):
dataset[split].save_to_disk(self.raw_paths[i])

def _load_raw_data(self) -> None:
import datasets
if not hasattr(self, "raw_dataset"):
self.raw_dataset = datasets.load_from_disk(self.raw_paths[0])
if not hasattr(self, "split_idxs"):
self.split_idxs = torch.load(self.raw_paths[1])
self.raw_dataset = datasets.load_from_disk(
self.raw_paths[self.raw_file_names.index[self.split]])

if self.limit >= 0:
self.raw_dataset = self.raw_dataset.select(
range(min(self.limit, len(self.raw_dataset))))

def download(self) -> None:
import datasets

dataset = datasets.load_dataset("rmanluo/RoG-webqsp")
self.raw_dataset = datasets.concatenate_datasets(
[dataset["train"], dataset["validation"], dataset["test"]])
self.split_idxs = {
"train":
torch.arange(len(dataset["train"])),
"val":
torch.arange(len(dataset["validation"])) + len(dataset["train"]),
"test":
torch.arange(len(dataset["test"])) + len(dataset["train"]) +
len(dataset["validation"]),
}

if self.limit >= 0:
self.raw_dataset = self.raw_dataset.select(range(self.limit))

# HACK
self.split_idxs = {
"train":
torch.arange(self.limit // 2),
"val":
torch.arange(self.limit // 4) + self.limit // 2,
"test":
torch.arange(self.limit // 4) + self.limit // 2 +
self.limit // 4,
}
self._save_raw_data()
self._save_raw_data(dataset)
self.raw_dataset = dataset[self.split]

def _get_trips(self) -> Iterator[TripletLike]:
return chain.from_iterable(
Expand Down Expand Up @@ -340,7 +319,8 @@ def _retrieve_subgraphs(self) -> None:
pcst_subgraph["desc"] = desc
list_of_graphs.append(pcst_subgraph.to("cpu"))
print("Saving subgraphs...")
self.save(list_of_graphs, self.processed_paths[0])
self.save(list_of_graphs,
self.processed_paths[0] + "" * (self.limit >= 0))

def process(self) -> None:
from pandas import DataFrame
Expand Down

0 comments on commit a43ce62

Please sign in to comment.