-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added qdrant label colection uploader
- Loading branch information
1 parent
55680c2
commit e127a24
Showing
2 changed files
with
111 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from qdrant_client import QdrantClient | ||
from qdrant_client.models import Distance, PointStruct, VectorParams | ||
import pickle | ||
import os | ||
from qdrant_client.http import models | ||
from geniml.search.backends.dbbackend import QdrantBackend | ||
|
||
|
||
TEXT_QDRANT_COLLECTION_NAME = "bed_text" | ||
|
||
DEFAULT_QUANTIZATION_CONFIG = models.ScalarQuantization( | ||
scalar=models.ScalarQuantizationConfig( | ||
type=models.ScalarType.INT8, | ||
quantile=0.99, | ||
always_ram=True, | ||
), | ||
) | ||
|
||
|
||
def upload_text_embeddings(): | ||
|
||
# lab qdrant client | ||
# qc = QdrantClient( | ||
# host=os.environ.get("QDRATN_HOST"), | ||
# api_key=os.environ.get("QDRANT_API_KEY") | ||
# ) | ||
|
||
qc = QdrantBackend( | ||
dim=384, | ||
collection=TEXT_QDRANT_COLLECTION_NAME, | ||
qdrant_host="", | ||
qdrant_api_key="", | ||
) | ||
qc = QdrantClient( | ||
url="", | ||
api_key="", | ||
) | ||
|
||
# load metadata embedddings into new collection | ||
with open("./text_loading.pkl", "rb") as f: | ||
text_vectors, payloads = pickle.load(f) | ||
|
||
ids = list(range(0, len(payloads))) | ||
|
||
points = [ | ||
PointStruct(id=ids[i], vector=text_vectors[i].tolist(), payload=payloads[i]) | ||
for i in range(len(payloads)) | ||
] | ||
|
||
qc.upsert(collection_name=TEXT_QDRANT_COLLECTION_NAME, points=points) | ||
|
||
|
||
if __name__ == "__main__": | ||
upload_text_embeddings() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from geniml.search.backends import BiVectorBackend, QdrantBackend | ||
from geniml.search.interfaces import BiVectorSearchInterface | ||
|
||
|
||
def search_test(): | ||
|
||
# backend for text embeddings and bed embeddings | ||
text_backend = QdrantBackend( | ||
dim=384, | ||
collection="bed_text", | ||
qdrant_host="", | ||
qdrant_api_key="", | ||
) # dim of sentence-transformers embedding output | ||
bed_backend = QdrantBackend( | ||
collection="bedbase2", | ||
qdrant_host="", | ||
qdrant_api_key="", | ||
) | ||
|
||
import cProfile | ||
|
||
# import pstats | ||
# | ||
# from bedboss.bedboss import run_all | ||
# | ||
# with cProfile.Profile() as pr: | ||
|
||
# the search backend | ||
from time import time | ||
|
||
search_backend = BiVectorBackend(text_backend, bed_backend) | ||
|
||
# the search interface | ||
search_interface = BiVectorSearchInterface( | ||
backend=search_backend, query2vec="sentence-transformers/all-MiniLM-L6-v2" | ||
) | ||
time1 = time() | ||
# actual search | ||
result = search_interface.query_search( | ||
query="leukemia", | ||
limit=500, | ||
with_payload=True, | ||
with_vectors=False, | ||
p=1.0, | ||
q=1.0, | ||
distance=False, # QdrantBackend returns similarity as the score, not distance | ||
) | ||
result | ||
time2 = time() | ||
print(time2 - time1) | ||
# stats = pstats.Stats(pr) | ||
# stats.sort_stats(pstats.SortKey.TIME) | ||
# stats.dump_stats(filename="test_profiling") | ||
|
||
|
||
if __name__ == "__main__": | ||
search_test() |