Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch insert #9

Merged
merged 5 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# vectordb-orm

`vectordb-orm` is an Object-Relational Mapping (ORM) library designed to work with vector databases. This project aims to provide a consistent and convenient interface for working with vector data, allowing you to interact with vector databases using familiar ORM concepts and syntax. Right now [Milvus](https://milvus.io/) and [Pinecone](https://www.pinecone.io/) are supported with more backend engines planned for the future.
`vectordb-orm` is an Object-Relational Mapping (ORM) library for vector databases. Define your data as objects and query for them using familiar SQL syntax, with all the added power of lighting fast vector search.

Right now [Milvus](https://milvus.io/) and [Pinecone](https://www.pinecone.io/) are supported with more backend engines planned for the future.

## Getting Started

Here are some simple examples demonstrating common behavior with vectordb-orm. First a note on structure. vectordb-orm is designed around the idea of a `Schema`, which is logically equivalent to a table in classic relational databases. This schema is marked up with python typehints that define the type of vector and metadata that will be stored alongisde the objects.
Here are some simple examples demonstrating common behavior with vectordb-orm. First a note on structure. vectordb-orm is designed around the idea of a `Schema`, which is logically equivalent to a table in classic relational databases. This schema is marked up with typehints that define the type of vector and metadata that will be stored alongisde the objects.

You create a class definition by subclassing `VectorSchemaBase` and providing typehints for the keys of your model, similar to pydantic. These fields also support custom initialization behavior if you want (or need) to modify their configuration options.

Expand Down Expand Up @@ -40,6 +42,26 @@ class MyObject(VectorSchemaBase):
embedding: np.ndarray = EmbeddingField(dim=128, index=PineconeIndex(metric_type=PineconeSimilarityMetric.COSINE))
```

### Indexing Data

To insert objects into the database, create a new instance of your object class and insert into the current session. The arguments to the init function mirror the typehinted schema that you defined above.

```python
obj = MyObject(text="my_text", embedding=np.array([1.0]*128))
session.insert(obj)
```

Once inserted, this object will be populated with a new `id` by the database engine. At this point it should be queryable (modulo some backends taking time for eventual consistency across different shards).

`vectordb-orm` also supports batch insertion. This is recommended in cases where you have a lot of data to insert at one time, since latencies can be significant on individual datapoints.

```python
obj = MyObject(text="my_text", embedding=np.array([1.0]*128))
session.insert_batch([obj], show_progress=True)
```

The optional `show_progress` allows you to show a progress bar to show the current status of the insertion and the estimated time remaining for the whole dataset.

### Querying Syntax

```python
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pymilvus = "^2.2.6"
protobuf = "^4.22.3"
numpy = "^1.24.2"
pinecone-client = "^2.2.1"
tqdm = "^4.65.0"


[tool.poetry.group.dev.dependencies]
Expand Down
4 changes: 4 additions & 0 deletions vectordb_orm/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def delete_collection(self, schema: Type[VectorSchemaBase]):
def insert(self, entity: VectorSchemaBase) -> int:
pass

@abstractmethod
def insert_batch(self, entities: list[VectorSchemaBase], show_progress: bool) -> list[int]:
pass

@abstractmethod
def delete(self, entity: VectorSchemaBase):
pass
Expand Down
84 changes: 83 additions & 1 deletion vectordb_orm/backends/milvus/milvus.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from logging import info
from typing import Any, Type, get_args, get_origin

Expand Down Expand Up @@ -72,6 +73,81 @@ def insert(self, entity: VectorSchemaBase) -> int:
mutation_result = self.client.insert(collection_name=entity.__class__.collection_name(), entities=entities)
return mutation_result.primary_keys[0]

def insert_batch(
self,
entities: list[VectorSchemaBase],
show_progress: bool,
) -> list[int]:
if show_progress:
raise ValueError("Milvus backend does not support batch insertion progress logging because it is done in one operation.")

# Group by the schema type since we allow for the insertion of multiple different schema
# `schema_to_entities` - input entities grouped by the schema name
# `schema_to_original_index` - since we switch to a per-schema representation, keep track of a mapping
# from the schema to the original index in `entities`
# `schema_to_ids` - map of schema to the resulting primary keys
schema_to_original_index = defaultdict(list)
schema_to_class = {}
schema_to_ids = {}

for i, entity in enumerate(entities):
schema = entity.__class__
schema_name = schema.collection_name()

schema_to_original_index[schema_name].append(i)
schema_to_class[schema_name] = schema

for schema_name, schema_indexes in schema_to_original_index.items():
schema = schema_to_class[schema_name]
schema_entities = [entities[index] for index in schema_indexes]

# The primary key should be null at this stage of things, so we ignore it
# during the insertion
ignore_keys = {
self._get_primary(schema)
}

# Group this schema's objects by their keys
by_key_values = defaultdict(list)
by_key_type = {}

for entity in schema_entities:
for attribute_name, type_hint in entity.__annotations__.items():
value = getattr(entity, attribute_name)
db_type, value = self._type_to_value(type_hint, value)
by_key_values[attribute_name].append(value)
by_key_type[attribute_name] = db_type

# Ensure each key in `schema_to_objects` matches the quantity of objects
# that should have been created. This *shouldn't* happen but it's possible
# some combination of programatically deleting attributes or annotations will
# lead to this case. We proactively raise an error because this could result in
# data corruption.
all_lengths = {len(values) for values in by_key_values.values()}
if len(all_lengths) > 1:
raise ValueError(f"Inserted objects don't align for schema `{schema_name}`")

payload = [
{
"name": attribute_name,
"type": by_key_type[attribute_name],
"values": values,
}
for attribute_name, values in by_key_values.items()
if attribute_name not in ignore_keys
]

mutation_result = self.client.insert(collection_name=schema_name, entities=payload)
schema_to_ids[schema_name] = mutation_result.primary_keys

# Reorder ids to match the input entities
ordered_ids = [None] * len(entities)
for schema_name, primary_keys in schema_to_ids.items():
for i, original_index in enumerate(schema_to_original_index[schema_name]):
ordered_ids[original_index] = primary_keys[i]

return ordered_ids

def delete(self, entity: VectorSchemaBase):
schema = entity.__class__
identifier_key = self._get_primary(schema)
Expand Down Expand Up @@ -228,7 +304,13 @@ def _dict_representation(self, entity: VectorSchemaBase):
value = getattr(entity, attribute_name)
if value is not None:
db_type, value = self._type_to_value(type_hint, value)
payload.append({"name": attribute_name, "type": db_type, "values": [value]})
payload.append(
{
"name": attribute_name,
"type": db_type,
"values": [value]
}
)
return payload

def _assert_embedding_validity(self, schema: Type[VectorSchemaBase]):
Expand Down
109 changes: 89 additions & 20 deletions vectordb_orm/backends/pinecone/pinecone.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from collections import defaultdict
from logging import info
from re import match as re_match
from typing import Type
from uuid import uuid4

import numpy as np
import pinecone
from tqdm import tqdm

from vectordb_orm.attributes import AttributeCompare
from vectordb_orm.attributes import AttributeCompare, OperationType
from vectordb_orm.backends.base import BackendBase
from vectordb_orm.backends.pinecone.indexes import PineconeIndex
from vectordb_orm.base import VectorSchemaBase
from vectordb_orm.fields import EmbeddingField
from vectordb_orm.results import QueryResult
from vectordb_orm.attributes import OperationType


class PineconeBackend(BackendBase):
Expand Down Expand Up @@ -99,32 +100,72 @@ def insert(self, entity: VectorSchemaBase) -> list:
schema = entity.__class__
collection_name = self.transform_collection_name(schema.collection_name())

schema = entity.__class__
embedding_field_key, _ = self._get_embedding_field(schema)
primary_key = self._get_primary(schema)

id = uuid4().int & (1<<64)-1

embedding_value : np.ndarray = getattr(entity, embedding_field_key)
metadata_fields = {
key: getattr(entity, key)
for key in schema.__annotations__.keys()
if key not in {embedding_field_key, primary_key}
}
identifier = uuid4().int & (1<<64)-1

index = pinecone.Index(index_name=collection_name)
index.upsert([
(
str(id),
embedding_value.tolist(),
{
**metadata_fields,
primary_key: id,
}
),
self._prepare_upsert_tuple(
entity,
identifier,
embedding_field_key=embedding_field_key,
primary_key=primary_key,
)
])

return id
return identifier

def insert_batch(
self,
entities: list[VectorSchemaBase],
show_progress: bool,
batch_size: int = 100
) -> list[int]:
identifiers = [
uuid4().int & (1<<64)-1
for _ in range(len(entities))
]

# Group by the objects
schema_to_original_index = defaultdict(list)
schema_to_class = {}

for i, entity in enumerate(entities):
schema = entity.__class__
schema_name = schema.collection_name()

schema_to_original_index[schema_name].append(i)
schema_to_class[schema_name] = schema

for schema_name, original_indexes in schema_to_original_index.items():
schema = schema_to_class[schema_name]
collection_name = self.transform_collection_name(schema.collection_name())

index = pinecone.Index(index_name=collection_name)

embedding_field_key, _ = self._get_embedding_field(schema)
primary_key = self._get_primary(schema)

for i in tqdm(range(0, len(original_indexes), batch_size)):
batch_indexes = original_indexes[i:i+batch_size]
batch_entities = [entities[index] for index in batch_indexes]
batch_identifiers = [identifiers[index] for index in batch_indexes]

index.upsert(
[
self._prepare_upsert_tuple(
entity,
identifier,
embedding_field_key=embedding_field_key,
primary_key=primary_key,
)
for entity, identifier in zip(batch_entities, batch_identifiers)
]
)

return identifiers

def delete(self, entity: VectorSchemaBase):
schema = entity.__class__
Expand Down Expand Up @@ -254,3 +295,31 @@ def _attribute_to_value_payload(self, schema: Type[VectorSchemaBase], attribute:
return {
operation_type_maps[attribute.op]: attribute.value
}

def _prepare_upsert_tuple(
self,
entity: VectorSchemaBase,
identifier: int,
embedding_field_key: str,
primary_key: str,
):
"""
Formats a tuple for upsert
"""
schema = entity.__class__

embedding_value : np.ndarray = getattr(entity, embedding_field_key)
metadata_fields = {
key: getattr(entity, key)
for key in schema.__annotations__.keys()
if key not in {embedding_field_key, primary_key}
}

return (
str(identifier),
embedding_value.tolist(),
{
**metadata_fields,
primary_key: identifier,
}
)
8 changes: 7 additions & 1 deletion vectordb_orm/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,17 @@ def clear_collection(self, schema: Type[VectorSchemaBase]):
def delete_collection(self, schema: Type[VectorSchemaBase]):
return self.backend.delete_collection(schema)

def insert(self, obj: VectorSchemaBase) -> int:
def insert(self, obj: VectorSchemaBase) -> VectorSchemaBase:
new_id = self.backend.insert(obj)
obj.id = new_id
return obj

def insert_batch(self, objs: list[VectorSchemaBase], show_progress: bool = False) -> list[VectorSchemaBase]:
new_ids = self.backend.insert_batch(objs, show_progress=show_progress)
for new_id, obj in zip(new_ids, objs):
obj.id = new_id
return objs

def delete(self, obj: VectorSchemaBase) -> None:
if not obj.id:
raise ValueError("Cannot delete object that hasn't been inserted into the database")
Expand Down
3 changes: 2 additions & 1 deletion vectordb_orm/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from pymilvus import Milvus, connections

from vectordb_orm import MilvusBackend, PineconeBackend, VectorSession
from vectordb_orm.tests.models import MilvusBinaryEmbeddingObject, MilvusMyObject, PineconeMyObject
from vectordb_orm.tests.models import (MilvusBinaryEmbeddingObject,
MilvusMyObject, PineconeMyObject)


@pytest.fixture()
Expand Down
3 changes: 2 additions & 1 deletion vectordb_orm/tests/milvus/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from vectordb_orm import (EmbeddingField, PrimaryKeyField, VectorSchemaBase,
VectorSession)
from vectordb_orm.backends.milvus.indexes import (BINARY_INDEXES,
FLOATING_INDEXES, MilvusIndexBase)
FLOATING_INDEXES,
MilvusIndexBase)
from vectordb_orm.backends.milvus.similarity import (
MilvusBinarySimilarityMetric, MilvusFloatSimilarityMetric)

Expand Down
6 changes: 4 additions & 2 deletions vectordb_orm/tests/milvus/test_milvus.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from vectordb_orm.tests.models import MilvusBinaryEmbeddingObject
import numpy as np
import pytest

from vectordb_orm import VectorSession
import numpy as np
from vectordb_orm.tests.models import MilvusBinaryEmbeddingObject


# @pierce 04-21- 2023: Currently flaky
# https://github.com/piercefreeman/vectordb-orm/pull/5
Expand Down
5 changes: 3 additions & 2 deletions vectordb_orm/tests/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np

from vectordb_orm import (ConsistencyType, EmbeddingField, Milvus_BIN_FLAT,
Milvus_IVF_FLAT, PineconeIndex, PrimaryKeyField,
VarCharField, VectorSchemaBase, PineconeSimilarityMetric)
Milvus_IVF_FLAT, PineconeIndex,
PineconeSimilarityMetric, PrimaryKeyField,
VarCharField, VectorSchemaBase)


class MilvusMyObject(VectorSchemaBase):
Expand Down
Loading