diff --git a/torch_geometric/datasets/web_qsp_dataset.py b/torch_geometric/datasets/web_qsp_dataset.py index fd3caeb99b70..b6676742c448 100644 --- a/torch_geometric/datasets/web_qsp_dataset.py +++ b/torch_geometric/datasets/web_qsp_dataset.py @@ -155,7 +155,7 @@ class WebQSPDataset(InMemoryDataset): Args: root (str): Root directory where the dataset should be saved. split (str, optional): If :obj:`"train"`, loads the training dataset. - If :obj:`"val"`, loads the validation dataset. + If :obj:`"validation"`, loads the validation dataset. If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) @@ -187,11 +187,11 @@ def __init__( self.force_reload = force_reload super().__init__(root, force_reload=force_reload) - if split not in {'train', 'val', 'test'} and limit < 0: + if split not in set(self.raw_file_names): raise ValueError(f"Invalid 'split' argument (got {split})") self._load_raw_data() - self.load(self.processed_paths[0]) + self.load(self.processed_paths[0] + "" * (self.limit >= 0)) ''' def _check_dependencies(self) -> None: @@ -210,63 +210,42 @@ def _check_dependencies(self) -> None: @property def raw_file_names(self) -> List[str]: - return ["raw_data", "split_idxs"] + return ["train", "validation", "test"] @property def processed_file_names(self) -> List[str]: file_lst = [ "train_data.pt", - "val_data.pt", + "validation_data.pt", "test_data.pt", "pre_filter.pt", "pre_transform.pt", "large_graph_indexer", ] - split_file = file_lst.pop(['train', 'val', 'test'].index(self.split)) + split_file = file_lst.pop(self.raw_file_names.index(self.split)) file_lst.insert(0, split_file) return file_lst - def _save_raw_data(self) -> None: - self.raw_dataset.save_to_disk(self.raw_paths[0]) - torch.save(self.split_idxs, self.raw_paths[1]) + def _save_raw_data(self, dataset) -> None: + for i, split in enumerate(self.raw_file_names): + dataset[split].save_to_disk(self.raw_paths[i]) def _load_raw_data(self) -> None: import datasets if not hasattr(self, "raw_dataset"): - self.raw_dataset = datasets.load_from_disk(self.raw_paths[0]) - if not hasattr(self, "split_idxs"): - self.split_idxs = torch.load(self.raw_paths[1]) + self.raw_dataset = datasets.load_from_disk( + self.raw_paths[self.raw_file_names.index[self.split]]) + + if self.limit >= 0: + self.raw_dataset = self.raw_dataset.select( + range(min(self.limit, len(self.raw_dataset)))) def download(self) -> None: import datasets dataset = datasets.load_dataset("rmanluo/RoG-webqsp") - self.raw_dataset = datasets.concatenate_datasets( - [dataset["train"], dataset["validation"], dataset["test"]]) - self.split_idxs = { - "train": - torch.arange(len(dataset["train"])), - "val": - torch.arange(len(dataset["validation"])) + len(dataset["train"]), - "test": - torch.arange(len(dataset["test"])) + len(dataset["train"]) + - len(dataset["validation"]), - } - - if self.limit >= 0: - self.raw_dataset = self.raw_dataset.select(range(self.limit)) - - # HACK - self.split_idxs = { - "train": - torch.arange(self.limit // 2), - "val": - torch.arange(self.limit // 4) + self.limit // 2, - "test": - torch.arange(self.limit // 4) + self.limit // 2 + - self.limit // 4, - } - self._save_raw_data() + self._save_raw_data(dataset) + self.raw_dataset = dataset[self.split] def _get_trips(self) -> Iterator[TripletLike]: return chain.from_iterable( @@ -340,7 +319,8 @@ def _retrieve_subgraphs(self) -> None: pcst_subgraph["desc"] = desc list_of_graphs.append(pcst_subgraph.to("cpu")) print("Saving subgraphs...") - self.save(list_of_graphs, self.processed_paths[0]) + self.save(list_of_graphs, + self.processed_paths[0] + "" * (self.limit >= 0)) def process(self) -> None: from pandas import DataFrame