Skip to content

Commit

Permalink
feat: Simpler and more elegant way to deal with the problem (1-liner)
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed May 3, 2024
1 parent 0ac51ac commit 8f661b6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 30 deletions.
11 changes: 3 additions & 8 deletions chromadb/segment/impl/vector/local_persistent_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,21 +242,16 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
self._max_seq_id = max(self._max_seq_id, record["log_offset"])
id = record["operation_record"]["id"]
op = record["operation_record"]["operation"]
exists_in_index = self._id_to_label.get(
id, None
) is not None or self._brute_force_index.has_id(id)
exists_in_index = id not in self._curr_batch._deleted_ids and (
self._brute_force_index.has_id(id) or id in self._id_to_label.keys()
)
exists_in_bf_index = self._brute_force_index.has_id(id)

if op == Operation.DELETE:
if exists_in_index:
self._curr_batch.apply(record)
if exists_in_bf_index:
self._brute_force_index.delete([record])
else:
_label = self._id_to_label.pop(id)
self._label_to_id.pop(_label)
self._id_to_seq_id.pop(id)
self._index.mark_deleted(_label)
else:
logger.warning(f"Delete of nonexisting embedding ID: {id}")

Expand Down
23 changes: 1 addition & 22 deletions chromadb/test/property/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from hypothesis import given
from typing import Dict, Set, cast, Union, DefaultDict, Any, List
from dataclasses import dataclass
import random

from chromadb.api.fastapi import FastAPI
from chromadb.api.types import ID, Include, IDs, validate_embeddings
Expand Down Expand Up @@ -216,26 +215,6 @@ def upsert_embeddings(self, record_set: strategies.RecordSet) -> None:
self.collection.upsert(**record_set)
self._upsert_embeddings(record_set)

@precondition(
lambda self: "hnsw:batch_size" in self._metadata
and len(self.record_set_state["ids"]) >= self._metadata["hnsw:batch_size"]
)
@rule()
def swap_embeddings(self) -> None:
trace("swap embeddings")
docs = self.collection.get(include=["embeddings", "documents", "metadatas"])
ids_to_swap = random.sample(docs["ids"], min(5, len(docs["ids"])))
indices_to_swap = [docs["ids"].index(id) for id in ids_to_swap]
record_set = {
"ids": [docs["ids"][i] for i in indices_to_swap],
"metadatas": [docs["metadatas"][i] for i in indices_to_swap],
"documents": [docs["documents"][i] for i in indices_to_swap],
"embeddings": [docs["embeddings"][i] for i in indices_to_swap],
}
self.collection.delete(ids=ids_to_swap)
self.collection.add(**record_set)
self._upsert_embeddings(record_set)

@invariant()
def count(self) -> None:
invariants.count(
Expand Down Expand Up @@ -528,7 +507,7 @@ def batching_params(draw: st.DrawFn) -> BatchParams:


@given(batching_params=batching_params())
def test_get_vector(batching_params: BatchParams, api: ServerAPI) -> None:
def test_batching(batching_params: BatchParams, api: ServerAPI) -> None:
error_distribution = {"IndexError": 0, "TypeError": 0, "NoError": 0}
rounds = 100
if isinstance(api, FastAPI) or not api.get_settings().is_persistent:
Expand Down

0 comments on commit 8f661b6

Please sign in to comment.