From 0c6645ca339a211acf0ae0b78038fc9c2992b483 Mon Sep 17 00:00:00 2001 From: "Eric O. Korman" Date: Wed, 10 Jul 2024 08:13:29 -0500 Subject: [PATCH] fix bug when starting with new db file (#3) --- affine/engine.py | 3 ++- tests/test_engine.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/affine/engine.py b/affine/engine.py index ee486c0..7953dec 100644 --- a/affine/engine.py +++ b/affine/engine.py @@ -1,3 +1,4 @@ +import os import pickle from abc import ABC, abstractmethod from collections import defaultdict @@ -91,7 +92,7 @@ def __init__( ) -> None: # maybe add option to the init for ANN algo self.path = path self.records: dict[str, list[Collection]] = defaultdict(list) - if self.path is not None: + if self.path is not None and os.path.exists(self.path): with open(self.path, "rb") as f: self.records = pickle.load(f) self.collection_id_counter: dict[str, int] = defaultdict( diff --git a/tests/test_engine.py b/tests/test_engine.py index 3d46bf9..20ba6b4 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -81,6 +81,17 @@ def test_local_engine(data: list[Collection]): assert db.insert(Product(name="Banana", price=2.0)) == 2 +def test_local_engine_persistence(data: list[Collection], tmp_path): + db = LocalEngine(tmp_path / "db.affine") + + assert len(db.query(Person.objects())) == 0 + for rec in data: + db.insert(rec) + + db2 = LocalEngine(tmp_path / "db.affine") + assert len(db2.query(Person.objects())) == 2 + + def test_local_engine_save_load(data: list[Collection], tmp_path): db = LocalEngine()