Skip to content

Commit

Permalink
test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
azat-manukyan committed Nov 14, 2024
1 parent 262fa81 commit d8eeb30
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
import deeplake

if deeplake.__version__.startswith("3."):
DEEPLAKE_V4 = False
from deeplake.core.vectorstore import VectorStore
else:
DEEPLAKE_V4 = True

class VectorStore:
def __init__(
Expand Down Expand Up @@ -119,19 +121,19 @@ def search_tql(

def search(
self,
embedding: Union[str, List[float]],
k: int,
distance_metric: str,
filter: Optional[Dict[str, Any]],
exec_option: Optional[str],
deep_memory: Optional[bool],
return_tensors: Optional[List[str]],
embedding: Union[None, str, List[float]] = None,
k: Optional[int] = None,
distance_metric: Optional[str] = None,
filter: Optional[Dict[str, Any]] = None,
exec_option: Optional[str] = None,
deep_memory: Optional[bool] = None,
return_tensors: Optional[List[str]] = None,
query: Optional[str] = None,
) -> Dict[str, Any]:
if query is None and embedding is None:
if query is None and embedding is None and filter is None:
raise ValueError(
"Both `embedding` and `query` were specified."
" Please specify either one or the other."
"all, `filter` , `embedding` and `query` were not specified."
" Please specify at least one."
)
if query is not None:
return self.search_tql(query, exec_option)
Expand All @@ -142,7 +144,11 @@ def search(
"embedding_function is required when embedding is a string"
)
embedding = self.embedding_function.embed_documents([embedding])[0]
emb_str = ", ".join([str(e) for e in embedding])
emb_str = (
None
if embedding is None
else ", ".join([str(e) for e in embedding])
)

column_list = " * " if not return_tensors else ", ".join(return_tensors)

Expand All @@ -151,18 +157,41 @@ def search(
if metric == "cosine_similarity":
order_by = " DESC "
dp = f"(embedding, ARRAY[{emb_str}])"
column_list += (
f", {self.__metric_to_function(distance_metric)}{dp} as score"
)
if emb_str is not None:
column_list += (
f", {self.__metric_to_function(distance_metric)}{dp} as score"
)
mf = self.__metric_to_function(distance_metric)
query = f"SELECT {column_list} ORDER BY {mf}{dp} {order_by} LIMIT {k}"

order_by_clause = (
"" if emb_str is None else f"ORDER BY {mf}{dp} {order_by}"
)
where_clause = self.__generate_where_clause(filter)
limit_clause = "" if k is None else f"LIMIT {k}"

query = f"SELECT {column_list} {where_clause} {order_by_clause} {limit_clause}"
print(">>>>>>>>>>>>>", query)
view = self.ds.query(query)
return self.__view_to_docs(view)

def delete(
self, ids: List[str], filter: Dict[str, Any], delete_all: bool
self,
ids: List[str],
filter: Optional[Dict[str, Any]] = None,
delete_all: Optional[bool] = None,
) -> None:
raise NotImplementedError
if ids is not None:
print(
f"SELECT * from (select *,ROW_NUMBER() as r_id) where id IN ({str(ids)[1:-1]})"
)
view = self.ds.query(
f"SELECT * from (select *,ROW_NUMBER() as r_id) where id IN ({str(ids)[1:-1]})"
)
dlist = view["r_id"][:].tolist()
dlist.reverse()
print(dlist)
for _id in dlist:
self.ds.delete(int(_id))

def dataset(self) -> Any:
return self.ds
Expand Down Expand Up @@ -195,6 +224,17 @@ def __metric_to_function(self, metric: str) -> str:
"['cosine', 'cosine_similarity', 'l2', 'l2_norm']"
)

def __generate_where_clause(self, filter: Dict[str, Any]) -> str:
if filter is None:
return ""
where_clause = "WHERE "
for key, value in filter.items():
if isinstance(value, list):
where_clause += f"{key} IN ({str(value)[1:-1]}) AND "
else:
where_clause += f"{key} == {value} AND "
return where_clause[:-5]

def __create_dataset(self, emb_size=None) -> None:
if emb_size is None:
if self.embedding_function is None:
Expand Down Expand Up @@ -402,7 +442,11 @@ def delete_nodes(

def clear(self) -> None:
"""Clear the vector store."""
self.vectorstore.delete(filter=lambda x: True)
if DEEPLAKE_V4:
for i in range(len(self.vectorstore.ds) - 1, -1, -1):
self.vectorstore.ds.delete(i)
else:
self.vectorstore.delete(filter=lambda x: True)

def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
"""Add the embeddings and their nodes into DeepLake.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import jwt # noqa

from llama_index.core import Document
Expand Down Expand Up @@ -125,6 +124,3 @@ def test_e2e():
assert [x.text for x in vs.get_nodes()] == ["Doc 2", "Doc 4"]

vs.clear()
with pytest.raises(ValueError) as e:
vs.get_nodes()
assert str(e.value) == "specified dataset is empty"

0 comments on commit d8eeb30

Please sign in to comment.