Skip to content

Commit

Permalink
Indexes as backend specific arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
piercefreeman committed Apr 21, 2023
1 parent 5483487 commit e728315
Show file tree
Hide file tree
Showing 24 changed files with 234 additions and 111 deletions.
20 changes: 19 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pinecone-client = "^2.2.1"
[tool.poetry.group.dev.dependencies]
pytest = "^7.3.1"
python-dotenv = "^1.0.0"
isort = "^5.12.0"

[build-system]
requires = ["poetry-core"]
Expand Down
17 changes: 12 additions & 5 deletions vectordb_orm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from vectordb_orm.backends.milvus.indexes import (Milvus_BIN_FLAT,
Milvus_BIN_IVF_FLAT,
Milvus_FLAT, Milvus_HNSW,
Milvus_IVF_FLAT,
Milvus_IVF_PQ,
Milvus_IVF_SQ8)
from vectordb_orm.backends.milvus.milvus import MilvusBackend
from vectordb_orm.backends.pinecone.indexes import (PineconeIndex,
PineconeSimilarityMetric)
from vectordb_orm.backends.pinecone.pinecone import PineconeBackend
from vectordb_orm.base import VectorSchemaBase
from vectordb_orm.fields import EmbeddingField, VarCharField, PrimaryKeyField
from vectordb_orm.enums import ConsistencyType
from vectordb_orm.fields import EmbeddingField, PrimaryKeyField, VarCharField
from vectordb_orm.session import VectorSession
from vectordb_orm.indexes import FLAT, IVF_FLAT, IVF_SQ8, IVF_PQ, HNSW, BIN_FLAT, BIN_IVF_FLAT
from vectordb_orm.similarity import ConsistencyType
from vectordb_orm.backends.milvus import MilvusBackend
from vectordb_orm.backends.pinecone import PineconeBackend
2 changes: 1 addition & 1 deletion vectordb_orm/attributes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, TYPE_CHECKING
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from vectordb_orm.base import MilvusBase
Expand Down
7 changes: 5 additions & 2 deletions vectordb_orm/backends/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from abc import ABC, abstractmethod
from typing import Type
from vectordb_orm.base import VectorSchemaBase
from vectordb_orm.fields import PrimaryKeyField

import numpy as np

from vectordb_orm.attributes import AttributeCompare
from vectordb_orm.base import VectorSchemaBase
from vectordb_orm.fields import PrimaryKeyField
from vectordb_orm.results import QueryResult


class BackendBase(ABC):
max_fetch_size: int

Expand Down
3 changes: 3 additions & 0 deletions vectordb_orm/backends/milvus/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from vectordb_orm.backends.milvus.indexes import *
from vectordb_orm.backends.milvus.milvus import *
from vectordb_orm.backends.milvus.similarity import *
48 changes: 26 additions & 22 deletions vectordb_orm/indexes.py → vectordb_orm/backends/milvus/indexes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from abc import ABC, abstractmethod
from vectordb_orm.similarity import FloatSimilarityMetric, BinarySimilarityMetric

class IndexBase(ABC):
from vectordb_orm.backends.milvus.similarity import (
MilvusBinarySimilarityMetric, MilvusFloatSimilarityMetric)
from vectordb_orm.index import IndexBase


class MilvusIndexBase(IndexBase):
"""
Specify indexes used for embedding creation: https://milvus.io/docs/index.md
Individual docstrings for the index types are taken from this page of ideal scenarios.
Expand All @@ -11,14 +15,14 @@ class IndexBase(ABC):

def __init__(
self,
metric_type: FloatSimilarityMetric | BinarySimilarityMetric | None = None,
metric_type: MilvusFloatSimilarityMetric | MilvusBinarySimilarityMetric | None = None,
):
# Choose a reasonable default if metric_type is null, depending on the type of index
if metric_type is None:
if isinstance(self, tuple(FLOATING_INDEXES)):
metric_type = FloatSimilarityMetric.L2
metric_type = MilvusFloatSimilarityMetric.L2
elif isinstance(self, tuple(BINARY_INDEXES)):
metric_type = BinarySimilarityMetric.JACCARD
metric_type = MilvusBinarySimilarityMetric.JACCARD

self._assert_metric_type(metric_type)
self.metric_type = metric_type
Expand All @@ -35,17 +39,17 @@ def get_inference_parameters(self):
"""
pass

def _assert_metric_type(self, metric_type: FloatSimilarityMetric | BinarySimilarityMetric):
def _assert_metric_type(self, metric_type: MilvusFloatSimilarityMetric | MilvusBinarySimilarityMetric):
"""
Binary indexes only support binary metrics, and floating indexes only support floating metrics. Assert
that the combination of metric type and index is valid.
"""
# Only support valid combinations of metric type and index
if isinstance(metric_type, FloatSimilarityMetric):
if isinstance(metric_type, MilvusFloatSimilarityMetric):
if not isinstance(self, tuple(FLOATING_INDEXES)):
raise ValueError(f"Index type {self} is not supported for metric type {metric_type}")
elif isinstance(metric_type, BinarySimilarityMetric):
elif isinstance(metric_type, MilvusBinarySimilarityMetric):
if not isinstance (self, tuple(BINARY_INDEXES)):
raise ValueError(f"Index type {self} is not supported for metric type {metric_type}")

Expand All @@ -55,7 +59,7 @@ def _assert_cluster_units_and_inference_comparison(self, cluster_units: int, inf
if inference_comparison is not None and not (inference_comparison >= 1 and inference_comparison <= cluster_units):
raise ValueError("inference_comparison must be between 1 and cluster_units")

class FLAT(IndexBase):
class Milvus_FLAT(MilvusIndexBase):
"""
- Relatively small dataset
- Requires a 100% recall rate
Expand All @@ -69,7 +73,7 @@ def get_inference_parameters(self):
return {"metric_type": self.metric_type.name}


class IVF_FLAT(IndexBase):
class Milvus_IVF_FLAT(MilvusIndexBase):
"""
- High-speed query
- Requires a recall rate as high as possible
Expand All @@ -80,7 +84,7 @@ def __init__(
self,
cluster_units: int,
inference_comparison: int | None = None,
metric_type: FloatSimilarityMetric | BinarySimilarityMetric | None = None,
metric_type: MilvusFloatSimilarityMetric | MilvusBinarySimilarityMetric | None = None,
):
"""
:param cluster_units: Number of clusters (nlist in the docs)
Expand All @@ -101,7 +105,7 @@ def get_inference_parameters(self):
return {"nprobe": self.nprobe}


class IVF_SQ8(IndexBase):
class Milvus_IVF_SQ8(MilvusIndexBase):
"""
- High-speed query
- Limited memory resources
Expand All @@ -113,7 +117,7 @@ def __init__(
self,
cluster_units: int,
inference_comparison: int | None = None,
metric_type: FloatSimilarityMetric | BinarySimilarityMetric | None = None,
metric_type: MilvusFloatSimilarityMetric | MilvusBinarySimilarityMetric | None = None,
):
"""
:param cluster_units: Number of clusters (nlist in the docs)
Expand All @@ -134,7 +138,7 @@ def get_inference_parameters(self):
return {"nprobe": self.nprobe}


class IVF_PQ(IndexBase):
class Milvus_IVF_PQ(MilvusIndexBase):
"""
- Very high-speed query
- Limited memory resources
Expand All @@ -148,7 +152,7 @@ def __init__(
product_quantization: int | None = None,
inference_comparison: int | None = None,
low_dimension_bits: int | None = None,
metric_type: FloatSimilarityMetric | BinarySimilarityMetric | None = None,
metric_type: MilvusFloatSimilarityMetric | MilvusBinarySimilarityMetric | None = None,
):
"""
:param cluster_units: Number of clusters (nlist in the docs)
Expand All @@ -175,7 +179,7 @@ def _assert_low_dimension_bits(self, low_dimension_bits: int | None):
if low_dimension_bits is not None and not (low_dimension_bits >= 1 and low_dimension_bits <= 16):
raise ValueError("low_dimension_bits must be between 1 and 16")

class HNSW(IndexBase):
class Milvus_HNSW(MilvusIndexBase):
"""
- High-speed query
- Requires a recall rate as high as possible
Expand All @@ -188,7 +192,7 @@ def __init__(
max_degree: int,
search_scope_index: int,
search_scope_inference: int,
metric_type: FloatSimilarityMetric | BinarySimilarityMetric | None = None,
metric_type: MilvusFloatSimilarityMetric | MilvusBinarySimilarityMetric | None = None,
):
"""
:param max_degree: Maximum degree of the node
Expand Down Expand Up @@ -225,7 +229,7 @@ def _assert_search_scope_inference(self, search_scope_inference: int):
raise ValueError("search_scope must be between 1 and 32768")


class BIN_FLAT(IndexBase):
class Milvus_BIN_FLAT(MilvusIndexBase):
"""
- Relatively small dataset
- Requires a 100% recall rate
Expand All @@ -239,7 +243,7 @@ def get_inference_parameters(self):
return {"metric_type": self.metric_type.name}


class BIN_IVF_FLAT(IndexBase):
class Milvus_BIN_IVF_FLAT(MilvusIndexBase):
"""
- High-speed query
- Requires a recall rate as high as possible
Expand All @@ -250,7 +254,7 @@ def __init__(
self,
cluster_units: int,
inference_comparison: int | None = None,
metric_type: FloatSimilarityMetric | BinarySimilarityMetric | None = None,
metric_type: MilvusFloatSimilarityMetric | MilvusBinarySimilarityMetric | None = None,
):
"""
:param cluster_units: Number of clusters (nlist in the docs)
Expand All @@ -271,5 +275,5 @@ def get_inference_parameters(self):
return {"nprobe": self.nprobe, "metric_type": self.metric_type.name}


FLOATING_INDEXES : set[IndexBase] = {FLAT, IVF_FLAT, IVF_SQ8, IVF_PQ, HNSW}
BINARY_INDEXES : set[IndexBase] = {BIN_FLAT, BIN_IVF_FLAT}
FLOATING_INDEXES : set[MilvusIndexBase] = {Milvus_FLAT, Milvus_IVF_FLAT, Milvus_IVF_SQ8, Milvus_IVF_PQ, Milvus_HNSW}
BINARY_INDEXES : set[MilvusIndexBase] = {Milvus_BIN_FLAT, Milvus_BIN_IVF_FLAT}
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
from pymilvus import Milvus, Collection
from pymilvus.client.types import DataType
from pymilvus.orm.schema import CollectionSchema, FieldSchema
from vectordb_orm.backends.base import BackendBase
import numpy as np
from typing import get_args, get_origin
from logging import info
from vectordb_orm.base import VectorSchemaBase
from typing import Type
from pymilvus import Milvus, Collection
from typing import Any, Type, get_args, get_origin

import numpy as np
from pymilvus import Collection, Milvus
from pymilvus.client.abstract import ChunkedQueryResult
from pymilvus.client.types import DataType
from pymilvus.orm.schema import CollectionSchema, FieldSchema
from vectordb_orm.fields import EmbeddingField, VarCharField, BaseField, PrimaryKeyField
from vectordb_orm.indexes import FLOATING_INDEXES, BINARY_INDEXES
from typing import Any
from vectordb_orm.attributes import AttributeCompare

from pymilvus.client.abstract import ChunkedQueryResult
from vectordb_orm.attributes import AttributeCompare
from vectordb_orm.backends.base import BackendBase
from vectordb_orm.backends.milvus.indexes import (BINARY_INDEXES,
FLOATING_INDEXES)
from vectordb_orm.base import VectorSchemaBase
from vectordb_orm.fields import (BaseField, EmbeddingField, PrimaryKeyField,
VarCharField)
from vectordb_orm.results import QueryResult


class MilvusBackend(BackendBase):
# https://milvus.io/docs/search.md
max_fetch_size = 16384
Expand Down Expand Up @@ -101,6 +100,10 @@ def search(
if schema.consistency_type() is not None:
optional_args["consistency_level"] = schema.consistency_type().value

# Milvus supports to different quering behaviors depending on whether or not we are looking
# for vector similarities
# A `search` is used when we are looking for vector similarities, and a `query` is used more akin
# to a traditional relational database when we're just looking to filter on metadata
if search_embedding_key is not None:
embedding_configuration : EmbeddingField = schema._type_configuration.get(search_embedding_key)

Expand Down Expand Up @@ -245,6 +248,11 @@ def _assert_embedding_validity(self, schema: Type[VectorSchemaBase]):
elif vector_type == DataType.FLOAT_VECTOR and not isinstance(field_index, tuple(FLOATING_INDEXES)):
raise ValueError(f"Index type {field_index} is not compatible with float vectors.")

# Milvus max size
# https://milvus.io/docs/limitations.md
if field_configuration.dim > 32768:
raise ValueError(f"Milvus only supports vectors with dimensions under 32768. {attribute_name} is too large: {field_configuration.dim}.")

def _result_to_objects(
self,
schema: Type[VectorSchemaBase],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from enum import Enum

from pymilvus.client.types import MetricType
from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_EVENTUALLY, CONSISTENCY_SESSION

class FloatSimilarityMetric(Enum):

class MilvusFloatSimilarityMetric(Enum):
"""
Specify the metric used for floating-point search. At inference time a query vector is broadcast to the vectors in the database
using this approach. The string values of this enums are directly used by Milvus, see here for more info: https://milvus.io/docs/metric.md
Expand All @@ -13,22 +14,11 @@ class FloatSimilarityMetric(Enum):
# Inner Product
IP = MetricType.IP.name

class BinarySimilarityMetric(Enum):
class MilvusBinarySimilarityMetric(Enum):
"""
Specify the metric used for binary search. These are distance metrics.
"""
JACCARD = MetricType.JACCARD.name
TANIMOTO = MetricType.TANIMOTO.name
HAMMING = MetricType.HAMMING.name

class ConsistencyType(Enum):
"""
Define the strength of the consistency within the distributed DB:
https://milvus.io/docs/consistency.md
"""
STRONG = CONSISTENCY_STRONG
BOUNDED = CONSISTENCY_BOUNDED
SESSION = CONSISTENCY_SESSION
EVENTUALLY = CONSISTENCY_EVENTUALLY
2 changes: 2 additions & 0 deletions vectordb_orm/backends/pinecone/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from vectordb_orm.backends.pinecone.indexes import *
from vectordb_orm.backends.pinecone.pinecone import *
29 changes: 29 additions & 0 deletions vectordb_orm/backends/pinecone/indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from enum import Enum

from vectordb_orm.index import IndexBase


class PineconeSimilarityMetric(Enum):
COSINE = "cosine"
EUCLIDEAN = "euclidean"
DOT_PRODUCT = "dotproduct"


class PineconeIndex(IndexBase):
"""
Pinecone only supports one type of index
"""
def __init__(self, metric_type: PineconeSimilarityMetric):
self.metric_type = metric_type
self._assert_metric_type(metric_type)

def get_index_parameters(self):
return {}

def get_inference_parameters(self):
return {"metric_type": self.metric_type.name}

def _assert_metric_type(self, metric_type: PineconeSimilarityMetric):
# Only support valid combinations of metric type and index
if isinstance(metric_type, PineconeSimilarityMetric):
raise ValueError(f"Index type {self} is not supported for metric type {metric_type}")
Loading

0 comments on commit e728315

Please sign in to comment.