diff --git a/README.md b/README.md index 4f58150..ee3befd 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,12 @@ # vectordb-orm -`vectordb-orm` is an Object-Relational Mapping (ORM) library designed to work with vector databases. This project aims to provide a consistent and convenient interface for working with vector data, allowing you to interact with vector databases using familiar ORM concepts and syntax. Right now [Milvus](https://milvus.io/) and [Pinecone](https://www.pinecone.io/) are supported with more backend engines planned for the future. +`vectordb-orm` is an Object-Relational Mapping (ORM) library for vector databases. Define your data as objects and query for them using familiar SQL syntax, with all the added power of lighting fast vector search. + +Right now [Milvus](https://milvus.io/) and [Pinecone](https://www.pinecone.io/) are supported with more backend engines planned for the future. ## Getting Started -Here are some simple examples demonstrating common behavior with vectordb-orm. First a note on structure. vectordb-orm is designed around the idea of a `Schema`, which is logically equivalent to a table in classic relational databases. This schema is marked up with python typehints that define the type of vector and metadata that will be stored alongisde the objects. +Here are some simple examples demonstrating common behavior with vectordb-orm. First a note on structure. vectordb-orm is designed around the idea of a `Schema`, which is logically equivalent to a table in classic relational databases. This schema is marked up with typehints that define the type of vector and metadata that will be stored alongisde the objects. You create a class definition by subclassing `VectorSchemaBase` and providing typehints for the keys of your model, similar to pydantic. These fields also support custom initialization behavior if you want (or need) to modify their configuration options. @@ -40,6 +42,26 @@ class MyObject(VectorSchemaBase): embedding: np.ndarray = EmbeddingField(dim=128, index=PineconeIndex(metric_type=PineconeSimilarityMetric.COSINE)) ``` +### Indexing Data + +To insert objects into the database, create a new instance of your object class and insert into the current session. The arguments to the init function mirror the typehinted schema that you defined above. + +```python +obj = MyObject(text="my_text", embedding=np.array([1.0]*128)) +session.insert(obj) +``` + +Once inserted, this object will be populated with a new `id` by the database engine. At this point it should be queryable (modulo some backends taking time for eventual consistency across different shards). + +`vectordb-orm` also supports batch insertion. This is recommended in cases where you have a lot of data to insert at one time, since latencies can be significant on individual datapoints. + +```python +obj = MyObject(text="my_text", embedding=np.array([1.0]*128)) +session.insert_batch([obj], show_progress=True) +``` + +The optional `show_progress` allows you to show a progress bar to show the current status of the insertion and the estimated time remaining for the whole dataset. + ### Querying Syntax ```python diff --git a/poetry.lock b/poetry.lock index 08d8938..f230795 100644 --- a/poetry.lock +++ b/poetry.lock @@ -385,7 +385,7 @@ dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] [metadata] lock-version = "1.1" python-versions = "^3.10" -content-hash = "fb360da7cdf923249ea6c8602f9c2e8268a2362d2a97aa8964985d167a16aa2b" +content-hash = "277f76ff10a6edb7d38ee3a94aecd1218e13005a09b456a697b2314bef0e6eab" [metadata.files] certifi = [ diff --git a/pyproject.toml b/pyproject.toml index fa048f9..376f13d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ pymilvus = "^2.2.6" protobuf = "^4.22.3" numpy = "^1.24.2" pinecone-client = "^2.2.1" +tqdm = "^4.65.0" [tool.poetry.group.dev.dependencies] diff --git a/vectordb_orm/backends/base.py b/vectordb_orm/backends/base.py index 757ffb8..1b22c49 100644 --- a/vectordb_orm/backends/base.py +++ b/vectordb_orm/backends/base.py @@ -28,6 +28,10 @@ def delete_collection(self, schema: Type[VectorSchemaBase]): def insert(self, entity: VectorSchemaBase) -> int: pass + @abstractmethod + def insert_batch(self, entities: list[VectorSchemaBase], show_progress: bool) -> list[int]: + pass + @abstractmethod def delete(self, entity: VectorSchemaBase): pass diff --git a/vectordb_orm/backends/milvus/milvus.py b/vectordb_orm/backends/milvus/milvus.py index 26d81c2..836e633 100644 --- a/vectordb_orm/backends/milvus/milvus.py +++ b/vectordb_orm/backends/milvus/milvus.py @@ -1,3 +1,4 @@ +from collections import defaultdict from logging import info from typing import Any, Type, get_args, get_origin @@ -72,6 +73,81 @@ def insert(self, entity: VectorSchemaBase) -> int: mutation_result = self.client.insert(collection_name=entity.__class__.collection_name(), entities=entities) return mutation_result.primary_keys[0] + def insert_batch( + self, + entities: list[VectorSchemaBase], + show_progress: bool, + ) -> list[int]: + if show_progress: + raise ValueError("Milvus backend does not support batch insertion progress logging because it is done in one operation.") + + # Group by the schema type since we allow for the insertion of multiple different schema + # `schema_to_entities` - input entities grouped by the schema name + # `schema_to_original_index` - since we switch to a per-schema representation, keep track of a mapping + # from the schema to the original index in `entities` + # `schema_to_ids` - map of schema to the resulting primary keys + schema_to_original_index = defaultdict(list) + schema_to_class = {} + schema_to_ids = {} + + for i, entity in enumerate(entities): + schema = entity.__class__ + schema_name = schema.collection_name() + + schema_to_original_index[schema_name].append(i) + schema_to_class[schema_name] = schema + + for schema_name, schema_indexes in schema_to_original_index.items(): + schema = schema_to_class[schema_name] + schema_entities = [entities[index] for index in schema_indexes] + + # The primary key should be null at this stage of things, so we ignore it + # during the insertion + ignore_keys = { + self._get_primary(schema) + } + + # Group this schema's objects by their keys + by_key_values = defaultdict(list) + by_key_type = {} + + for entity in schema_entities: + for attribute_name, type_hint in entity.__annotations__.items(): + value = getattr(entity, attribute_name) + db_type, value = self._type_to_value(type_hint, value) + by_key_values[attribute_name].append(value) + by_key_type[attribute_name] = db_type + + # Ensure each key in `schema_to_objects` matches the quantity of objects + # that should have been created. This *shouldn't* happen but it's possible + # some combination of programatically deleting attributes or annotations will + # lead to this case. We proactively raise an error because this could result in + # data corruption. + all_lengths = {len(values) for values in by_key_values.values()} + if len(all_lengths) > 1: + raise ValueError(f"Inserted objects don't align for schema `{schema_name}`") + + payload = [ + { + "name": attribute_name, + "type": by_key_type[attribute_name], + "values": values, + } + for attribute_name, values in by_key_values.items() + if attribute_name not in ignore_keys + ] + + mutation_result = self.client.insert(collection_name=schema_name, entities=payload) + schema_to_ids[schema_name] = mutation_result.primary_keys + + # Reorder ids to match the input entities + ordered_ids = [None] * len(entities) + for schema_name, primary_keys in schema_to_ids.items(): + for i, original_index in enumerate(schema_to_original_index[schema_name]): + ordered_ids[original_index] = primary_keys[i] + + return ordered_ids + def delete(self, entity: VectorSchemaBase): schema = entity.__class__ identifier_key = self._get_primary(schema) @@ -228,7 +304,13 @@ def _dict_representation(self, entity: VectorSchemaBase): value = getattr(entity, attribute_name) if value is not None: db_type, value = self._type_to_value(type_hint, value) - payload.append({"name": attribute_name, "type": db_type, "values": [value]}) + payload.append( + { + "name": attribute_name, + "type": db_type, + "values": [value] + } + ) return payload def _assert_embedding_validity(self, schema: Type[VectorSchemaBase]): diff --git a/vectordb_orm/backends/pinecone/pinecone.py b/vectordb_orm/backends/pinecone/pinecone.py index ec8b4f5..ad64b26 100644 --- a/vectordb_orm/backends/pinecone/pinecone.py +++ b/vectordb_orm/backends/pinecone/pinecone.py @@ -1,3 +1,4 @@ +from collections import defaultdict from logging import info from re import match as re_match from typing import Type @@ -5,14 +6,14 @@ import numpy as np import pinecone +from tqdm import tqdm -from vectordb_orm.attributes import AttributeCompare +from vectordb_orm.attributes import AttributeCompare, OperationType from vectordb_orm.backends.base import BackendBase from vectordb_orm.backends.pinecone.indexes import PineconeIndex from vectordb_orm.base import VectorSchemaBase from vectordb_orm.fields import EmbeddingField from vectordb_orm.results import QueryResult -from vectordb_orm.attributes import OperationType class PineconeBackend(BackendBase): @@ -99,32 +100,72 @@ def insert(self, entity: VectorSchemaBase) -> list: schema = entity.__class__ collection_name = self.transform_collection_name(schema.collection_name()) - schema = entity.__class__ embedding_field_key, _ = self._get_embedding_field(schema) primary_key = self._get_primary(schema) - id = uuid4().int & (1<<64)-1 - - embedding_value : np.ndarray = getattr(entity, embedding_field_key) - metadata_fields = { - key: getattr(entity, key) - for key in schema.__annotations__.keys() - if key not in {embedding_field_key, primary_key} - } + identifier = uuid4().int & (1<<64)-1 index = pinecone.Index(index_name=collection_name) index.upsert([ - ( - str(id), - embedding_value.tolist(), - { - **metadata_fields, - primary_key: id, - } - ), + self._prepare_upsert_tuple( + entity, + identifier, + embedding_field_key=embedding_field_key, + primary_key=primary_key, + ) ]) - return id + return identifier + + def insert_batch( + self, + entities: list[VectorSchemaBase], + show_progress: bool, + batch_size: int = 100 + ) -> list[int]: + identifiers = [ + uuid4().int & (1<<64)-1 + for _ in range(len(entities)) + ] + + # Group by the objects + schema_to_original_index = defaultdict(list) + schema_to_class = {} + + for i, entity in enumerate(entities): + schema = entity.__class__ + schema_name = schema.collection_name() + + schema_to_original_index[schema_name].append(i) + schema_to_class[schema_name] = schema + + for schema_name, original_indexes in schema_to_original_index.items(): + schema = schema_to_class[schema_name] + collection_name = self.transform_collection_name(schema.collection_name()) + + index = pinecone.Index(index_name=collection_name) + + embedding_field_key, _ = self._get_embedding_field(schema) + primary_key = self._get_primary(schema) + + for i in tqdm(range(0, len(original_indexes), batch_size)): + batch_indexes = original_indexes[i:i+batch_size] + batch_entities = [entities[index] for index in batch_indexes] + batch_identifiers = [identifiers[index] for index in batch_indexes] + + index.upsert( + [ + self._prepare_upsert_tuple( + entity, + identifier, + embedding_field_key=embedding_field_key, + primary_key=primary_key, + ) + for entity, identifier in zip(batch_entities, batch_identifiers) + ] + ) + + return identifiers def delete(self, entity: VectorSchemaBase): schema = entity.__class__ @@ -254,3 +295,31 @@ def _attribute_to_value_payload(self, schema: Type[VectorSchemaBase], attribute: return { operation_type_maps[attribute.op]: attribute.value } + + def _prepare_upsert_tuple( + self, + entity: VectorSchemaBase, + identifier: int, + embedding_field_key: str, + primary_key: str, + ): + """ + Formats a tuple for upsert + """ + schema = entity.__class__ + + embedding_value : np.ndarray = getattr(entity, embedding_field_key) + metadata_fields = { + key: getattr(entity, key) + for key in schema.__annotations__.keys() + if key not in {embedding_field_key, primary_key} + } + + return ( + str(identifier), + embedding_value.tolist(), + { + **metadata_fields, + primary_key: identifier, + } + ) diff --git a/vectordb_orm/session.py b/vectordb_orm/session.py index c4950da..5ee208b 100644 --- a/vectordb_orm/session.py +++ b/vectordb_orm/session.py @@ -27,11 +27,17 @@ def clear_collection(self, schema: Type[VectorSchemaBase]): def delete_collection(self, schema: Type[VectorSchemaBase]): return self.backend.delete_collection(schema) - def insert(self, obj: VectorSchemaBase) -> int: + def insert(self, obj: VectorSchemaBase) -> VectorSchemaBase: new_id = self.backend.insert(obj) obj.id = new_id return obj + def insert_batch(self, objs: list[VectorSchemaBase], show_progress: bool = False) -> list[VectorSchemaBase]: + new_ids = self.backend.insert_batch(objs, show_progress=show_progress) + for new_id, obj in zip(new_ids, objs): + obj.id = new_id + return objs + def delete(self, obj: VectorSchemaBase) -> None: if not obj.id: raise ValueError("Cannot delete object that hasn't been inserted into the database") diff --git a/vectordb_orm/tests/conftest.py b/vectordb_orm/tests/conftest.py index b62751a..032f24c 100644 --- a/vectordb_orm/tests/conftest.py +++ b/vectordb_orm/tests/conftest.py @@ -6,7 +6,8 @@ from pymilvus import Milvus, connections from vectordb_orm import MilvusBackend, PineconeBackend, VectorSession -from vectordb_orm.tests.models import MilvusBinaryEmbeddingObject, MilvusMyObject, PineconeMyObject +from vectordb_orm.tests.models import (MilvusBinaryEmbeddingObject, + MilvusMyObject, PineconeMyObject) @pytest.fixture() diff --git a/vectordb_orm/tests/milvus/test_indexes.py b/vectordb_orm/tests/milvus/test_indexes.py index 90f4b6f..031ad5c 100644 --- a/vectordb_orm/tests/milvus/test_indexes.py +++ b/vectordb_orm/tests/milvus/test_indexes.py @@ -7,7 +7,8 @@ from vectordb_orm import (EmbeddingField, PrimaryKeyField, VectorSchemaBase, VectorSession) from vectordb_orm.backends.milvus.indexes import (BINARY_INDEXES, - FLOATING_INDEXES, MilvusIndexBase) + FLOATING_INDEXES, + MilvusIndexBase) from vectordb_orm.backends.milvus.similarity import ( MilvusBinarySimilarityMetric, MilvusFloatSimilarityMetric) diff --git a/vectordb_orm/tests/milvus/test_milvus.py b/vectordb_orm/tests/milvus/test_milvus.py index 581047e..a8c7640 100644 --- a/vectordb_orm/tests/milvus/test_milvus.py +++ b/vectordb_orm/tests/milvus/test_milvus.py @@ -1,7 +1,9 @@ -from vectordb_orm.tests.models import MilvusBinaryEmbeddingObject +import numpy as np import pytest + from vectordb_orm import VectorSession -import numpy as np +from vectordb_orm.tests.models import MilvusBinaryEmbeddingObject + # @pierce 04-21- 2023: Currently flaky # https://github.com/piercefreeman/vectordb-orm/pull/5 diff --git a/vectordb_orm/tests/models.py b/vectordb_orm/tests/models.py index 5fe12f6..ad2e55f 100644 --- a/vectordb_orm/tests/models.py +++ b/vectordb_orm/tests/models.py @@ -1,8 +1,9 @@ import numpy as np from vectordb_orm import (ConsistencyType, EmbeddingField, Milvus_BIN_FLAT, - Milvus_IVF_FLAT, PineconeIndex, PrimaryKeyField, - VarCharField, VectorSchemaBase, PineconeSimilarityMetric) + Milvus_IVF_FLAT, PineconeIndex, + PineconeSimilarityMetric, PrimaryKeyField, + VarCharField, VectorSchemaBase) class MilvusMyObject(VectorSchemaBase): diff --git a/vectordb_orm/tests/test_base.py b/vectordb_orm/tests/test_base.py index 2248c26..2364bad 100644 --- a/vectordb_orm/tests/test_base.py +++ b/vectordb_orm/tests/test_base.py @@ -1,3 +1,5 @@ +from typing import Type + import numpy as np import pytest @@ -5,7 +7,6 @@ VectorSession) from vectordb_orm.backends.milvus.indexes import Milvus_IVF_FLAT from vectordb_orm.tests.conftest import SESSION_MODEL_PAIRS -from typing import Type @pytest.mark.parametrize("session,model", SESSION_MODEL_PAIRS) @@ -34,6 +35,18 @@ def test_insert_object(session: str, model: Type[VectorSchemaBase], request): result : model = results[0].result assert result.text == my_object.text +@pytest.mark.parametrize("session,model", SESSION_MODEL_PAIRS) +def test_insert_batch(session: str, model: Type[VectorSchemaBase], request): + session : VectorSession = request.getfixturevalue(session) + + obj1 = model(text='example1', embedding=np.array([1.0] * 128)) + obj2 = model(text='example2', embedding=np.array([2.0] * 128)) + obj3 = model(text='example3', embedding=np.array([3.0] * 128)) + + session.insert_batch([obj1, obj2, obj3]) + + for obj in [obj1, obj2, obj3]: + assert obj.id is not None @pytest.mark.parametrize("session,model", SESSION_MODEL_PAIRS) def test_delete_object(session: str, model: Type[VectorSchemaBase],request): diff --git a/vectordb_orm/tests/test_query.py b/vectordb_orm/tests/test_query.py index 220d496..6261642 100644 --- a/vectordb_orm/tests/test_query.py +++ b/vectordb_orm/tests/test_query.py @@ -1,9 +1,10 @@ +from typing import Type + import numpy as np import pytest -from vectordb_orm import VectorSession, VectorSchemaBase +from vectordb_orm import VectorSchemaBase, VectorSession from vectordb_orm.tests.conftest import SESSION_MODEL_PAIRS -from typing import Type @pytest.mark.parametrize("session,model", SESSION_MODEL_PAIRS)