From ba74d6e2ce785cb59ab1c6ef92f065a4a1fa1ea6 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Mon, 18 Sep 2023 15:58:33 +0200 Subject: [PATCH] Use raw `Tensor` data as `BLOB` in `SQLiteDatabase` (#8054) --- CHANGELOG.md | 2 +- test/data/test_database.py | 4 +- torch_geometric/data/database.py | 178 +++++++++++++++++++++++++------ 3 files changed, 147 insertions(+), 37 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f635619420af..0017c0643ed7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added a `Database` interface and `SQLiteDatabase`/`RocksDatabase` implementations ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051), [#8052](https://github.com/pyg-team/pytorch_geometric/pull/8052)) +- Added a `Database` interface and `SQLiteDatabase`/`RocksDatabase` implementations ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051), [#8052](https://github.com/pyg-team/pytorch_geometric/pull/8052), [#8054](https://github.com/pyg-team/pytorch_geometric/pull/8054)) - Added support for weighted/biased sampling in `NeighborLoader`/`LinkNeighborLoader` ([#8038](https://github.com/pyg-team/pytorch_geometric/pull/8038)) - Added the `MixHopConv` layer and an corresponding example ([#8025](https://github.com/pyg-team/pytorch_geometric/pull/8025)) - Added the option to pass keyword arguments to the underlying normalization layers within `BasicGNN` and `MLP` ([#8024](https://github.com/pyg-team/pytorch_geometric/pull/8024), [#8033](https://github.com/pyg-team/pytorch_geometric/pull/8033)) diff --git a/test/data/test_database.py b/test/data/test_database.py index 0a9a59f18db1..63a603aca0b5 100644 --- a/test/data/test_database.py +++ b/test/data/test_database.py @@ -113,11 +113,11 @@ def test_database_syntactic_sugar(tmp_path): print(f'Initialized RocksDB in {time.perf_counter() - t:.2f} seconds') def in_memory_get(data): - index = torch.randint(0, args.numel, (128, )) + index = torch.randint(0, args.numel, (args.batch_size, )) return data[index] def db_get(db): - index = torch.randint(0, args.numel, (128, )) + index = torch.randint(0, args.numel, (args.batch_size, )) return db[index] benchmark( diff --git a/torch_geometric/data/database.py b/torch_geometric/data/database.py index f8df0cb280c4..992354672e89 100644 --- a/torch_geometric/data/database.py +++ b/torch_geometric/data/database.py @@ -1,14 +1,51 @@ import pickle +import warnings from abc import ABC, abstractmethod -from typing import Any, Iterable, List, Optional, Union +from dataclasses import dataclass, field +from functools import cached_property +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from uuid import uuid4 +import torch from torch import Tensor from tqdm import tqdm +from torch_geometric.utils.mixin import CastMixin + + +@dataclass +class TensorInfo(CastMixin): + dtype: torch.dtype + size: Tuple[int, ...] = field(default_factory=lambda: (-1, )) + + +def maybe_cast_to_tensor_info(value: Any) -> Union[Any, TensorInfo]: + if not isinstance(value, dict): + return value + if len(value) < 1 or len(value) > 2: + return value + if len(value) == 1 and 'dtype' not in value: + return value + if len(value) == 2 and 'dtype' not in value and 'size' not in value: + return value + return TensorInfo.cast(value) + + +Schema = Union[Any, Dict[str, Any], Tuple[Any], List[Any]] + class Database(ABC): - r"""Base class for database.""" + r"""Base class for inserting and retrieving data from a database.""" + def __init__(self, schema: Schema = object): + schema = maybe_cast_to_tensor_info(schema) + schema = self._to_dict(schema) + schema = { + key: maybe_cast_to_tensor_info(value) + for key, value in schema.items() + } + + self.schema: Dict[Union[str, int], Any] = schema + def connect(self): pass @@ -83,18 +120,13 @@ def _multi_get(self, indices: Union[Iterable[int], Tensor]) -> List[Any]: # Helper functions ######################################################## @staticmethod - def serialize(data: Any) -> bytes: - r"""Serializes :obj:`data` into bytes.""" - # Ensure that data is not a view of a larger tensor: - if isinstance(data, Tensor): - data = data.clone() - - return pickle.dumps(data) - - @staticmethod - def deserialize(data: bytes) -> Any: - r"""Deserializes bytes into the original data.""" - return pickle.loads(data) + def _to_dict(value) -> Dict[Union[str, int], Any]: + if isinstance(value, dict): + return value + if isinstance(value, (tuple, list)): + return {i: v for i, v in enumerate(value)} + else: + return {0: value} def slice_to_range(self, indices: slice) -> range: start = 0 if indices.start is None else indices.start @@ -136,8 +168,10 @@ def __repr__(self) -> str: class SQLiteDatabase(Database): - def __init__(self, path: str, name: str): - super().__init__() + def __init__(self, path: str, name: str, schema: Schema = object): + super().__init__(schema) + + warnings.filterwarnings('ignore', '.*given buffer is not writable.*') import sqlite3 @@ -149,9 +183,13 @@ def __init__(self, path: str, name: str): self.connect() + sql_schema = ',\n'.join([ + f' {col_name} {self._to_sql_type(type_info)} NOT NULL' for + col_name, type_info in zip(self._col_names, self.schema.values()) + ]) query = (f'CREATE TABLE IF NOT EXISTS {self.name} (\n' f' id INTEGER PRIMARY KEY,\n' - f' data BLOB NOT NULL\n' + f'{sql_schema}\n' f')') self.cursor.execute(query) @@ -174,8 +212,10 @@ def cursor(self) -> Any: return self._cursor def insert(self, index: int, data: Any): - query = f'INSERT INTO {self.name} (id, data) VALUES (?, ?)' - self.cursor.execute(query, (index, self.serialize(data))) + query = (f'INSERT INTO {self.name} ' + f'(id, {self._joined_col_names}) ' + f'VALUES (?, {self._dummies})') + self.cursor.execute(query, (index, self._serialize(data))) def _multi_insert( self, @@ -185,15 +225,18 @@ def _multi_insert( if isinstance(indices, Tensor): indices = indices.tolist() - data_list = [self.serialize(data) for data in data_list] + data_list = [self._serialize(data) for data in data_list] - query = f'INSERT INTO {self.name} (id, data) VALUES (?, ?)' + query = (f'INSERT INTO {self.name} ' + f'(id, {self._joined_col_names}) ' + f'VALUES (?, {self._dummies})') self.cursor.executemany(query, zip(indices, data_list)) def get(self, index: int) -> Any: - query = f'SELECT data FROM {self.name} WHERE id = ?' + query = (f'SELECT {self._joined_col_names} FROM {self.name} ' + f'WHERE id = ?') self.cursor.execute(query, (index, )) - return self.deserialize(self.cursor.fetchone()[0]) + return self._deserialize(self.cursor.fetchone()) def multi_get( self, @@ -221,7 +264,7 @@ def multi_get( query = f'SELECT * FROM {join_table_name}' self.cursor.execute(query) - query = (f'SELECT {self.name}.data ' + query = (f'SELECT {self._joined_col_names} ' f'FROM {self.name} INNER JOIN {join_table_name} ' f'ON {self.name}.id = {join_table_name}.id ' f'ORDER BY {join_table_name}.row_id') @@ -240,17 +283,77 @@ def multi_get( query = f'DROP TABLE {join_table_name}' self.cursor.execute(query) - return [self.deserialize(data[0]) for data in data_list] + return [self._deserialize(data) for data in data_list] def __len__(self) -> int: query = f'SELECT COUNT(*) FROM {self.name}' self.cursor.execute(query) return self.cursor.fetchone()[0] + # Helper functions ######################################################## + + @cached_property + def _col_names(self) -> List[str]: + return [f'COL_{key}' for key in self.schema.keys()] + + @cached_property + def _joined_col_names(self) -> str: + return ', '.join(self._col_names) + + @cached_property + def _dummies(self) -> str: + return ', '.join(['?'] * len(self.schema.keys())) + + def _to_sql_type(self, type_info: Any) -> str: + if type_info == int: + return 'INTEGER' + if type_info == int: + return 'FLOAT' + if type_info == str: + return 'TEXT' + else: + return 'BLOB' + + def _serialize(self, row: Any) -> Union[Any, List[Any]]: + out_list: List[Any] = [] + for key, col in self._to_dict(row).items(): + if isinstance(self.schema[key], TensorInfo): + out = row.numpy().tobytes() + elif isinstance(col, Tensor): + self.schema[key] = TensorInfo(dtype=col.dtype) + out = row.numpy().tobytes() + elif self.schema[key] in {int, float, str}: + out = col + else: + out = pickle.dumps(col) + + out_list.append(out) + + return out_list if len(out_list) > 1 else out_list[0] + + def _deserialize(self, row: Tuple[Any]) -> Any: + out_dict = {} + for i, (key, col_schema) in enumerate(self.schema.items()): + if isinstance(col_schema, TensorInfo): + out_dict[key] = torch.frombuffer( + row[i], dtype=col_schema.dtype).view(*col_schema.size) + elif col_schema in {int, float, str}: + out_dict[key] = row[i] + else: + out_dict[key] = pickle.loads(row[i]) + + if 0 in self.schema: + if len(self.schema) == 1: + return out_dict[0] + else: + return tuple(out_dict.values()) + else: + return out_dict + class RocksDatabase(Database): - def __init__(self, path: str): - super().__init__() + def __init__(self, path: str, schema: Schema = object): + super().__init__(schema) import rocksdict @@ -283,18 +386,25 @@ def to_key(index: int) -> bytes: return index.to_bytes(8, byteorder='big', signed=True) def insert(self, index: int, data: Any): - # Ensure that data is not a view of a larger tensor: - if isinstance(data, Tensor): - data = data.clone() - - self.db[self.to_key(index)] = self.serialize(data) + self.db[self.to_key(index)] = self._serialize(data) def get(self, index: int) -> Any: - return self.deserialize(self.db[self.to_key(index)]) + return self._deserialize(self.db[self.to_key(index)]) def _multi_get(self, indices: Union[Iterable[int], Tensor]) -> List[Any]: if isinstance(indices, Tensor): indices = indices.tolist() indices = [self.to_key(index) for index in indices] data_list = self.db[indices] - return [self.deserialize(data) for data in data_list] + return [self._deserialize(data) for data in data_list] + + # Helper functions ######################################################## + + def _serialize(self, row: Any) -> bytes: + # Ensure that data is not a view of a larger tensor: + if isinstance(row, Tensor): + row = row.clone() + return pickle.dumps(row) + + def _deserialize(self, row: bytes) -> Any: + return pickle.loads(row)