Skip to content

Commit

Permalink
Add object deletion in ORM
Browse files Browse the repository at this point in the history
  • Loading branch information
piercefreeman committed Apr 18, 2023
1 parent 8f173ee commit eb3b753
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 25 deletions.
28 changes: 22 additions & 6 deletions vectordb_orm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,17 @@ def insert(self, milvus_client: Milvus) -> int:
self.id = mutation_result.primary_keys[0]
return self.id

def delete(self, milvus_client: Milvus) -> None:
if not self.id:
raise ValueError("Cannot delete object that hasn't been inserted into the database")

identifier_key = self._get_primary()
# Milvus only supports deleting entities with the `in` conditional; equality doesn't work
delete_expression = f"{identifier_key} in [{self.id}]"
milvus_client.delete(collection_name=self.collection_name(), expr=delete_expression)

self.id = None

def _dict_representation(self):
type_converters = {
np.ndarray: DataType.FLOAT_VECTOR,
Expand Down Expand Up @@ -148,15 +159,20 @@ def from_dict(cls, data: dict):
setattr(obj, attribute_name, value)
return obj

@classmethod
def _get_primary(cls):
"""
If the class has a primary key, return it, otherwise return None
"""
for attribute_name in cls.__annotations__.keys():
if isinstance(cls._type_configuration.get(attribute_name), PrimaryKeyField):
return attribute_name
return None

@classmethod
def _assert_has_primary(cls):
"""
Ensure we have a primary key, this is the only field that's fully required
"""
if not any(
[
isinstance(cls._type_configuration.get(attribute_name), PrimaryKeyField)
for attribute_name in cls.__annotations__.keys()
]
):
if cls._get_primary() is None:
raise ValueError(f"Class {cls.__name__} does not have a primary key, specify `PrimaryKeyField` on the class definition.")
22 changes: 22 additions & 0 deletions vectordb_orm/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest
from pymilvus import Milvus, connections
from vectordb_orm import MilvusSession
from vectordb_orm.tests.models import MyObject

@pytest.fixture()
def milvus_client():
return Milvus()

@pytest.fixture()
def session(milvus_client):
session = MilvusSession(milvus_client)
connections.connect("default", host="localhost", port="19530")
return session

@pytest.fixture()
def collection(session: MilvusSession, milvus_client: Milvus):
# Wipe the collection
milvus_client.drop_collection(MyObject.collection_name())

# Create a new default one
return MyObject._create_collection(milvus_client)
49 changes: 49 additions & 0 deletions vectordb_orm/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from pymilvus import Milvus
from vectordb_orm import MilvusSession
from vectordb_orm.tests.models import MyObject
import numpy as np
from time import sleep

def test_create_object(collection):
my_object = MyObject(text='example', embedding=np.array([1.0] * 128))
assert my_object.text == 'example'
assert np.array_equal(my_object.embedding, np.array([1.0] * 128))
assert my_object.id is None


def test_insert_object(collection, milvus_client: Milvus, session: MilvusSession):
my_object = MyObject(text='example', embedding=np.array([1.0] * 128))
my_object.insert(milvus_client)
assert my_object.id is not None

collection.flush()
collection.load()

# Retrieve the object and ensure the values are equivalent
results = session.query(MyObject).filter(MyObject.id == my_object.id).all()
assert len(results) == 1

result : MyObject = results[0].result
assert result.text == my_object.text


def test_delete_object(collection, milvus_client: Milvus, session: MilvusSession):
my_object = MyObject(text='example', embedding=np.array([1.0] * 128))
my_object.insert(milvus_client)

collection.flush()
collection.load()

results = session.query(MyObject).filter(MyObject.text == "example").all()
assert len(results) == 1

my_object.delete(milvus_client)

# Allow enough time to become consistent
collection.flush()
collection.load()
sleep(1)

results = session.query(MyObject).filter(MyObject.text == "example").all()
assert len(results) == 0
22 changes: 3 additions & 19 deletions vectordb_orm/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,10 @@
from vectordb_orm.tests.models import MyObject
import numpy as np

milvus_client = Milvus()
session = MilvusSession(milvus_client)
connections.connect("default", host="localhost", port="19530")

@pytest.fixture()
def drop_collection():
# Wipe the collection
milvus_client.drop_collection(MyObject.collection_name())


def test_query(drop_collection):
def test_query(collection, milvus_client: Milvus, session: MilvusSession):
"""
General test of querying and query chaining
"""
collection = MyObject._create_collection(milvus_client)

# Create some MyObject instances
obj1 = MyObject(text="foo", embedding=np.array([1.0] * 128))
obj2 = MyObject(text="bar", embedding=np.array([4.0] * 128))
Expand All @@ -44,13 +32,11 @@ def test_query(drop_collection):
assert results[0].result.id == obj3.id


def test_query_default_ignores_embeddings(drop_collection):
def test_query_default_ignores_embeddings(collection, milvus_client: Milvus, session: MilvusSession):
"""
Ensure that querying on the class by default ignores embeddings that are included
within the type definition.
"""
collection = MyObject._create_collection(milvus_client)

obj1 = MyObject(text="foo", embedding=np.array([1.0] * 128))
obj1.insert(milvus_client)

Expand All @@ -65,12 +51,10 @@ def test_query_default_ignores_embeddings(drop_collection):
assert result.embedding is None


def test_query_with_fields(drop_collection):
def test_query_with_fields(collection, milvus_client: Milvus, session: MilvusSession):
"""
Test querying with specific fields
"""
collection = MyObject._create_collection(milvus_client)

obj1 = MyObject(text="foo", embedding=np.array([1.0] * 128))
obj1.insert(milvus_client)

Expand Down

0 comments on commit eb3b753

Please sign in to comment.