From 75f0edbe023dcb88c72b1151ac02331f2d838d43 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Tue, 24 Dec 2024 22:28:24 +0530 Subject: [PATCH] refactor Signed-off-by: Anush008 --- ragna/source_storages/_qdrant.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/ragna/source_storages/_qdrant.py b/ragna/source_storages/_qdrant.py index aa2d2add..804ba6da 100644 --- a/ragna/source_storages/_qdrant.py +++ b/ragna/source_storages/_qdrant.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os import uuid from collections import defaultdict from typing import TYPE_CHECKING, Any, Optional, cast @@ -43,7 +44,14 @@ def __init__(self) -> None: from qdrant_client import QdrantClient - self._client = QdrantClient(path=ragna.local_root() / "qdrant") + url = os.getenv("QDRANT_URL") + api_key = os.getenv("QDRANT_API_KEY") + path = path = ragna.local_root() / "qdrant" + + # Cannot pass both url and path + self._client = ( + QdrantClient(url=url, api_key=api_key) if url else QdrantClient(path=path) + ) def list_corpuses(self) -> list[str]: return [c.name for c in self._client.get_collections().collections] @@ -216,7 +224,11 @@ def retrieve( query_vector = self._embedding_function([prompt])[0] - search_filter = self._translate_metadata_filter(metadata_filter) + search_filter = ( + self._translate_metadata_filter(metadata_filter) + if metadata_filter + else None + ) if isinstance(search_filter, models.FieldCondition): search_filter = models.Filter(must=[search_filter])