Skip to content

Commit

Permalink
PGVector Support for Custom Connection Object (microsoft#2566)
Browse files Browse the repository at this point in the history
* Added fixes and tests for basic auth format

* User can provide their own connection object. Added test for it.

* Updated instructions on how to use. Fully tested all 3 authentication methods successfully.

* Get password from gitlab secrets.

* Hide passwords.

* Update notebook/agentchat_pgvector_RetrieveChat.ipynb

Co-authored-by: Li Jiang <[email protected]>

* Hide passwords.

* Added connection_string test. 3 tests total for auth.

* Fixed quotes on db config params. No other changes found.

* Ran notebook

* Ran pre-commits and updated setup to include psycopg[binary] for windows and mac.

* Corrected list extension.

* Separate connection establishment function. Testing pending.

* Fixed pgvectordb auth

* Update agentchat_pgvector_RetrieveChat.ipynb

Added autocommit=True in example

* Rerun notebook

---------

Co-authored-by: Li Jiang <[email protected]>
Co-authored-by: Li Jiang <[email protected]>
  • Loading branch information
3 people authored May 24, 2024
1 parent aad6a28 commit 40ee011
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 398 deletions.
88 changes: 88 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1,3 +1,91 @@
# Source code
*.bash text eol=lf
*.bat text eol=crlf
*.cmd text eol=crlf
*.coffee text
*.css text diff=css eol=lf
*.htm text diff=html eol=lf
*.html text diff=html eol=lf
*.inc text
*.ini text
*.js text
*.json text eol=lf
*.jsx text
*.less text
*.ls text
*.map text -diff
*.od text
*.onlydata text
*.php text diff=php
*.pl text
*.ps1 text eol=crlf
*.py text diff=python eol=lf
*.rb text diff=ruby eol=lf
*.sass text
*.scm text
*.scss text diff=css
*.sh text eol=lf
.husky/* text eol=lf
*.sql text
*.styl text
*.tag text
*.ts text
*.tsx text
*.xml text
*.xhtml text diff=html

# Docker
Dockerfile text eol=lf

# Documentation
*.ipynb text
*.markdown text diff=markdown eol=lf
*.md text diff=markdown eol=lf
*.mdwn text diff=markdown eol=lf
*.mdown text diff=markdown eol=lf
*.mkd text diff=markdown eol=lf
*.mkdn text diff=markdown eol=lf
*.mdtxt text eol=lf
*.mdtext text eol=lf
*.txt text eol=lf
AUTHORS text eol=lf
CHANGELOG text eol=lf
CHANGES text eol=lf
CONTRIBUTING text eol=lf
COPYING text eol=lf
copyright text eol=lf
*COPYRIGHT* text eol=lf
INSTALL text eol=lf
license text eol=lf
LICENSE text eol=lf
NEWS text eol=lf
readme text eol=lf
*README* text eol=lf
TODO text

# Configs
*.cnf text eol=lf
*.conf text eol=lf
*.config text eol=lf
.editorconfig text
.env text eol=lf
.gitattributes text eol=lf
.gitconfig text eol=lf
.htaccess text
*.lock text -diff
package.json text eol=lf
package-lock.json text eol=lf -diff
pnpm-lock.yaml text eol=lf -diff
.prettierrc text
yarn.lock text -diff
*.toml text eol=lf
*.yaml text eol=lf
*.yml text eol=lf
browserslist text
Makefile text eol=lf
makefile text eol=lf

# Images
*.png filter=lfs diff=lfs merge=lfs -text
*.jpg filter=lfs diff=lfs merge=lfs -text
*.jpeg filter=lfs diff=lfs merge=lfs -text
164 changes: 118 additions & 46 deletions autogen/agentchat/contrib/vectordb/pgvectordb.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import re
import urllib.parse
from typing import Callable, List
from typing import Callable, List, Optional, Union

import numpy as np
from sentence_transformers import SentenceTransformer
Expand Down Expand Up @@ -231,7 +231,14 @@ def table_exists(self, table_name: str) -> bool:
exists = cursor.fetchone()[0]
return exists

def get(self, ids=None, include=None, where=None, limit=None, offset=None) -> List[Document]:
def get(
self,
ids: Optional[str] = None,
include: Optional[str] = None,
where: Optional[str] = None,
limit: Optional[Union[int, str]] = None,
offset: Optional[Union[int, str]] = None,
) -> List[Document]:
"""
Retrieve documents from the collection.
Expand Down Expand Up @@ -272,7 +279,6 @@ def get(self, ids=None, include=None, where=None, limit=None, offset=None) -> Li

# Construct the full query
query = f"{select_clause} {from_clause} {where_clause} {limit_clause} {offset_clause}"

retrieved_documents = []
try:
# Execute the query with the appropriate values
Expand Down Expand Up @@ -380,11 +386,11 @@ def inner_product_distance(arr1: List[float], arr2: List[float]) -> float:
def query(
self,
query_texts: List[str],
collection_name: str = None,
n_results: int = 10,
distance_type: str = "euclidean",
distance_threshold: float = -1,
include_embedding: bool = False,
collection_name: Optional[str] = None,
n_results: Optional[int] = 10,
distance_type: Optional[str] = "euclidean",
distance_threshold: Optional[float] = -1,
include_embedding: Optional[bool] = False,
) -> QueryResults:
"""
Query documents in the collection.
Expand Down Expand Up @@ -450,7 +456,7 @@ def query(
return results

@staticmethod
def convert_string_to_array(array_string) -> List[float]:
def convert_string_to_array(array_string: str) -> List[float]:
"""
Convert a string representation of an array to a list of floats.
Expand All @@ -467,7 +473,7 @@ def convert_string_to_array(array_string) -> List[float]:
array = [float(num) for num in array_string.split()]
return array

def modify(self, metadata, collection_name: str = None) -> None:
def modify(self, metadata, collection_name: Optional[str] = None) -> None:
"""
Modify metadata for the collection.
Expand All @@ -486,7 +492,7 @@ def modify(self, metadata, collection_name: str = None) -> None:
)
cursor.close()

def delete(self, ids: List[ItemID], collection_name: str = None) -> None:
def delete(self, ids: List[ItemID], collection_name: Optional[str] = None) -> None:
"""
Delete documents from the collection.
Expand All @@ -504,7 +510,7 @@ def delete(self, ids: List[ItemID], collection_name: str = None) -> None:
cursor.execute(f"DELETE FROM {self.name} WHERE id IN ({id_placeholders});", ids)
cursor.close()

def delete_collection(self, collection_name: str = None) -> None:
def delete_collection(self, collection_name: Optional[str] = None) -> None:
"""
Delete the entire collection.
Expand All @@ -520,7 +526,7 @@ def delete_collection(self, collection_name: str = None) -> None:
cursor.execute(f"DROP TABLE IF EXISTS {self.name}")
cursor.close()

def create_collection(self, collection_name: str = None) -> None:
def create_collection(self, collection_name: Optional[str] = None) -> None:
"""
Create a new collection.
Expand Down Expand Up @@ -557,23 +563,27 @@ class PGVectorDB(VectorDB):
def __init__(
self,
*,
connection_string: str = None,
host: str = None,
port: int = None,
dbname: str = None,
username: str = None,
password: str = None,
connect_timeout: int = 10,
conn: Optional[psycopg.Connection] = None,
connection_string: Optional[str] = None,
host: Optional[str] = None,
port: Optional[Union[int, str]] = None,
dbname: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
connect_timeout: Optional[int] = 10,
embedding_function: Callable = None,
metadata: dict = None,
model_name: str = "all-MiniLM-L6-v2",
metadata: Optional[dict] = None,
model_name: Optional[str] = "all-MiniLM-L6-v2",
) -> None:
"""
Initialize the vector database.
Note: connection_string or host + port + dbname must be specified
Args:
conn: psycopg.Connection | A customer connection object to connect to the database.
A connection object may include additional key/values:
https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
connection_string: "postgresql://username:password@hostname:port/database" | The PGVector connection string. Default is None.
host: str | The host to connect to. Default is None.
port: int | The port to connect to. Default is None.
Expand All @@ -593,46 +603,108 @@ def __init__(
Returns:
None
"""
self.client = self.establish_connection(
conn=conn,
connection_string=connection_string,
host=host,
port=port,
dbname=dbname,
username=username,
password=password,
connect_timeout=connect_timeout,
)
self.model_name = model_name
try:
if connection_string:
self.embedding_function = (
SentenceTransformer(self.model_name) if embedding_function is None else embedding_function
)
except Exception as e:
logger.error(
f"Validate the model name entered: {self.model_name} "
f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}"
)
raise e
self.metadata = metadata
register_vector(self.client)
self.active_collection = None

def establish_connection(
self,
conn: Optional[psycopg.Connection] = None,
connection_string: Optional[str] = None,
host: Optional[str] = None,
port: Optional[Union[int, str]] = None,
dbname: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
connect_timeout: Optional[int] = 10,
) -> psycopg.Connection:
"""
Establishes a connection to a PostgreSQL database using psycopg.
Args:
conn: An existing psycopg connection object. If provided, this connection will be used.
connection_string: A string containing the connection information. If provided, a new connection will be established using this string.
host: The hostname of the PostgreSQL server. Used if connection_string is not provided.
port: The port number to connect to at the server host. Used if connection_string is not provided.
dbname: The database name. Used if connection_string is not provided.
username: The username to connect as. Used if connection_string is not provided.
password: The user's password. Used if connection_string is not provided.
connect_timeout: Maximum wait for connection, in seconds. The default is 10 seconds.
Returns:
A psycopg.Connection object representing the established connection.
Raises:
PermissionError if no credentials are supplied
psycopg.Error: If an error occurs while trying to connect to the database.
"""
try:
if conn:
self.client = conn
elif connection_string:
parsed_connection = urllib.parse.urlparse(connection_string)
encoded_username = urllib.parse.quote(parsed_connection.username, safe="")
encoded_password = urllib.parse.quote(parsed_connection.password, safe="")
encoded_password = f":{encoded_password}@"
encoded_host = urllib.parse.quote(parsed_connection.hostname, safe="")
encoded_port = f":{parsed_connection.port}"
encoded_database = urllib.parse.quote(parsed_connection.path[1:], safe="")
connection_string_encoded = (
f"{parsed_connection.scheme}://{encoded_username}:{encoded_password}"
f"@{encoded_host}:{parsed_connection.port}/{encoded_database}"
f"{parsed_connection.scheme}://{encoded_username}{encoded_password}"
f"{encoded_host}{encoded_port}/{encoded_database}"
)
self.client = psycopg.connect(conninfo=connection_string_encoded, autocommit=True)
elif host and port and dbname:
elif host:
connection_string = ""
if host:
encoded_host = urllib.parse.quote(host, safe="")
connection_string += f"host={encoded_host} "
if port:
connection_string += f"port={port} "
if dbname:
encoded_database = urllib.parse.quote(dbname, safe="")
connection_string += f"dbname={encoded_database} "
if username:
encoded_username = urllib.parse.quote(username, safe="")
connection_string += f"user={encoded_username} "
if password:
encoded_password = urllib.parse.quote(password, safe="")
connection_string += f"password={encoded_password} "

self.client = psycopg.connect(
host=host,
port=port,
dbname=dbname,
username=username,
password=password,
conninfo=connection_string,
connect_timeout=connect_timeout,
autocommit=True,
)
else:
logger.error("Credentials were not supplied...")
raise PermissionError
self.client.execute("CREATE EXTENSION IF NOT EXISTS vector")
except psycopg.Error as e:
logger.error("Error connecting to the database: ", e)
raise e
self.model_name = model_name
try:
self.embedding_function = (
SentenceTransformer(self.model_name) if embedding_function is None else embedding_function
)
except Exception as e:
logger.error(
f"Validate the model name entered: {self.model_name} "
f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}"
)
raise e
self.metadata = metadata
self.client.execute("CREATE EXTENSION IF NOT EXISTS vector")
register_vector(self.client)
self.active_collection = None
return self.client

def create_collection(
self, collection_name: str, overwrite: bool = False, get_or_create: bool = True
Expand Down
Loading

0 comments on commit 40ee011

Please sign in to comment.