Skip to content

Commit

Permalink
Fix test inconsistency due to weaker consistency
Browse files Browse the repository at this point in the history
Milvus supports different levels of consistency depending on the demands
of the business case. The default setting is Bounded staleness, which
provides relatively good syncronization between replicas and much faster
inference times. This is undesirable for unit tests, however, because it
can result in stochastic behavior for some tests versus others.

Add a new customization parameter to the model definitions that allow
for specifying a different type of consistency for the schema.
  • Loading branch information
piercefreeman committed Apr 20, 2023
1 parent 0c1d4d6 commit bd6c18e
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 5 deletions.
1 change: 1 addition & 0 deletions vectordb_orm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from vectordb_orm.fields import EmbeddingField, VarCharField, PrimaryKeyField
from vectordb_orm.session import MilvusSession
from vectordb_orm.indexes import FLAT, IVF_FLAT, IVF_SQ8, IVF_PQ, HNSW, BIN_FLAT, BIN_IVF_FLAT
from vectordb_orm.similarity import ConsistencyType
from pymilvus import Milvus
7 changes: 7 additions & 0 deletions vectordb_orm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from vectordb_orm.attributes import AttributeCompare
from vectordb_orm.fields import EmbeddingField, VarCharField, BaseField, PrimaryKeyField
from vectordb_orm.indexes import FLOATING_INDEXES, BINARY_INDEXES
from vectordb_orm.similarity import ConsistencyType
from typing import Any
import numpy as np
from typing import get_args, get_origin
Expand Down Expand Up @@ -66,6 +67,12 @@ def collection_name(self) -> str:
raise ValueError(f"Class {self.__name__} does not have a collection name, specify `__collection_name__` on the class definition.")
return self.__collection_name__

@classmethod
def consistency_type(self) -> ConsistencyType | None:
if not hasattr(self, '__consistency_type__'):
return None
return self.__consistency_type__

@classmethod
def _create_collection(cls, milvus_client: Milvus):
"""
Expand Down
8 changes: 6 additions & 2 deletions vectordb_orm/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def all(self):
# Sum of limit and offset should be less than MAX_MILVUS_INT
limit = self._limit if self._limit is not None else (MAX_MILVUS_INT - offset)

optional_args = dict()
if self.cls.consistency_type() is not None:
optional_args["consistency_level"] = self.cls.consistency_type().value

if self._similarity_attribute is not None:
embedding_field_name = self._similarity_attribute.attr
embedding_configuration : EmbeddingField = self.cls._type_configuration.get(self._similarity_attribute.attr)
Expand All @@ -93,6 +97,7 @@ def all(self):
collection_name=self.cls.collection_name(),
expression=filters,
output_fields=output_fields,
**optional_args,
)
else:
search_result = self.milvus_client.query(
Expand All @@ -101,6 +106,7 @@ def all(self):
limit=limit,
output_fields=output_fields,
collection_name=self.cls.collection_name(),
**optional_args,
)
return self._result_to_objects(search_result)

Expand Down Expand Up @@ -136,8 +142,6 @@ def _result_to_objects(self, search_result: ChunkedQueryResult | list[dict[str,
key: result.entity.get(key)
for key in result.entity.fields
}
print(entity)
print(dir(result.entity))
obj = self.cls.from_dict(entity)
query_results.append(QueryResult(obj, score=result.score, distance=result.distance))
else:
Expand Down
12 changes: 12 additions & 0 deletions vectordb_orm/similarity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import Enum
from pymilvus.client.types import MetricType
from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_EVENTUALLY, CONSISTENCY_SESSION

class FloatSimilarityMetric(Enum):
"""
Expand All @@ -20,3 +21,14 @@ class BinarySimilarityMetric(Enum):
JACCARD = MetricType.JACCARD.name
TANIMOTO = MetricType.TANIMOTO.name
HAMMING = MetricType.HAMMING.name

class ConsistencyType(Enum):
"""
Define the strength of the consistency within the distributed DB:
https://milvus.io/docs/consistency.md
"""
STRONG = CONSISTENCY_STRONG
BOUNDED = CONSISTENCY_BOUNDED
SESSION = CONSISTENCY_SESSION
EVENTUALLY = CONSISTENCY_EVENTUALLY
4 changes: 3 additions & 1 deletion vectordb_orm/tests/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from vectordb_orm import MilvusBase, EmbeddingField, VarCharField, PrimaryKeyField
from vectordb_orm import MilvusBase, EmbeddingField, VarCharField, PrimaryKeyField, ConsistencyType
from vectordb_orm.indexes import IVF_FLAT, BIN_FLAT
import numpy as np

class MyObject(MilvusBase):
__collection_name__ = 'my_collection'
__consistency_type__ = ConsistencyType.STRONG

id: int = PrimaryKeyField()
text: str = VarCharField(max_length=128)
Expand All @@ -12,6 +13,7 @@ class MyObject(MilvusBase):

class BinaryEmbeddingObject(MilvusBase):
__collection_name__ = 'binary_collection'
__consistency_type__ = ConsistencyType.STRONG

id: int = PrimaryKeyField()
embedding: np.ndarray[np.bool_] = EmbeddingField(dim=128, index=BIN_FLAT())
3 changes: 1 addition & 2 deletions vectordb_orm/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from vectordb_orm import MilvusSession
from vectordb_orm.tests.models import MyObject, BinaryEmbeddingObject
import numpy as np
from time import sleep

def test_query(collection, milvus_client: Milvus, session: MilvusSession):
"""
Expand Down Expand Up @@ -46,12 +47,10 @@ def test_binary_collection_query(binary_collection, milvus_client: Milvus, sessi
# Test our ability to recall 1:1 the input content
results = session.query(BinaryEmbeddingObject).order_by_similarity(BinaryEmbeddingObject.embedding, np.array([True]*128)).limit(2).all()
assert len(results) == 2
print(results[0])
assert results[0].result.id == obj1.id

results = session.query(BinaryEmbeddingObject).order_by_similarity(BinaryEmbeddingObject.embedding, np.array([False]*128)).limit(2).all()
assert len(results) == 2
print(results[0])
assert results[0].result.id == obj2.id

def test_query_default_ignores_embeddings(collection, milvus_client: Milvus, session: MilvusSession):
Expand Down

0 comments on commit bd6c18e

Please sign in to comment.