Skip to content

Commit

Permalink
[ENH]: Support numpy data types for embeddings (#1448)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - Embeddings now can be `ndarray` with `np.integer` or `np.floating`
	 - Mutating changes to convert ndarray happen within `Collection`
- Minor bugfix for arrays if `bool`s as the current check for
`instanceof` `int` or `float` results in success - `bool` is a subclass
of `int`

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python

## Documentation Changes
N/A

Supersedes #1014
  • Loading branch information
tazarov authored Dec 8, 2023
1 parent 9264346 commit 8875603
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 12 deletions.
61 changes: 52 additions & 9 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, Optional, Tuple, Any
from typing import TYPE_CHECKING, Optional, Tuple, Any, Union

import numpy as np
from pydantic import BaseModel, PrivateAttr

from uuid import UUID
Expand Down Expand Up @@ -102,7 +104,12 @@ def count(self) -> int:
def add(
self,
ids: OneOrMany[ID],
embeddings: Optional[OneOrMany[Embedding]] = None,
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
Expand Down Expand Up @@ -238,7 +245,12 @@ def peek(self, limit: int = 10) -> GetResult:

def query(
self,
query_embeddings: Optional[OneOrMany[Embedding]] = None,
query_embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
]
] = None,
query_texts: Optional[OneOrMany[Document]] = None,
query_images: Optional[OneOrMany[Image]] = None,
query_uris: Optional[OneOrMany[URI]] = None,
Expand Down Expand Up @@ -285,7 +297,11 @@ def query(
validate_where_document(where_document) if where_document else {}
)
valid_query_embeddings = (
validate_embeddings(maybe_cast_one_to_many_embedding(query_embeddings))
validate_embeddings(
self._normalize_embeddings(
maybe_cast_one_to_many_embedding(query_embeddings)
)
)
if query_embeddings is not None
else None
)
Expand Down Expand Up @@ -326,7 +342,6 @@ def query(

if "data" in include and "uris" not in include:
valid_include.append("uris")

query_results = self._client._query(
collection_id=self.id,
query_embeddings=valid_query_embeddings,
Expand Down Expand Up @@ -375,7 +390,12 @@ def modify(
def update(
self,
ids: OneOrMany[ID],
embeddings: Optional[OneOrMany[Embedding]] = None,
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
Expand Down Expand Up @@ -421,7 +441,12 @@ def update(
def upsert(
self,
ids: OneOrMany[ID],
embeddings: Optional[OneOrMany[Embedding]] = None,
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
Expand Down Expand Up @@ -495,7 +520,12 @@ def delete(
def _validate_embedding_set(
self,
ids: OneOrMany[ID],
embeddings: Optional[OneOrMany[Embedding]],
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
]
],
metadatas: Optional[OneOrMany[Metadata]],
documents: Optional[OneOrMany[Document]],
images: Optional[OneOrMany[Image]] = None,
Expand All @@ -511,7 +541,9 @@ def _validate_embedding_set(
]:
valid_ids = validate_ids(maybe_cast_one_to_many_ids(ids))
valid_embeddings = (
validate_embeddings(maybe_cast_one_to_many_embedding(embeddings))
validate_embeddings(
self._normalize_embeddings(maybe_cast_one_to_many_embedding(embeddings))
)
if embeddings is not None
else None
)
Expand Down Expand Up @@ -578,6 +610,17 @@ def _validate_embedding_set(
valid_uris,
)

@staticmethod
def _normalize_embeddings(
embeddings: Union[
OneOrMany[Embedding],
OneOrMany[np.ndarray],
]
) -> Embeddings:
if isinstance(embeddings, np.ndarray):
return embeddings.tolist()
return embeddings

def _embed(self, input: Any) -> Embeddings:
if self._embedding_function is None:
raise ValueError(
Expand Down
7 changes: 6 additions & 1 deletion chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,12 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings:
f"Expected each embedding in the embeddings to be a list, got {embeddings}"
)
for embedding in embeddings:
if not all([isinstance(value, (int, float)) for value in embedding]):
if not all(
[
isinstance(value, (int, float)) and not isinstance(value, bool)
for value in embedding
]
):
raise ValueError(
f"Expected each value in the embedding to be a int or float, got {embeddings}"
)
Expand Down
12 changes: 12 additions & 0 deletions chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,18 @@ def create_embeddings(
return embeddings


def create_embeddings_ndarray(
dim: int,
count: int,
dtype: npt.DTypeLike,
) -> np.typing.NDArray[Any]:
return np.random.uniform(
low=-1.0,
high=1.0,
size=(count, dim),
).astype(dtype)


class hashing_embedding_function(types.EmbeddingFunction[Documents]):
def __init__(self, dim: int, dtype: npt.DTypeLike) -> None:
self.dim = dim
Expand Down
56 changes: 54 additions & 2 deletions chromadb/test/property/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pytest
import logging
import hypothesis.strategies as st
from typing import Dict, Set, cast, Union, DefaultDict
from hypothesis import given
from typing import Dict, Set, cast, Union, DefaultDict, Any, List
from dataclasses import dataclass
from chromadb.api.types import ID, Include, IDs
from chromadb.api.types import ID, Include, IDs, validate_embeddings
import chromadb.errors as errors
from chromadb.api import ServerAPI
from chromadb.api.models.Collection import Collection
Expand Down Expand Up @@ -403,3 +404,54 @@ def test_delete_success(api: ServerAPI, kwargs: dict):
coll = api.create_collection(name="foo")
# Should not raise
coll.delete(**kwargs)


@given(supported_types=st.sampled_from([np.float32, np.int32, np.int64, int, float]))
def test_autocasting_validate_embeddings_for_compatible_types(
supported_types: List[Any],
) -> None:
embds = strategies.create_embeddings(10, 10, supported_types)
validated_embeddings = validate_embeddings(Collection._normalize_embeddings(embds))
assert all(
[
isinstance(value, list)
and all(
[
isinstance(vec, (int, float)) and not isinstance(vec, bool)
for vec in value
]
)
for value in validated_embeddings
]
)


@given(supported_types=st.sampled_from([np.float32, np.int32, np.int64, int, float]))
def test_autocasting_validate_embeddings_with_ndarray(
supported_types: List[Any],
) -> None:
embds = strategies.create_embeddings_ndarray(10, 10, supported_types)
validated_embeddings = validate_embeddings(Collection._normalize_embeddings(embds))
assert all(
[
isinstance(value, list)
and all(
[
isinstance(vec, (int, float)) and not isinstance(vec, bool)
for vec in value
]
)
for value in validated_embeddings
]
)


@given(unsupported_types=st.sampled_from([str, bool]))
def test_autocasting_validate_embeddings_incompatible_types(
unsupported_types: List[Any],
) -> None:
embds = strategies.create_embeddings(10, 10, unsupported_types)
with pytest.raises(ValueError) as e:
validate_embeddings(Collection._normalize_embeddings(embds))

assert "Expected each value in the embedding to be a int or float" in str(e)
1 change: 1 addition & 0 deletions chromadb/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Union, Sequence, Dict, Mapping, List

from typing_extensions import Literal, TypedDict, TypeVar
from uuid import UUID
from enum import Enum
Expand Down

0 comments on commit 8875603

Please sign in to comment.