Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: remove vector points when unassociating datasource from collection #322

Closed
47 changes: 45 additions & 2 deletions backend/modules/vector_db/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from abc import ABC, abstractmethod
from typing import List
from typing import Generator, List, Optional

from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.schema.vectorstore import VectorStore

from backend.constants import DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE
from backend.logger import logger
from backend.types import DataPointVector

MAX_SCROLL_LIMIT = int(1e6)


class BaseVectorDB(ABC):
@abstractmethod
Expand Down Expand Up @@ -61,6 +64,17 @@ def get_vector_client(self):
raise NotImplementedError()

@abstractmethod
def yield_data_point_vector_batches(
self,
collection_name: str,
data_source_fqn: str,
batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE,
) -> Generator[List[DataPointVector], None, Optional[List[DataPointVector]]]:
"""
Yield vectors from the collection
"""
raise NotImplementedError()

def list_data_point_vectors(
self,
collection_name: str,
Expand All @@ -70,7 +84,18 @@ def list_data_point_vectors(
"""
Get vectors from the collection
"""
raise NotImplementedError()
logger.debug(f"Listing all data point vectors for collection {collection_name}")
data_point_vectors = []
for batch in self.yield_data_point_vector_batches(
collection_name, data_source_fqn, batch_size
):
data_point_vectors.extend(batch)
if len(data_point_vectors) >= MAX_SCROLL_LIMIT:
break
logger.debug(
f"Listing {len(data_point_vectors)} data point vectors for collection {collection_name}"
)
return data_point_vectors

@abstractmethod
def delete_data_point_vectors(
Expand All @@ -83,3 +108,21 @@ def delete_data_point_vectors(
Delete vectors from the collection
"""
raise NotImplementedError()

def delete_data_point_vectors_by_data_source(
self,
collection_name: str,
data_source_fqn: str,
batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE,
):
"""
Delete vectors from the collection based on data_source_fqn
"""
for data_points_batch in self.yield_data_point_vector_batches(
collection_name=collection_name,
data_source_fqn=data_source_fqn,
batch_size=batch_size,
):
self.delete_data_point_vectors(
collection_name=collection_name, data_point_vectors=data_points_batch
),
21 changes: 6 additions & 15 deletions backend/modules/vector_db/qdrant.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Generator, List
from urllib.parse import urlparse

from langchain.embeddings.base import Embeddings
Expand All @@ -8,10 +8,9 @@

from backend.constants import DATA_POINT_FQN_METADATA_KEY, DATA_POINT_HASH_METADATA_KEY
from backend.logger import logger
from backend.modules.vector_db.base import BaseVectorDB
from backend.modules.vector_db.base import MAX_SCROLL_LIMIT, BaseVectorDB
from backend.types import DataPointVector, QdrantClientConfig, VectorDBConfig

MAX_SCROLL_LIMIT = int(1e6)
BATCH_SIZE = 1000


Expand Down Expand Up @@ -188,12 +187,9 @@ def get_vector_client(self):
logger.debug(f"[Qdrant] Getting Qdrant client")
return self.qdrant_client

def list_data_point_vectors(
def yield_data_point_vector_batches(
self, collection_name: str, data_source_fqn: str, batch_size: int = BATCH_SIZE
) -> List[DataPointVector]:
logger.debug(
f"[Qdrant] Listing all data point vectors for collection {collection_name}"
)
) -> Generator[List[DataPointVector], None, None]:
stop = False
offset = None
data_point_vectors: List[DataPointVector] = []
Expand Down Expand Up @@ -232,17 +228,12 @@ def list_data_point_vectors(
data_point_hash=metadata.get(DATA_POINT_HASH_METADATA_KEY),
)
)
if len(data_point_vectors) > MAX_SCROLL_LIMIT:
stop = True
break
yield data_point_vectors
data_point_vectors = []
if next_offset is None:
stop = True
else:
offset = next_offset
logger.debug(
f"[Qdrant] Listing {len(data_point_vectors)} data point vectors for collection {collection_name}"
)
return data_point_vectors

def delete_data_point_vectors(
self,
Expand Down
20 changes: 7 additions & 13 deletions backend/modules/vector_db/singlestore.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Generator, Iterable, List, Optional

import singlestoredb as s2
from langchain.docstore.document import Document
Expand All @@ -8,10 +8,9 @@

from backend.constants import DATA_POINT_FQN_METADATA_KEY, DATA_POINT_HASH_METADATA_KEY
from backend.logger import logger
from backend.modules.vector_db.base import BaseVectorDB
from backend.modules.vector_db.base import MAX_SCROLL_LIMIT, BaseVectorDB
from backend.types import DataPointVector, VectorDBConfig

MAX_SCROLL_LIMIT = int(1e6)
BATCH_SIZE = 1000


Expand Down Expand Up @@ -218,15 +217,12 @@ def get_vector_store(self, collection_name: str, embeddings: Embeddings):
def get_vector_client(self):
return s2.connect(self.host)

def list_data_point_vectors(
def yield_data_point_vector_batches(
self,
collection_name: str,
data_source_fqn: str,
batch_size: int = BATCH_SIZE,
) -> List[DataPointVector]:
logger.debug(
f"[SingleStore] Listing all data point vectors for collection {collection_name}"
)
) -> Generator[List[DataPointVector], None, None]:
data_point_vectors: List[DataPointVector] = []
logger.debug(f"data_source_fqn: {data_source_fqn}")

Expand All @@ -236,6 +232,7 @@ def list_data_point_vectors(
curr = conn.cursor()

# Remove all data point vectors with the same data_source_fqn
# TODO : Scroll in batches
curr.execute(
f"SELECT * FROM {collection_name} WHERE JSON_EXTRACT_JSON(metadata, '{DATA_POINT_FQN_METADATA_KEY}') LIKE '%{data_source_fqn}%' LIMIT {MAX_SCROLL_LIMIT}"
)
Expand All @@ -255,16 +252,13 @@ def list_data_point_vectors(
data_point_hash=metadata.get(DATA_POINT_HASH_METADATA_KEY),
)
)
yield data_point_vectors
data_point_vectors = []
except Exception as e:
logger.exception(f"[SingleStore] Failed to list data point vectors: {e}")
finally:
conn.close()

logger.debug(
f"[SingleStore] Listing {len(data_point_vectors)} data point vectors for collection {collection_name}"
)
return data_point_vectors

def delete_data_point_vectors(
self,
collection_name: str,
Expand Down
6 changes: 3 additions & 3 deletions backend/modules/vector_db/weaviate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, Generator, List, Optional

import weaviate
from langchain.embeddings.base import Embeddings
Expand Down Expand Up @@ -131,12 +131,12 @@ def delete_documents(self, collection_name: str, document_ids: List[str]):
def get_vector_client(self):
return self.weaviate_client

def list_data_point_vectors(
def yield_data_point_vector_batches(
self,
collection_name: str,
data_source_fqn: str,
batch_size: int = 1000,
) -> List[DataPointVector]:
) -> Generator[List[DataPointVector], None, Optional[List[DataPointVector]]]:
pass

def delete_data_point_vectors(
Expand Down
10 changes: 10 additions & 0 deletions backend/server/routers/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ListDataIngestionRunsDto,
UnassociateDataSourceWithCollectionDto,
)
from backend.utils import run_in_executor

router = APIRouter(prefix="/v1/collections", tags=["collections"])

Expand Down Expand Up @@ -118,10 +119,19 @@ async def unassociate_data_source_from_collection(
):
"""Remove a data source to the collection"""
metadata_store_client = await get_client()
# Remove the association between datasource and collection
collection = await metadata_store_client.aunassociate_data_source_with_collection(
collection_name=request.collection_name,
data_source_fqn=request.data_source_fqn,
)
# If there are any vector points attached to the collection due to the unassociated data source,
# asynchronously remove them from the vector database
await run_in_executor(
executor=None,
func=VECTOR_STORE_CLIENT.delete_data_point_vectors_by_data_source,
collection_name=request.collection_name,
data_source_fqn=request.data_source_fqn,
)
return JSONResponse(content={"collection": collection.model_dump()})


Expand Down
Loading