From 466dbaea5b30981fcd20f431d72d5af40674f5d8 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Mon, 12 Feb 2024 11:18:26 +0200 Subject: [PATCH] feat: PR redone - Enabled FK constraints on all connections - Fixed FK for segments to collections - Implemented a new mechanism for disabling/enabling FKs for migration files (required to retrospectively apply it to tenants and databases migration as it was breaking FK constraints) - Removed unnecessary embedding metadata cleanup (taken care of by FKs) - Added utils lib to generate correct (topologically sorted in reverse) DROP statements for tables according FK constraints - Fixed client_test.py failing test - the test server dir was not removed so subsequent tests were failing - Fixed test_segment_manager.py where collection names were causing FK constraint failures --- chromadb/api/segment.py | 6 +-- chromadb/db/impl/sqlite.py | 22 ++++++++-- chromadb/db/impl/sqlite_utils.py | 44 +++++++++++++++++++ .../migrations/metadb/00006-em-fk.sqlite.sql | 22 ++++++++++ ...04-tenants-databases.sqlite.sql.disable_fk | 0 .../.00006-segments-fk.sqlite.sql.disable_fk | 0 .../sysdb/00006-segments-fk.sqlite.sql | 13 ++++++ chromadb/segment/impl/metadata/sqlite.py | 14 ------ .../test/property/test_segment_manager.py | 35 ++++++++++----- chromadb/test/test_client.py | 2 + 10 files changed, 125 insertions(+), 33 deletions(-) create mode 100644 chromadb/db/impl/sqlite_utils.py create mode 100644 chromadb/migrations/metadb/00006-em-fk.sqlite.sql create mode 100644 chromadb/migrations/sysdb/.00004-tenants-databases.sqlite.sql.disable_fk create mode 100644 chromadb/migrations/sysdb/.00006-segments-fk.sqlite.sql.disable_fk create mode 100644 chromadb/migrations/sysdb/00006-segments-fk.sqlite.sql diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 72df138d9bec..019d8a87948e 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -335,13 +335,13 @@ def delete_collection( ) if existing: - self._sysdb.delete_collection( - existing[0]["id"], tenant=tenant, database=database - ) for s in self._manager.delete_segments(existing[0]["id"]): self._sysdb.delete_segment(s) if existing and existing[0]["id"] in self._collection_cache: del self._collection_cache[existing[0]["id"]] + self._sysdb.delete_collection( + existing[0]["id"], tenant=tenant, database=database + ) else: raise ValueError(f"Collection {name} does not exist.") diff --git a/chromadb/db/impl/sqlite.py b/chromadb/db/impl/sqlite.py index c7cdb3063246..3db807d79521 100644 --- a/chromadb/db/impl/sqlite.py +++ b/chromadb/db/impl/sqlite.py @@ -1,4 +1,7 @@ +from pathlib import Path + from chromadb.db.impl.sqlite_pool import Connection, LockPool, PerThreadPool, Pool +from chromadb.db.impl.sqlite_utils import get_drop_order from chromadb.db.migrations import MigratableDB, Migration from chromadb.config import System, Settings import chromadb.db.base as base @@ -34,6 +37,7 @@ def __init__(self, conn_pool: Pool, stack: local): @override def __enter__(self) -> base.Cursor: if len(self._tx_stack.stack) == 0: + self._conn.execute("PRAGMA foreign_keys = ON") self._conn.execute("BEGIN;") self._tx_stack.stack.append(self) return self._conn.cursor() # type: ignore @@ -138,15 +142,16 @@ def reset_state(self) -> None: "Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted." ) with self.tx() as cur: - # Drop all tables cur.execute( """ SELECT name FROM sqlite_master WHERE type='table' """ ) - for row in cur.fetchall(): - cur.execute(f"DROP TABLE IF EXISTS {row[0]}") + drop_statement = "" + for t in get_drop_order(cur): + drop_statement += f"DROP TABLE IF EXISTS {t};\n" + cur.executescript(drop_statement) self._conn_pool.close() self.start() super().reset_state() @@ -217,7 +222,16 @@ def db_migrations(self, dir: Traversable) -> Sequence[Migration]: @override def apply_migration(self, cur: base.Cursor, migration: Migration) -> None: - cur.executescript(migration["sql"]) + if any(item.name == f".{migration['filename']}.disable_fk" + for traversable in self.migration_dirs() + for item in traversable.iterdir() if item.is_file()): + cur.executescript( + "PRAGMA foreign_keys = OFF;\n" + + migration["sql"] + + ";\nPRAGMA foreign_keys = ON;" + ) + else: + cur.executescript(migration["sql"]) cur.execute( """ INSERT INTO migrations (dir, version, filename, sql, hash) diff --git a/chromadb/db/impl/sqlite_utils.py b/chromadb/db/impl/sqlite_utils.py new file mode 100644 index 000000000000..66265f8c8dc2 --- /dev/null +++ b/chromadb/db/impl/sqlite_utils.py @@ -0,0 +1,44 @@ +from collections import defaultdict, deque +from graphlib import TopologicalSorter +from typing import List, Dict + +from chromadb.db.base import Cursor + + +def fetch_tables(cursor: Cursor) -> List[str]: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") + return [row[0] for row in cursor.fetchall()] + + +def fetch_foreign_keys(cursor: Cursor, table_name: str) -> List[str]: + cursor.execute(f"PRAGMA foreign_key_list({table_name});") + return [row[2] for row in cursor.fetchall()] # Table being referenced + + +def build_dependency_graph(tables: List[str], cursor: Cursor) -> Dict[str, List[str]]: + graph = defaultdict(list) + for table in tables: + foreign_keys = fetch_foreign_keys(cursor, table) + for fk_table in foreign_keys: + graph[table].append(fk_table) + if not foreign_keys and table not in graph.keys(): + graph[table] = [] + + return graph + + +def topological_sort(graph: Dict[str, List[str]]) -> List[str]: + ts = TopologicalSorter(graph) + # Reverse the order since TopologicalSorter gives the order of dependencies, + # but we want to drop tables in reverse dependency order + return list(ts.static_order())[::-1] + + +def get_drop_order(cursor: Cursor) -> List[str]: + tables = fetch_tables(cursor) + filtered_tables = [ + table for table in tables if not table.startswith("embedding_fulltext_search_") + ] + graph = build_dependency_graph(filtered_tables, cursor) + drop_order = topological_sort(graph) + return drop_order diff --git a/chromadb/migrations/metadb/00006-em-fk.sqlite.sql b/chromadb/migrations/metadb/00006-em-fk.sqlite.sql new file mode 100644 index 000000000000..0fc9c46cb86f --- /dev/null +++ b/chromadb/migrations/metadb/00006-em-fk.sqlite.sql @@ -0,0 +1,22 @@ +-- Disable foreign key constraints to us to update the segments table +PRAGMA foreign_keys = OFF; + +CREATE TABLE embedding_metadata_temp ( + id INTEGER REFERENCES embeddings(id) ON DELETE CASCADE NOT NULL, + key TEXT NOT NULL, + string_value TEXT, + int_value INTEGER, + float_value REAL, + bool_value INTEGER, + PRIMARY KEY (id, key) +); + +INSERT INTO embedding_metadata_temp +SELECT id, key, string_value, int_value, float_value, bool_value +FROM embedding_metadata; + +DROP TABLE embedding_metadata; + +ALTER TABLE embedding_metadata_temp RENAME TO embedding_metadata; + +PRAGMA foreign_keys = ON; diff --git a/chromadb/migrations/sysdb/.00004-tenants-databases.sqlite.sql.disable_fk b/chromadb/migrations/sysdb/.00004-tenants-databases.sqlite.sql.disable_fk new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/chromadb/migrations/sysdb/.00006-segments-fk.sqlite.sql.disable_fk b/chromadb/migrations/sysdb/.00006-segments-fk.sqlite.sql.disable_fk new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/chromadb/migrations/sysdb/00006-segments-fk.sqlite.sql b/chromadb/migrations/sysdb/00006-segments-fk.sqlite.sql new file mode 100644 index 000000000000..29a72f845bf6 --- /dev/null +++ b/chromadb/migrations/sysdb/00006-segments-fk.sqlite.sql @@ -0,0 +1,13 @@ +CREATE TABLE segments_temp ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + scope TEXT NOT NULL, + topic TEXT, + collection TEXT REFERENCES collections(id) +); + +INSERT INTO segments_temp SELECT * FROM segments; + +DROP TABLE segments; + +ALTER TABLE segments_temp RENAME TO segments; diff --git a/chromadb/segment/impl/metadata/sqlite.py b/chromadb/segment/impl/metadata/sqlite.py index 2e5af88b0d05..db950ab0e2b3 100644 --- a/chromadb/segment/impl/metadata/sqlite.py +++ b/chromadb/segment/impl/metadata/sqlite.py @@ -416,20 +416,6 @@ def _delete_record(self, cur: Cursor, record: EmbeddingRecord) -> None: result = cur.execute(sql, params).fetchone() if result is None: logger.warning(f"Delete of nonexisting embedding ID: {record['id']}") - else: - id = result[0] - - # Manually delete metadata; cannot use cascade because - # that triggers on replace - metadata_t = Table("embedding_metadata") - q = ( - self._db.querybuilder() - .from_(metadata_t) - .where(metadata_t.id == ParameterValue(id)) - .delete() - ) - sql, params = get_sql(q) - cur.execute(sql, params) @trace_method("SqliteMetadataSegment._update_record", OpenTelemetryGranularity.ALL) def _update_record(self, cur: Cursor, record: EmbeddingRecord) -> None: diff --git a/chromadb/test/property/test_segment_manager.py b/chromadb/test/property/test_segment_manager.py index ff5e057dff4c..ed8ff5466693 100644 --- a/chromadb/test/property/test_segment_manager.py +++ b/chromadb/test/property/test_segment_manager.py @@ -17,9 +17,7 @@ MultipleResults, ) from typing import Dict -from chromadb.segment import ( - VectorReader -) +from chromadb.segment import VectorReader from chromadb.segment import SegmentManager from chromadb.segment.impl.manager.local import LocalSegmentManager @@ -30,6 +28,7 @@ # Memory limit use for testing memory_limit = 100 + # Helper class to keep tract of the last use id class LastUse: def __init__(self, n: int): @@ -72,7 +71,7 @@ def last_queried_segments_should_be_in_cache(self): index = 0 for id in reversed(self.last_use.store): cache_sum += self.collection_size_store[id] - if cache_sum >= memory_limit and index is not 0: + if cache_sum >= memory_limit and index != 0: break assert id in self.segment_manager.segment_cache[SegmentScope.VECTOR].cache index += 1 @@ -80,7 +79,10 @@ def last_queried_segments_should_be_in_cache(self): @invariant() @precondition(lambda self: self.system.settings.is_persistent is True) def cache_should_not_be_bigger_than_settings(self): - segment_sizes = {id: self.collection_size_store[id] for id in self.segment_manager.segment_cache[SegmentScope.VECTOR].cache} + segment_sizes = { + id: self.collection_size_store[id] + for id in self.segment_manager.segment_cache[SegmentScope.VECTOR].cache + } total_size = sum(segment_sizes.values()) if len(segment_sizes) != 1: assert total_size <= memory_limit @@ -95,8 +97,12 @@ def initialize(self) -> None: @rule(target=collections, coll=strategies.collections()) @precondition(lambda self: self.collection_created_counter <= 50) def create_segment( - self, coll: strategies.Collection + self, coll: strategies.Collection ) -> MultipleResults[strategies.Collection]: + coll.name = f"{coll.name}_{uuid.uuid4()}" + self.sysdb.create_collection( + name=coll.name, id=coll.id, metadata=coll.metadata, dimension=coll.dimension + ) segments = self.segment_manager.create_segments(asdict(coll)) for segment in segments: self.sysdb.create_segment(segment) @@ -107,22 +113,27 @@ def create_segment( @rule(coll=collections) def get_segment(self, coll: strategies.Collection) -> None: - segment = self.segment_manager.get_segment(collection_id=coll.id, type=VectorReader) + segment = self.segment_manager.get_segment( + collection_id=coll.id, type=VectorReader + ) self.last_use.add(coll.id) assert segment is not None - @staticmethod def mock_directory_size(directory: str): path_id = directory.split("/").pop() - collection_id = SegmentManagerStateMachine.segment_collection[uuid.UUID(path_id)] + collection_id = SegmentManagerStateMachine.segment_collection[ + uuid.UUID(path_id) + ] return SegmentManagerStateMachine.collection_size_store[collection_id] -@patch('chromadb.segment.impl.manager.local.get_directory_size', SegmentManagerStateMachine.mock_directory_size) +@patch( + "chromadb.segment.impl.manager.local.get_directory_size", + SegmentManagerStateMachine.mock_directory_size, +) def test_segment_manager(caplog: pytest.LogCaptureFixture, system: System) -> None: system.settings.chroma_memory_limit_bytes = memory_limit system.settings.chroma_segment_cache_policy = "LRU" - run_state_machine_as_test( - lambda: SegmentManagerStateMachine(system=system)) + run_state_machine_as_test(lambda: SegmentManagerStateMachine(system=system)) diff --git a/chromadb/test/test_client.py b/chromadb/test/test_client.py index f67293d85864..4ba88decd315 100644 --- a/chromadb/test/test_client.py +++ b/chromadb/test/test_client.py @@ -1,3 +1,4 @@ +import shutil from typing import Generator from unittest.mock import patch import chromadb @@ -17,6 +18,7 @@ def ephemeral_api() -> Generator[ClientAPI, None, None]: @pytest.fixture def persistent_api() -> Generator[ClientAPI, None, None]: + shutil.rmtree(tempfile.gettempdir() + "/test_server", ignore_errors=True) client = chromadb.PersistentClient( path=tempfile.gettempdir() + "/test_server", )