Skip to content

Commit

Permalink
feat: Metadata persisted to protobuf binary file
Browse files Browse the repository at this point in the history
- Also added copy the existing metadata file on write to handle data corruption during writes
  • Loading branch information
tazarov committed May 1, 2024
1 parent 7caf2b7 commit 3d78a6a
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 157 deletions.
9 changes: 5 additions & 4 deletions chromadb/proto/coordinator_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

169 changes: 76 additions & 93 deletions chromadb/segment/impl/vector/local_persistent_hnsw.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import orjson as json
import os
import shutil
from uuid import UUID

from google.protobuf import message
from overrides import override
import pickle
from typing import Dict, List, Optional, Sequence, Set, cast
Expand All @@ -11,7 +10,7 @@

from chromadb.config import System
from chromadb.db.base import ParameterValue, get_sql
from chromadb.db.impl.sqlite import SqliteDB
from chromadb.proto.chroma_pb2 import LocalSegmentMetadataTuple, LocalSegmentMetadata
from chromadb.segment.impl.metadata.sqlite import _encode_seq_id, _decode_seq_id
from chromadb.segment.impl.vector.batch import Batch
from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams
Expand Down Expand Up @@ -79,93 +78,69 @@ def load_from_file(filename: str) -> "PersistentData":
ret = cast(PersistentData, pickle.load(f))
return ret

@staticmethod
def load_from_sysdb(db: SqliteDB, segment_id: UUID) -> "PersistentData":
t2 = Table("segment_metadata")
q2 = (
db.querybuilder()
.from_(t2)
.select(t2.key, t2.int_value, t2.str_value)
.where(t2.segment_id == ParameterValue(db.uuid_to_db(segment_id)))
def store_to_proto(self, metadata_file: str) -> None:
result = LocalSegmentMetadata(
tuples=[
LocalSegmentMetadataTuple(
embedding_id=_id,
hnsw_label=self.id_to_label[_id],
seq_id=self.id_to_seq_id[_id],
)
for _id in self.id_to_label
],
max_seq_id=self.max_seq_id,
total_elements_added=self.total_elements_added,
dimensionality=self.dimensionality,
)
sql2, params2 = get_sql(q2)
with db.tx() as cur:
result = cur.execute(sql2, params2).fetchall()
kdict = {r[0]: r[1] if r[1] is not None else r[2] for r in result}
_dimensionality = kdict.get("dimensionality")
_total_elements_added = kdict.get("total_elements_added")
_max_seq_id = kdict.get("max_seq_id")
id_label_seq_id_tuple_list = kdict.get("id_label_seq_id_tuple_list")
if (
_dimensionality is None
or _total_elements_added is None
or _max_seq_id is None
):
raise ValueError("Missing required metadata in segment_metadata")
if id_label_seq_id_tuple_list is not None:
tuple_list = json.loads(id_label_seq_id_tuple_list)
_id_to_label = {r[0]: r[1] for r in tuple_list}
_id_to_seq_id = {r[0]: r[2] for r in tuple_list}
_label_to_id = {r[1]: r[0] for r in tuple_list}
else:
raise ValueError("Missing required metadata in segment_metadata")

return PersistentData(
dimensionality=_dimensionality,
total_elements_added=_total_elements_added,
max_seq_id=_max_seq_id,
id_to_label=_id_to_label,
label_to_id=_label_to_id,
id_to_seq_id=_id_to_seq_id,
)
with open(metadata_file + ".new", "wb") as f:
f.write(result.SerializeToString())
# we copy only when the new file is written successfully
shutil.copy(metadata_file + ".new", metadata_file)
os.unlink(metadata_file + ".new")

def store_to_db(self, db: SqliteDB, segment_id: UUID) -> None:
with db.tx() as cur:
q1 = (
db.querybuilder()
.into(Table("segment_metadata"))
.columns("segment_id", "key", "int_value")
.insert(
ParameterValue(db.uuid_to_db(segment_id)),
"total_elements_added",
self.total_elements_added,
)
.insert(
ParameterValue(db.uuid_to_db(segment_id)),
"dimensionality",
self.dimensionality,
)
.insert(
ParameterValue(db.uuid_to_db(segment_id)),
"max_seq_id",
self.max_seq_id,
)
@staticmethod
def load_from_proto(metadata_file: str) -> "PersistentData":
"""Load persistent data from a protobuf file"""

def _load_from_file(metadata_file_to_load: str) -> LocalSegmentMetadata:
_result = LocalSegmentMetadata()
with open(metadata_file_to_load, "rb") as f:
_result.ParseFromString(f.read())
return _result

_new_metadata_file = metadata_file + ".new"
if os.path.exists(_new_metadata_file):
logger.warning(
f"Found new metadata file {metadata_file}.new, using it instead of {metadata_file}"
)
sql, params = get_sql(q1)
sql = sql.replace("INSERT", "INSERT OR REPLACE")
cur.execute(sql, params)
result = [
(_id, self.id_to_label[_id], self.id_to_seq_id[_id])
for _id in self.id_to_label
]
dumped_result = json.dumps(result)
q2 = (
db.querybuilder()
.into(Table("segment_metadata"))
.columns("segment_id", "key", "str_value")
.insert(
ParameterValue(db.uuid_to_db(segment_id)),
"id_label_seq_id_tuple_list",
ParameterValue(dumped_result),
try:
result = _load_from_file(_new_metadata_file)
except message.DecodeError:
logger.warning(
f"Failed to load metadata file {_new_metadata_file}, "
f"falling back to original file {metadata_file}"
)
)
sql, params = get_sql(q2)
sql = sql.replace("INSERT", "INSERT OR REPLACE")
cur.execute(sql, params)
result = _load_from_file(metadata_file)
else:
result = _load_from_file(metadata_file)

id_to_label = {r.embedding_id: r.hnsw_label for r in result.tuples}
id_to_seq_id = {r.embedding_id: r.seq_id for r in result.tuples}
label_to_id = {r.hnsw_label: r.embedding_id for r in result.tuples}

return PersistentData(
dimensionality=result.dimensionality,
total_elements_added=result.total_elements_added,
max_seq_id=result.max_seq_id,
id_to_label=id_to_label,
label_to_id=label_to_id,
id_to_seq_id=id_to_seq_id,
)


class PersistentLocalHnswSegment(LocalHnswSegment):
METADATA_FILE: str = "index_metadata.pickle"
LEGACY_METADATA_FILE: str = "index_metadata.pickle" # TODO remove in 0.5+
METADATA_FILE: str = "index_metadata.bin"
# How many records to add to index at once, we do this because crossing the python/c++ boundary is expensive (for add())
# When records are not added to the c++ index, they are buffered in memory and served
# via brute force search.
Expand Down Expand Up @@ -197,17 +172,19 @@ def __init__(self, system: System, segment: Segment):
os.makedirs(self._get_storage_folder(), exist_ok=True)
# Load persist data if it exists already, otherwise create it
if self._index_exists():
# migration from pickle file to sqlite
_migrated = False
if os.path.exists(self._get_metadata_file()):
# migration from pickle file to protobufs
_migrated = False # TODO remove in 0.5+
if os.path.exists(self._get_legacy_metadata_file()):
tmp_persist_data = PersistentData.load_from_file(
self._get_metadata_file()
self._get_legacy_metadata_file()
)
tmp_persist_data.store_to_db(self._db, self._id)
tmp_persist_data.store_to_proto(self._get_metadata_file())
_migrated = True
self._persist_data = PersistentData.load_from_sysdb(self._db, self._id)
if _migrated:
os.remove(self._get_metadata_file())
self._persist_data = PersistentData.load_from_proto(
self._get_metadata_file()
)
if _migrated: # TODO remove in 0.5+
os.remove(self._get_legacy_metadata_file())
self._dimensionality = self._persist_data.dimensionality
self._total_elements_added = self._persist_data.total_elements_added
self._max_seq_id = self._persist_data.max_seq_id
Expand Down Expand Up @@ -237,12 +214,18 @@ def propagate_collection_metadata(metadata: Metadata) -> Optional[Metadata]:

def _index_exists(self) -> bool:
"""Check if the index exists via the metadata file"""
return os.path.exists(self._get_metadata_file())
return os.path.exists(self._get_metadata_file()) or os.path.exists(
self._get_legacy_metadata_file()
)

def _get_metadata_file(self) -> str:
"""Get the metadata file path"""
return os.path.join(self._get_storage_folder(), self.METADATA_FILE)

def _get_legacy_metadata_file(self) -> str:
"""Get the metadata file path"""
return os.path.join(self._get_storage_folder(), self.LEGACY_METADATA_FILE)

def _get_storage_folder(self) -> str:
"""Get the storage folder path"""
folder = os.path.join(self._persist_directory, str(self._id))
Expand Down Expand Up @@ -317,7 +300,7 @@ def _persist(self) -> None:
sql, params = get_sql(q)
sql = sql.replace("INSERT", "INSERT OR REPLACE")
cur.execute(sql, params)
self._persist_data.store_to_db(self._db, self._id)
self._persist_data.store_to_proto(self._get_metadata_file())

@override
def max_seqid(self) -> SeqId:
Expand Down
Loading

0 comments on commit 3d78a6a

Please sign in to comment.