diff --git a/poetry.lock b/poetry.lock index 3c92a33..08d8938 100644 --- a/poetry.lock +++ b/poetry.lock @@ -77,6 +77,20 @@ category = "dev" optional = false python-versions = ">=3.7" +[[package]] +name = "isort" +version = "5.12.0" +description = "A Python utility / library to sort Python imports." +category = "dev" +optional = false +python-versions = ">=3.8.0" + +[package.extras] +colors = ["colorama (>=0.4.3)"] +pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib"] +plugins = ["setuptools"] +requirements-deprecated-finder = ["pip-api", "pipreqs"] + [[package]] name = "loguru" version = "0.7.0" @@ -371,7 +385,7 @@ dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] [metadata] lock-version = "1.1" python-versions = "^3.10" -content-hash = "5f1918e2ecdaf84015e723b7914acc95b884c2623c72c3b58ba5c375f9f5ca03" +content-hash = "fb360da7cdf923249ea6c8602f9c2e8268a2362d2a97aa8964985d167a16aa2b" [metadata.files] certifi = [ @@ -522,6 +536,10 @@ iniconfig = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +isort = [ + {file = "isort-5.12.0-py3-none-any.whl", hash = "sha256:f84c2818376e66cf843d497486ea8fed8700b340f308f076c6fb1229dff318b6"}, + {file = "isort-5.12.0.tar.gz", hash = "sha256:8bef7dde241278824a6d83f44a544709b065191b95b6e50894bdc722fcba0504"}, +] loguru = [ {file = "loguru-0.7.0-py3-none-any.whl", hash = "sha256:b93aa30099fa6860d4727f1b81f8718e965bb96253fa190fab2077aaad6d15d3"}, {file = "loguru-0.7.0.tar.gz", hash = "sha256:1612053ced6ae84d7959dd7d5e431a0532642237ec21f7fd83ac73fe539e03e1"}, diff --git a/pyproject.toml b/pyproject.toml index 898d43e..eecb386 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ pinecone-client = "^2.2.1" [tool.poetry.group.dev.dependencies] pytest = "^7.3.1" python-dotenv = "^1.0.0" +isort = "^5.12.0" [build-system] requires = ["poetry-core"] diff --git a/vectordb_orm/__init__.py b/vectordb_orm/__init__.py index 782b32a..d340600 100644 --- a/vectordb_orm/__init__.py +++ b/vectordb_orm/__init__.py @@ -1,7 +1,14 @@ +from vectordb_orm.backends.milvus.indexes import (Milvus_BIN_FLAT, + Milvus_BIN_IVF_FLAT, + Milvus_FLAT, Milvus_HNSW, + Milvus_IVF_FLAT, + Milvus_IVF_PQ, + Milvus_IVF_SQ8) +from vectordb_orm.backends.milvus.milvus import MilvusBackend +from vectordb_orm.backends.pinecone.indexes import (PineconeIndex, + PineconeSimilarityMetric) +from vectordb_orm.backends.pinecone.pinecone import PineconeBackend from vectordb_orm.base import VectorSchemaBase -from vectordb_orm.fields import EmbeddingField, VarCharField, PrimaryKeyField +from vectordb_orm.enums import ConsistencyType +from vectordb_orm.fields import EmbeddingField, PrimaryKeyField, VarCharField from vectordb_orm.session import VectorSession -from vectordb_orm.indexes import FLAT, IVF_FLAT, IVF_SQ8, IVF_PQ, HNSW, BIN_FLAT, BIN_IVF_FLAT -from vectordb_orm.similarity import ConsistencyType -from vectordb_orm.backends.milvus import MilvusBackend -from vectordb_orm.backends.pinecone import PineconeBackend diff --git a/vectordb_orm/attributes.py b/vectordb_orm/attributes.py index 99880e6..2d4c3e3 100644 --- a/vectordb_orm/attributes.py +++ b/vectordb_orm/attributes.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from vectordb_orm.base import MilvusBase diff --git a/vectordb_orm/backends/base.py b/vectordb_orm/backends/base.py index f323a2a..757ffb8 100644 --- a/vectordb_orm/backends/base.py +++ b/vectordb_orm/backends/base.py @@ -1,11 +1,14 @@ from abc import ABC, abstractmethod from typing import Type -from vectordb_orm.base import VectorSchemaBase -from vectordb_orm.fields import PrimaryKeyField + import numpy as np + from vectordb_orm.attributes import AttributeCompare +from vectordb_orm.base import VectorSchemaBase +from vectordb_orm.fields import PrimaryKeyField from vectordb_orm.results import QueryResult + class BackendBase(ABC): max_fetch_size: int diff --git a/vectordb_orm/backends/milvus/__init__.py b/vectordb_orm/backends/milvus/__init__.py new file mode 100644 index 0000000..4cb1293 --- /dev/null +++ b/vectordb_orm/backends/milvus/__init__.py @@ -0,0 +1,3 @@ +from vectordb_orm.backends.milvus.indexes import * +from vectordb_orm.backends.milvus.milvus import * +from vectordb_orm.backends.milvus.similarity import * diff --git a/vectordb_orm/indexes.py b/vectordb_orm/backends/milvus/indexes.py similarity index 84% rename from vectordb_orm/indexes.py rename to vectordb_orm/backends/milvus/indexes.py index b2f32b0..319eba9 100644 --- a/vectordb_orm/indexes.py +++ b/vectordb_orm/backends/milvus/indexes.py @@ -1,7 +1,11 @@ from abc import ABC, abstractmethod -from vectordb_orm.similarity import FloatSimilarityMetric, BinarySimilarityMetric -class IndexBase(ABC): +from vectordb_orm.backends.milvus.similarity import ( + MilvusBinarySimilarityMetric, MilvusFloatSimilarityMetric) +from vectordb_orm.index import IndexBase + + +class MilvusIndexBase(IndexBase): """ Specify indexes used for embedding creation: https://milvus.io/docs/index.md Individual docstrings for the index types are taken from this page of ideal scenarios. @@ -11,14 +15,14 @@ class IndexBase(ABC): def __init__( self, - metric_type: FloatSimilarityMetric | BinarySimilarityMetric | None = None, + metric_type: MilvusFloatSimilarityMetric | MilvusBinarySimilarityMetric | None = None, ): # Choose a reasonable default if metric_type is null, depending on the type of index if metric_type is None: if isinstance(self, tuple(FLOATING_INDEXES)): - metric_type = FloatSimilarityMetric.L2 + metric_type = MilvusFloatSimilarityMetric.L2 elif isinstance(self, tuple(BINARY_INDEXES)): - metric_type = BinarySimilarityMetric.JACCARD + metric_type = MilvusBinarySimilarityMetric.JACCARD self._assert_metric_type(metric_type) self.metric_type = metric_type @@ -35,17 +39,17 @@ def get_inference_parameters(self): """ pass - def _assert_metric_type(self, metric_type: FloatSimilarityMetric | BinarySimilarityMetric): + def _assert_metric_type(self, metric_type: MilvusFloatSimilarityMetric | MilvusBinarySimilarityMetric): """ Binary indexes only support binary metrics, and floating indexes only support floating metrics. Assert that the combination of metric type and index is valid. """ # Only support valid combinations of metric type and index - if isinstance(metric_type, FloatSimilarityMetric): + if isinstance(metric_type, MilvusFloatSimilarityMetric): if not isinstance(self, tuple(FLOATING_INDEXES)): raise ValueError(f"Index type {self} is not supported for metric type {metric_type}") - elif isinstance(metric_type, BinarySimilarityMetric): + elif isinstance(metric_type, MilvusBinarySimilarityMetric): if not isinstance (self, tuple(BINARY_INDEXES)): raise ValueError(f"Index type {self} is not supported for metric type {metric_type}") @@ -55,7 +59,7 @@ def _assert_cluster_units_and_inference_comparison(self, cluster_units: int, inf if inference_comparison is not None and not (inference_comparison >= 1 and inference_comparison <= cluster_units): raise ValueError("inference_comparison must be between 1 and cluster_units") -class FLAT(IndexBase): +class Milvus_FLAT(MilvusIndexBase): """ - Relatively small dataset - Requires a 100% recall rate @@ -69,7 +73,7 @@ def get_inference_parameters(self): return {"metric_type": self.metric_type.name} -class IVF_FLAT(IndexBase): +class Milvus_IVF_FLAT(MilvusIndexBase): """ - High-speed query - Requires a recall rate as high as possible @@ -80,7 +84,7 @@ def __init__( self, cluster_units: int, inference_comparison: int | None = None, - metric_type: FloatSimilarityMetric | BinarySimilarityMetric | None = None, + metric_type: MilvusFloatSimilarityMetric | MilvusBinarySimilarityMetric | None = None, ): """ :param cluster_units: Number of clusters (nlist in the docs) @@ -101,7 +105,7 @@ def get_inference_parameters(self): return {"nprobe": self.nprobe} -class IVF_SQ8(IndexBase): +class Milvus_IVF_SQ8(MilvusIndexBase): """ - High-speed query - Limited memory resources @@ -113,7 +117,7 @@ def __init__( self, cluster_units: int, inference_comparison: int | None = None, - metric_type: FloatSimilarityMetric | BinarySimilarityMetric | None = None, + metric_type: MilvusFloatSimilarityMetric | MilvusBinarySimilarityMetric | None = None, ): """ :param cluster_units: Number of clusters (nlist in the docs) @@ -134,7 +138,7 @@ def get_inference_parameters(self): return {"nprobe": self.nprobe} -class IVF_PQ(IndexBase): +class Milvus_IVF_PQ(MilvusIndexBase): """ - Very high-speed query - Limited memory resources @@ -148,7 +152,7 @@ def __init__( product_quantization: int | None = None, inference_comparison: int | None = None, low_dimension_bits: int | None = None, - metric_type: FloatSimilarityMetric | BinarySimilarityMetric | None = None, + metric_type: MilvusFloatSimilarityMetric | MilvusBinarySimilarityMetric | None = None, ): """ :param cluster_units: Number of clusters (nlist in the docs) @@ -175,7 +179,7 @@ def _assert_low_dimension_bits(self, low_dimension_bits: int | None): if low_dimension_bits is not None and not (low_dimension_bits >= 1 and low_dimension_bits <= 16): raise ValueError("low_dimension_bits must be between 1 and 16") -class HNSW(IndexBase): +class Milvus_HNSW(MilvusIndexBase): """ - High-speed query - Requires a recall rate as high as possible @@ -188,7 +192,7 @@ def __init__( max_degree: int, search_scope_index: int, search_scope_inference: int, - metric_type: FloatSimilarityMetric | BinarySimilarityMetric | None = None, + metric_type: MilvusFloatSimilarityMetric | MilvusBinarySimilarityMetric | None = None, ): """ :param max_degree: Maximum degree of the node @@ -225,7 +229,7 @@ def _assert_search_scope_inference(self, search_scope_inference: int): raise ValueError("search_scope must be between 1 and 32768") -class BIN_FLAT(IndexBase): +class Milvus_BIN_FLAT(MilvusIndexBase): """ - Relatively small dataset - Requires a 100% recall rate @@ -239,7 +243,7 @@ def get_inference_parameters(self): return {"metric_type": self.metric_type.name} -class BIN_IVF_FLAT(IndexBase): +class Milvus_BIN_IVF_FLAT(MilvusIndexBase): """ - High-speed query - Requires a recall rate as high as possible @@ -250,7 +254,7 @@ def __init__( self, cluster_units: int, inference_comparison: int | None = None, - metric_type: FloatSimilarityMetric | BinarySimilarityMetric | None = None, + metric_type: MilvusFloatSimilarityMetric | MilvusBinarySimilarityMetric | None = None, ): """ :param cluster_units: Number of clusters (nlist in the docs) @@ -271,5 +275,5 @@ def get_inference_parameters(self): return {"nprobe": self.nprobe, "metric_type": self.metric_type.name} -FLOATING_INDEXES : set[IndexBase] = {FLAT, IVF_FLAT, IVF_SQ8, IVF_PQ, HNSW} -BINARY_INDEXES : set[IndexBase] = {BIN_FLAT, BIN_IVF_FLAT} +FLOATING_INDEXES : set[MilvusIndexBase] = {Milvus_FLAT, Milvus_IVF_FLAT, Milvus_IVF_SQ8, Milvus_IVF_PQ, Milvus_HNSW} +BINARY_INDEXES : set[MilvusIndexBase] = {Milvus_BIN_FLAT, Milvus_BIN_IVF_FLAT} diff --git a/vectordb_orm/backends/milvus.py b/vectordb_orm/backends/milvus/milvus.py similarity index 92% rename from vectordb_orm/backends/milvus.py rename to vectordb_orm/backends/milvus/milvus.py index c9398e6..b34db09 100644 --- a/vectordb_orm/backends/milvus.py +++ b/vectordb_orm/backends/milvus/milvus.py @@ -1,23 +1,22 @@ -from pymilvus import Milvus, Collection -from pymilvus.client.types import DataType -from pymilvus.orm.schema import CollectionSchema, FieldSchema -from vectordb_orm.backends.base import BackendBase -import numpy as np -from typing import get_args, get_origin from logging import info -from vectordb_orm.base import VectorSchemaBase -from typing import Type -from pymilvus import Milvus, Collection +from typing import Any, Type, get_args, get_origin + +import numpy as np +from pymilvus import Collection, Milvus +from pymilvus.client.abstract import ChunkedQueryResult from pymilvus.client.types import DataType from pymilvus.orm.schema import CollectionSchema, FieldSchema -from vectordb_orm.fields import EmbeddingField, VarCharField, BaseField, PrimaryKeyField -from vectordb_orm.indexes import FLOATING_INDEXES, BINARY_INDEXES -from typing import Any -from vectordb_orm.attributes import AttributeCompare -from pymilvus.client.abstract import ChunkedQueryResult +from vectordb_orm.attributes import AttributeCompare +from vectordb_orm.backends.base import BackendBase +from vectordb_orm.backends.milvus.indexes import (BINARY_INDEXES, + FLOATING_INDEXES) +from vectordb_orm.base import VectorSchemaBase +from vectordb_orm.fields import (BaseField, EmbeddingField, PrimaryKeyField, + VarCharField) from vectordb_orm.results import QueryResult + class MilvusBackend(BackendBase): # https://milvus.io/docs/search.md max_fetch_size = 16384 @@ -101,6 +100,10 @@ def search( if schema.consistency_type() is not None: optional_args["consistency_level"] = schema.consistency_type().value + # Milvus supports to different quering behaviors depending on whether or not we are looking + # for vector similarities + # A `search` is used when we are looking for vector similarities, and a `query` is used more akin + # to a traditional relational database when we're just looking to filter on metadata if search_embedding_key is not None: embedding_configuration : EmbeddingField = schema._type_configuration.get(search_embedding_key) @@ -245,6 +248,11 @@ def _assert_embedding_validity(self, schema: Type[VectorSchemaBase]): elif vector_type == DataType.FLOAT_VECTOR and not isinstance(field_index, tuple(FLOATING_INDEXES)): raise ValueError(f"Index type {field_index} is not compatible with float vectors.") + # Milvus max size + # https://milvus.io/docs/limitations.md + if field_configuration.dim > 32768: + raise ValueError(f"Milvus only supports vectors with dimensions under 32768. {attribute_name} is too large: {field_configuration.dim}.") + def _result_to_objects( self, schema: Type[VectorSchemaBase], diff --git a/vectordb_orm/similarity.py b/vectordb_orm/backends/milvus/similarity.py similarity index 57% rename from vectordb_orm/similarity.py rename to vectordb_orm/backends/milvus/similarity.py index 2f76157..05ca808 100644 --- a/vectordb_orm/similarity.py +++ b/vectordb_orm/backends/milvus/similarity.py @@ -1,8 +1,9 @@ from enum import Enum + from pymilvus.client.types import MetricType -from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_EVENTUALLY, CONSISTENCY_SESSION -class FloatSimilarityMetric(Enum): + +class MilvusFloatSimilarityMetric(Enum): """ Specify the metric used for floating-point search. At inference time a query vector is broadcast to the vectors in the database using this approach. The string values of this enums are directly used by Milvus, see here for more info: https://milvus.io/docs/metric.md @@ -13,7 +14,7 @@ class FloatSimilarityMetric(Enum): # Inner Product IP = MetricType.IP.name -class BinarySimilarityMetric(Enum): +class MilvusBinarySimilarityMetric(Enum): """ Specify the metric used for binary search. These are distance metrics. @@ -21,14 +22,3 @@ class BinarySimilarityMetric(Enum): JACCARD = MetricType.JACCARD.name TANIMOTO = MetricType.TANIMOTO.name HAMMING = MetricType.HAMMING.name - -class ConsistencyType(Enum): - """ - Define the strength of the consistency within the distributed DB: - https://milvus.io/docs/consistency.md - - """ - STRONG = CONSISTENCY_STRONG - BOUNDED = CONSISTENCY_BOUNDED - SESSION = CONSISTENCY_SESSION - EVENTUALLY = CONSISTENCY_EVENTUALLY diff --git a/vectordb_orm/backends/pinecone/__init__.py b/vectordb_orm/backends/pinecone/__init__.py new file mode 100644 index 0000000..52ad9b2 --- /dev/null +++ b/vectordb_orm/backends/pinecone/__init__.py @@ -0,0 +1,2 @@ +from vectordb_orm.backends.pinecone.indexes import * +from vectordb_orm.backends.pinecone.pinecone import * diff --git a/vectordb_orm/backends/pinecone/indexes.py b/vectordb_orm/backends/pinecone/indexes.py new file mode 100644 index 0000000..7b66e60 --- /dev/null +++ b/vectordb_orm/backends/pinecone/indexes.py @@ -0,0 +1,29 @@ +from enum import Enum + +from vectordb_orm.index import IndexBase + + +class PineconeSimilarityMetric(Enum): + COSINE = "cosine" + EUCLIDEAN = "euclidean" + DOT_PRODUCT = "dotproduct" + + +class PineconeIndex(IndexBase): + """ + Pinecone only supports one type of index + """ + def __init__(self, metric_type: PineconeSimilarityMetric): + self.metric_type = metric_type + self._assert_metric_type(metric_type) + + def get_index_parameters(self): + return {} + + def get_inference_parameters(self): + return {"metric_type": self.metric_type.name} + + def _assert_metric_type(self, metric_type: PineconeSimilarityMetric): + # Only support valid combinations of metric type and index + if isinstance(metric_type, PineconeSimilarityMetric): + raise ValueError(f"Index type {self} is not supported for metric type {metric_type}") diff --git a/vectordb_orm/backends/pinecone.py b/vectordb_orm/backends/pinecone/pinecone.py similarity index 94% rename from vectordb_orm/backends/pinecone.py rename to vectordb_orm/backends/pinecone/pinecone.py index 044ada9..630879a 100644 --- a/vectordb_orm/backends/pinecone.py +++ b/vectordb_orm/backends/pinecone/pinecone.py @@ -1,15 +1,19 @@ -import pinecone -from vectordb_orm.backends.base import BackendBase +from logging import info +from re import match as re_match from typing import Type -from vectordb_orm.base import VectorSchemaBase +from uuid import uuid4 + import numpy as np +import pinecone + from vectordb_orm.attributes import AttributeCompare +from vectordb_orm.backends.base import BackendBase +from vectordb_orm.backends.pinecone.indexes import PineconeIndex +from vectordb_orm.base import VectorSchemaBase from vectordb_orm.fields import EmbeddingField -from re import match as re_match -from logging import info -from uuid import uuid4 from vectordb_orm.results import QueryResult + class PineconeBackend(BackendBase): max_fetch_size = 1000 @@ -34,6 +38,7 @@ def create_collection(self, schema: Type[VectorSchemaBase]): self._assert_valid_collection_name(collection_name) self._assert_has_primary(schema) + self._assert_valid_embedding_field(schema) # Pinecone allows for dynamic keys on each object # However we need to pre-provide the keys we want to search on @@ -215,3 +220,10 @@ def _get_embedding_field(self, schema: Type[VectorSchemaBase]): raise ValueError(f"Pinecone only supports one embedding field per collection. {schema} has {len(embedding_fields)} defined: {list(embedding_fields.keys())}.") return list(embedding_fields.items())[0] + + def _assert_valid_embedding_field(self, schema: Type[VectorSchemaBase]): + _, embedding_field = self._get_embedding_field(schema) + + # Ensure that we are using a supported index + if not isinstance(embedding_field.index, PineconeIndex): + raise ValueError("Pinecone only supports a basic `PineconeIndex`.") diff --git a/vectordb_orm/base.py b/vectordb_orm/base.py index aff1c43..f42ebea 100644 --- a/vectordb_orm/base.py +++ b/vectordb_orm/base.py @@ -1,6 +1,7 @@ -from vectordb_orm.attributes import AttributeCompare -from vectordb_orm.similarity import ConsistencyType from typing import Any + +from vectordb_orm.attributes import AttributeCompare +from vectordb_orm.enums import ConsistencyType from vectordb_orm.fields import BaseField, PrimaryKeyField diff --git a/vectordb_orm/enums.py b/vectordb_orm/enums.py new file mode 100644 index 0000000..557cd00 --- /dev/null +++ b/vectordb_orm/enums.py @@ -0,0 +1,16 @@ +from enum import Enum + +from pymilvus.orm.types import (CONSISTENCY_BOUNDED, CONSISTENCY_EVENTUALLY, + CONSISTENCY_SESSION, CONSISTENCY_STRONG) + + +class ConsistencyType(Enum): + """ + Define the strength of the consistency within the distributed DB: + https://milvus.io/docs/consistency.md + + """ + STRONG = CONSISTENCY_STRONG + BOUNDED = CONSISTENCY_BOUNDED + SESSION = CONSISTENCY_SESSION + EVENTUALLY = CONSISTENCY_EVENTUALLY diff --git a/vectordb_orm/fields.py b/vectordb_orm/fields.py index 455a4e1..d080bad 100644 --- a/vectordb_orm/fields.py +++ b/vectordb_orm/fields.py @@ -1,8 +1,9 @@ from abc import ABC -from vectordb_orm.indexes import IndexBase, FLOATING_INDEXES, BINARY_INDEXES -from vectordb_orm.similarity import FloatSimilarityMetric, BinarySimilarityMetric from typing import Any +from vectordb_orm.index import IndexBase + + class BaseField(ABC): """ BaseField is the superclass to all fields that require additional customization behavior. They @@ -19,8 +20,8 @@ def __init__(self, default: Any): class PrimaryKeyField(BaseField): - def __init__(self, default: Any = None): - super().__init__(default=default) + def __init__(self): + super().__init__(default=None) class EmbeddingField(BaseField): diff --git a/vectordb_orm/index.py b/vectordb_orm/index.py new file mode 100644 index 0000000..5f3fab9 --- /dev/null +++ b/vectordb_orm/index.py @@ -0,0 +1,2 @@ +class IndexBase: + pass diff --git a/vectordb_orm/query.py b/vectordb_orm/query.py index 83ffff3..1d41598 100644 --- a/vectordb_orm/query.py +++ b/vectordb_orm/query.py @@ -1,8 +1,10 @@ -from vectordb_orm.backends.base import BackendBase +from typing import Any + from vectordb_orm.attributes import AttributeCompare -from vectordb_orm.fields import EmbeddingField +from vectordb_orm.backends.base import BackendBase from vectordb_orm.base import VectorSchemaBase -from typing import Any +from vectordb_orm.fields import EmbeddingField + class VectorQueryBuilder: """ diff --git a/vectordb_orm/results.py b/vectordb_orm/results.py index 74b3e05..9cfbee1 100644 --- a/vectordb_orm/results.py +++ b/vectordb_orm/results.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from vectordb_orm.base import VectorSchemaBase @@ -9,5 +9,5 @@ class QueryResult: result: "VectorSchemaBase" # Score and distance is only returned for queries requesting vector-similarity - score: float | None = None - distance: float | None = None + score: Optional[float] = None + distance: Optional[float] = None diff --git a/vectordb_orm/session.py b/vectordb_orm/session.py index d3dfb54..c4950da 100644 --- a/vectordb_orm/session.py +++ b/vectordb_orm/session.py @@ -1,8 +1,11 @@ +from typing import Type + +from pymilvus import Milvus + +from vectordb_orm.backends.base import BackendBase from vectordb_orm.base import VectorSchemaBase from vectordb_orm.query import VectorQueryBuilder -from vectordb_orm.backends.base import BackendBase -from pymilvus import Milvus -from typing import Type + class VectorSession: """ diff --git a/vectordb_orm/tests/conftest.py b/vectordb_orm/tests/conftest.py index adec9ec..ad15329 100644 --- a/vectordb_orm/tests/conftest.py +++ b/vectordb_orm/tests/conftest.py @@ -1,10 +1,12 @@ -import pytest -from pymilvus import Milvus, connections -from vectordb_orm import VectorSession, MilvusBackend, PineconeBackend -from vectordb_orm.tests.models import MyObject, BinaryEmbeddingObject -from time import sleep from os import getenv +from time import sleep + +import pytest from dotenv import load_dotenv +from pymilvus import Milvus, connections + +from vectordb_orm import MilvusBackend, PineconeBackend, VectorSession +from vectordb_orm.tests.models import BinaryEmbeddingObject, MyObject @pytest.fixture() diff --git a/vectordb_orm/tests/models.py b/vectordb_orm/tests/models.py index f529086..3375070 100644 --- a/vectordb_orm/tests/models.py +++ b/vectordb_orm/tests/models.py @@ -1,14 +1,18 @@ -from vectordb_orm import VectorSchemaBase, EmbeddingField, VarCharField, PrimaryKeyField, ConsistencyType -from vectordb_orm.indexes import IVF_FLAT, BIN_FLAT import numpy as np +from vectordb_orm import (ConsistencyType, EmbeddingField, PrimaryKeyField, + VarCharField, VectorSchemaBase) +from vectordb_orm.backends.milvus.indexes import (Milvus_BIN_FLAT, + Milvus_IVF_FLAT) + + class MyObject(VectorSchemaBase): __collection_name__ = 'my_collection' __consistency_type__ = ConsistencyType.STRONG id: int = PrimaryKeyField() text: str = VarCharField(max_length=128) - embedding: np.ndarray = EmbeddingField(dim=128, index=IVF_FLAT(cluster_units=128)) + embedding: np.ndarray = EmbeddingField(dim=128, index=Milvus_IVF_FLAT(cluster_units=128)) class BinaryEmbeddingObject(VectorSchemaBase): @@ -16,4 +20,4 @@ class BinaryEmbeddingObject(VectorSchemaBase): __consistency_type__ = ConsistencyType.STRONG id: int = PrimaryKeyField() - embedding: np.ndarray[np.bool_] = EmbeddingField(dim=128, index=BIN_FLAT()) + embedding: np.ndarray[np.bool_] = EmbeddingField(dim=128, index=Milvus_BIN_FLAT()) diff --git a/vectordb_orm/tests/test_base.py b/vectordb_orm/tests/test_base.py index 8bafc0c..eb925d7 100644 --- a/vectordb_orm/tests/test_base.py +++ b/vectordb_orm/tests/test_base.py @@ -1,11 +1,12 @@ -import pytest -from vectordb_orm import VectorSession -from vectordb_orm.tests.models import MyObject import numpy as np -from time import sleep -from vectordb_orm import VectorSchemaBase, EmbeddingField, PrimaryKeyField -from vectordb_orm.indexes import IVF_FLAT +import pytest + +from vectordb_orm import (EmbeddingField, PrimaryKeyField, VectorSchemaBase, + VectorSession) +from vectordb_orm.backends.milvus.indexes import Milvus_IVF_FLAT from vectordb_orm.tests.conftest import SESSION_FIXTURE_KEYS +from vectordb_orm.tests.models import MyObject + def test_create_object(): my_object = MyObject(text='example', embedding=np.array([1.0] * 128)) @@ -65,7 +66,7 @@ class TestInvalidObject(VectorSchemaBase): __collection_name__ = 'invalid_collection' id: int = PrimaryKeyField() - embedding: np.ndarray[np.bool_] = EmbeddingField(dim=128, index=IVF_FLAT(cluster_units=128)) + embedding: np.ndarray[np.bool_] = EmbeddingField(dim=128, index=Milvus_IVF_FLAT(cluster_units=128)) with pytest.raises(ValueError, match="not compatible with binary vectors"): milvus_session.create_collection(TestInvalidObject) diff --git a/vectordb_orm/tests/test_indexes.py b/vectordb_orm/tests/test_indexes.py index a4b2a71..d306e6c 100644 --- a/vectordb_orm/tests/test_indexes.py +++ b/vectordb_orm/tests/test_indexes.py @@ -1,10 +1,15 @@ +from itertools import product + +import numpy as np import pytest from pymilvus import Milvus -from vectordb_orm import VectorSession, EmbeddingField, PrimaryKeyField, VectorSchemaBase -import numpy as np -from vectordb_orm.indexes import IndexBase, FLOATING_INDEXES, BINARY_INDEXES -from vectordb_orm.similarity import FloatSimilarityMetric, BinarySimilarityMetric -from itertools import product + +from vectordb_orm import (EmbeddingField, PrimaryKeyField, VectorSchemaBase, + VectorSession) +from vectordb_orm.backends.milvus.indexes import (BINARY_INDEXES, + FLOATING_INDEXES, IndexBase) +from vectordb_orm.backends.milvus.similarity import ( + BinarySimilarityMetric, MilvusFloatSimilarityMetric) # Different index definitions require different kwarg arguments; we centralize # them here for ease of accessing them during different test runs @@ -34,13 +39,13 @@ "index_cls,metric_type", product( FLOATING_INDEXES, - [item for item in FloatSimilarityMetric], + [item for item in MilvusFloatSimilarityMetric], ) ) def test_floating_index( session: VectorSession, index_cls: IndexBase, - metric_type: FloatSimilarityMetric, + metric_type: MilvusFloatSimilarityMetric, ): class IndexSubclassObject(VectorSchemaBase): __collection_name__ = 'index_collection' diff --git a/vectordb_orm/tests/test_query.py b/vectordb_orm/tests/test_query.py index 2b07b2c..1438773 100644 --- a/vectordb_orm/tests/test_query.py +++ b/vectordb_orm/tests/test_query.py @@ -1,14 +1,18 @@ +import numpy as np import pytest -from pymilvus import Milvus, connections + from vectordb_orm import VectorSession -from vectordb_orm.tests.models import MyObject, BinaryEmbeddingObject -import numpy as np -from time import sleep +from vectordb_orm.tests.conftest import SESSION_FIXTURE_KEYS +from vectordb_orm.tests.models import BinaryEmbeddingObject, MyObject + -def test_query(session: VectorSession): +@pytest.mark.parametrize("session", SESSION_FIXTURE_KEYS) +def test_query(session: str, request): """ General test of querying and query chaining """ + session : VectorSession = request.getfixturevalue(session) + # Create some MyObject instances obj1 = MyObject(text="foo", embedding=np.array([1.0] * 128)) obj2 = MyObject(text="bar", embedding=np.array([4.0] * 128)) @@ -56,11 +60,14 @@ def test_binary_collection_query(session: VectorSession): assert len(results) == 2 assert results[0].result.id == obj2.id -def test_query_default_ignores_embeddings(session: VectorSession): +@pytest.mark.parametrize("session", SESSION_FIXTURE_KEYS) +def test_query_default_ignores_embeddings(session: str, request): """ Ensure that querying on the class by default ignores embeddings that are included within the type definition. """ + session : VectorSession = request.getfixturevalue(session) + obj1 = MyObject(text="foo", embedding=np.array([1.0] * 128)) session.insert(obj1) @@ -74,11 +81,13 @@ def test_query_default_ignores_embeddings(session: VectorSession): result : MyObject = results[0].result assert result.embedding is None - -def test_query_with_fields(session: VectorSession): +@pytest.mark.parametrize("session", SESSION_FIXTURE_KEYS) +def test_query_with_fields(session: str, request): """ Test querying with specific fields """ + session : VectorSession = request.getfixturevalue(session) + obj1 = MyObject(text="foo", embedding=np.array([1.0] * 128)) session.insert(obj1)