From bc282305d2f57a476813f0cc5e899acf5b439de6 Mon Sep 17 00:00:00 2001 From: Maja Milinkovic Date: Wed, 6 Dec 2023 20:56:29 -0500 Subject: [PATCH 1/3] add support for initializing vecs client with custom schema --- src/tests/conftest.py | 6 +++--- src/tests/test_client.py | 7 ++++++- src/vecs/__init__.py | 4 ++-- src/vecs/client.py | 13 +++++++------ src/vecs/collection.py | 18 +++++++++--------- 5 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 4b6e3e7..65dcb79 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -13,7 +13,7 @@ import vecs PYTEST_DB = "postgresql://postgres:password@localhost:5611/vecs_db" - +PYTEST_SCHEMA = "test_schema" @pytest.fixture(scope="session") def maybe_start_pg() -> Generator[None, None, None]: @@ -94,12 +94,12 @@ def maybe_start_pg() -> Generator[None, None, None]: def clean_db(maybe_start_pg: None) -> Generator[str, None, None]: eng = create_engine(PYTEST_DB) with eng.begin() as connection: - connection.execute(text("drop schema if exists vecs cascade;")) + connection.execute(text(f"drop schema if exists {PYTEST_SCHEMA} cascade;")) yield PYTEST_DB eng.dispose() @pytest.fixture(scope="function") def client(clean_db: str) -> Generator[vecs.Client, None, None]: - client_ = vecs.create_client(clean_db) + client_ = vecs.create_client(clean_db, PYTEST_SCHEMA) yield client_ diff --git a/src/tests/test_client.py b/src/tests/test_client.py index 6e8694f..a5b29e8 100644 --- a/src/tests/test_client.py +++ b/src/tests/test_client.py @@ -1,7 +1,12 @@ import pytest - import vecs +def test_create_client(clean_db) -> None: + client = vecs.create_client(clean_db) + assert client.schema == "vecs" + + client = vecs.create_client(clean_db, "my_schema") + assert client.schema == "my_schema" def test_extracts_vector_version(client: vecs.Client) -> None: # pgvector version is sucessfully extracted diff --git a/src/vecs/__init__.py b/src/vecs/__init__.py index 55f80e5..a35dc02 100644 --- a/src/vecs/__init__.py +++ b/src/vecs/__init__.py @@ -23,6 +23,6 @@ ] -def create_client(connection_string: str) -> Client: +def create_client(connection_string: str, schema: str="vecs") -> Client: """Creates a client from a Postgres connection string""" - return Client(connection_string) + return Client(connection_string=connection_string, schema=schema) diff --git a/src/vecs/client.py b/src/vecs/client.py index 89bb3e3..d923b75 100644 --- a/src/vecs/client.py +++ b/src/vecs/client.py @@ -47,23 +47,24 @@ class Client: vx.disconnect() """ - def __init__(self, connection_string: str): + def __init__(self, connection_string: str, schema: str): """ Initialize a Client instance. Args: connection_string (str): A string representing the database connection information. - + schema (str): A string representing the database schema to connect to. Returns: None """ + self.schema = schema self.engine = create_engine(connection_string) - self.meta = MetaData(schema="vecs") + self.meta = MetaData(schema=self.schema) self.Session = sessionmaker(self.engine) with self.Session() as sess: with sess.begin(): - sess.execute(text("create schema if not exists vecs;")) + sess.execute(text(f"create schema if not exists {self.schema};")) sess.execute(text("create extension if not exists vector;")) self.vector_version: str = sess.execute( text( @@ -105,7 +106,7 @@ def get_or_create_collection( CollectionAlreadyExists: If a collection with the same name already exists """ from vecs.collection import Collection - + adapter_dimension = adapter.exported_dimension if adapter else None collection = Collection( @@ -162,7 +163,7 @@ def get_collection(self, name: str) -> Collection: join pg_attribute pa on pc.oid = pa.attrelid where - pc.relnamespace = 'vecs'::regnamespace + pc.relnamespace = '{self.schema}'::regnamespace and pc.relkind = 'r' and pa.attname = 'vec' and not pc.relname ^@ '_' diff --git a/src/vecs/collection.py b/src/vecs/collection.py index af35538..ba50527 100644 --- a/src/vecs/collection.py +++ b/src/vecs/collection.py @@ -235,7 +235,7 @@ def _create_if_not_exists(self): join pg_attribute pa on pc.oid = pa.attrelid where - pc.relnamespace = 'vecs'::regnamespace + pc.relnamespace = '{self.client.schema}'::regnamespace and pc.relkind = 'r' and pa.attname = 'vec' and not pc.relname ^@ '_' @@ -289,7 +289,7 @@ def _create(self): text( f""" create index ix_meta_{unique_string} - on vecs."{self.table.name}" + on {self.client.schema}."{self.table.name}" using gin ( metadata jsonb_path_ops ) """ ) @@ -576,7 +576,7 @@ def _list_collections(cls, client: "Client") -> List["Collection"]: """ query = text( - """ + f""" select relname as table_name, atttypmod as embedding_dim @@ -585,7 +585,7 @@ def _list_collections(cls, client: "Client") -> List["Collection"]: join pg_attribute pa on pc.oid = pa.attrelid where - pc.relnamespace = 'vecs'::regnamespace + pc.relnamespace = '{client.schema}'::regnamespace and pc.relkind = 'r' and pa.attname = 'vec' and not pc.relname ^@ '_' @@ -636,13 +636,13 @@ def index(self) -> Optional[str]: if self._index is None: query = text( - """ + f""" select relname as table_name from pg_class pc where - pc.relnamespace = 'vecs'::regnamespace + pc.relnamespace = '{self.client.schema}'::regnamespace and relname ilike 'ix_vector%' and pc.relkind = 'i' """ @@ -760,7 +760,7 @@ def create_index( with sess.begin(): if self.index is not None: if replace: - sess.execute(text(f'drop index vecs."{self.index}";')) + sess.execute(text(f'drop index "{self.client.schema}"."{self.index}";')) self._index = None else: raise ArgError("replace is set to False but an index exists") @@ -787,7 +787,7 @@ def create_index( text( f""" create index ix_{ops}_ivfflat_nl{n_lists}_{unique_string} - on vecs."{self.table.name}" + on {self.client.schema}."{self.table.name}" using ivfflat (vec {ops}) with (lists={n_lists}) """ ) @@ -806,7 +806,7 @@ def create_index( text( f""" create index ix_{ops}_hnsw_m{m}_efc{ef_construction}_{unique_string} - on vecs."{self.table.name}" + on {self.client.schema}."{self.table.name}" using hnsw (vec {ops}) WITH (m={m}, ef_construction={ef_construction}); """ ) From 07450933704deaad024dd7bdd865cbbb6b75b36e Mon Sep 17 00:00:00 2001 From: Maja Milinkovic Date: Mon, 15 Jan 2024 11:28:07 -0500 Subject: [PATCH 2/3] lint --- .pre-commit-config.yaml | 1 + src/tests/conftest.py | 4 ++- src/tests/test_client.py | 13 +++---- src/tests/test_collection.py | 66 ++++++++++++++++++++++++++++++++++++ src/vecs/__init__.py | 9 +++-- src/vecs/client.py | 25 +++++++------- src/vecs/collection.py | 31 +++++++++++------ 7 files changed, 116 insertions(+), 33 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1fcf653..ac35a43 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,6 +20,7 @@ repos: hooks: - id: autoflake args: ['--in-place', '--remove-all-unused-imports'] + language_version: python3.9 - repo: https://github.com/ambv/black rev: 22.10.0 diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 65dcb79..2435714 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -15,6 +15,7 @@ PYTEST_DB = "postgresql://postgres:password@localhost:5611/vecs_db" PYTEST_SCHEMA = "test_schema" + @pytest.fixture(scope="session") def maybe_start_pg() -> Generator[None, None, None]: """Creates a postgres 15 docker container that can be connected @@ -94,6 +95,7 @@ def maybe_start_pg() -> Generator[None, None, None]: def clean_db(maybe_start_pg: None) -> Generator[str, None, None]: eng = create_engine(PYTEST_DB) with eng.begin() as connection: + connection.execute(text("drop schema if exists vecs cascade;")) connection.execute(text(f"drop schema if exists {PYTEST_SCHEMA} cascade;")) yield PYTEST_DB eng.dispose() @@ -101,5 +103,5 @@ def clean_db(maybe_start_pg: None) -> Generator[str, None, None]: @pytest.fixture(scope="function") def client(clean_db: str) -> Generator[vecs.Client, None, None]: - client_ = vecs.create_client(clean_db, PYTEST_SCHEMA) + client_ = vecs.create_client(clean_db) yield client_ diff --git a/src/tests/test_client.py b/src/tests/test_client.py index a5b29e8..db88c85 100644 --- a/src/tests/test_client.py +++ b/src/tests/test_client.py @@ -1,12 +1,7 @@ import pytest -import vecs -def test_create_client(clean_db) -> None: - client = vecs.create_client(clean_db) - assert client.schema == "vecs" +import vecs - client = vecs.create_client(clean_db, "my_schema") - assert client.schema == "my_schema" def test_extracts_vector_version(client: vecs.Client) -> None: # pgvector version is sucessfully extracted @@ -34,11 +29,17 @@ def test_get_collection(client: vecs.Client) -> None: def test_list_collections(client: vecs.Client) -> None: + """ + Test list_collections returns appropriate results for default schema (vecs) and custom schema + """ assert len(client.list_collections()) == 0 client.get_or_create_collection(name="docs", dimension=384) client.get_or_create_collection(name="books", dimension=1586) + client.get_or_create_collection(name="movies", schema="test_schema", dimension=384) collections = client.list_collections() + collections_test_schema = client.list_collections(schema="test_schema") assert len(collections) == 2 + assert len(collections_test_schema) == 1 def test_delete_collection(client: vecs.Client) -> None: diff --git a/src/tests/test_collection.py b/src/tests/test_collection.py index e7c4b38..be8ec38 100644 --- a/src/tests/test_collection.py +++ b/src/tests/test_collection.py @@ -815,3 +815,69 @@ def test_hnsw_unavailable_error(client: vecs.Client) -> None: bar = client.get_or_create_collection(name="bar", dimension=dim) with pytest.raises(ArgError): bar.create_index(method=IndexMethod.hnsw) + + +def test_get_or_create_with_schema(client: vecs.Client): + """ + Test that get_or_create_collection works when specifying custom schema + """ + + dim = 384 + + collection_1 = client.get_or_create_collection( + name="collection_1", schema="test_schema", dimension=dim + ) + collection_2 = client.get_or_create_collection( + name="collection_1", schema="test_schema", dimension=dim + ) + + assert collection_1.schema == "test_schema" + assert collection_1.schema == collection_2.schema + assert collection_1.name == collection_2.name + + +def test_upsert_with_schema(client: vecs.Client) -> None: + n_records = 100 + dim = 384 + + movies1 = client.get_or_create_collection( + name="ping", schema="test_schema", dimension=dim + ) + movies2 = client.get_or_create_collection(name="ping", schema="vecs", dimension=dim) + + # collection initially empty + assert len(movies1) == 0 + assert len(movies2) == 0 + + records = [ + ( + f"vec{ix}", + vec, + { + "genre": random.choice(["action", "rom-com", "drama"]), + "year": int(50 * random.random()) + 1970, + }, + ) + for ix, vec in enumerate(np.random.random((n_records, dim))) + ] + + # insert works + movies1.upsert(records) + assert len(movies1) == n_records + + movies2.upsert(records) + assert len(movies2) == n_records + + # upserting overwrites + new_record = ("vec0", np.zeros(384), {}) + movies1.upsert([new_record]) + db_record = movies1["vec0"] + db_record[0] == new_record[0] + db_record[1] == new_record[1] + db_record[2] == new_record[2] + + movies2.upsert([new_record]) + db_record = movies2["vec0"] + db_record[0] == new_record[0] + db_record[1] == new_record[1] + db_record[2] == new_record[2] diff --git a/src/vecs/__init__.py b/src/vecs/__init__.py index a35dc02..c226523 100644 --- a/src/vecs/__init__.py +++ b/src/vecs/__init__.py @@ -23,6 +23,9 @@ ] -def create_client(connection_string: str, schema: str="vecs") -> Client: - """Creates a client from a Postgres connection string""" - return Client(connection_string=connection_string, schema=schema) +def create_client(connection_string: str) -> Client: + """ + Creates a client from a Postgres connection string and optional schema. + Defaults to `vecs` schema. + """ + return Client(connection_string=connection_string) diff --git a/src/vecs/client.py b/src/vecs/client.py index d923b75..54b5c78 100644 --- a/src/vecs/client.py +++ b/src/vecs/client.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, List, Optional from deprecated import deprecated -from sqlalchemy import MetaData, create_engine, text +from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker from vecs.adapter import Adapter @@ -47,24 +47,21 @@ class Client: vx.disconnect() """ - def __init__(self, connection_string: str, schema: str): + def __init__(self, connection_string: str): """ Initialize a Client instance. Args: connection_string (str): A string representing the database connection information. - schema (str): A string representing the database schema to connect to. Returns: None """ - self.schema = schema self.engine = create_engine(connection_string) - self.meta = MetaData(schema=self.schema) self.Session = sessionmaker(self.engine) with self.Session() as sess: with sess.begin(): - sess.execute(text(f"create schema if not exists {self.schema};")) + sess.execute(text("create schema if not exists vecs;")) sess.execute(text("create extension if not exists vector;")) self.vector_version: str = sess.execute( text( @@ -84,6 +81,7 @@ def _supports_hnsw(self): def get_or_create_collection( self, name: str, + schema: str = "vecs", *, dimension: Optional[int] = None, adapter: Optional[Adapter] = None, @@ -106,7 +104,7 @@ def get_or_create_collection( CollectionAlreadyExists: If a collection with the same name already exists """ from vecs.collection import Collection - + adapter_dimension = adapter.exported_dimension if adapter else None collection = Collection( @@ -114,6 +112,7 @@ def get_or_create_collection( dimension=dimension or adapter_dimension, # type: ignore client=self, adapter=adapter, + schema=schema, ) return collection._create_if_not_exists() @@ -163,7 +162,7 @@ def get_collection(self, name: str) -> Collection: join pg_attribute pa on pc.oid = pa.attrelid where - pc.relnamespace = '{self.schema}'::regnamespace + pc.relnamespace = 'vecs'::regnamespace and pc.relkind = 'r' and pa.attname = 'vec' and not pc.relname ^@ '_' @@ -183,18 +182,18 @@ def get_collection(self, name: str) -> Collection: self, ) - def list_collections(self) -> List["Collection"]: + def list_collections(self, schema: str = "vecs") -> List["Collection"]: """ - List all vector collections. + List all vector collections by database schema. Returns: list[Collection]: A list of all collections. """ from vecs.collection import Collection - return Collection._list_collections(self) + return Collection._list_collections(self, schema) - def delete_collection(self, name: str) -> None: + def delete_collection(self, name: str, schema: str = "vecs") -> None: """ Delete a vector collection. @@ -208,7 +207,7 @@ def delete_collection(self, name: str) -> None: """ from vecs.collection import Collection - Collection(name, -1, self)._drop() + Collection(name, -1, self, schema=schema)._drop() return def disconnect(self) -> None: diff --git a/src/vecs/collection.py b/src/vecs/collection.py index ba50527..3367f0a 100644 --- a/src/vecs/collection.py +++ b/src/vecs/collection.py @@ -159,6 +159,7 @@ def __init__( dimension: int, client: Client, adapter: Optional[Adapter] = None, + schema: Optional[str] = "vecs", ): """ Initializes a new instance of the `Collection` class. @@ -174,7 +175,9 @@ def __init__( self.client = client self.name = name self.dimension = dimension - self.table = build_table(name, client.meta, dimension) + self.schema = schema + self.meta = MetaData(schema=self.schema) + self.table = build_table(name, self.meta, dimension) self._index: Optional[str] = None self.adapter = adapter or Adapter(steps=[NoOp(dimension=dimension)]) @@ -195,6 +198,10 @@ def __init__( "Dimensions reported by adapter, dimension, and collection do not match" ) + with self.client.Session() as sess: + with sess.begin(): + sess.execute(text(f"create schema if not exists {self.schema};")) + def __repr__(self): """ Returns a string representation of the `Collection` instance. @@ -235,7 +242,7 @@ def _create_if_not_exists(self): join pg_attribute pa on pc.oid = pa.attrelid where - pc.relnamespace = '{self.client.schema}'::regnamespace + pc.relnamespace = '{self.schema}'::regnamespace and pc.relkind = 'r' and pa.attname = 'vec' and not pc.relname ^@ '_' @@ -285,11 +292,12 @@ def _create(self): unique_string = str(uuid.uuid4()).replace("-", "_")[0:7] with self.client.Session() as sess: + sess.execute(text(f"create schema if not exists {self.schema};")) sess.execute( text( f""" create index ix_meta_{unique_string} - on {self.client.schema}."{self.table.name}" + on {self.schema}."{self.table.name}" using gin ( metadata jsonb_path_ops ) """ ) @@ -562,7 +570,7 @@ def query( return sess.execute(stmt).fetchall() or [] @classmethod - def _list_collections(cls, client: "Client") -> List["Collection"]: + def _list_collections(cls, client: "Client", schema: str) -> List["Collection"]: """ PRIVATE @@ -570,9 +578,10 @@ def _list_collections(cls, client: "Client") -> List["Collection"]: Args: client (Client): The database client. + schema (str): The database schema to query. Returns: - List[Collection]: A list of all existing collections. + List[Collection]: A list of all existing collections within the specified schema. """ query = text( @@ -585,7 +594,7 @@ def _list_collections(cls, client: "Client") -> List["Collection"]: join pg_attribute pa on pc.oid = pa.attrelid where - pc.relnamespace = '{client.schema}'::regnamespace + pc.relnamespace = '{schema}'::regnamespace and pc.relkind = 'r' and pa.attname = 'vec' and not pc.relname ^@ '_' @@ -642,7 +651,7 @@ def index(self) -> Optional[str]: from pg_class pc where - pc.relnamespace = '{self.client.schema}'::regnamespace + pc.relnamespace = '{self.schema}'::regnamespace and relname ilike 'ix_vector%' and pc.relkind = 'i' """ @@ -760,7 +769,9 @@ def create_index( with sess.begin(): if self.index is not None: if replace: - sess.execute(text(f'drop index "{self.client.schema}"."{self.index}";')) + sess.execute( + text(f'drop index "{self.schema}"."{self.index}";') + ) self._index = None else: raise ArgError("replace is set to False but an index exists") @@ -787,7 +798,7 @@ def create_index( text( f""" create index ix_{ops}_ivfflat_nl{n_lists}_{unique_string} - on {self.client.schema}."{self.table.name}" + on {self.schema}."{self.table.name}" using ivfflat (vec {ops}) with (lists={n_lists}) """ ) @@ -806,7 +817,7 @@ def create_index( text( f""" create index ix_{ops}_hnsw_m{m}_efc{ef_construction}_{unique_string} - on {self.client.schema}."{self.table.name}" + on {self.schema}."{self.table.name}" using hnsw (vec {ops}) WITH (m={m}, ef_construction={ef_construction}); """ ) From 06ec1bd7c544275f659be6c604c6bab157764961 Mon Sep 17 00:00:00 2001 From: Maja Milinkovic <38793515+majamil16@users.noreply.github.com> Date: Mon, 12 Feb 2024 19:34:18 -0500 Subject: [PATCH 3/3] updates based on feedback --- .pre-commit-config.yaml | 4 ++-- src/vecs/client.py | 7 ++++--- src/vecs/collection.py | 5 ++++- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ac35a43..95b9cd5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,10 +20,10 @@ repos: hooks: - id: autoflake args: ['--in-place', '--remove-all-unused-imports'] - language_version: python3.9 + language_version: python3.8 - repo: https://github.com/ambv/black rev: 22.10.0 hooks: - id: black - language_version: python3.9 + language_version: python3.8 diff --git a/src/vecs/client.py b/src/vecs/client.py index 54b5c78..be9430a 100644 --- a/src/vecs/client.py +++ b/src/vecs/client.py @@ -81,8 +81,8 @@ def _supports_hnsw(self): def get_or_create_collection( self, name: str, - schema: str = "vecs", *, + schema: str = "vecs", dimension: Optional[int] = None, adapter: Optional[Adapter] = None, ) -> Collection: @@ -182,7 +182,7 @@ def get_collection(self, name: str) -> Collection: self, ) - def list_collections(self, schema: str = "vecs") -> List["Collection"]: + def list_collections(self, *, schema: str = "vecs") -> List["Collection"]: """ List all vector collections by database schema. @@ -193,7 +193,7 @@ def list_collections(self, schema: str = "vecs") -> List["Collection"]: return Collection._list_collections(self, schema) - def delete_collection(self, name: str, schema: str = "vecs") -> None: + def delete_collection(self, name: str, *, schema: str = "vecs") -> None: """ Delete a vector collection. @@ -201,6 +201,7 @@ def delete_collection(self, name: str, schema: str = "vecs") -> None: Args: name (str): The name of the collection. + schema (str): Optional, the database schema. Defaults to `vecs`. Returns: None diff --git a/src/vecs/collection.py b/src/vecs/collection.py index 3367f0a..2442b28 100644 --- a/src/vecs/collection.py +++ b/src/vecs/collection.py @@ -175,7 +175,10 @@ def __init__( self.client = client self.name = name self.dimension = dimension - self.schema = schema + self._schema = schema + self.schema = self.client.engine.dialect.identifier_preparer.quote_schema( + self._schema + ) self.meta = MetaData(schema=self.schema) self.table = build_table(name, self.meta, dimension) self._index: Optional[str] = None