Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: ChromaVectorStore can attempt to add in excess of chromadb batch… #8019

Merged
merged 5 commits into from
Oct 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 45 additions & 20 deletions llama_index/vector_stores/chroma.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Chroma vector store."""
import logging
import math
from typing import Any, Dict, List, Optional, cast
from typing import Any, Dict, Generator, List, Optional, cast

from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.schema import BaseNode, MetadataMode, TextNode
Expand Down Expand Up @@ -31,6 +31,24 @@ def _to_chroma_filter(standard_filters: MetadataFilters) -> dict:

import_err_msg = "`chromadb` package not found, please run `pip install chromadb`"

MAX_CHUNK_SIZE = 41665 # One less than the max chunk size for ChromaDB


def chunk_list(
lst: List[BaseNode], max_chunk_size: int
) -> Generator[List[BaseNode], None, None]:
"""Yield successive max_chunk_size-sized chunks from lst.

Args:
lst (List[BaseNode]): list of nodes with embeddings
max_chunk_size (int): max chunk size

Yields:
Generator[List[BaseNode], None, None]: list of nodes with embeddings
"""
for i in range(0, len(lst), max_chunk_size):
yield lst[i : i + max_chunk_size]


class ChromaVectorStore(BasePydanticVectorStore):
"""Chroma vector store.
Expand Down Expand Up @@ -129,27 +147,34 @@ def add(self, nodes: List[BaseNode]) -> List[str]:
if not self._collection:
raise ValueError("Collection not initialized")

embeddings = []
metadatas = []
ids = []
documents = []
for node in nodes:
embeddings.append(node.get_embedding())
metadatas.append(
node_to_metadata_dict(
node, remove_text=True, flat_metadata=self.flat_metadata
max_chunk_size = MAX_CHUNK_SIZE
node_chunks = chunk_list(nodes, max_chunk_size)

all_ids = []
for node_chunk in node_chunks:
embeddings = []
metadatas = []
ids = []
documents = []
for node in node_chunk:
embeddings.append(node.get_embedding())
metadatas.append(
node_to_metadata_dict(
node, remove_text=True, flat_metadata=self.flat_metadata
)
)
ids.append(node.node_id)
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))

self._collection.add(
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
documents=documents,
)
ids.append(node.node_id)
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))

self._collection.add(
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
documents=documents,
)
return ids
all_ids.extend(ids)

return all_ids

def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
Expand Down