Skip to content

Commit

Permalink
update save/load
Browse files Browse the repository at this point in the history
  • Loading branch information
ekorman committed Jul 20, 2024
1 parent 2041130 commit 0c9d0f7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 40 deletions.
37 changes: 12 additions & 25 deletions affine/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import pickle
from abc import ABC, abstractmethod
from collections import defaultdict
Expand Down Expand Up @@ -88,31 +87,25 @@ def delete(self, collection: type, id_: int) -> None:


class LocalEngine(Engine):
def __init__(
self, fp: str | Path | BinaryIO = None
) -> None: # maybe add option to the init for ANN algo
self.fp = fp
self._load_records()
# maybe pickle this too?
def __init__(self) -> None: # maybe add option to the init for ANN algo
self.records: dict[str, list[Collection]] = defaultdict(list)
self.build_collection_id_counter()

def build_collection_id_counter(self):
# maybe pickle this too on save?
self.collection_id_counter: dict[str, int] = defaultdict(int)
for k, recs in self.records.items():
if len(recs) > 0:
self.collection_id_counter[k] = max([r.id for r in recs])

def _load_records(self):
def load(self, fp: str | Path | BinaryIO) -> None:
self.records: dict[str, list[Collection]] = defaultdict(list)
if self.fp is None:
return

if isinstance(self.fp, (Path, str)):
if os.path.exists(self.fp):
with open(self.fp, "rb") as f:
self.records = pickle.load(f)
if isinstance(fp, (str, Path)):
with open(fp, "rb") as f:
self.records = pickle.load(f)
else:
self.fp.seek(0)
b = self.fp.read()
if len(b) > 0:
self.records = pickle.loads(b)
self.records = pickle.load(fp)
self.build_collection_id_counter()

def save(self, fp: str | Path | BinaryIO = None) -> None:
fp = fp or self.fp
Expand All @@ -130,23 +123,17 @@ def query(self, filter_set: FilterSet = None) -> list[Collection]:

return apply_filters_to_records(filter_set.filters, records)

def _maybe_save(self):
if self.fp is not None:
self.save()

def insert(self, record: Collection) -> int:
record.id = self.collection_id_counter[record.__class__.__name__] + 1
self.records[record.__class__.__name__].append(record)
self.collection_id_counter[record.__class__.__name__] = record.id
self._maybe_save()

return record.id

def delete(self, collection: type, id_: int) -> None:
for r in self.records[collection.__name__]:
if r.id == id_:
self.records[collection.__name__].remove(r)
self._maybe_save()
return
raise ValueError(
f"Record with id {id_} not found in collection {collection.__name__}"
Expand Down
26 changes: 11 additions & 15 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,26 +88,18 @@ def test_local_engine(data: list[Collection]):
assert q10[0].name == "Banana"


def test_local_engine_persistence(data: list[Collection], tmp_path):
db = LocalEngine(tmp_path / "db.affine")

assert len(db.query(Person.objects())) == 0
for rec in data:
db.insert(rec)

db2 = LocalEngine(tmp_path / "db.affine")
assert len(db2.query(Person.objects())) == 2


def test_local_engine_save_load(data: list[Collection], tmp_path):
db = LocalEngine()

for rec in data:
db.insert(rec)

db.save(tmp_path / "db.affine")
path = tmp_path / "db.affine"

db.save(path)

db2 = LocalEngine(tmp_path / "db.affine")
db2 = LocalEngine()
db2.load(path)

q1 = db2.query(Person.objects())
assert len(q1) == 2
Expand All @@ -124,11 +116,15 @@ def test_local_engine_save_load(data: list[Collection], tmp_path):
def test_save_load_from_buffer(data: list[Collection]):
f = io.BytesIO()

db = LocalEngine(f)
db = LocalEngine()

for rec in data:
db.insert(rec)

db2 = LocalEngine(f)
db.save(f)
f.seek(0)

db2 = LocalEngine()
db2.load(f)
assert len(db2.query(Person.objects())) == 2
assert len(db2.query(Product.objects())) == 1

0 comments on commit 0c9d0f7

Please sign in to comment.