Skip to content

Commit

Permalink
Fix: ChromaVectorStore can attempt to add in excess of chromadb batch… (
Browse files Browse the repository at this point in the history
  • Loading branch information
Brad-Edwards authored Oct 9, 2023
1 parent 6c60b86 commit ec7f434
Showing 1 changed file with 45 additions and 20 deletions.
65 changes: 45 additions & 20 deletions llama_index/vector_stores/chroma.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Chroma vector store."""
import logging
import math
from typing import Any, Dict, List, Optional, cast
from typing import Any, Dict, Generator, List, Optional, cast

from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.schema import BaseNode, MetadataMode, TextNode
Expand Down Expand Up @@ -31,6 +31,24 @@ def _to_chroma_filter(standard_filters: MetadataFilters) -> dict:

import_err_msg = "`chromadb` package not found, please run `pip install chromadb`"

MAX_CHUNK_SIZE = 41665 # One less than the max chunk size for ChromaDB


def chunk_list(
lst: List[BaseNode], max_chunk_size: int
) -> Generator[List[BaseNode], None, None]:
"""Yield successive max_chunk_size-sized chunks from lst.
Args:
lst (List[BaseNode]): list of nodes with embeddings
max_chunk_size (int): max chunk size
Yields:
Generator[List[BaseNode], None, None]: list of nodes with embeddings
"""
for i in range(0, len(lst), max_chunk_size):
yield lst[i : i + max_chunk_size]


class ChromaVectorStore(BasePydanticVectorStore):
"""Chroma vector store.
Expand Down Expand Up @@ -129,27 +147,34 @@ def add(self, nodes: List[BaseNode]) -> List[str]:
if not self._collection:
raise ValueError("Collection not initialized")

embeddings = []
metadatas = []
ids = []
documents = []
for node in nodes:
embeddings.append(node.get_embedding())
metadatas.append(
node_to_metadata_dict(
node, remove_text=True, flat_metadata=self.flat_metadata
max_chunk_size = MAX_CHUNK_SIZE
node_chunks = chunk_list(nodes, max_chunk_size)

all_ids = []
for node_chunk in node_chunks:
embeddings = []
metadatas = []
ids = []
documents = []
for node in node_chunk:
embeddings.append(node.get_embedding())
metadatas.append(
node_to_metadata_dict(
node, remove_text=True, flat_metadata=self.flat_metadata
)
)
ids.append(node.node_id)
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))

self._collection.add(
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
documents=documents,
)
ids.append(node.node_id)
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))

self._collection.add(
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
documents=documents,
)
return ids
all_ids.extend(ids)

return all_ids

def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
Expand Down

0 comments on commit ec7f434

Please sign in to comment.