Skip to content

Commit

Permalink
feat: PR redone
Browse files Browse the repository at this point in the history
- Enabled FK constraints on all connections
- Fixed FK for segments to collections
- Implemented a new mechanism for disabling/enabling FKs for migration files (required to retrospectively apply it to tenants and databases migration as it was breaking FK constraints)
- Removed unnecessary embedding metadata cleanup (taken care of by FKs)
- Added utils lib to generate correct (topologically sorted in reverse) DROP statements for tables according FK constraints
- Fixed client_test.py failing test - the test server dir was not removed so subsequent tests were failing
- Fixed test_segment_manager.py where collection names were causing FK constraint failures
  • Loading branch information
tazarov committed Feb 12, 2024
1 parent 01369af commit 466dbae
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 33 deletions.
6 changes: 3 additions & 3 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,13 +335,13 @@ def delete_collection(
)

if existing:
self._sysdb.delete_collection(
existing[0]["id"], tenant=tenant, database=database
)
for s in self._manager.delete_segments(existing[0]["id"]):
self._sysdb.delete_segment(s)
if existing and existing[0]["id"] in self._collection_cache:
del self._collection_cache[existing[0]["id"]]
self._sysdb.delete_collection(
existing[0]["id"], tenant=tenant, database=database
)
else:
raise ValueError(f"Collection {name} does not exist.")

Expand Down
22 changes: 18 additions & 4 deletions chromadb/db/impl/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pathlib import Path

from chromadb.db.impl.sqlite_pool import Connection, LockPool, PerThreadPool, Pool
from chromadb.db.impl.sqlite_utils import get_drop_order
from chromadb.db.migrations import MigratableDB, Migration
from chromadb.config import System, Settings
import chromadb.db.base as base
Expand Down Expand Up @@ -34,6 +37,7 @@ def __init__(self, conn_pool: Pool, stack: local):
@override
def __enter__(self) -> base.Cursor:
if len(self._tx_stack.stack) == 0:
self._conn.execute("PRAGMA foreign_keys = ON")
self._conn.execute("BEGIN;")
self._tx_stack.stack.append(self)
return self._conn.cursor() # type: ignore
Expand Down Expand Up @@ -138,15 +142,16 @@ def reset_state(self) -> None:
"Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted."
)
with self.tx() as cur:
# Drop all tables
cur.execute(
"""
SELECT name FROM sqlite_master
WHERE type='table'
"""
)
for row in cur.fetchall():
cur.execute(f"DROP TABLE IF EXISTS {row[0]}")
drop_statement = ""
for t in get_drop_order(cur):
drop_statement += f"DROP TABLE IF EXISTS {t};\n"
cur.executescript(drop_statement)
self._conn_pool.close()
self.start()
super().reset_state()
Expand Down Expand Up @@ -217,7 +222,16 @@ def db_migrations(self, dir: Traversable) -> Sequence[Migration]:

@override
def apply_migration(self, cur: base.Cursor, migration: Migration) -> None:
cur.executescript(migration["sql"])
if any(item.name == f".{migration['filename']}.disable_fk"
for traversable in self.migration_dirs()
for item in traversable.iterdir() if item.is_file()):
cur.executescript(
"PRAGMA foreign_keys = OFF;\n"
+ migration["sql"]
+ ";\nPRAGMA foreign_keys = ON;"
)
else:
cur.executescript(migration["sql"])
cur.execute(
"""
INSERT INTO migrations (dir, version, filename, sql, hash)
Expand Down
44 changes: 44 additions & 0 deletions chromadb/db/impl/sqlite_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from collections import defaultdict, deque
from graphlib import TopologicalSorter
from typing import List, Dict

from chromadb.db.base import Cursor


def fetch_tables(cursor: Cursor) -> List[str]:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
return [row[0] for row in cursor.fetchall()]


def fetch_foreign_keys(cursor: Cursor, table_name: str) -> List[str]:
cursor.execute(f"PRAGMA foreign_key_list({table_name});")
return [row[2] for row in cursor.fetchall()] # Table being referenced


def build_dependency_graph(tables: List[str], cursor: Cursor) -> Dict[str, List[str]]:
graph = defaultdict(list)
for table in tables:
foreign_keys = fetch_foreign_keys(cursor, table)
for fk_table in foreign_keys:
graph[table].append(fk_table)
if not foreign_keys and table not in graph.keys():
graph[table] = []

return graph


def topological_sort(graph: Dict[str, List[str]]) -> List[str]:
ts = TopologicalSorter(graph)
# Reverse the order since TopologicalSorter gives the order of dependencies,
# but we want to drop tables in reverse dependency order
return list(ts.static_order())[::-1]


def get_drop_order(cursor: Cursor) -> List[str]:
tables = fetch_tables(cursor)
filtered_tables = [
table for table in tables if not table.startswith("embedding_fulltext_search_")
]
graph = build_dependency_graph(filtered_tables, cursor)
drop_order = topological_sort(graph)
return drop_order
22 changes: 22 additions & 0 deletions chromadb/migrations/metadb/00006-em-fk.sqlite.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
-- Disable foreign key constraints to us to update the segments table
PRAGMA foreign_keys = OFF;

CREATE TABLE embedding_metadata_temp (
id INTEGER REFERENCES embeddings(id) ON DELETE CASCADE NOT NULL,
key TEXT NOT NULL,
string_value TEXT,
int_value INTEGER,
float_value REAL,
bool_value INTEGER,
PRIMARY KEY (id, key)
);

INSERT INTO embedding_metadata_temp
SELECT id, key, string_value, int_value, float_value, bool_value
FROM embedding_metadata;

DROP TABLE embedding_metadata;

ALTER TABLE embedding_metadata_temp RENAME TO embedding_metadata;

PRAGMA foreign_keys = ON;
Empty file.
Empty file.
13 changes: 13 additions & 0 deletions chromadb/migrations/sysdb/00006-segments-fk.sqlite.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
CREATE TABLE segments_temp (
id TEXT PRIMARY KEY,
type TEXT NOT NULL,
scope TEXT NOT NULL,
topic TEXT,
collection TEXT REFERENCES collections(id)
);

INSERT INTO segments_temp SELECT * FROM segments;

DROP TABLE segments;

ALTER TABLE segments_temp RENAME TO segments;
14 changes: 0 additions & 14 deletions chromadb/segment/impl/metadata/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,20 +416,6 @@ def _delete_record(self, cur: Cursor, record: EmbeddingRecord) -> None:
result = cur.execute(sql, params).fetchone()
if result is None:
logger.warning(f"Delete of nonexisting embedding ID: {record['id']}")
else:
id = result[0]

# Manually delete metadata; cannot use cascade because
# that triggers on replace
metadata_t = Table("embedding_metadata")
q = (
self._db.querybuilder()
.from_(metadata_t)
.where(metadata_t.id == ParameterValue(id))
.delete()
)
sql, params = get_sql(q)
cur.execute(sql, params)

@trace_method("SqliteMetadataSegment._update_record", OpenTelemetryGranularity.ALL)
def _update_record(self, cur: Cursor, record: EmbeddingRecord) -> None:
Expand Down
35 changes: 23 additions & 12 deletions chromadb/test/property/test_segment_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
MultipleResults,
)
from typing import Dict
from chromadb.segment import (
VectorReader
)
from chromadb.segment import VectorReader
from chromadb.segment import SegmentManager

from chromadb.segment.impl.manager.local import LocalSegmentManager
Expand All @@ -30,6 +28,7 @@
# Memory limit use for testing
memory_limit = 100


# Helper class to keep tract of the last use id
class LastUse:
def __init__(self, n: int):
Expand Down Expand Up @@ -72,15 +71,18 @@ def last_queried_segments_should_be_in_cache(self):
index = 0
for id in reversed(self.last_use.store):
cache_sum += self.collection_size_store[id]
if cache_sum >= memory_limit and index is not 0:
if cache_sum >= memory_limit and index != 0:
break
assert id in self.segment_manager.segment_cache[SegmentScope.VECTOR].cache
index += 1

@invariant()
@precondition(lambda self: self.system.settings.is_persistent is True)
def cache_should_not_be_bigger_than_settings(self):
segment_sizes = {id: self.collection_size_store[id] for id in self.segment_manager.segment_cache[SegmentScope.VECTOR].cache}
segment_sizes = {
id: self.collection_size_store[id]
for id in self.segment_manager.segment_cache[SegmentScope.VECTOR].cache
}
total_size = sum(segment_sizes.values())
if len(segment_sizes) != 1:
assert total_size <= memory_limit
Expand All @@ -95,8 +97,12 @@ def initialize(self) -> None:
@rule(target=collections, coll=strategies.collections())
@precondition(lambda self: self.collection_created_counter <= 50)
def create_segment(
self, coll: strategies.Collection
self, coll: strategies.Collection
) -> MultipleResults[strategies.Collection]:
coll.name = f"{coll.name}_{uuid.uuid4()}"
self.sysdb.create_collection(
name=coll.name, id=coll.id, metadata=coll.metadata, dimension=coll.dimension
)
segments = self.segment_manager.create_segments(asdict(coll))
for segment in segments:
self.sysdb.create_segment(segment)
Expand All @@ -107,22 +113,27 @@ def create_segment(

@rule(coll=collections)
def get_segment(self, coll: strategies.Collection) -> None:
segment = self.segment_manager.get_segment(collection_id=coll.id, type=VectorReader)
segment = self.segment_manager.get_segment(
collection_id=coll.id, type=VectorReader
)
self.last_use.add(coll.id)
assert segment is not None


@staticmethod
def mock_directory_size(directory: str):
path_id = directory.split("/").pop()
collection_id = SegmentManagerStateMachine.segment_collection[uuid.UUID(path_id)]
collection_id = SegmentManagerStateMachine.segment_collection[
uuid.UUID(path_id)
]
return SegmentManagerStateMachine.collection_size_store[collection_id]


@patch('chromadb.segment.impl.manager.local.get_directory_size', SegmentManagerStateMachine.mock_directory_size)
@patch(
"chromadb.segment.impl.manager.local.get_directory_size",
SegmentManagerStateMachine.mock_directory_size,
)
def test_segment_manager(caplog: pytest.LogCaptureFixture, system: System) -> None:
system.settings.chroma_memory_limit_bytes = memory_limit
system.settings.chroma_segment_cache_policy = "LRU"

run_state_machine_as_test(
lambda: SegmentManagerStateMachine(system=system))
run_state_machine_as_test(lambda: SegmentManagerStateMachine(system=system))
2 changes: 2 additions & 0 deletions chromadb/test/test_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import shutil
from typing import Generator
from unittest.mock import patch
import chromadb
Expand All @@ -17,6 +18,7 @@ def ephemeral_api() -> Generator[ClientAPI, None, None]:

@pytest.fixture
def persistent_api() -> Generator[ClientAPI, None, None]:
shutil.rmtree(tempfile.gettempdir() + "/test_server", ignore_errors=True)
client = chromadb.PersistentClient(
path=tempfile.gettempdir() + "/test_server",
)
Expand Down

0 comments on commit 466dbae

Please sign in to comment.