Skip to content

Commit

Permalink
Use raw Tensor data as BLOB in SQLiteDatabase (#8054)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored and JakubPietrakIntel committed Sep 27, 2023
1 parent 79f97f6 commit ba74d6e
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 37 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added a `Database` interface and `SQLiteDatabase`/`RocksDatabase` implementations ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051), [#8052](https://github.com/pyg-team/pytorch_geometric/pull/8052))
- Added a `Database` interface and `SQLiteDatabase`/`RocksDatabase` implementations ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051), [#8052](https://github.com/pyg-team/pytorch_geometric/pull/8052), [#8054](https://github.com/pyg-team/pytorch_geometric/pull/8054))
- Added support for weighted/biased sampling in `NeighborLoader`/`LinkNeighborLoader` ([#8038](https://github.com/pyg-team/pytorch_geometric/pull/8038))
- Added the `MixHopConv` layer and an corresponding example ([#8025](https://github.com/pyg-team/pytorch_geometric/pull/8025))
- Added the option to pass keyword arguments to the underlying normalization layers within `BasicGNN` and `MLP` ([#8024](https://github.com/pyg-team/pytorch_geometric/pull/8024), [#8033](https://github.com/pyg-team/pytorch_geometric/pull/8033))
Expand Down
4 changes: 2 additions & 2 deletions test/data/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ def test_database_syntactic_sugar(tmp_path):
print(f'Initialized RocksDB in {time.perf_counter() - t:.2f} seconds')

def in_memory_get(data):
index = torch.randint(0, args.numel, (128, ))
index = torch.randint(0, args.numel, (args.batch_size, ))
return data[index]

def db_get(db):
index = torch.randint(0, args.numel, (128, ))
index = torch.randint(0, args.numel, (args.batch_size, ))
return db[index]

benchmark(
Expand Down
178 changes: 144 additions & 34 deletions torch_geometric/data/database.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,51 @@
import pickle
import warnings
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional, Union
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from uuid import uuid4

import torch
from torch import Tensor
from tqdm import tqdm

from torch_geometric.utils.mixin import CastMixin


@dataclass
class TensorInfo(CastMixin):
dtype: torch.dtype
size: Tuple[int, ...] = field(default_factory=lambda: (-1, ))


def maybe_cast_to_tensor_info(value: Any) -> Union[Any, TensorInfo]:
if not isinstance(value, dict):
return value
if len(value) < 1 or len(value) > 2:
return value
if len(value) == 1 and 'dtype' not in value:
return value
if len(value) == 2 and 'dtype' not in value and 'size' not in value:
return value
return TensorInfo.cast(value)


Schema = Union[Any, Dict[str, Any], Tuple[Any], List[Any]]


class Database(ABC):
r"""Base class for database."""
r"""Base class for inserting and retrieving data from a database."""
def __init__(self, schema: Schema = object):
schema = maybe_cast_to_tensor_info(schema)
schema = self._to_dict(schema)
schema = {
key: maybe_cast_to_tensor_info(value)
for key, value in schema.items()
}

self.schema: Dict[Union[str, int], Any] = schema

def connect(self):
pass

Expand Down Expand Up @@ -83,18 +120,13 @@ def _multi_get(self, indices: Union[Iterable[int], Tensor]) -> List[Any]:
# Helper functions ########################################################

@staticmethod
def serialize(data: Any) -> bytes:
r"""Serializes :obj:`data` into bytes."""
# Ensure that data is not a view of a larger tensor:
if isinstance(data, Tensor):
data = data.clone()

return pickle.dumps(data)

@staticmethod
def deserialize(data: bytes) -> Any:
r"""Deserializes bytes into the original data."""
return pickle.loads(data)
def _to_dict(value) -> Dict[Union[str, int], Any]:
if isinstance(value, dict):
return value
if isinstance(value, (tuple, list)):
return {i: v for i, v in enumerate(value)}
else:
return {0: value}

def slice_to_range(self, indices: slice) -> range:
start = 0 if indices.start is None else indices.start
Expand Down Expand Up @@ -136,8 +168,10 @@ def __repr__(self) -> str:


class SQLiteDatabase(Database):
def __init__(self, path: str, name: str):
super().__init__()
def __init__(self, path: str, name: str, schema: Schema = object):
super().__init__(schema)

warnings.filterwarnings('ignore', '.*given buffer is not writable.*')

import sqlite3

Expand All @@ -149,9 +183,13 @@ def __init__(self, path: str, name: str):

self.connect()

sql_schema = ',\n'.join([
f' {col_name} {self._to_sql_type(type_info)} NOT NULL' for
col_name, type_info in zip(self._col_names, self.schema.values())
])
query = (f'CREATE TABLE IF NOT EXISTS {self.name} (\n'
f' id INTEGER PRIMARY KEY,\n'
f' data BLOB NOT NULL\n'
f'{sql_schema}\n'
f')')
self.cursor.execute(query)

Expand All @@ -174,8 +212,10 @@ def cursor(self) -> Any:
return self._cursor

def insert(self, index: int, data: Any):
query = f'INSERT INTO {self.name} (id, data) VALUES (?, ?)'
self.cursor.execute(query, (index, self.serialize(data)))
query = (f'INSERT INTO {self.name} '
f'(id, {self._joined_col_names}) '
f'VALUES (?, {self._dummies})')
self.cursor.execute(query, (index, self._serialize(data)))

def _multi_insert(
self,
Expand All @@ -185,15 +225,18 @@ def _multi_insert(
if isinstance(indices, Tensor):
indices = indices.tolist()

data_list = [self.serialize(data) for data in data_list]
data_list = [self._serialize(data) for data in data_list]

query = f'INSERT INTO {self.name} (id, data) VALUES (?, ?)'
query = (f'INSERT INTO {self.name} '
f'(id, {self._joined_col_names}) '
f'VALUES (?, {self._dummies})')
self.cursor.executemany(query, zip(indices, data_list))

def get(self, index: int) -> Any:
query = f'SELECT data FROM {self.name} WHERE id = ?'
query = (f'SELECT {self._joined_col_names} FROM {self.name} '
f'WHERE id = ?')
self.cursor.execute(query, (index, ))
return self.deserialize(self.cursor.fetchone()[0])
return self._deserialize(self.cursor.fetchone())

def multi_get(
self,
Expand Down Expand Up @@ -221,7 +264,7 @@ def multi_get(
query = f'SELECT * FROM {join_table_name}'
self.cursor.execute(query)

query = (f'SELECT {self.name}.data '
query = (f'SELECT {self._joined_col_names} '
f'FROM {self.name} INNER JOIN {join_table_name} '
f'ON {self.name}.id = {join_table_name}.id '
f'ORDER BY {join_table_name}.row_id')
Expand All @@ -240,17 +283,77 @@ def multi_get(
query = f'DROP TABLE {join_table_name}'
self.cursor.execute(query)

return [self.deserialize(data[0]) for data in data_list]
return [self._deserialize(data) for data in data_list]

def __len__(self) -> int:
query = f'SELECT COUNT(*) FROM {self.name}'
self.cursor.execute(query)
return self.cursor.fetchone()[0]

# Helper functions ########################################################

@cached_property
def _col_names(self) -> List[str]:
return [f'COL_{key}' for key in self.schema.keys()]

@cached_property
def _joined_col_names(self) -> str:
return ', '.join(self._col_names)

@cached_property
def _dummies(self) -> str:
return ', '.join(['?'] * len(self.schema.keys()))

def _to_sql_type(self, type_info: Any) -> str:
if type_info == int:
return 'INTEGER'
if type_info == int:
return 'FLOAT'
if type_info == str:
return 'TEXT'
else:
return 'BLOB'

def _serialize(self, row: Any) -> Union[Any, List[Any]]:
out_list: List[Any] = []
for key, col in self._to_dict(row).items():
if isinstance(self.schema[key], TensorInfo):
out = row.numpy().tobytes()
elif isinstance(col, Tensor):
self.schema[key] = TensorInfo(dtype=col.dtype)
out = row.numpy().tobytes()
elif self.schema[key] in {int, float, str}:
out = col
else:
out = pickle.dumps(col)

out_list.append(out)

return out_list if len(out_list) > 1 else out_list[0]

def _deserialize(self, row: Tuple[Any]) -> Any:
out_dict = {}
for i, (key, col_schema) in enumerate(self.schema.items()):
if isinstance(col_schema, TensorInfo):
out_dict[key] = torch.frombuffer(
row[i], dtype=col_schema.dtype).view(*col_schema.size)
elif col_schema in {int, float, str}:
out_dict[key] = row[i]
else:
out_dict[key] = pickle.loads(row[i])

if 0 in self.schema:
if len(self.schema) == 1:
return out_dict[0]
else:
return tuple(out_dict.values())
else:
return out_dict


class RocksDatabase(Database):
def __init__(self, path: str):
super().__init__()
def __init__(self, path: str, schema: Schema = object):
super().__init__(schema)

import rocksdict

Expand Down Expand Up @@ -283,18 +386,25 @@ def to_key(index: int) -> bytes:
return index.to_bytes(8, byteorder='big', signed=True)

def insert(self, index: int, data: Any):
# Ensure that data is not a view of a larger tensor:
if isinstance(data, Tensor):
data = data.clone()

self.db[self.to_key(index)] = self.serialize(data)
self.db[self.to_key(index)] = self._serialize(data)

def get(self, index: int) -> Any:
return self.deserialize(self.db[self.to_key(index)])
return self._deserialize(self.db[self.to_key(index)])

def _multi_get(self, indices: Union[Iterable[int], Tensor]) -> List[Any]:
if isinstance(indices, Tensor):
indices = indices.tolist()
indices = [self.to_key(index) for index in indices]
data_list = self.db[indices]
return [self.deserialize(data) for data in data_list]
return [self._deserialize(data) for data in data_list]

# Helper functions ########################################################

def _serialize(self, row: Any) -> bytes:
# Ensure that data is not a view of a larger tensor:
if isinstance(row, Tensor):
row = row.clone()
return pickle.dumps(row)

def _deserialize(self, row: bytes) -> Any:
return pickle.loads(row)

0 comments on commit ba74d6e

Please sign in to comment.