-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdb_client.py
31 lines (24 loc) · 899 Bytes
/
db_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import chromadb
from config import config
from logger import setup_logger
logger = setup_logger(__name__)
class DbClient(object):
def __init__(self):
self.client = chromadb.HttpClient(
host=config["db_host"], port=config["db_port"]
)
def __new__(cls):
if not hasattr(cls, "instance"):
try:
cls.instance = super(DbClient, cls).__new__(cls)
except Exception as inst:
logger.error("Error connecting to vector db.")
logger.exception(inst)
return cls.instance
def query_docs(self, query, collection, embed_func, num_docs=4):
collection = self.client.get_collection(
collection, embedding_function=embed_func
)
return collection.query(
query_texts=[query], include=["documents", "metadatas"], n_results=num_docs
)