From f30f1ade91a4cf8d830390f7c0e867767190f920 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Wed, 17 Apr 2024 11:01:14 +0200 Subject: [PATCH 01/36] ElasticsearchStore --- elasticsearch/store/__init__.py | 0 elasticsearch/store/_utilities.py | 153 +++ elasticsearch/store/embedding_service.py | 132 +++ elasticsearch/store/store.py | 399 ++++++++ elasticsearch/store/strategies.py | 563 +++++++++++ .../test_store_integration/__init__.py | 0 .../test_store_integration/_test_utilities.py | 130 +++ .../test_store_integration/docker-compose.yml | 34 + .../test_embedding_service.py | 47 + .../test_store_integration/test_store.py | 939 ++++++++++++++++++ 10 files changed, 2397 insertions(+) create mode 100644 elasticsearch/store/__init__.py create mode 100644 elasticsearch/store/_utilities.py create mode 100644 elasticsearch/store/embedding_service.py create mode 100644 elasticsearch/store/store.py create mode 100644 elasticsearch/store/strategies.py create mode 100644 test_elasticsearch/test_store_integration/__init__.py create mode 100644 test_elasticsearch/test_store_integration/_test_utilities.py create mode 100644 test_elasticsearch/test_store_integration/docker-compose.yml create mode 100644 test_elasticsearch/test_store_integration/test_embedding_service.py create mode 100644 test_elasticsearch/test_store_integration/test_store.py diff --git a/elasticsearch/store/__init__.py b/elasticsearch/store/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/elasticsearch/store/_utilities.py b/elasticsearch/store/_utilities.py new file mode 100644 index 000000000..a3b6c36c7 --- /dev/null +++ b/elasticsearch/store/_utilities.py @@ -0,0 +1,153 @@ +from typing import Any, Dict, List, Optional, Union + +import numpy as np +from elasticsearch import ( + AsyncElasticsearch, + BadRequestError, + ConflictError, + NotFoundError, +) + +Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] + + +def create_elasticsearch_client( + agent_header: str, + client: Optional[AsyncElasticsearch] = None, + url: Optional[str] = None, + cloud_id: Optional[str] = None, + api_key: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + client_params: Optional[Dict[str, Any]] = None, +) -> AsyncElasticsearch: + if not client: + if url and cloud_id: + raise ValueError( + "Both es_url and cloud_id are defined. Please provide only one." + ) + + connection_params: Dict[str, Any] = {} + + if url: + connection_params["hosts"] = [url] + elif cloud_id: + connection_params["cloud_id"] = cloud_id + else: + raise ValueError("Please provide either elasticsearch_url or cloud_id.") + + if api_key: + connection_params["api_key"] = api_key + elif username and password: + connection_params["basic_auth"] = (username, password) + + if client_params is not None: + connection_params.update(client_params) + + client = AsyncElasticsearch(**connection_params) + + if not isinstance(client, AsyncElasticsearch): + raise TypeError("Elasticsearch client must be AsyncElasticsearch client") + + # Add integration-specific usage header for tracking usage in Elastic Cloud. + # client.options preserces existing (non-user-agent) headers. + client = client.options(headers={"User-Agent": agent_header}) + + return client + + +async def model_must_be_deployed_async( + client: AsyncElasticsearch, model_id: str +) -> None: + try: + dummy = {"x": "y"} + await client.ml.infer_trained_model(model_id=model_id, docs=[dummy]) + except NotFoundError as err: + raise err + except ConflictError as err: + raise NotFoundError( + f"model '{model_id}' not found, please deploy it first", + meta=err.meta, + body=err.body, + ) from err + except BadRequestError: + # This error is expected because we do not know the expected document + # shape and just use a dummy doc above. + pass + + return None + + +async def model_is_deployed_async(es_client: AsyncElasticsearch, model_id: str) -> bool: + try: + await model_must_be_deployed_async(es_client, model_id) + return True + except NotFoundError: + return False + + +def maximal_marginal_relevance( + query_embedding: list, + embedding_list: list, + lambda_mult: float = 0.5, + k: int = 4, +) -> List[int]: + """Calculate maximal marginal relevance.""" + query_embedding_arr = np.array(query_embedding) + + if min(k, len(embedding_list)) <= 0: + return [] + if query_embedding_arr.ndim == 1: + query_embedding_arr = np.expand_dims(query_embedding_arr, axis=0) + similarity_to_query = _cosine_similarity(query_embedding_arr, embedding_list)[0] + most_similar = int(np.argmax(similarity_to_query)) + idxs = [most_similar] + selected = np.array([embedding_list[most_similar]]) + while len(idxs) < min(k, len(embedding_list)): + best_score = -np.inf + idx_to_add = -1 + similarity_to_selected = _cosine_similarity(embedding_list, selected) + for i, query_score in enumerate(similarity_to_query): + if i in idxs: + continue + redundant_score = max(similarity_to_selected[i]) + equation_score = ( + lambda_mult * query_score - (1 - lambda_mult) * redundant_score + ) + if equation_score > best_score: + best_score = equation_score + idx_to_add = i + idxs.append(idx_to_add) + selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) + return idxs + + +def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: + """Row-wise cosine similarity between two equal-width matrices.""" + if len(X) == 0 or len(Y) == 0: + return np.array([]) + + X = np.array(X) + Y = np.array(Y) + if X.shape[1] != Y.shape[1]: + raise ValueError( + f"Number of columns in X and Y must be the same. X has shape {X.shape} " + f"and Y has shape {Y.shape}." + ) + try: + import simsimd as simd # type: ignore + + X = np.array(X, dtype=np.float32) + Y = np.array(Y, dtype=np.float32) + Z = 1 - simd.cdist(X, Y, metric="cosine") + if isinstance(Z, float): + return np.array([Z]) + return np.array(Z) + except ImportError: + X_norm = np.linalg.norm(X, axis=1) + Y_norm = np.linalg.norm(Y, axis=1) + # Ignore divide by zero errors run time warnings as those are handled below. + with np.errstate(divide="ignore", invalid="ignore"): + similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) + similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 + return similarity diff --git a/elasticsearch/store/embedding_service.py b/elasticsearch/store/embedding_service.py new file mode 100644 index 000000000..e6c5a470c --- /dev/null +++ b/elasticsearch/store/embedding_service.py @@ -0,0 +1,132 @@ +import asyncio +from abc import ABC, abstractmethod +from typing import List, Optional + +import nest_asyncio # type: ignore +from elasticsearch import AsyncElasticsearch + +from elasticsearch.store._utilities import create_elasticsearch_client + + +class EmbeddingService(ABC): + @abstractmethod + async def embed_documents_async(self, texts: List[str]) -> List[List[float]]: + """Generate embeddings for a list of documents. + + Args: + texts: A list of document strings to generate embeddings for. + + Returns: + A list of embeddings, one for each document in the input. + """ + + @abstractmethod + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Generate embeddings for a list of documents. + + Args: + texts: A list of document strings to generate embeddings for. + + Returns: + A list of embeddings, one for each document in the input. + """ + + @abstractmethod + async def embed_query_async(self, query: str) -> List[float]: + """Generate an embedding for a single query text. + + Args: + text: The query text to generate an embedding for. + + Returns: + The embedding for the input query text. + """ + + @abstractmethod + def embed_query(self, query: str) -> List[float]: + """Generate an embedding for a single query text. + + Args: + text: The query text to generate an embedding for. + + Returns: + The embedding for the input query text. + """ + + +class ElasticsearchEmbeddings(EmbeddingService): + """Elasticsearch as a service for embedding model inference. + + You need to have an embedding model downloaded and deployed in Elasticsearch: + - https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html + - https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html + """ # noqa: E501 + + def __init__( + self, + agent_header: str, + model_id: str, + input_field: str = "text_field", + num_dimensions: Optional[int] = None, + # Connection params + es_client: Optional[AsyncElasticsearch] = None, + es_url: Optional[str] = None, + es_cloud_id: Optional[str] = None, + es_api_key: Optional[str] = None, + es_user: Optional[str] = None, + es_password: Optional[str] = None, + ): + """ + Args: + agent_header: user agent header specific to the 3rd party integration. + Used for usage tracking in Elastic Cloud. + model_id: The model_id of the model deployed in the Elasticsearch cluster. + input_field: The name of the key for the input text field in the + document. Defaults to 'text_field'. + num_dimensions: The number of embedding dimensions. If None, then dimensions + will be infer from an example inference call. + es_client: Elasticsearch client connection. Alternatively specify the + Elasticsearch connection with the other es_* parameters. + """ + nest_asyncio.apply() + + client = create_elasticsearch_client( + agent_header=agent_header, + client=es_client, + url=es_url, + cloud_id=es_cloud_id, + api_key=es_api_key, + username=es_user, + password=es_password, + ) + + self.client = client.ml + self.model_id = model_id + self.input_field = input_field + self._num_dimensions = num_dimensions + + async def embed_documents_async(self, texts: List[str]) -> List[List[float]]: + result = await self._embedding_func_async(texts) + return result + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return asyncio.get_event_loop().run_until_complete( + self.embed_documents_async(texts) + ) + + async def embed_query_async(self, text: str) -> List[float]: + result = await self._embedding_func_async([text]) + return result[0] + + def embed_query(self, query: str) -> List[float]: + return asyncio.get_event_loop().run_until_complete( + self.embed_query_async(query) + ) + + async def _embedding_func_async(self, texts: List[str]) -> List[List[float]]: + response = await self.client.infer_trained_model( + model_id=self.model_id, docs=[{self.input_field: text} for text in texts] + ) + + embeddings = [doc["predicted_value"] for doc in response["inference_results"]] + return embeddings diff --git a/elasticsearch/store/store.py b/elasticsearch/store/store.py new file mode 100644 index 000000000..da61ee4f4 --- /dev/null +++ b/elasticsearch/store/store.py @@ -0,0 +1,399 @@ +import asyncio +import logging +import uuid +from typing import Any, Callable, Dict, List, Optional + +import nest_asyncio # type: ignore +from elasticsearch import AsyncElasticsearch +from elasticsearch.helpers import BulkIndexError, async_bulk + +from elasticsearch.store._utilities import ( + create_elasticsearch_client, + maximal_marginal_relevance, +) +from elasticsearch.store.embedding_service import EmbeddingService +from elasticsearch.store.strategies import RetrievalStrategy + +logger = logging.getLogger(__name__) + + +class ElasticsearchStore: + """ElasticsearchStore is a higher-level abstraction of indexing and search. + Users can pick from available retrieval strategies. + + Documents are flat text documents. Depending on the strategy, vector embeddings are + - created by the user beforehand + - created by this class in Python + - created in-stack by inference pipelines. + """ + + def __init__( + self, + agent_header: str, + index_name: str, + retrieval_strategy: RetrievalStrategy, + text_field: str = "text_field", + vector_field: str = "vector_field", + metadata_mapping: Optional[dict[str, str]] = None, + # Connection params + es_client: Optional[AsyncElasticsearch] = None, + es_url: Optional[str] = None, + es_cloud_id: Optional[str] = None, + es_api_key: Optional[str] = None, + es_user: Optional[str] = None, + es_password: Optional[str] = None, + es_params: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Args: + agent_header: user agent header specific to the 3rd party integration. + Used for usage tracking in Elastic Cloud. + index_name: The name of the index to query. + retrieval_strategy: how to index and search the data. See the strategies + module for availble strategies. + text_field: Name of the field with the textual data. + vector_field: For strategies that perform embedding inference in Python, + the embedding vector goes in this field. + es_client: Elasticsearch client connection. Alternatively specify the + Elasticsearch connection with the other es_* parameters. + """ + nest_asyncio.apply() + + self.es_client = create_elasticsearch_client( + agent_header=agent_header, + client=es_client, + url=es_url, + cloud_id=es_cloud_id, + api_key=es_api_key, + username=es_user, + password=es_password, + client_params=es_params, + ) + + if hasattr(retrieval_strategy, "text_field"): + retrieval_strategy.text_field = text_field + if hasattr(retrieval_strategy, "vector_field"): + retrieval_strategy.vector_field = vector_field + + self.index_name = index_name + self.retrieval_strategy = retrieval_strategy + self.text_field = text_field + self.vector_field = vector_field + self.metadata_mapping = metadata_mapping + + def close(self): + return asyncio.get_event_loop().run_until_complete(self.es_client.close()) + + async def add_texts_async( + self, + texts: List[str], + metadatas: Optional[List[Dict[str, Any]]] = None, + vectors: Optional[List[List[float]]] = None, + ids: Optional[List[str]] = None, + refresh_indices: bool = True, + create_index_if_not_exists: bool = True, + bulk_kwargs: Optional[Dict[str, Any]] = None, + ) -> List[str]: + """Add documents to the Elasticsearch index. + + Args: + texts: List of text documents. + metadata: Optional list of document metadata. Must be of same length as + texts. + vectors: Optional list of embedding vectors. Must be of same length as + texts. + ids: Optional list of ID strings. Must be of same length as texts. + refresh_indices: Whether to refresh the index after deleting documents. + Defaults to True. + create_index_if_not_exists: Whether to create the index if it does not + exist. Defaults to True. + bulk_kwargs: Arguments to pass to the bulk function when indexing + (for example chunk_size). + + Returns: + List of IDs of the created documents, either echoing the provided one + or returning newly created ones. + """ + bulk_kwargs = bulk_kwargs or {} + ids = ids or [str(uuid.uuid4()) for _ in texts] + requests = [] + + if create_index_if_not_exists: + await self._create_index_if_not_exists() + + for i, text in enumerate(texts): + metadata = metadatas[i] if metadatas else {} + + request: Dict[str, Any] = { + "_op_type": "index", + "_index": self.index_name, + self.text_field: text, + "metadata": metadata, + "_id": ids[i], + } + + if vectors: + request[self.vector_field] = vectors[i] + + request.update(self.retrieval_strategy.embed_for_indexing(text)) + requests.append(request) + + if len(requests) > 0: + try: + success, failed = await async_bulk( + self.es_client, + requests, + stats_only=True, + refresh=refresh_indices, + **bulk_kwargs, + ) + logger.debug( + f"Added {success} and failed to add {failed} texts to index" + ) + + logger.debug(f"added texts {ids} to index") + return ids + except BulkIndexError as e: + logger.error(f"Error adding texts: {e}") + firstError = e.errors[0].get("index", {}).get("error", {}) + logger.error(f"First error reason: {firstError.get('reason')}") + raise e + + else: + logger.debug("No texts to add to index") + return [] + + def add_texts( + self, + texts: List[str], + metadatas: Optional[List[Dict[str, Any]]] = None, + vectors: Optional[List[List[float]]] = None, + ids: Optional[List[str]] = None, + refresh_indices: bool = True, + create_index_if_not_exists: bool = True, + bulk_kwargs: Optional[Dict[str, Any]] = None, + ) -> List[str]: + return asyncio.get_event_loop().run_until_complete( + self.add_texts_async( + texts=texts, + metadatas=metadatas, + vectors=vectors, + ids=ids, + refresh_indices=refresh_indices, + create_index_if_not_exists=create_index_if_not_exists, + bulk_kwargs=bulk_kwargs, + ) + ) + + async def delete_async( + self, + ids: Optional[List[str]] = None, + query: Optional[Dict[str, Any]] = None, + refresh_indices: bool = True, + **delete_kwargs, + ) -> bool: + """Delete documents from the Elasticsearch index. + + Args: + ids: List of IDs of documents to delete. + refresh_indices: Whether to refresh the index after deleting documents. + Defaults to True. + """ + if ids is not None and query is not None: + raise ValueError("one of ids or query must be specified") + elif ids is None and query is None: + raise ValueError("either specify ids or query") + + try: + if ids: + body = [ + {"_op_type": "delete", "_index": self.index_name, "_id": _id} + for _id in ids + ] + await async_bulk( + self.es_client, + body, + refresh=refresh_indices, + ignore_status=404, + **delete_kwargs, + ) + logger.debug(f"Deleted {len(body)} texts from index") + + else: + await self.es_client.delete_by_query( + index=self.index_name, + query=query, + refresh=refresh_indices, + **delete_kwargs, + ) + + except BulkIndexError as e: + logger.error(f"Error deleting texts: {e}") + firstError = e.errors[0].get("index", {}).get("error", {}) + logger.error(f"First error reason: {firstError.get('reason')}") + raise e + + return True + + def delete( + self, + ids: Optional[List[str]] = None, + query: Optional[Dict[str, Any]] = None, + refresh_indices: bool = True, + ) -> bool: + return asyncio.get_event_loop().run_until_complete( + self.delete_async(ids=ids, query=query, refresh_indices=refresh_indices) + ) + + async def search_async( + self, + query: Optional[str], + query_vector: Optional[List[float]] = None, + k: int = 4, + num_candidates: int = 50, + fields: Optional[List[str]] = None, + filter: Optional[List[dict]] = None, + custom_query: Optional[Callable[[Dict, Optional[str]], Dict]] = None, + ) -> List[Dict[str, Any]]: + """ + Args: + query: Input query string. + query_vector: Input embedding vector. If given, input query string is + ignored. + k: Number of returned results. + num_candidates: Number of candidates to fetch from data nodes in knn. + fields: List of field names to return. + filter: Elasticsearch filters to apply. + custom_query: Function to modify the Elasticsearch query body before it is + sent to Elasticsearch. + + Returns: + List of document hits. Includes _index, _id, _score and _source. + """ + if fields is None: + fields = [] + if "metadata" not in fields: + fields.append("metadata") + if self.text_field not in fields: + fields.append(self.text_field) + + query_body = self.retrieval_strategy.es_query( + query=query, + k=k, + num_candidates=num_candidates, + filter=filter or [], + query_vector=query_vector, + ) + logger.debug(f"Query body: {query_body}") + + if custom_query is not None: + query_body = custom_query(query_body, query) + logger.debug(f"Calling custom_query, Query body now: {query_body}") + + response = await self.es_client.search( + index=self.index_name, + **query_body, + size=k, + source=True, + source_includes=fields, + ) + + return response["hits"]["hits"] + + def search( + self, + query: Optional[str], + query_vector: Optional[List[float]] = None, + k: int = 4, + num_candidates: int = 50, + fields: Optional[List[str]] = None, + filter: Optional[List[dict]] = None, + custom_query: Optional[Callable[[Dict, Optional[str]], Dict]] = None, + ) -> List[Dict[str, Any]]: + return asyncio.get_event_loop().run_until_complete( + self.search_async( + query=query, + query_vector=query_vector, + k=k, + num_candidates=num_candidates, + fields=fields, + filter=filter, + custom_query=custom_query, + ) + ) + + async def _create_index_if_not_exists(self) -> None: + exists = await self.es_client.indices.exists(index=self.index_name) + if exists.meta.status == 200: + logger.debug(f"Index {self.index_name} already exists. Skipping creation.") + else: + await self.retrieval_strategy.create_index( + client=self.es_client, + index_name=self.index_name, + metadata_mapping=self.metadata_mapping, + ) + + def max_marginal_relevance_search( + self, + embedding_service: EmbeddingService, + query: str, + vector_field: str, + k: int = 4, + num_candidates: int = 20, + lambda_mult: float = 0.5, + fields: Optional[List[str]] = None, + custom_query: Optional[Callable[[Dict, Optional[str]], Dict]] = None, + ) -> List[Dict]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + fields: Other fields to get from elasticsearch source. These fields + will be added to the document metadata. + + Returns: + List[Document]: A list of Documents selected by maximal marginal relevance. + """ + remove_vector_query_field_from_metadata = True + if fields is None: + fields = [vector_field] + elif vector_field not in fields: + fields.append(vector_field) + else: + remove_vector_query_field_from_metadata = False + + # Embed the query + query_embedding = embedding_service.embed_query(query) + + # Fetch the initial documents + got_hits = self.search( + query=None, + query_vector=query_embedding, + k=num_candidates, + fields=fields, + custom_query=custom_query, + ) + + # Get the embeddings for the fetched documents + got_embeddings = [hit["_source"][vector_field] for hit in got_hits] + + # Select documents using maximal marginal relevance + selected_indices = maximal_marginal_relevance( + query_embedding, got_embeddings, lambda_mult=lambda_mult, k=k + ) + selected_hits = [got_hits[i] for i in selected_indices] + + if remove_vector_query_field_from_metadata: + for hit in selected_hits: + del hit["_source"][vector_field] + + return selected_hits diff --git a/elasticsearch/store/strategies.py b/elasticsearch/store/strategies.py new file mode 100644 index 000000000..4daa0d097 --- /dev/null +++ b/elasticsearch/store/strategies.py @@ -0,0 +1,563 @@ +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union, cast + +from elasticsearch import AsyncElasticsearch + +from elasticsearch.store._utilities import model_must_be_deployed_async +from elasticsearch.store.embedding_service import EmbeddingService + + +class DistanceMetric(str, Enum): + """Enumerator of all Elasticsearch dense vector distance metrics.""" + + COSINE = "COSINE" + DOT_PRODUCT = "DOT_PRODUCT" + EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE" + MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT" + + +class RetrievalStrategy(ABC): + @abstractmethod + def es_query( + self, + query: Optional[str], + k: int, + num_candidates: int, + filter: List[dict] = [], + query_vector: Optional[List[float]] = None, + ) -> Dict: + """ + Returns the Elasticsearch query body for the given parameters. + The store will execute the query. + + Args: + query: The text query. Can be None if query_vector is given. + k: The total number of results to retrieve. + num_candidates: The number of results to fetch initially in knn search. + filter: List of filter clauses to apply to the query. + query_vector: The query vector. Can be None if a query string is given. + + Returns: + Dict: The Elasticsearch query body. + """ + + @abstractmethod + async def create_index( + self, + client: AsyncElasticsearch, + index_name: str, + metadata_mapping: Optional[dict[str, str]], + ) -> None: + """ + Create the required index and do necessary preliminary work, like + creating inference pipelines or checking if a required model was deployed. + + Args: + client: Elasticsearch client connection. + index_name: The name of the Elasticsearch index to create. + metadata_mapping: Flat dictionary with field and field type pairs that + describe the schema of the metadata. + """ + + def embed_for_indexing(self, text: str) -> Dict[str, Any]: + """ + If this strategy creates vector embeddings in Python (not in Elasticsearch), + this method is used to apply the inference. + The output is a dictionary with the vector field and the vector embedding. + It is merged in the ElasticserachStore with the rest of the document (text data, + metadata) before indexing. + + Args: + text: Text input that can be used as input for inference. + + Returns: + Dict: field and value pairs that extend the document to be indexed. + """ + return {} + + +# TODO test when repsective image is released +class Semantic(RetrievalStrategy): + """Dense or sparse retrieval with in-stack inference using semantic_text fields.""" + + def __init__( + self, + model_id: str, + text_field: str = "text_field", + inference_field: str = "text_semantic", + ): + self.model_id = model_id + self.text_field = text_field + self.inference_field = inference_field + + def es_query( + self, + query: Optional[str], + k: int, + num_candidates: int, + filter: List[dict] = [], + query_vector: Optional[List[float]] = None, + ) -> Dict: + if query_vector: + raise ValueError( + "Cannot do sparse retrieval with a query_vector. " + "Inference is currently always applied in-stack." + ) + + return { + "query": { + "semantic": { + self.text_field: query, + }, + }, + "filter": filter, + } + + async def create_index( + self, + client: AsyncElasticsearch, + index_name: str, + metadata_mapping: Optional[dict[str, str]], + ) -> None: + if self.model_id: + await model_must_be_deployed_async(client, self.model_id) + + mappings: dict[str, Any] = { + "properties": { + self.inference_field: { + "type": "semantic_text", + "model_id": self.model_id, + } + } + } + if metadata_mapping: + mappings["properties"]["metadata"] = {"properties": metadata_mapping} + + await client.indices.create(index=index_name, mappings=mappings) + + +class SparseVector(RetrievalStrategy): + """Sparse retrieval strategy using the `text_expansion` processor.""" + + def __init__( + self, + model_id: str = ".elser_model_2", + text_field: str = "text_field", + vector_field: str = "vector_field", + ): + self.model_id = model_id + self.text_field = text_field + self.vector_field = vector_field + self._tokens_field = "tokens" + + def es_query( + self, + query: Optional[str], + k: int, + num_candidates: int, + filter: List[dict] = [], + query_vector: Optional[List[float]] = None, + ) -> Dict: + if query_vector: + raise ValueError( + "Cannot do sparse retrieval with a query_vector. " + "Inference is currently always applied in Elasticsearch." + ) + if query is None: + raise ValueError("please specify a query string") + + return { + "query": { + "bool": { + "must": [ + { + "text_expansion": { + f"{self.vector_field}.{self._tokens_field}": { + "model_id": self.model_id, + "model_text": query, + } + } + } + ], + "filter": filter, + } + }, + "size": k, + } + + async def create_index( + self, + client: AsyncElasticsearch, + index_name: str, + metadata_mapping: Optional[dict[str, str]], + ) -> None: + pipeline_name = f"{self.model_id}_sparse_embedding" + + if self.model_id: + await model_must_be_deployed_async(client, self.model_id) + + # Create a pipeline for the model + await client.ingest.put_pipeline( + id=pipeline_name, + description="Embedding pipeline for ElasticsearchStore", + processors=[ + { + "inference": { + "model_id": self.model_id, + "target_field": self.vector_field, + "field_map": {self.text_field: "text_field"}, + "inference_config": { + "text_expansion": {"results_field": self._tokens_field} + }, + } + } + ], + ) + + mappings = { + "properties": { + self.vector_field: { + "properties": {self._tokens_field: {"type": "rank_features"}} + } + } + } + if metadata_mapping: + mappings["properties"]["metadata"] = {"properties": metadata_mapping} + settings = {"default_pipeline": pipeline_name} + + await client.indices.create( + index=index_name, mappings=mappings, settings=settings + ) + + return None + + +class DenseVector(RetrievalStrategy): + """K-nearest-neighbors retrieval.""" + + def __init__( + self, + knn_type: Literal["hnsw", "int8_hnsw", "flat", "int8_flat"] = "hnsw", + vector_field: str = "vector_field", + distance: DistanceMetric = DistanceMetric.COSINE, + embedding_service: Optional[EmbeddingService] = None, + model_id: Optional[str] = None, + num_dimensions: Optional[int] = None, + hybrid: bool = False, + rrf: Union[bool, dict] = True, + text_field: Optional[str] = "text_field", + ): + if embedding_service and model_id: + raise ValueError("either specify embedding_service or model_id, not both") + if model_id and not num_dimensions: + raise ValueError( + "if model_id is specified, num_dimensions must also be specified" + ) + if hybrid and not text_field: + raise ValueError( + "to enable hybrid you have to specify a text_field (for BM25 matching)" + ) + + self.knn_type = knn_type + self.vector_field = vector_field + self.distance = distance + self.embedding_service = embedding_service + self.model_id = model_id + self.num_dimensions = num_dimensions + self.hybrid = hybrid + self.rrf = rrf + self.text_field = text_field + + def es_query( + self, + query: Optional[str], + k: int, + num_candidates: int, + filter: List[dict] = [], + query_vector: Optional[List[float]] = None, + ) -> Dict: + knn = { + "filter": filter, + "field": self.vector_field, + "k": k, + "num_candidates": num_candidates, + } + + if query_vector: + knn["query_vector"] = query_vector + elif self.embedding_service: + knn["query_vector"] = self.embedding_service.embed_query(cast(str, query)) + else: + # Inference in Elasticsearch. When initializing we make sure to always have + # a model_id if don't have an embedding_service. + knn["query_vector_builder"] = { + "text_embedding": { + "model_id": self.model_id, + "model_text": query, + } + } + + if self.hybrid: + x = self._hybrid(query=cast(str, query), knn=knn, filter=filter) + print(x) + return x + + return {"knn": knn} + + async def create_index( + self, + client: AsyncElasticsearch, + index_name: str, + metadata_mapping: Optional[dict[str, str]], + ) -> None: + if self.embedding_service and not self.num_dimensions: + self.num_dimensions = len( + self.embedding_service.embed_query("get number of dimensions") + ) + + if self.model_id: + await model_must_be_deployed_async(client, self.model_id) + + if self.distance is DistanceMetric.COSINE: + similarityAlgo = "cosine" + elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE: + similarityAlgo = "l2_norm" + elif self.distance is DistanceMetric.DOT_PRODUCT: + similarityAlgo = "dot_product" + elif self.distance is DistanceMetric.MAX_INNER_PRODUCT: + similarityAlgo = "max_inner_product" + else: + raise ValueError(f"Similarity {self.distance} not supported.") + + mappings = { + "properties": { + self.vector_field: { + "type": "dense_vector", + "dims": self.num_dimensions, + "index": True, + "similarity": similarityAlgo, + }, + } + } + if metadata_mapping: + mappings["properties"]["metadata"] = {"properties": metadata_mapping} + + await client.indices.create(index=index_name, mappings=mappings) + + return None + + def embed_for_indexing(self, text: str) -> Dict[str, Any]: + if self.embedding_service: + vector = self.embedding_service.embed_query(text) + return {self.vector_field: vector} + return {} + + def _hybrid(self, query: str, knn: dict, filter: list): + # Add a query to the knn query. + # RRF is used to even the score from the knn query and text query + # RRF has two optional parameters: {'rank_constant':int, 'window_size':int} + # https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html + query_body = { + "knn": knn, + "query": { + "bool": { + "must": [ + { + "match": { + self.text_field: { + "query": query, + } + } + } + ], + "filter": filter, + } + }, + } + + if isinstance(self.rrf, dict): + query_body["rank"] = {"rrf": self.rrf} + elif isinstance(self.rrf, bool) and self.rrf is True: + query_body["rank"] = {"rrf": {}} + + return query_body + + +class DenseVectorScriptScore(RetrievalStrategy): + """Exact nearest neighbors retrieval using the `script_score` query.""" + + def __init__( + self, + embedding_service: EmbeddingService, + vector_field: str = "vector_field", + distance: DistanceMetric = DistanceMetric.COSINE, + num_dimensions: Optional[int] = None, + ) -> None: + self.vector_field = vector_field + self.distance = distance + self.embedding_service = embedding_service + self.num_dimensions = num_dimensions + + def es_query( + self, + query: Optional[str], + k: int, + num_candidates: int, + filter: List[dict] = [], + query_vector: Optional[List[float]] = None, + ) -> Dict: + if self.distance is DistanceMetric.COSINE: + similarityAlgo = ( + f"cosineSimilarity(params.query_vector, '{self.vector_field}') + 1.0" + ) + elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE: + similarityAlgo = ( + f"1 / (1 + l2norm(params.query_vector, '{self.vector_field}'))" + ) + elif self.distance is DistanceMetric.DOT_PRODUCT: + similarityAlgo = f""" + double value = dotProduct(params.query_vector, '{self.vector_field}'); + return sigmoid(1, Math.E, -value); + """ + elif self.distance is DistanceMetric.MAX_INNER_PRODUCT: + similarityAlgo = f""" + double value = dotProduct(params.query_vector, '{self.vector_field}'); + if (dotProduct < 0) {{ + return 1 / (1 + -1 * dotProduct); + }} + return dotProduct + 1; + """ + else: + raise ValueError(f"Similarity {self.distance} not supported.") + + queryBool: Dict = {"match_all": {}} + if filter: + queryBool = {"bool": {"filter": filter}} + + if not query_vector: + if not self.embedding_service: + raise ValueError( + "if not embedding_service is given, you need to " + "procive a query_vector" + ) + if not query: + raise ValueError("either specify a query string or a query_vector") + query_vector = self.embedding_service.embed_query(query) + + return { + "query": { + "script_score": { + "query": queryBool, + "script": { + "source": similarityAlgo, + "params": {"query_vector": query_vector}, + }, + }, + } + } + + async def create_index( + self, + client: AsyncElasticsearch, + index_name: str, + metadata_mapping: Optional[dict[str, str]], + ) -> None: + if not self.num_dimensions: + self.num_dimensions = len( + self.embedding_service.embed_query("get number of dimensions") + ) + + mappings = { + "properties": { + self.vector_field: { + "type": "dense_vector", + "dims": self.num_dimensions, + "index": False, + } + } + } + if metadata_mapping: + mappings["properties"]["metadata"] = {"properties": metadata_mapping} + + await client.indices.create(index=index_name, mappings=mappings) + + return None + + def embed_for_indexing(self, text: str) -> Dict[str, Any]: + return {self.vector_field: self.embedding_service.embed_query(text)} + + +class BM25(RetrievalStrategy): + def __init__( + self, + text_field: str = "text_field", + k1: Optional[float] = None, + b: Optional[float] = None, + ): + self.text_field = text_field + self.k1 = k1 + self.b = b + + def es_query( + self, + query: Optional[str], + k: int, + num_candidates: int, + filter: List[dict] = [], + query_vector: Optional[List[float]] = None, + ) -> Dict: + return { + "query": { + "bool": { + "must": [ + { + "match": { + self.text_field: { + "query": query, + } + }, + }, + ], + "filter": filter, + }, + }, + } + + async def create_index( + self, + client: AsyncElasticsearch, + index_name: str, + metadata_mapping: Optional[dict[str, str]], + ) -> None: + similarity_name = "custom_bm25" + + mappings: Dict = { + "properties": { + self.text_field: { + "type": "text", + "similarity": similarity_name, + }, + }, + } + if metadata_mapping: + mappings["properties"]["metadata"] = {"properties": metadata_mapping} + + bm25: Dict = { + "type": "BM25", + } + if self.k1 is not None: + bm25["k1"] = self.k1 + if self.b is not None: + bm25["b"] = self.b + settings = { + "similarity": { + similarity_name: bm25, + } + } + + await client.indices.create( + index=index_name, mappings=mappings, settings=settings + ) + + return None diff --git a/test_elasticsearch/test_store_integration/__init__.py b/test_elasticsearch/test_store_integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test_elasticsearch/test_store_integration/_test_utilities.py b/test_elasticsearch/test_store_integration/_test_utilities.py new file mode 100644 index 000000000..e9158eed9 --- /dev/null +++ b/test_elasticsearch/test_store_integration/_test_utilities.py @@ -0,0 +1,130 @@ +import asyncio +import os +from typing import Any, Dict, List, Optional + +import nest_asyncio # type: ignore +from elastic_transport import AsyncTransport +from elasticsearch import AsyncElasticsearch + +from elasticsearch.store._utilities import model_is_deployed_async +from elasticsearch.store.embedding_service import EmbeddingService + + +class FakeEmbeddings(EmbeddingService): + """Fake embeddings functionality for testing.""" + + def __init__(self, dimensionality: int = 10) -> None: + nest_asyncio.apply() + + self.dimensionality = dimensionality + + def num_dimensions(self) -> int: + return self.dimensionality + + async def embed_documents_async(self, texts: List[str]) -> List[List[float]]: + """Return simple embeddings. Embeddings encode each text as its index.""" + return [ + [float(1.0)] * (self.dimensionality - 1) + [float(i)] + for i in range(len(texts)) + ] + + async def embed_query_async(self, text: str) -> List[float]: + """Return constant query embeddings. + Embeddings are identical to embed_documents(texts)[0]. + Distance to each text will be that text's index, + as it was passed to embed_documents. + """ + return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)] + + def embed_query(self, text: str) -> List[float]: + return asyncio.get_event_loop().run_until_complete(self.embed_query_async(text)) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return asyncio.get_event_loop().run_until_complete( + self.embed_documents_async(texts) + ) + + +class ConsistentFakeEmbeddings(FakeEmbeddings): + """Fake embeddings which remember all the texts seen so far to return consistent + vectors for the same texts.""" + + def __init__(self, dimensionality: int = 10) -> None: + self.known_texts: List[str] = [] + self.dimensionality = dimensionality + + def num_dimensions(self) -> int: + return self.dimensionality + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return consistent embeddings for each text seen so far.""" + out_vectors = [] + for text in texts: + if text not in self.known_texts: + self.known_texts.append(text) + vector = [float(1.0)] * (self.dimensionality - 1) + [ + float(self.known_texts.index(text)) + ] + out_vectors.append(vector) + return out_vectors + + def embed_query(self, text: str) -> List[float]: + """Return consistent embeddings for the text, if seen before, or a constant + one if the text is unknown.""" + return self.embed_documents([text])[0] + + +class RequestSavingTransport(AsyncTransport): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.requests: List[Dict] = [] + + async def perform_request(self, *args, **kwargs): # type: ignore + self.requests.append(kwargs) + return await super().perform_request(*args, **kwargs) + + +def create_es_client( + es_params: Optional[Dict[str, str]] = None, es_kwargs: Dict = {} +) -> AsyncElasticsearch: + if es_params is None: + es_params = read_env() + if not es_kwargs: + es_kwargs = {} + + if "es_cloud_id" in es_params: + return AsyncElasticsearch( + cloud_id=es_params["es_cloud_id"], + api_key=es_params["es_api_key"], + **es_kwargs, + ) + return AsyncElasticsearch(hosts=[es_params["es_url"]], **es_kwargs) + + +def create_requests_saving_client() -> AsyncElasticsearch: + return create_es_client(es_kwargs={"transport_class": RequestSavingTransport}) + + +async def clear_test_indices(client: AsyncElasticsearch) -> None: + response = await client.indices.get(index="_all") + index_names = response.keys() + for index_name in index_names: + if index_name.startswith("test_"): + await client.indices.delete(index=index_name) + await client.indices.refresh(index="_all") + + +def model_is_deployed(es_client: AsyncElasticsearch, model_id: str) -> bool: + return asyncio.get_event_loop().run_until_complete( + model_is_deployed_async(es_client, model_id) + ) + + +def read_env() -> Dict: + url = os.environ.get("ES_URL", "http://localhost:9200") + cloud_id = os.environ.get("ES_CLOUD_ID") + api_key = os.environ.get("ES_API_KEY") + + if cloud_id: + return {"es_cloud_id": cloud_id, "es_api_key": api_key} + return {"es_url": url} diff --git a/test_elasticsearch/test_store_integration/docker-compose.yml b/test_elasticsearch/test_store_integration/docker-compose.yml new file mode 100644 index 000000000..b0e832e37 --- /dev/null +++ b/test_elasticsearch/test_store_integration/docker-compose.yml @@ -0,0 +1,34 @@ +version: "3" + +services: + elasticsearch: + image: elasticsearch:8.13.0 + environment: + - discovery.type=single-node + - xpack.license.self_generated.type=trial + - xpack.security.enabled=false # disable password and TLS; never do this in production! + ports: + - "9200:9200" + healthcheck: + test: + [ + "CMD-SHELL", + "curl --silent --fail http://localhost:9200/_cluster/health || exit 1" + ] + interval: 10s + retries: 60 + + kibana: + image: kibana:8.13.0 + environment: + - ELASTICSEARCH_URL=http://elasticsearch:9200 + ports: + - "5601:5601" + healthcheck: + test: + [ + "CMD-SHELL", + "curl --silent --fail http://localhost:5601/login || exit 1" + ] + interval: 10s + retries: 60 diff --git a/test_elasticsearch/test_store_integration/test_embedding_service.py b/test_elasticsearch/test_store_integration/test_embedding_service.py new file mode 100644 index 000000000..61119396c --- /dev/null +++ b/test_elasticsearch/test_store_integration/test_embedding_service.py @@ -0,0 +1,47 @@ +import os + +import pytest +from elasticsearch import AsyncElasticsearch + +from elasticsearch.store.embedding_service import ElasticsearchEmbeddings + +from ._test_utilities import model_is_deployed + +# deployed with +# https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html +MODEL_ID = os.getenv("MODEL_ID", "sentence-transformers__msmarco-minilm-l-12-v3") +NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384")) + +ES_URL = os.environ.get("ES_URL", "http://localhost:9200") +ES_CLIENT = AsyncElasticsearch(hosts=[ES_URL]) + + +@pytest.mark.skipif( + not model_is_deployed(ES_CLIENT, MODEL_ID), + reason=f"{MODEL_ID} model is not deployed in ML Node, skipping test", +) +def test_elasticsearch_embedding_documents() -> None: + """Test Elasticsearch embedding documents.""" + documents = ["foo bar", "bar foo", "foo"] + embedding = ElasticsearchEmbeddings( + agent_header="test", model_id=MODEL_ID, es_url=ES_URL + ) + output = embedding.embed_documents(documents) + assert len(output) == 3 + assert len(output[0]) == NUM_DIMENSIONS + assert len(output[1]) == NUM_DIMENSIONS + assert len(output[2]) == NUM_DIMENSIONS + + +@pytest.mark.skipif( + not model_is_deployed(ES_CLIENT, MODEL_ID), + reason=f"{MODEL_ID} model is not deployed in ML Node, skipping test", +) +def test_elasticsearch_embedding_query() -> None: + """Test Elasticsearch embedding query.""" + document = "foo bar" + embedding = ElasticsearchEmbeddings( + agent_header="test", model_id=MODEL_ID, es_url=ES_URL + ) + output = embedding.embed_query(document) + assert len(output) == NUM_DIMENSIONS diff --git a/test_elasticsearch/test_store_integration/test_store.py b/test_elasticsearch/test_store_integration/test_store.py new file mode 100644 index 000000000..552a61fb7 --- /dev/null +++ b/test_elasticsearch/test_store_integration/test_store.py @@ -0,0 +1,939 @@ +import logging +import uuid +from functools import partial +from typing import Any, AsyncGenerator, List, Optional, Union, cast + +import pytest +import pytest_asyncio +from elasticsearch import AsyncElasticsearch, NotFoundError +from elasticsearch.helpers import BulkIndexError + +from elasticsearch.store.store import ElasticsearchStore +from elasticsearch.store.strategies import ( + BM25, + DenseVector, + DenseVectorScriptScore, + DistanceMetric, + Semantic, +) + +from ._test_utilities import ( + ConsistentFakeEmbeddings, + FakeEmbeddings, + RequestSavingTransport, + clear_test_indices, + create_es_client, + create_requests_saving_client, + model_is_deployed, + read_env, +) + +logging.basicConfig(level=logging.DEBUG) + +""" +docker-compose up elasticsearch + +By default runs against local docker instance of Elasticsearch. +To run against Elastic Cloud, set the following environment variables: +- ES_CLOUD_ID +- ES_API_KEY + +Some of the tests require the following models to be deployed in the ML Node: +- elser (can be downloaded and deployed through Kibana and trained models UI) +- sentence-transformers__all-minilm-l6-v2 (can be deployed through the API, + loaded via eland) + +These tests that require the models to be deployed are skipped by default. +Enable them by adding the model name to the modelsDeployed list below. +""" + +ELSER_MODEL_ID = ".elser_model_2" +TRANSFORMER_MODEL_ID = "sentence-transformers__all-minilm-l6-v2" + + +class TestElasticsearch: + @pytest_asyncio.fixture(autouse=True) + async def es_client(self) -> AsyncGenerator[AsyncElasticsearch, None]: + params = read_env() + client = create_es_client(params) + + yield client + + # clear indices + await clear_test_indices(client) + + # clear all test pipelines + try: + response = await client.ingest.get_pipeline(id="test_*,*_sparse_embedding") + + for pipeline_id, _ in response.items(): + try: + await client.ingest.delete_pipeline(id=pipeline_id) + print(f"Deleted pipeline: {pipeline_id}") # noqa: T201 + except Exception as e: + print(f"Pipeline error: {e}") # noqa: T201 + + except Exception: + pass + finally: + await client.close() + + @pytest_asyncio.fixture(autouse=True) + async def requests_saving_client(self) -> AsyncGenerator[AsyncElasticsearch, None]: + client = create_requests_saving_client() + try: + yield client + finally: + await client.close() + + @pytest.fixture(scope="function") + def index_name(self) -> str: + """Return the index name.""" + return f"test_{uuid.uuid4().hex}" + + def test_initialize_from_params(self, index_name: str) -> None: + params = read_env() + agent_header = "test initialize from params" + store = ElasticsearchStore( + agent_header=agent_header, + index_name=index_name, + retrieval_strategy=BM25(), + **params, + ) + + assert store.es_client._headers["User-Agent"] == agent_header + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + output = store.search("foo", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + @pytest.mark.asyncio + async def test_search_without_metadata( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """Test end to end construction and search without metadata.""" + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "knn": { + "field": "vector_field", + "filter": [], + "k": 1, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + } + } + return query_body + + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + @pytest.mark.asyncio + async def test_search_without_metadata_async( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """Test end to end construction and search without metadata.""" + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + output = await store.search_async("foo", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + def test_add_vectors(self, es_client: AsyncElasticsearch, index_name: str) -> None: + """ + Test adding pre-built embeddings instead of using inference for the texts. + This allows you to separate the embeddings text and the page_content + for better proximity between user's question and embedded text. + For example, your embedding text can be a question, whereas page_content + is the answer. + """ + embeddings = ConsistentFakeEmbeddings() + texts = ["foo1", "foo2", "foo3"] + metadatas = [{"page": i} for i in range(len(texts))] + + """In real use case, embedding_input can be questions for each text""" + embedding_vectors = embeddings.embed_documents(texts) + + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=DenseVector(embedding_service=embeddings), + es_client=es_client, + ) + + store.add_texts(texts=texts, vectors=embedding_vectors, metadatas=metadatas) + output = store.search("foo1", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["foo1"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] + + def test_search_with_metadata( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=DenseVector( + embedding_service=ConsistentFakeEmbeddings() + ), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + store.add_texts(texts=texts, metadatas=metadatas) + + output = store.search("foo", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] + + output = store.search("bar", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["bar"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] + + def test_search_with_filter( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + es_client=es_client, + ) + + texts = ["foo", "foo", "foo"] + metadatas = [{"page": i} for i in range(len(texts))] + store.add_texts(texts=texts, metadatas=metadatas) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "knn": { + "field": "vector_field", + "filter": [{"term": {"metadata.page": "1"}}], + "k": 3, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + } + } + return query_body + + output = store.search( + query="foo", + k=3, + filter=[{"term": {"metadata.page": "1"}}], + custom_query=assert_query, + ) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] + + def test_search_script_score( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=DenseVectorScriptScore( + embedding_service=FakeEmbeddings() + ), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + expected_query = { + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": "cosineSimilarity(params.query_vector, 'vector_field') + 1.0", # noqa: E501 + "params": { + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ] + }, + }, + } + } + } + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == expected_query + return query_body + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + def test_search_script_score_with_filter( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=DenseVectorScriptScore( + embedding_service=FakeEmbeddings() + ), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + store.add_texts(texts=texts, metadatas=metadatas) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + expected_query = { + "query": { + "script_score": { + "query": {"bool": {"filter": [{"term": {"metadata.page": 0}}]}}, + "script": { + "source": "cosineSimilarity(params.query_vector, 'vector_field') + 1.0", # noqa: E501 + "params": { + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ] + }, + }, + } + } + } + assert query_body == expected_query + return query_body + + output = store.search( + "foo", + k=1, + custom_query=assert_query, + filter=[{"term": {"metadata.page": 0}}], + ) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] + + def test_search_script_score_distance_dot_product( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=DenseVectorScriptScore( + embedding_service=FakeEmbeddings(), + distance=DistanceMetric.DOT_PRODUCT, + ), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": """ + double value = dotProduct(params.query_vector, 'vector_field'); + return sigmoid(1, Math.E, -value); + """, + "params": { + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ] + }, + }, + } + } + } + return query_body + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + def test_search_knn_with_hybrid_search( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=DenseVector( + embedding_service=FakeEmbeddings(), + hybrid=True, + ), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "knn": { + "field": "vector_field", + "filter": [], + "k": 1, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + }, + "query": { + "bool": { + "filter": [], + "must": [{"match": {"text_field": {"query": "foo"}}}], + } + }, + "rank": {"rrf": {}}, + } + return query_body + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + @pytest.mark.asyncio + async def test_search_knn_with_hybrid_search_rrf( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """Test end to end construction and rrf hybrid search with metadata.""" + texts = ["foo", "bar", "baz"] + + def assert_query( + query_body: dict, + query: Optional[str], + expected_rrf: Union[dict, bool], + ) -> dict: + cmp_query_body = { + "knn": { + "field": "vector_field", + "filter": [], + "k": 3, + "num_candidates": 50, + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ], + }, + "query": { + "bool": { + "filter": [], + "must": [{"match": {"text_field": {"query": "foo"}}}], + } + }, + } + + if isinstance(expected_rrf, dict): + cmp_query_body["rank"] = {"rrf": expected_rrf} + elif isinstance(expected_rrf, bool) and expected_rrf is True: + cmp_query_body["rank"] = {"rrf": {}} + + assert query_body == cmp_query_body + + return query_body + + # 1. check query_body is okay + rrf_test_cases: List[Union[dict, bool]] = [ + True, + False, + {"rank_constant": 1, "window_size": 5}, + ] + for rrf_test_case in rrf_test_cases: + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=DenseVector( + embedding_service=FakeEmbeddings(), + hybrid=True, + rrf=rrf_test_case, + ), + es_client=es_client, + ) + store.add_texts(texts) + + ## without fetch_k parameter + output = store.search( + "foo", + k=3, + custom_query=partial(assert_query, expected_rrf=rrf_test_case), + ) + + # 2. check query result is okay + es_output = await store.es_client.search( + index=index_name, + query={ + "bool": { + "filter": [], + "must": [{"match": {"text_field": {"query": "foo"}}}], + } + }, + knn={ + "field": "vector_field", + "filter": [], + "k": 3, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + }, + size=3, + rank={"rrf": {"rank_constant": 1, "window_size": 5}}, + ) + + assert [o["_source"]["text_field"] for o in output] == [ + e["_source"]["text_field"] for e in es_output["hits"]["hits"] + ] + + # 3. check rrf default option is okay + store = ElasticsearchStore( + agent_header="test", + index_name=f"{index_name}_default", + retrieval_strategy=DenseVector( + embedding_service=FakeEmbeddings(), + hybrid=True, + ), + es_client=es_client, + ) + store.add_texts(texts) + + ## with fetch_k parameter + output = store.search( + "foo", + k=3, + num_candidates=50, + custom_query=partial(assert_query, expected_rrf={}), + ) + + def test_search_knn_with_custom_query_fn( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """test that custom query function is called + with the query string and query body""" + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + es_client=es_client, + ) + + def my_custom_query(query_body: dict, query: Optional[str]) -> dict: + assert query == "foo" + assert query_body == { + "knn": { + "field": "vector_field", + "filter": [], + "k": 1, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + } + } + return {"query": {"match": {"text_field": {"query": "bar"}}}} + + """Test end to end construction and search with metadata.""" + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + output = store.search("foo", k=1, custom_query=my_custom_query) + assert [doc["_source"]["text_field"] for doc in output] == ["bar"] + + @pytest.mark.asyncio + @pytest.mark.skipif( + not model_is_deployed(create_es_client(), TRANSFORMER_MODEL_ID), + reason=f"{TRANSFORMER_MODEL_ID} model not deployed in ML Node, " + "skipping test", + ) + async def test_search_with_knn_infer_instack( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """test end to end with knn retrieval strategy and inference in-stack""" + text_field = "text_field" + + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=Semantic( + model_id="sentence-transformers__all-minilm-l6-v2", + text_field=text_field, + ), + es_client=es_client, + ) + + # setting up the pipeline for inference + await store.es_client.ingest.put_pipeline( + id="test_pipeline", + processors=[ + { + "inference": { + "model_id": TRANSFORMER_MODEL_ID, + "field_map": {"query_field": text_field}, + "target_field": "vector_query_field", + } + } + ], + ) + + # creating a new index with the pipeline, + # not relying on langchain to create the index + await store.es_client.indices.create( + index=index_name, + mappings={ + "properties": { + text_field: {"type": "text_field"}, + "vector_query_field": { + "properties": { + "predicted_value": { + "type": "dense_vector", + "dims": 384, + "index": True, + "similarity": "l2_norm", + } + } + }, + } + }, + settings={"index": {"default_pipeline": "test_pipeline"}}, + ) + + # adding documents to the index + texts = ["foo", "bar", "baz"] + + for i, text in enumerate(texts): + await store.es_client.create( + index=index_name, + id=str(i), + document={text_field: text, "metadata": {}}, + ) + + await store.es_client.indices.refresh(index=index_name) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "knn": { + "filter": [], + "field": "vector_query_field.predicted_value", + "k": 1, + "num_candidates": 50, + "query_vector_builder": { + "text_embedding": { + "model_id": TRANSFORMER_MODEL_ID, + "model_text": "foo", + } + }, + } + } + return query_body + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + output = store.search("bar", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["bar"] + + @pytest.mark.skipif( + not model_is_deployed(create_es_client(), ELSER_MODEL_ID), + reason=f"{ELSER_MODEL_ID} model not deployed in ML Node, skipping test", + ) + def test_search_with_sparse_infer_instack( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """test end to end with sparse retrieval strategy and inference in-stack""" + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=Semantic(model_id=ELSER_MODEL_ID), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + output = store.search("foo", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + def test_deployed_model_check_fails_semantic( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """test that exceptions are raised if a specified model is not deployed""" + with pytest.raises(NotFoundError): + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=Semantic(model_id="non-existing model ID"), + es_client=es_client, + ) + store.add_texts(["foo", "bar", "baz"]) + + def test_search_bm25(self, es_client: AsyncElasticsearch, index_name: str) -> None: + """Test end to end using the BM25 retrieval strategy.""" + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=BM25(), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "query": { + "bool": { + "must": [{"match": {"text_field": {"query": "foo"}}}], + "filter": [], + } + } + } + return query_body + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + def test_search_bm25_with_filter( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """Test end to using the BM25 retrieval strategy with metadata.""" + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=BM25(), + es_client=es_client, + ) + + texts = ["foo", "foo", "foo"] + metadatas = [{"page": i} for i in range(len(texts))] + store.add_texts(texts=texts, metadatas=metadatas) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "query": { + "bool": { + "must": [{"match": {"text_field": {"query": "foo"}}}], + "filter": [{"term": {"metadata.page": 1}}], + } + } + } + return query_body + + output = store.search( + "foo", + k=3, + custom_query=assert_query, + filter=[{"term": {"metadata.page": 1}}], + ) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] + + def test_delete(self, es_client: AsyncElasticsearch, index_name: str) -> None: + """Test delete methods from vector store.""" + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz", "gni"] + metadatas = [{"page": i} for i in range(len(texts))] + ids = store.add_texts(texts=texts, metadatas=metadatas) + + output = store.search("foo", k=10) + assert len(output) == 4 + + store.delete(ids[1:3]) + output = store.search("foo", k=10) + assert len(output) == 2 + + store.delete(["not-existing"]) + output = store.search("foo", k=10) + assert len(output) == 2 + + store.delete([ids[0]]) + output = store.search("foo", k=10) + assert len(output) == 1 + + store.delete([ids[3]]) + output = store.search("gni", k=10) + assert len(output) == 0 + + @pytest.mark.asyncio + async def test_indexing_exception_error( + self, + es_client: AsyncElasticsearch, + index_name: str, + caplog: pytest.LogCaptureFixture, + ) -> None: + """Test bulk exception logging is giving better hints.""" + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=BM25(), + es_client=es_client, + ) + + await store.es_client.indices.create( + index=index_name, + mappings={"properties": {}}, + settings={"index": {"default_pipeline": "not-existing-pipeline"}}, + ) + + texts = ["foo"] + + with pytest.raises(BulkIndexError): + store.add_texts(texts) + + error_reason = "pipeline with id [not-existing-pipeline] does not exist" + log_message = f"First error reason: {error_reason}" + + assert log_message in caplog.text + + def test_user_agent( + self, requests_saving_client: AsyncElasticsearch, index_name: str + ) -> None: + """Test to make sure the user-agent is set correctly.""" + agent_header = "this is THE agent_header!" + store = ElasticsearchStore( + agent_header=agent_header, + index_name=index_name, + retrieval_strategy=BM25(), + es_client=requests_saving_client, + ) + + assert store.es_client._headers["User-Agent"] == agent_header + + texts = ["foo", "bob", "baz"] + store.add_texts(texts) + + transport = cast(RequestSavingTransport, store.es_client.transport) + + for request in transport.requests: + assert request["headers"]["User-Agent"] == agent_header + + def test_bulk_args(self, requests_saving_client: Any, index_name: str) -> None: + """Test to make sure the bulk arguments work as expected.""" + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=BM25(), + es_client=requests_saving_client, + ) + + texts = ["foo", "bob", "baz"] + store.add_texts(texts, bulk_kwargs={"chunk_size": 1}) + + # 1 for index exist, 1 for index create, 3 to index docs + assert len(store.es_client.transport.requests) == 5 # type: ignore + + def test_max_marginal_relevance_search( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """Test max marginal relevance search.""" + texts = ["foo", "bar", "baz"] + vector_field = "vector_field" + text_field = "text_field" + embedding_service = ConsistentFakeEmbeddings() + store = ElasticsearchStore( + agent_header="test", + index_name=index_name, + retrieval_strategy=DenseVectorScriptScore( + embedding_service=embedding_service + ), + vector_field=vector_field, + text_field=text_field, + es_client=es_client, + ) + store.add_texts(texts) + + mmr_output = store.max_marginal_relevance_search( + embedding_service, + texts[0], + vector_field=vector_field, + k=3, + num_candidates=3, + ) + sim_output = store.search(texts[0], k=3) + assert mmr_output == sim_output + + mmr_output = store.max_marginal_relevance_search( + embedding_service, + texts[0], + vector_field=vector_field, + k=2, + num_candidates=3, + ) + assert len(mmr_output) == 2 + assert mmr_output[0]["_source"][text_field] == texts[0] + assert mmr_output[1]["_source"][text_field] == texts[1] + + mmr_output = store.max_marginal_relevance_search( + embedding_service, + texts[0], + vector_field=vector_field, + k=2, + num_candidates=3, + lambda_mult=0.1, # more diversity + ) + assert len(mmr_output) == 2 + assert mmr_output[0]["_source"][text_field] == texts[0] + assert mmr_output[1]["_source"][text_field] == texts[2] + + # if fetch_k < k, then the output will be less than k + mmr_output = store.max_marginal_relevance_search( + embedding_service, + texts[0], + vector_field=vector_field, + k=3, + num_candidates=2, + ) + assert len(mmr_output) == 2 From e03a17f94b8ba29f7a004d9d66f0d4af136fa70b Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Wed, 17 Apr 2024 16:44:43 +0200 Subject: [PATCH 02/36] Update elasticsearch/store/_utilities.py Co-authored-by: Quentin Pradet --- elasticsearch/store/_utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elasticsearch/store/_utilities.py b/elasticsearch/store/_utilities.py index a3b6c36c7..f1f8d7110 100644 --- a/elasticsearch/store/_utilities.py +++ b/elasticsearch/store/_utilities.py @@ -50,7 +50,7 @@ def create_elasticsearch_client( raise TypeError("Elasticsearch client must be AsyncElasticsearch client") # Add integration-specific usage header for tracking usage in Elastic Cloud. - # client.options preserces existing (non-user-agent) headers. + # client.options preserves existing (non-user-agent) headers. client = client.options(headers={"User-Agent": agent_header}) return client From 8ff1c7c083cfce99bae8ead72016e9899882d356 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Thu, 18 Apr 2024 15:06:24 +0200 Subject: [PATCH 03/36] rename; depend on client; async only --- elasticsearch/store/embedding_service.py | 132 --- .../{store => vectorstore}/__init__.py | 0 elasticsearch/vectorstore/_async/__init__.py | 5 + elasticsearch/vectorstore/_async/_utils.py | 36 + .../vectorstore/_async/embedding_service.py | 80 ++ .../_async}/strategies.py | 61 +- .../_async/vectorestore.py} | 124 +-- .../_utilities.py => vectorstore/_utils.py} | 83 +- .../test_vectorstore/_async}/__init__.py | 0 .../test_vectorstore/_async/_test_utils.py} | 69 +- .../_async/test_embedding_service.py | 62 ++ .../_async/test_vectorestore.py} | 377 ++++--- .../test_vectorstore}/docker-compose.yml | 0 .../test_vectorestore copy.py1 | 986 ++++++++++++++++++ .../test_embedding_service.py | 47 - 15 files changed, 1445 insertions(+), 617 deletions(-) delete mode 100644 elasticsearch/store/embedding_service.py rename elasticsearch/{store => vectorstore}/__init__.py (100%) create mode 100644 elasticsearch/vectorstore/_async/__init__.py create mode 100644 elasticsearch/vectorstore/_async/_utils.py create mode 100644 elasticsearch/vectorstore/_async/embedding_service.py rename elasticsearch/{store => vectorstore/_async}/strategies.py (91%) rename elasticsearch/{store/store.py => vectorstore/_async/vectorestore.py} (74%) rename elasticsearch/{store/_utilities.py => vectorstore/_utils.py} (50%) rename test_elasticsearch/{test_store_integration => test_server/test_vectorstore/_async}/__init__.py (100%) rename test_elasticsearch/{test_store_integration/_test_utilities.py => test_server/test_vectorstore/_async/_test_utils.py} (68%) create mode 100644 test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py rename test_elasticsearch/{test_store_integration/test_store.py => test_server/test_vectorstore/_async/test_vectorestore.py} (75%) rename test_elasticsearch/{test_store_integration => test_server/test_vectorstore}/docker-compose.yml (100%) create mode 100644 test_elasticsearch/test_server/test_vectorstore/test_vectorestore copy.py1 delete mode 100644 test_elasticsearch/test_store_integration/test_embedding_service.py diff --git a/elasticsearch/store/embedding_service.py b/elasticsearch/store/embedding_service.py deleted file mode 100644 index e6c5a470c..000000000 --- a/elasticsearch/store/embedding_service.py +++ /dev/null @@ -1,132 +0,0 @@ -import asyncio -from abc import ABC, abstractmethod -from typing import List, Optional - -import nest_asyncio # type: ignore -from elasticsearch import AsyncElasticsearch - -from elasticsearch.store._utilities import create_elasticsearch_client - - -class EmbeddingService(ABC): - @abstractmethod - async def embed_documents_async(self, texts: List[str]) -> List[List[float]]: - """Generate embeddings for a list of documents. - - Args: - texts: A list of document strings to generate embeddings for. - - Returns: - A list of embeddings, one for each document in the input. - """ - - @abstractmethod - def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Generate embeddings for a list of documents. - - Args: - texts: A list of document strings to generate embeddings for. - - Returns: - A list of embeddings, one for each document in the input. - """ - - @abstractmethod - async def embed_query_async(self, query: str) -> List[float]: - """Generate an embedding for a single query text. - - Args: - text: The query text to generate an embedding for. - - Returns: - The embedding for the input query text. - """ - - @abstractmethod - def embed_query(self, query: str) -> List[float]: - """Generate an embedding for a single query text. - - Args: - text: The query text to generate an embedding for. - - Returns: - The embedding for the input query text. - """ - - -class ElasticsearchEmbeddings(EmbeddingService): - """Elasticsearch as a service for embedding model inference. - - You need to have an embedding model downloaded and deployed in Elasticsearch: - - https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html - - https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html - """ # noqa: E501 - - def __init__( - self, - agent_header: str, - model_id: str, - input_field: str = "text_field", - num_dimensions: Optional[int] = None, - # Connection params - es_client: Optional[AsyncElasticsearch] = None, - es_url: Optional[str] = None, - es_cloud_id: Optional[str] = None, - es_api_key: Optional[str] = None, - es_user: Optional[str] = None, - es_password: Optional[str] = None, - ): - """ - Args: - agent_header: user agent header specific to the 3rd party integration. - Used for usage tracking in Elastic Cloud. - model_id: The model_id of the model deployed in the Elasticsearch cluster. - input_field: The name of the key for the input text field in the - document. Defaults to 'text_field'. - num_dimensions: The number of embedding dimensions. If None, then dimensions - will be infer from an example inference call. - es_client: Elasticsearch client connection. Alternatively specify the - Elasticsearch connection with the other es_* parameters. - """ - nest_asyncio.apply() - - client = create_elasticsearch_client( - agent_header=agent_header, - client=es_client, - url=es_url, - cloud_id=es_cloud_id, - api_key=es_api_key, - username=es_user, - password=es_password, - ) - - self.client = client.ml - self.model_id = model_id - self.input_field = input_field - self._num_dimensions = num_dimensions - - async def embed_documents_async(self, texts: List[str]) -> List[List[float]]: - result = await self._embedding_func_async(texts) - return result - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - return asyncio.get_event_loop().run_until_complete( - self.embed_documents_async(texts) - ) - - async def embed_query_async(self, text: str) -> List[float]: - result = await self._embedding_func_async([text]) - return result[0] - - def embed_query(self, query: str) -> List[float]: - return asyncio.get_event_loop().run_until_complete( - self.embed_query_async(query) - ) - - async def _embedding_func_async(self, texts: List[str]) -> List[List[float]]: - response = await self.client.infer_trained_model( - model_id=self.model_id, docs=[{self.input_field: text} for text in texts] - ) - - embeddings = [doc["predicted_value"] for doc in response["inference_results"]] - return embeddings diff --git a/elasticsearch/store/__init__.py b/elasticsearch/vectorstore/__init__.py similarity index 100% rename from elasticsearch/store/__init__.py rename to elasticsearch/vectorstore/__init__.py diff --git a/elasticsearch/vectorstore/_async/__init__.py b/elasticsearch/vectorstore/_async/__init__.py new file mode 100644 index 000000000..135dccf64 --- /dev/null +++ b/elasticsearch/vectorstore/_async/__init__.py @@ -0,0 +1,5 @@ +from elasticsearch.vectorstore._async.vectorestore import AsyncVectorStore + +__all__ = [ + "AsyncVectorStore", +] diff --git a/elasticsearch/vectorstore/_async/_utils.py b/elasticsearch/vectorstore/_async/_utils.py new file mode 100644 index 000000000..5b5aacdbc --- /dev/null +++ b/elasticsearch/vectorstore/_async/_utils.py @@ -0,0 +1,36 @@ +from elasticsearch import ( + AsyncElasticsearch, + BadRequestError, + ConflictError, + NotFoundError, +) + + +async def async_model_must_be_deployed( + client: AsyncElasticsearch, model_id: str +) -> None: + try: + dummy = {"x": "y"} + await client.ml.infer_trained_model(model_id=model_id, docs=[dummy]) + except NotFoundError as err: + raise err + except ConflictError as err: + raise NotFoundError( + f"model '{model_id}' not found, please deploy it first", + meta=err.meta, + body=err.body, + ) from err + except BadRequestError: + # This error is expected because we do not know the expected document + # shape and just use a dummy doc above. + pass + + return None + + +async def async_model_is_deployed(es_client: AsyncElasticsearch, model_id: str) -> bool: + try: + await async_model_must_be_deployed(es_client, model_id) + return True + except NotFoundError: + return False diff --git a/elasticsearch/vectorstore/_async/embedding_service.py b/elasticsearch/vectorstore/_async/embedding_service.py new file mode 100644 index 000000000..00611e9cd --- /dev/null +++ b/elasticsearch/vectorstore/_async/embedding_service.py @@ -0,0 +1,80 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +from elasticsearch import AsyncElasticsearch + + +class AsyncEmbeddingService(ABC): + @abstractmethod + async def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Generate embeddings for a list of documents. + + Args: + texts: A list of document strings to generate embeddings for. + + Returns: + A list of embeddings, one for each document in the input. + """ + + @abstractmethod + async def embed_query(self, query: str) -> List[float]: + """Generate an embedding for a single query text. + + Args: + text: The query text to generate an embedding for. + + Returns: + The embedding for the input query text. + """ + + +class AsyncElasticsearchEmbeddings(AsyncEmbeddingService): + """Elasticsearch as a service for embedding model inference. + + You need to have an embedding model downloaded and deployed in Elasticsearch: + - https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html + - https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html + """ # noqa: E501 + + def __init__( + self, + es_client: AsyncElasticsearch, + user_agent: str, + model_id: str, + input_field: str = "text_field", + num_dimensions: Optional[int] = None, + ): + """ + Args: + agent_header: user agent header specific to the 3rd party integration. + Used for usage tracking in Elastic Cloud. + model_id: The model_id of the model deployed in the Elasticsearch cluster. + input_field: The name of the key for the input text field in the + document. Defaults to 'text_field'. + num_dimensions: The number of embedding dimensions. If None, then dimensions + will be infer from an example inference call. + es_client: Elasticsearch client connection. Alternatively specify the + Elasticsearch connection with the other es_* parameters. + """ + # Add integration-specific usage header for tracking usage in Elastic Cloud. + # client.options preserces existing (non-user-agent) headers. + es_client = es_client.options(headers={"User-Agent": user_agent}) + + self.client = es_client.ml + self.model_id = model_id + self.input_field = input_field + self._num_dimensions = num_dimensions + + async def embed_documents(self, texts: List[str]) -> List[List[float]]: + result = await self._embedding_func(texts) + return result + + async def embed_query(self, text: str) -> List[float]: + result = await self._embedding_func([text]) + return result[0] + + async def _embedding_func(self, texts: List[str]) -> List[List[float]]: + response = await self.client.infer_trained_model( + model_id=self.model_id, docs=[{self.input_field: text} for text in texts] + ) + return [doc["predicted_value"] for doc in response["inference_results"]] diff --git a/elasticsearch/store/strategies.py b/elasticsearch/vectorstore/_async/strategies.py similarity index 91% rename from elasticsearch/store/strategies.py rename to elasticsearch/vectorstore/_async/strategies.py index 4daa0d097..22d9ab66b 100644 --- a/elasticsearch/store/strategies.py +++ b/elasticsearch/vectorstore/_async/strategies.py @@ -4,8 +4,8 @@ from elasticsearch import AsyncElasticsearch -from elasticsearch.store._utilities import model_must_be_deployed_async -from elasticsearch.store.embedding_service import EmbeddingService +from elasticsearch.vectorstore._async._utils import async_model_must_be_deployed +from elasticsearch.vectorstore._async.embedding_service import AsyncEmbeddingService class DistanceMetric(str, Enum): @@ -19,7 +19,7 @@ class DistanceMetric(str, Enum): class RetrievalStrategy(ABC): @abstractmethod - def es_query( + async def es_query( self, query: Optional[str], k: int, @@ -60,7 +60,7 @@ async def create_index( describe the schema of the metadata. """ - def embed_for_indexing(self, text: str) -> Dict[str, Any]: + async def embed_for_indexing(self, text: str) -> Dict[str, Any]: """ If this strategy creates vector embeddings in Python (not in Elasticsearch), this method is used to apply the inference. @@ -91,7 +91,7 @@ def __init__( self.text_field = text_field self.inference_field = inference_field - def es_query( + async def es_query( self, query: Optional[str], k: int, @@ -121,7 +121,7 @@ async def create_index( metadata_mapping: Optional[dict[str, str]], ) -> None: if self.model_id: - await model_must_be_deployed_async(client, self.model_id) + await async_model_must_be_deployed(client, self.model_id) mappings: dict[str, Any] = { "properties": { @@ -151,7 +151,7 @@ def __init__( self.vector_field = vector_field self._tokens_field = "tokens" - def es_query( + async def es_query( self, query: Optional[str], k: int, @@ -195,12 +195,12 @@ async def create_index( pipeline_name = f"{self.model_id}_sparse_embedding" if self.model_id: - await model_must_be_deployed_async(client, self.model_id) + await async_model_must_be_deployed(client, self.model_id) # Create a pipeline for the model await client.ingest.put_pipeline( id=pipeline_name, - description="Embedding pipeline for ElasticsearchStore", + description="Embedding pipeline for Python VectorStore", processors=[ { "inference": { @@ -241,7 +241,7 @@ def __init__( knn_type: Literal["hnsw", "int8_hnsw", "flat", "int8_flat"] = "hnsw", vector_field: str = "vector_field", distance: DistanceMetric = DistanceMetric.COSINE, - embedding_service: Optional[EmbeddingService] = None, + embedding_service: Optional[AsyncEmbeddingService] = None, model_id: Optional[str] = None, num_dimensions: Optional[int] = None, hybrid: bool = False, @@ -269,7 +269,7 @@ def __init__( self.rrf = rrf self.text_field = text_field - def es_query( + async def es_query( self, query: Optional[str], k: int, @@ -287,7 +287,9 @@ def es_query( if query_vector: knn["query_vector"] = query_vector elif self.embedding_service: - knn["query_vector"] = self.embedding_service.embed_query(cast(str, query)) + knn["query_vector"] = await self.embedding_service.embed_query( + cast(str, query) + ) else: # Inference in Elasticsearch. When initializing we make sure to always have # a model_id if don't have an embedding_service. @@ -299,9 +301,7 @@ def es_query( } if self.hybrid: - x = self._hybrid(query=cast(str, query), knn=knn, filter=filter) - print(x) - return x + return self._hybrid(query=cast(str, query), knn=knn, filter=filter) return {"knn": knn} @@ -313,11 +313,11 @@ async def create_index( ) -> None: if self.embedding_service and not self.num_dimensions: self.num_dimensions = len( - self.embedding_service.embed_query("get number of dimensions") + await self.embedding_service.embed_query("get number of dimensions") ) if self.model_id: - await model_must_be_deployed_async(client, self.model_id) + await async_model_must_be_deployed(client, self.model_id) if self.distance is DistanceMetric.COSINE: similarityAlgo = "cosine" @@ -330,7 +330,7 @@ async def create_index( else: raise ValueError(f"Similarity {self.distance} not supported.") - mappings = { + mappings: Dict[str, Any] = { "properties": { self.vector_field: { "type": "dense_vector", @@ -343,13 +343,12 @@ async def create_index( if metadata_mapping: mappings["properties"]["metadata"] = {"properties": metadata_mapping} - await client.indices.create(index=index_name, mappings=mappings) + r = await client.indices.create(index=index_name, mappings=mappings) + print(r) - return None - - def embed_for_indexing(self, text: str) -> Dict[str, Any]: + async def embed_for_indexing(self, text: str) -> Dict[str, Any]: if self.embedding_service: - vector = self.embedding_service.embed_query(text) + vector = await self.embedding_service.embed_query(text) return {self.vector_field: vector} return {} @@ -389,7 +388,7 @@ class DenseVectorScriptScore(RetrievalStrategy): def __init__( self, - embedding_service: EmbeddingService, + embedding_service: AsyncEmbeddingService, vector_field: str = "vector_field", distance: DistanceMetric = DistanceMetric.COSINE, num_dimensions: Optional[int] = None, @@ -399,7 +398,7 @@ def __init__( self.embedding_service = embedding_service self.num_dimensions = num_dimensions - def es_query( + async def es_query( self, query: Optional[str], k: int, @@ -443,7 +442,7 @@ def es_query( ) if not query: raise ValueError("either specify a query string or a query_vector") - query_vector = self.embedding_service.embed_query(query) + query_vector = await self.embedding_service.embed_query(query) return { "query": { @@ -465,7 +464,7 @@ async def create_index( ) -> None: if not self.num_dimensions: self.num_dimensions = len( - self.embedding_service.embed_query("get number of dimensions") + await self.embedding_service.embed_query("get number of dimensions") ) mappings = { @@ -484,8 +483,8 @@ async def create_index( return None - def embed_for_indexing(self, text: str) -> Dict[str, Any]: - return {self.vector_field: self.embedding_service.embed_query(text)} + async def embed_for_indexing(self, text: str) -> Dict[str, Any]: + return {self.vector_field: await self.embedding_service.embed_query(text)} class BM25(RetrievalStrategy): @@ -499,7 +498,7 @@ def __init__( self.k1 = k1 self.b = b - def es_query( + async def es_query( self, query: Optional[str], k: int, @@ -559,5 +558,3 @@ async def create_index( await client.indices.create( index=index_name, mappings=mappings, settings=settings ) - - return None diff --git a/elasticsearch/store/store.py b/elasticsearch/vectorstore/_async/vectorestore.py similarity index 74% rename from elasticsearch/store/store.py rename to elasticsearch/vectorstore/_async/vectorestore.py index da61ee4f4..ba9243e85 100644 --- a/elasticsearch/store/store.py +++ b/elasticsearch/vectorstore/_async/vectorestore.py @@ -1,24 +1,21 @@ -import asyncio import logging import uuid from typing import Any, Callable, Dict, List, Optional -import nest_asyncio # type: ignore from elasticsearch import AsyncElasticsearch from elasticsearch.helpers import BulkIndexError, async_bulk -from elasticsearch.store._utilities import ( - create_elasticsearch_client, +from elasticsearch.vectorstore._utils import ( maximal_marginal_relevance, ) -from elasticsearch.store.embedding_service import EmbeddingService -from elasticsearch.store.strategies import RetrievalStrategy +from elasticsearch.vectorstore._async.embedding_service import AsyncEmbeddingService +from elasticsearch.vectorstore._async.strategies import RetrievalStrategy logger = logging.getLogger(__name__) -class ElasticsearchStore: - """ElasticsearchStore is a higher-level abstraction of indexing and search. +class AsyncVectorStore: + """VectorStore is a higher-level abstraction of indexing and search. Users can pick from available retrieval strategies. Documents are flat text documents. Depending on the strategy, vector embeddings are @@ -29,24 +26,17 @@ class ElasticsearchStore: def __init__( self, - agent_header: str, + es_client: AsyncElasticsearch, + user_agent: str, index_name: str, retrieval_strategy: RetrievalStrategy, text_field: str = "text_field", vector_field: str = "vector_field", metadata_mapping: Optional[dict[str, str]] = None, - # Connection params - es_client: Optional[AsyncElasticsearch] = None, - es_url: Optional[str] = None, - es_cloud_id: Optional[str] = None, - es_api_key: Optional[str] = None, - es_user: Optional[str] = None, - es_password: Optional[str] = None, - es_params: Optional[Dict[str, Any]] = None, ) -> None: """ Args: - agent_header: user agent header specific to the 3rd party integration. + user_header: user agent header specific to the 3rd party integration. Used for usage tracking in Elastic Cloud. index_name: The name of the index to query. retrieval_strategy: how to index and search the data. See the strategies @@ -57,34 +47,27 @@ def __init__( es_client: Elasticsearch client connection. Alternatively specify the Elasticsearch connection with the other es_* parameters. """ - nest_asyncio.apply() - - self.es_client = create_elasticsearch_client( - agent_header=agent_header, - client=es_client, - url=es_url, - cloud_id=es_cloud_id, - api_key=es_api_key, - username=es_user, - password=es_password, - client_params=es_params, - ) + + # Add integration-specific usage header for tracking usage in Elastic Cloud. + # client.options preserces existing (non-user-agent) headers. + es_client = es_client.options(headers={"User-Agent": user_agent}) if hasattr(retrieval_strategy, "text_field"): retrieval_strategy.text_field = text_field if hasattr(retrieval_strategy, "vector_field"): retrieval_strategy.vector_field = vector_field + self.es_client = es_client self.index_name = index_name self.retrieval_strategy = retrieval_strategy self.text_field = text_field self.vector_field = vector_field self.metadata_mapping = metadata_mapping - def close(self): - return asyncio.get_event_loop().run_until_complete(self.es_client.close()) + async def close(self): + return await self.es_client.close() - async def add_texts_async( + async def add_texts( self, texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None, @@ -135,7 +118,7 @@ async def add_texts_async( if vectors: request[self.vector_field] = vectors[i] - request.update(self.retrieval_strategy.embed_for_indexing(text)) + request.update(await self.retrieval_strategy.embed_for_indexing(text)) requests.append(request) if len(requests) > 0: @@ -147,10 +130,6 @@ async def add_texts_async( refresh=refresh_indices, **bulk_kwargs, ) - logger.debug( - f"Added {success} and failed to add {failed} texts to index" - ) - logger.debug(f"added texts {ids} to index") return ids except BulkIndexError as e: @@ -163,29 +142,7 @@ async def add_texts_async( logger.debug("No texts to add to index") return [] - def add_texts( - self, - texts: List[str], - metadatas: Optional[List[Dict[str, Any]]] = None, - vectors: Optional[List[List[float]]] = None, - ids: Optional[List[str]] = None, - refresh_indices: bool = True, - create_index_if_not_exists: bool = True, - bulk_kwargs: Optional[Dict[str, Any]] = None, - ) -> List[str]: - return asyncio.get_event_loop().run_until_complete( - self.add_texts_async( - texts=texts, - metadatas=metadatas, - vectors=vectors, - ids=ids, - refresh_indices=refresh_indices, - create_index_if_not_exists=create_index_if_not_exists, - bulk_kwargs=bulk_kwargs, - ) - ) - - async def delete_async( + async def delete( self, ids: Optional[List[str]] = None, query: Optional[Dict[str, Any]] = None, @@ -235,17 +192,7 @@ async def delete_async( return True - def delete( - self, - ids: Optional[List[str]] = None, - query: Optional[Dict[str, Any]] = None, - refresh_indices: bool = True, - ) -> bool: - return asyncio.get_event_loop().run_until_complete( - self.delete_async(ids=ids, query=query, refresh_indices=refresh_indices) - ) - - async def search_async( + async def search( self, query: Optional[str], query_vector: Optional[List[float]] = None, @@ -277,14 +224,13 @@ async def search_async( if self.text_field not in fields: fields.append(self.text_field) - query_body = self.retrieval_strategy.es_query( + query_body = await self.retrieval_strategy.es_query( query=query, k=k, num_candidates=num_candidates, filter=filter or [], query_vector=query_vector, ) - logger.debug(f"Query body: {query_body}") if custom_query is not None: query_body = custom_query(query_body, query) @@ -300,28 +246,6 @@ async def search_async( return response["hits"]["hits"] - def search( - self, - query: Optional[str], - query_vector: Optional[List[float]] = None, - k: int = 4, - num_candidates: int = 50, - fields: Optional[List[str]] = None, - filter: Optional[List[dict]] = None, - custom_query: Optional[Callable[[Dict, Optional[str]], Dict]] = None, - ) -> List[Dict[str, Any]]: - return asyncio.get_event_loop().run_until_complete( - self.search_async( - query=query, - query_vector=query_vector, - k=k, - num_candidates=num_candidates, - fields=fields, - filter=filter, - custom_query=custom_query, - ) - ) - async def _create_index_if_not_exists(self) -> None: exists = await self.es_client.indices.exists(index=self.index_name) if exists.meta.status == 200: @@ -333,9 +257,9 @@ async def _create_index_if_not_exists(self) -> None: metadata_mapping=self.metadata_mapping, ) - def max_marginal_relevance_search( + async def max_marginal_relevance_search( self, - embedding_service: EmbeddingService, + embedding_service: AsyncEmbeddingService, query: str, vector_field: str, k: int = 4, @@ -372,10 +296,10 @@ def max_marginal_relevance_search( remove_vector_query_field_from_metadata = False # Embed the query - query_embedding = embedding_service.embed_query(query) + query_embedding = await embedding_service.embed_query(query) # Fetch the initial documents - got_hits = self.search( + got_hits = await self.search( query=None, query_vector=query_embedding, k=num_candidates, diff --git a/elasticsearch/store/_utilities.py b/elasticsearch/vectorstore/_utils.py similarity index 50% rename from elasticsearch/store/_utilities.py rename to elasticsearch/vectorstore/_utils.py index f1f8d7110..b0e4f1372 100644 --- a/elasticsearch/store/_utilities.py +++ b/elasticsearch/vectorstore/_utils.py @@ -1,91 +1,10 @@ -from typing import Any, Dict, List, Optional, Union +from typing import List, Union import numpy as np -from elasticsearch import ( - AsyncElasticsearch, - BadRequestError, - ConflictError, - NotFoundError, -) Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] -def create_elasticsearch_client( - agent_header: str, - client: Optional[AsyncElasticsearch] = None, - url: Optional[str] = None, - cloud_id: Optional[str] = None, - api_key: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - client_params: Optional[Dict[str, Any]] = None, -) -> AsyncElasticsearch: - if not client: - if url and cloud_id: - raise ValueError( - "Both es_url and cloud_id are defined. Please provide only one." - ) - - connection_params: Dict[str, Any] = {} - - if url: - connection_params["hosts"] = [url] - elif cloud_id: - connection_params["cloud_id"] = cloud_id - else: - raise ValueError("Please provide either elasticsearch_url or cloud_id.") - - if api_key: - connection_params["api_key"] = api_key - elif username and password: - connection_params["basic_auth"] = (username, password) - - if client_params is not None: - connection_params.update(client_params) - - client = AsyncElasticsearch(**connection_params) - - if not isinstance(client, AsyncElasticsearch): - raise TypeError("Elasticsearch client must be AsyncElasticsearch client") - - # Add integration-specific usage header for tracking usage in Elastic Cloud. - # client.options preserves existing (non-user-agent) headers. - client = client.options(headers={"User-Agent": agent_header}) - - return client - - -async def model_must_be_deployed_async( - client: AsyncElasticsearch, model_id: str -) -> None: - try: - dummy = {"x": "y"} - await client.ml.infer_trained_model(model_id=model_id, docs=[dummy]) - except NotFoundError as err: - raise err - except ConflictError as err: - raise NotFoundError( - f"model '{model_id}' not found, please deploy it first", - meta=err.meta, - body=err.body, - ) from err - except BadRequestError: - # This error is expected because we do not know the expected document - # shape and just use a dummy doc above. - pass - - return None - - -async def model_is_deployed_async(es_client: AsyncElasticsearch, model_id: str) -> bool: - try: - await model_must_be_deployed_async(es_client, model_id) - return True - except NotFoundError: - return False - - def maximal_marginal_relevance( query_embedding: list, embedding_list: list, diff --git a/test_elasticsearch/test_store_integration/__init__.py b/test_elasticsearch/test_server/test_vectorstore/_async/__init__.py similarity index 100% rename from test_elasticsearch/test_store_integration/__init__.py rename to test_elasticsearch/test_server/test_vectorstore/_async/__init__.py diff --git a/test_elasticsearch/test_store_integration/_test_utilities.py b/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py similarity index 68% rename from test_elasticsearch/test_store_integration/_test_utilities.py rename to test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py index e9158eed9..9a69c8b42 100644 --- a/test_elasticsearch/test_store_integration/_test_utilities.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py @@ -1,34 +1,30 @@ -import asyncio import os -from typing import Any, Dict, List, Optional +import pytest_asyncio +from typing import Any, Dict, List, Optional, AsyncGenerator -import nest_asyncio # type: ignore from elastic_transport import AsyncTransport from elasticsearch import AsyncElasticsearch -from elasticsearch.store._utilities import model_is_deployed_async -from elasticsearch.store.embedding_service import EmbeddingService +from elasticsearch.vectorstore._async.embedding_service import AsyncEmbeddingService -class FakeEmbeddings(EmbeddingService): +class AsyncFakeEmbeddings(AsyncEmbeddingService): """Fake embeddings functionality for testing.""" def __init__(self, dimensionality: int = 10) -> None: - nest_asyncio.apply() - self.dimensionality = dimensionality def num_dimensions(self) -> int: return self.dimensionality - async def embed_documents_async(self, texts: List[str]) -> List[List[float]]: + async def embed_documents(self, texts: List[str]) -> List[List[float]]: """Return simple embeddings. Embeddings encode each text as its index.""" return [ [float(1.0)] * (self.dimensionality - 1) + [float(i)] for i in range(len(texts)) ] - async def embed_query_async(self, text: str) -> List[float]: + async def embed_query(self, text: str) -> List[float]: """Return constant query embeddings. Embeddings are identical to embed_documents(texts)[0]. Distance to each text will be that text's index, @@ -36,16 +32,8 @@ async def embed_query_async(self, text: str) -> List[float]: """ return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)] - def embed_query(self, text: str) -> List[float]: - return asyncio.get_event_loop().run_until_complete(self.embed_query_async(text)) - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - return asyncio.get_event_loop().run_until_complete( - self.embed_documents_async(texts) - ) - -class ConsistentFakeEmbeddings(FakeEmbeddings): +class AsyncConsistentFakeEmbeddings(AsyncFakeEmbeddings): """Fake embeddings which remember all the texts seen so far to return consistent vectors for the same texts.""" @@ -56,7 +44,7 @@ def __init__(self, dimensionality: int = 10) -> None: def num_dimensions(self) -> int: return self.dimensionality - def embed_documents(self, texts: List[str]) -> List[List[float]]: + async def embed_documents(self, texts: List[str]) -> List[List[float]]: """Return consistent embeddings for each text seen so far.""" out_vectors = [] for text in texts: @@ -68,13 +56,14 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: out_vectors.append(vector) return out_vectors - def embed_query(self, text: str) -> List[float]: + async def embed_query(self, text: str) -> List[float]: """Return consistent embeddings for the text, if seen before, or a constant one if the text is unknown.""" - return self.embed_documents([text])[0] + result = await self.embed_documents([text]) + return result[0] -class RequestSavingTransport(AsyncTransport): +class AsyncRequestSavingTransport(AsyncTransport): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.requests: List[Dict] = [] @@ -102,7 +91,33 @@ def create_es_client( def create_requests_saving_client() -> AsyncElasticsearch: - return create_es_client(es_kwargs={"transport_class": RequestSavingTransport}) + return create_es_client(es_kwargs={"transport_class": AsyncRequestSavingTransport}) + + +async def es_client_fixture() -> AsyncGenerator[AsyncElasticsearch, None]: + params = read_env() + client = create_es_client(params) + + yield client + + # clear indices + await clear_test_indices(client) + + # clear all test pipelines + try: + response = await client.ingest.get_pipeline(id="test_*,*_sparse_embedding") + + for pipeline_id, _ in response.items(): + try: + await client.ingest.delete_pipeline(id=pipeline_id) + print(f"Deleted pipeline: {pipeline_id}") # noqa: T201 + except Exception as e: + print(f"Pipeline error: {e}") # noqa: T201 + + except Exception: + pass + finally: + await client.close() async def clear_test_indices(client: AsyncElasticsearch) -> None: @@ -114,12 +129,6 @@ async def clear_test_indices(client: AsyncElasticsearch) -> None: await client.indices.refresh(index="_all") -def model_is_deployed(es_client: AsyncElasticsearch, model_id: str) -> bool: - return asyncio.get_event_loop().run_until_complete( - model_is_deployed_async(es_client, model_id) - ) - - def read_env() -> Dict: url = os.environ.get("ES_URL", "http://localhost:9200") cloud_id = os.environ.get("ES_CLOUD_ID") diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py new file mode 100644 index 000000000..b349b724c --- /dev/null +++ b/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py @@ -0,0 +1,62 @@ +import os + +import pytest + +import pytest_asyncio +from elasticsearch import AsyncElasticsearch + +from typing import AsyncGenerator + +from elasticsearch.vectorstore._async._utils import async_model_is_deployed + +from ._test_utils import ( + es_client_fixture, +) + +from elasticsearch.vectorstore._async.embedding_service import ( + AsyncElasticsearchEmbeddings, +) + +# deployed with +# https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html +MODEL_ID = os.getenv("MODEL_ID", "sentence-transformers__msmarco-minilm-l-12-v3") +NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384")) + + +@pytest_asyncio.fixture(autouse=True) +async def es_client() -> AsyncGenerator[AsyncElasticsearch, None]: + async for x in es_client_fixture(): + yield x + + +@pytest.mark.asyncio +async def test_elasticsearch_embedding_documents(es_client: AsyncElasticsearch) -> None: + """Test Elasticsearch embedding documents.""" + + if not await async_model_is_deployed(es_client, MODEL_ID): + pytest.skip(f"{MODEL_ID} model is not deployed in ML Node, skipping test") + + documents = ["foo bar", "bar foo", "foo"] + embedding = AsyncElasticsearchEmbeddings( + es_client=es_client, user_agent="test", model_id=MODEL_ID + ) + output = await embedding.embed_documents(documents) + assert len(output) == 3 + assert len(output[0]) == NUM_DIMENSIONS + assert len(output[1]) == NUM_DIMENSIONS + assert len(output[2]) == NUM_DIMENSIONS + + +@pytest.mark.asyncio +async def test_elasticsearch_embedding_query(es_client: AsyncElasticsearch) -> None: + """Test Elasticsearch embedding query.""" + + if not await async_model_is_deployed(es_client, MODEL_ID): + pytest.skip(f"{MODEL_ID} model is not deployed in ML Node, skipping test") + + document = "foo bar" + embedding = AsyncElasticsearchEmbeddings( + es_client=es_client, user_agent="test", model_id=MODEL_ID + ) + output = await embedding.embed_query(document) + assert len(output) == NUM_DIMENSIONS diff --git a/test_elasticsearch/test_store_integration/test_store.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py similarity index 75% rename from test_elasticsearch/test_store_integration/test_store.py rename to test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py index 552a61fb7..7dab019a6 100644 --- a/test_elasticsearch/test_store_integration/test_store.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py @@ -1,15 +1,19 @@ import logging import uuid +from typing import AsyncGenerator +from typing import Any, List, Optional, Union, cast from functools import partial -from typing import Any, AsyncGenerator, List, Optional, Union, cast import pytest import pytest_asyncio -from elasticsearch import AsyncElasticsearch, NotFoundError +from elasticsearch import AsyncElasticsearch + +from elasticsearch import NotFoundError from elasticsearch.helpers import BulkIndexError -from elasticsearch.store.store import ElasticsearchStore -from elasticsearch.store.strategies import ( +from elasticsearch.vectorstore._async import AsyncVectorStore +from elasticsearch.vectorstore._async._utils import async_model_is_deployed +from elasticsearch.vectorstore._async.strategies import ( BM25, DenseVector, DenseVectorScriptScore, @@ -17,15 +21,12 @@ Semantic, ) -from ._test_utilities import ( - ConsistentFakeEmbeddings, - FakeEmbeddings, - RequestSavingTransport, - clear_test_indices, - create_es_client, +from ._test_utils import ( create_requests_saving_client, - model_is_deployed, - read_env, + es_client_fixture, + AsyncConsistentFakeEmbeddings, + AsyncFakeEmbeddings, + AsyncRequestSavingTransport, ) logging.basicConfig(level=logging.DEBUG) @@ -54,29 +55,8 @@ class TestElasticsearch: @pytest_asyncio.fixture(autouse=True) async def es_client(self) -> AsyncGenerator[AsyncElasticsearch, None]: - params = read_env() - client = create_es_client(params) - - yield client - - # clear indices - await clear_test_indices(client) - - # clear all test pipelines - try: - response = await client.ingest.get_pipeline(id="test_*,*_sparse_embedding") - - for pipeline_id, _ in response.items(): - try: - await client.ingest.delete_pipeline(id=pipeline_id) - print(f"Deleted pipeline: {pipeline_id}") # noqa: T201 - except Exception as e: - print(f"Pipeline error: {e}") # noqa: T201 - - except Exception: - pass - finally: - await client.close() + async for x in es_client_fixture(): + yield x @pytest_asyncio.fixture(autouse=True) async def requests_saving_client(self) -> AsyncGenerator[AsyncElasticsearch, None]: @@ -91,24 +71,6 @@ def index_name(self) -> str: """Return the index name.""" return f"test_{uuid.uuid4().hex}" - def test_initialize_from_params(self, index_name: str) -> None: - params = read_env() - agent_header = "test initialize from params" - store = ElasticsearchStore( - agent_header=agent_header, - index_name=index_name, - retrieval_strategy=BM25(), - **params, - ) - - assert store.es_client._headers["User-Agent"] == agent_header - - texts = ["foo", "bar", "baz"] - store.add_texts(texts) - - output = store.search("foo", k=1) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - @pytest.mark.asyncio async def test_search_without_metadata( self, es_client: AsyncElasticsearch, index_name: str @@ -127,17 +89,17 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: } return query_body - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + retrieval_strategy=DenseVector(embedding_service=AsyncFakeEmbeddings()), es_client=es_client, ) texts = ["foo", "bar", "baz"] - store.add_texts(texts) + await store.add_texts(texts) - output = store.search("foo", k=1, custom_query=assert_query) + output = await store.search("foo", k=1, custom_query=assert_query) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] @pytest.mark.asyncio @@ -145,20 +107,23 @@ async def test_search_without_metadata_async( self, es_client: AsyncElasticsearch, index_name: str ) -> None: """Test end to end construction and search without metadata.""" - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + retrieval_strategy=DenseVector(embedding_service=AsyncFakeEmbeddings()), es_client=es_client, ) texts = ["foo", "bar", "baz"] - store.add_texts(texts) + await store.add_texts(texts) - output = await store.search_async("foo", k=1) + output = await store.search("foo", k=1) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - def test_add_vectors(self, es_client: AsyncElasticsearch, index_name: str) -> None: + @pytest.mark.asyncio + async def test_add_vectors( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: """ Test adding pre-built embeddings instead of using inference for the texts. This allows you to separate the embeddings text and the page_content @@ -166,64 +131,68 @@ def test_add_vectors(self, es_client: AsyncElasticsearch, index_name: str) -> No For example, your embedding text can be a question, whereas page_content is the answer. """ - embeddings = ConsistentFakeEmbeddings() + embeddings = AsyncConsistentFakeEmbeddings() texts = ["foo1", "foo2", "foo3"] metadatas = [{"page": i} for i in range(len(texts))] """In real use case, embedding_input can be questions for each text""" - embedding_vectors = embeddings.embed_documents(texts) + embedding_vectors = await embeddings.embed_documents(texts) - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, retrieval_strategy=DenseVector(embedding_service=embeddings), es_client=es_client, ) - store.add_texts(texts=texts, vectors=embedding_vectors, metadatas=metadatas) - output = store.search("foo1", k=1) + await store.add_texts( + texts=texts, vectors=embedding_vectors, metadatas=metadatas + ) + output = await store.search("foo1", k=1) assert [doc["_source"]["text_field"] for doc in output] == ["foo1"] assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] - def test_search_with_metadata( + @pytest.mark.asyncio + async def test_search_with_metadata( self, es_client: AsyncElasticsearch, index_name: str ) -> None: """Test end to end construction and search with metadata.""" - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, retrieval_strategy=DenseVector( - embedding_service=ConsistentFakeEmbeddings() + embedding_service=AsyncConsistentFakeEmbeddings() ), es_client=es_client, ) texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] - store.add_texts(texts=texts, metadatas=metadatas) + await store.add_texts(texts=texts, metadatas=metadatas) - output = store.search("foo", k=1) + output = await store.search("foo", k=1) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] - output = store.search("bar", k=1) + output = await store.search("bar", k=1) assert [doc["_source"]["text_field"] for doc in output] == ["bar"] assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - def test_search_with_filter( + @pytest.mark.asyncio + async def test_search_with_filter( self, es_client: AsyncElasticsearch, index_name: str ) -> None: """Test end to end construction and search with metadata.""" - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + retrieval_strategy=DenseVector(embedding_service=AsyncFakeEmbeddings()), es_client=es_client, ) texts = ["foo", "foo", "foo"] metadatas = [{"page": i} for i in range(len(texts))] - store.add_texts(texts=texts, metadatas=metadatas) + await store.add_texts(texts=texts, metadatas=metadatas) def assert_query(query_body: dict, query: Optional[str]) -> dict: assert query_body == { @@ -237,7 +206,7 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: } return query_body - output = store.search( + output = await store.search( query="foo", k=3, filter=[{"term": {"metadata.page": "1"}}], @@ -246,21 +215,22 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["text_field"] for doc in output] == ["foo"] assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - def test_search_script_score( + @pytest.mark.asyncio + async def test_search_script_score( self, es_client: AsyncElasticsearch, index_name: str ) -> None: """Test end to end construction and search with metadata.""" - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, retrieval_strategy=DenseVectorScriptScore( - embedding_service=FakeEmbeddings() + embedding_service=AsyncFakeEmbeddings() ), es_client=es_client, ) texts = ["foo", "bar", "baz"] - store.add_texts(texts) + await store.add_texts(texts) expected_query = { "query": { @@ -291,25 +261,26 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert query_body == expected_query return query_body - output = store.search("foo", k=1, custom_query=assert_query) + output = await store.search("foo", k=1, custom_query=assert_query) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - def test_search_script_score_with_filter( + @pytest.mark.asyncio + async def test_search_script_score_with_filter( self, es_client: AsyncElasticsearch, index_name: str ) -> None: """Test end to end construction and search with metadata.""" - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, retrieval_strategy=DenseVectorScriptScore( - embedding_service=FakeEmbeddings() + embedding_service=AsyncFakeEmbeddings() ), es_client=es_client, ) texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] - store.add_texts(texts=texts, metadatas=metadatas) + await store.add_texts(texts=texts, metadatas=metadatas) def assert_query(query_body: dict, query: Optional[str]) -> dict: expected_query = { @@ -339,7 +310,7 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert query_body == expected_query return query_body - output = store.search( + output = await store.search( "foo", k=1, custom_query=assert_query, @@ -348,22 +319,23 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["text_field"] for doc in output] == ["foo"] assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] - def test_search_script_score_distance_dot_product( + @pytest.mark.asyncio + async def test_search_script_score_distance_dot_product( self, es_client: AsyncElasticsearch, index_name: str ) -> None: """Test end to end construction and search with metadata.""" - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, retrieval_strategy=DenseVectorScriptScore( - embedding_service=FakeEmbeddings(), + embedding_service=AsyncFakeEmbeddings(), distance=DistanceMetric.DOT_PRODUCT, ), es_client=es_client, ) texts = ["foo", "bar", "baz"] - store.add_texts(texts) + await store.add_texts(texts) def assert_query(query_body: dict, query: Optional[str]) -> dict: assert query_body == { @@ -395,25 +367,26 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: } return query_body - output = store.search("foo", k=1, custom_query=assert_query) + output = await store.search("foo", k=1, custom_query=assert_query) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - def test_search_knn_with_hybrid_search( + @pytest.mark.asyncio + async def test_search_knn_with_hybrid_search( self, es_client: AsyncElasticsearch, index_name: str ) -> None: """Test end to end construction and search with metadata.""" - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, retrieval_strategy=DenseVector( - embedding_service=FakeEmbeddings(), + embedding_service=AsyncFakeEmbeddings(), hybrid=True, ), es_client=es_client, ) texts = ["foo", "bar", "baz"] - store.add_texts(texts) + await store.add_texts(texts) def assert_query(query_body: dict, query: Optional[str]) -> dict: assert query_body == { @@ -434,7 +407,7 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: } return query_body - output = store.search("foo", k=1, custom_query=assert_query) + output = await store.search("foo", k=1, custom_query=assert_query) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] @pytest.mark.asyncio @@ -492,20 +465,20 @@ def assert_query( {"rank_constant": 1, "window_size": 5}, ] for rrf_test_case in rrf_test_cases: - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, retrieval_strategy=DenseVector( - embedding_service=FakeEmbeddings(), + embedding_service=AsyncFakeEmbeddings(), hybrid=True, rrf=rrf_test_case, ), es_client=es_client, ) - store.add_texts(texts) + await store.add_texts(texts) ## without fetch_k parameter - output = store.search( + output = await store.search( "foo", k=3, custom_query=partial(assert_query, expected_rrf=rrf_test_case), @@ -536,34 +509,35 @@ def assert_query( ] # 3. check rrf default option is okay - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=f"{index_name}_default", retrieval_strategy=DenseVector( - embedding_service=FakeEmbeddings(), + embedding_service=AsyncFakeEmbeddings(), hybrid=True, ), es_client=es_client, ) - store.add_texts(texts) + await store.add_texts(texts) ## with fetch_k parameter - output = store.search( + output = await store.search( "foo", k=3, num_candidates=50, custom_query=partial(assert_query, expected_rrf={}), ) - def test_search_knn_with_custom_query_fn( + @pytest.mark.asyncio + async def test_search_knn_with_custom_query_fn( self, es_client: AsyncElasticsearch, index_name: str ) -> None: """test that custom query function is called with the query string and query body""" - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + retrieval_strategy=DenseVector(embedding_service=AsyncFakeEmbeddings()), es_client=es_client, ) @@ -582,25 +556,26 @@ def my_custom_query(query_body: dict, query: Optional[str]) -> dict: """Test end to end construction and search with metadata.""" texts = ["foo", "bar", "baz"] - store.add_texts(texts) + await store.add_texts(texts) - output = store.search("foo", k=1, custom_query=my_custom_query) + output = await store.search("foo", k=1, custom_query=my_custom_query) assert [doc["_source"]["text_field"] for doc in output] == ["bar"] @pytest.mark.asyncio - @pytest.mark.skipif( - not model_is_deployed(create_es_client(), TRANSFORMER_MODEL_ID), - reason=f"{TRANSFORMER_MODEL_ID} model not deployed in ML Node, " - "skipping test", - ) async def test_search_with_knn_infer_instack( self, es_client: AsyncElasticsearch, index_name: str ) -> None: """test end to end with knn retrieval strategy and inference in-stack""" + + if not await async_model_is_deployed(es_client, TRANSFORMER_MODEL_ID): + pytest.skip( + f"{TRANSFORMER_MODEL_ID} model not deployed in ML Node skipping test" + ) + text_field = "text_field" - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, retrieval_strategy=Semantic( model_id="sentence-transformers__all-minilm-l6-v2", @@ -674,57 +649,64 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: } return query_body - output = store.search("foo", k=1, custom_query=assert_query) + output = await store.search("foo", k=1, custom_query=assert_query) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - output = store.search("bar", k=1) + output = await store.search("bar", k=1) assert [doc["_source"]["text_field"] for doc in output] == ["bar"] - @pytest.mark.skipif( - not model_is_deployed(create_es_client(), ELSER_MODEL_ID), - reason=f"{ELSER_MODEL_ID} model not deployed in ML Node, skipping test", - ) - def test_search_with_sparse_infer_instack( + @pytest.mark.asyncio + async def test_search_with_sparse_infer_instack( self, es_client: AsyncElasticsearch, index_name: str ) -> None: """test end to end with sparse retrieval strategy and inference in-stack""" - store = ElasticsearchStore( - agent_header="test", + + if not await async_model_is_deployed(es_client, ELSER_MODEL_ID): + reason = f"{ELSER_MODEL_ID} model not deployed in ML Node, skipping test" + + pytest.skip(reason) + + store = AsyncVectorStore( + user_agent="test", index_name=index_name, retrieval_strategy=Semantic(model_id=ELSER_MODEL_ID), es_client=es_client, ) texts = ["foo", "bar", "baz"] - store.add_texts(texts) + await store.add_texts(texts) - output = store.search("foo", k=1) + output = await store.search("foo", k=1) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - def test_deployed_model_check_fails_semantic( + @pytest.mark.asyncio + async def test_deployed_model_check_fails_semantic( self, es_client: AsyncElasticsearch, index_name: str ) -> None: """test that exceptions are raised if a specified model is not deployed""" with pytest.raises(NotFoundError): - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, retrieval_strategy=Semantic(model_id="non-existing model ID"), es_client=es_client, ) - store.add_texts(["foo", "bar", "baz"]) + await store.add_texts(["foo", "bar", "baz"]) - def test_search_bm25(self, es_client: AsyncElasticsearch, index_name: str) -> None: + @pytest.mark.asyncio + async def test_search_bm25( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: """Test end to end using the BM25 retrieval strategy.""" - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, retrieval_strategy=BM25(), es_client=es_client, ) texts = ["foo", "bar", "baz"] - store.add_texts(texts) + await store.add_texts(texts) def assert_query(query_body: dict, query: Optional[str]) -> dict: assert query_body == { @@ -737,15 +719,16 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: } return query_body - output = store.search("foo", k=1, custom_query=assert_query) + output = await store.search("foo", k=1, custom_query=assert_query) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - def test_search_bm25_with_filter( + @pytest.mark.asyncio + async def test_search_bm25_with_filter( self, es_client: AsyncElasticsearch, index_name: str ) -> None: """Test end to using the BM25 retrieval strategy with metadata.""" - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, retrieval_strategy=BM25(), es_client=es_client, @@ -753,7 +736,7 @@ def test_search_bm25_with_filter( texts = ["foo", "foo", "foo"] metadatas = [{"page": i} for i in range(len(texts))] - store.add_texts(texts=texts, metadatas=metadatas) + await store.add_texts(texts=texts, metadatas=metadatas) def assert_query(query_body: dict, query: Optional[str]) -> dict: assert query_body == { @@ -766,7 +749,7 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: } return query_body - output = store.search( + output = await store.search( "foo", k=3, custom_query=assert_query, @@ -775,36 +758,37 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["text_field"] for doc in output] == ["foo"] assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - def test_delete(self, es_client: AsyncElasticsearch, index_name: str) -> None: + @pytest.mark.asyncio + async def test_delete(self, es_client: AsyncElasticsearch, index_name: str) -> None: """Test delete methods from vector store.""" - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + retrieval_strategy=DenseVector(embedding_service=AsyncFakeEmbeddings()), es_client=es_client, ) texts = ["foo", "bar", "baz", "gni"] metadatas = [{"page": i} for i in range(len(texts))] - ids = store.add_texts(texts=texts, metadatas=metadatas) + ids = await store.add_texts(texts=texts, metadatas=metadatas) - output = store.search("foo", k=10) + output = await store.search("foo", k=10) assert len(output) == 4 - store.delete(ids[1:3]) - output = store.search("foo", k=10) + await store.delete(ids[1:3]) + output = await store.search("foo", k=10) assert len(output) == 2 - store.delete(["not-existing"]) - output = store.search("foo", k=10) + await store.delete(["not-existing"]) + output = await store.search("foo", k=10) assert len(output) == 2 - store.delete([ids[0]]) - output = store.search("foo", k=10) + await store.delete([ids[0]]) + output = await store.search("foo", k=10) assert len(output) == 1 - store.delete([ids[3]]) - output = store.search("gni", k=10) + await store.delete([ids[3]]) + output = await store.search("gni", k=10) assert len(output) == 0 @pytest.mark.asyncio @@ -815,8 +799,8 @@ async def test_indexing_exception_error( caplog: pytest.LogCaptureFixture, ) -> None: """Test bulk exception logging is giving better hints.""" - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, retrieval_strategy=BM25(), es_client=es_client, @@ -831,60 +815,65 @@ async def test_indexing_exception_error( texts = ["foo"] with pytest.raises(BulkIndexError): - store.add_texts(texts) + await store.add_texts(texts) error_reason = "pipeline with id [not-existing-pipeline] does not exist" log_message = f"First error reason: {error_reason}" assert log_message in caplog.text - def test_user_agent( + @pytest.mark.asyncio + async def test_user_agent( self, requests_saving_client: AsyncElasticsearch, index_name: str ) -> None: """Test to make sure the user-agent is set correctly.""" - agent_header = "this is THE agent_header!" - store = ElasticsearchStore( - agent_header=agent_header, + user_agent = "this is THE user_agent!" + store = AsyncVectorStore( + user_agent=user_agent, index_name=index_name, retrieval_strategy=BM25(), es_client=requests_saving_client, ) - assert store.es_client._headers["User-Agent"] == agent_header + assert store.es_client._headers["User-Agent"] == user_agent texts = ["foo", "bob", "baz"] - store.add_texts(texts) + await store.add_texts(texts) - transport = cast(RequestSavingTransport, store.es_client.transport) + transport = cast(AsyncRequestSavingTransport, store.es_client.transport) for request in transport.requests: - assert request["headers"]["User-Agent"] == agent_header + assert request["headers"]["User-Agent"] == user_agent - def test_bulk_args(self, requests_saving_client: Any, index_name: str) -> None: + @pytest.mark.asyncio + async def test_bulk_args( + self, requests_saving_client: Any, index_name: str + ) -> None: """Test to make sure the bulk arguments work as expected.""" - store = ElasticsearchStore( - agent_header="test", + store = AsyncVectorStore( + user_agent="test", index_name=index_name, retrieval_strategy=BM25(), es_client=requests_saving_client, ) texts = ["foo", "bob", "baz"] - store.add_texts(texts, bulk_kwargs={"chunk_size": 1}) + await store.add_texts(texts, bulk_kwargs={"chunk_size": 1}) # 1 for index exist, 1 for index create, 3 to index docs assert len(store.es_client.transport.requests) == 5 # type: ignore - def test_max_marginal_relevance_search( + @pytest.mark.asyncio + async def test_max_marginal_relevance_search( self, es_client: AsyncElasticsearch, index_name: str ) -> None: """Test max marginal relevance search.""" texts = ["foo", "bar", "baz"] vector_field = "vector_field" text_field = "text_field" - embedding_service = ConsistentFakeEmbeddings() - store = ElasticsearchStore( - agent_header="test", + embedding_service = AsyncConsistentFakeEmbeddings() + store = AsyncVectorStore( + user_agent="test", index_name=index_name, retrieval_strategy=DenseVectorScriptScore( embedding_service=embedding_service @@ -893,19 +882,19 @@ def test_max_marginal_relevance_search( text_field=text_field, es_client=es_client, ) - store.add_texts(texts) + await store.add_texts(texts) - mmr_output = store.max_marginal_relevance_search( + mmr_output = await store.max_marginal_relevance_search( embedding_service, texts[0], vector_field=vector_field, k=3, num_candidates=3, ) - sim_output = store.search(texts[0], k=3) + sim_output = await store.search(texts[0], k=3) assert mmr_output == sim_output - mmr_output = store.max_marginal_relevance_search( + mmr_output = await store.max_marginal_relevance_search( embedding_service, texts[0], vector_field=vector_field, @@ -916,7 +905,7 @@ def test_max_marginal_relevance_search( assert mmr_output[0]["_source"][text_field] == texts[0] assert mmr_output[1]["_source"][text_field] == texts[1] - mmr_output = store.max_marginal_relevance_search( + mmr_output = await store.max_marginal_relevance_search( embedding_service, texts[0], vector_field=vector_field, @@ -929,7 +918,7 @@ def test_max_marginal_relevance_search( assert mmr_output[1]["_source"][text_field] == texts[2] # if fetch_k < k, then the output will be less than k - mmr_output = store.max_marginal_relevance_search( + mmr_output = await store.max_marginal_relevance_search( embedding_service, texts[0], vector_field=vector_field, diff --git a/test_elasticsearch/test_store_integration/docker-compose.yml b/test_elasticsearch/test_server/test_vectorstore/docker-compose.yml similarity index 100% rename from test_elasticsearch/test_store_integration/docker-compose.yml rename to test_elasticsearch/test_server/test_vectorstore/docker-compose.yml diff --git a/test_elasticsearch/test_server/test_vectorstore/test_vectorestore copy.py1 b/test_elasticsearch/test_server/test_vectorstore/test_vectorestore copy.py1 new file mode 100644 index 000000000..c432582d5 --- /dev/null +++ b/test_elasticsearch/test_server/test_vectorstore/test_vectorestore copy.py1 @@ -0,0 +1,986 @@ +import logging +import uuid +from typing import AsyncGenerator + +import pytest +import pytest_asyncio +from elasticsearch_serverless import AsyncElasticsearch +from elasticsearch_serverless.helpers import async_bulk + + +from ._test_utilities import ( + create_es_client, + create_requests_saving_client, + clear_test_indices, + read_env, +) + +logging.basicConfig(level=logging.DEBUG) + +""" +docker-compose up elasticsearch + +By default runs against local docker instance of Elasticsearch. +To run against Elastic Cloud, set the following environment variables: +- ES_CLOUD_ID +- ES_API_KEY + +Some of the tests require the following models to be deployed in the ML Node: +- elser (can be downloaded and deployed through Kibana and trained models UI) +- sentence-transformers__all-minilm-l6-v2 (can be deployed through the API, + loaded via eland) + +These tests that require the models to be deployed are skipped by default. +Enable them by adding the model name to the modelsDeployed list below. +""" + +ELSER_MODEL_ID = ".elser_model_2" +TRANSFORMER_MODEL_ID = "sentence-transformers__all-minilm-l6-v2" + + +class TestElasticsearch: + @pytest_asyncio.fixture(autouse=True) + async def es_client(self) -> AsyncGenerator[AsyncElasticsearch, None]: + params = read_env() + client = create_es_client(params) + + yield client + + # clear indices + if False: + await clear_test_indices(client) + + # clear all test pipelines + try: + response = await client.ingest.get_pipeline(id="test_*,*_sparse_embedding") + + for pipeline_id, _ in response.items(): + try: + await client.ingest.delete_pipeline(id=pipeline_id) + print(f"Deleted pipeline: {pipeline_id}") # noqa: T201 + except Exception as e: + print(f"Pipeline error: {e}") # noqa: T201 + + except Exception: + pass + finally: + await client.close() + + @pytest_asyncio.fixture(autouse=True) + async def requests_saving_client(self) -> AsyncGenerator[AsyncElasticsearch, None]: + client = create_requests_saving_client() + try: + yield client + finally: + await client.close() + + @pytest.fixture(scope="function") + def index_name(self) -> str: + """Return the index name.""" + return f"test_{uuid.uuid4().hex}" + + # def test_initialize_from_params(self, index_name: str) -> None: + # params = read_env() + # agent_header = "test initialize from params" + # store = VectorStore( + # agent_header=agent_header, + # index_name=index_name, + # retrieval_strategy=BM25(), + # **params, + # ) + + # assert store.es_client._headers["User-Agent"] == agent_header + + # texts = ["foo", "bar", "baz"] + # store.add_texts(texts) + + # output = store.search("foo", k=1) + # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + # @pytest.mark.asyncio + # async def test_search_without_metadata( + # self, es_client: AsyncElasticsearch, index_name: str + # ) -> None: + # """Test end to end construction and search without metadata.""" + + # def assert_query(query_body: dict, query: Optional[str]) -> dict: + # assert query_body == { + # "knn": { + # "field": "vector_field", + # "filter": [], + # "k": 1, + # "num_candidates": 50, + # "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + # } + # } + # return query_body + + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + # es_client=es_client, + # ) + + # texts = ["foo", "bar", "baz"] + # store.add_texts(texts) + + # output = store.search("foo", k=1, custom_query=assert_query) + # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + @pytest.mark.asyncio + async def test_add_vectors( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """ + Test adding pre-built embeddings instead of using inference for the texts. + This allows you to separate the embeddings text and the page_content + for better proximity between user's question and embedded text. + For example, your embedding text can be a question, whereas page_content + is the answer. + """ + docs = [ + ("foo1", [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0]), + ("foo2", [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]), + ("foo3", [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0]), + ] + + texts = [t for t, _ in docs] + # embeddings = ConsistentFakeEmbeddings() + # texts = ["foo1", "foo2", "foo3"] + metadatas = [{"page": i} for i in range(len(texts))] + + """In real use case, embedding_input can be questions for each text""" + # embedding_vectors = embeddings.embed_documents(texts) + + embedding_vectors = [e for _, e in docs] + + index_name = "test_2" + + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=DenseVector(num_dimensions=10), + # es_client=es_client, + # ) + + # store.add_texts(texts=texts, vectors=embedding_vectors, metadatas=metadatas) + + #################### + + mappings = { + "properties": { + "vector_field": { + "type": "dense_vector", + "dims": len(docs[0][1]), + "index": True, + "similarity": "cosine", + }, + } + } + await es_client.indices.create(index=index_name, mappings=mappings) + + indexing_requests = [ + { + "_op_type": "index", + "_index": index_name, + "_id": str(uuid.uuid4()), + "text_field": doc_id, + "metadata": metadatas[i], + "vector_field": vector, + } + for i, (doc_id, vector) in enumerate(docs) + ] + await async_bulk(es_client, indexing_requests, refresh=True) + + #################### + + # query_vector = embedding_vectors[0] + query_vector = docs[0][1] + assert query_vector == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0] + query_body = { + "knn": { + "filter": [], + "field": "vector_field", + "k": 1, + "num_candidates": 50, + "query_vector": query_vector, + } + } + print(query_body) + + # if custom_query is not None: + # query_body = custom_query(query_body, query) + # logger.debug(f"Calling custom_query, Query body now: {query_body}") + + output = await es_client.search( + index=index_name, + **query_body, + size=1, + source=True, + ) + output = output["hits"]["hits"] + + # output = store.search(query=None, query_vector=query_vector, k=1) + + # print("\n".join([str(v) for v in embedding_vectors])) + print(query_vector) + print(output) + print([doc["_score"] for doc in output]) + + assert [doc["_source"]["text_field"] for doc in output] == ["foo1"] + assert [doc["_score"] for doc in output] == [1.0] + # assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] + + # def test_search_with_metadata( + # self, es_client: AsyncElasticsearch, index_name: str + # ) -> None: + # """Test end to end construction and search with metadata.""" + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=DenseVector( + # embedding_service=ConsistentFakeEmbeddings() + # ), + # es_client=es_client, + # ) + + # texts = ["foo", "bar", "baz"] + # metadatas = [{"page": i} for i in range(len(texts))] + # store.add_texts(texts=texts, metadatas=metadatas) + + # output = store.search("foo", k=1) + # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + # assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] + + # output = store.search("bar", k=1) + # assert [doc["_source"]["text_field"] for doc in output] == ["bar"] + # assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] + + # def test_search_with_filter( + # self, es_client: AsyncElasticsearch, index_name: str + # ) -> None: + # """Test end to end construction and search with metadata.""" + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + # es_client=es_client, + # ) + + # texts = ["foo", "foo", "foo"] + # metadatas = [{"page": i} for i in range(len(texts))] + # store.add_texts(texts=texts, metadatas=metadatas) + + # def assert_query(query_body: dict, query: Optional[str]) -> dict: + # assert query_body == { + # "knn": { + # "field": "vector_field", + # "filter": [{"term": {"metadata.page": "1"}}], + # "k": 3, + # "num_candidates": 50, + # "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + # } + # } + # return query_body + + # output = store.search( + # query="foo", + # k=3, + # filter=[{"term": {"metadata.page": "1"}}], + # custom_query=assert_query, + # ) + # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + # assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] + + # def test_search_script_score( + # self, es_client: AsyncElasticsearch, index_name: str + # ) -> None: + # """Test end to end construction and search with metadata.""" + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=DenseVectorScriptScore( + # embedding_service=FakeEmbeddings() + # ), + # es_client=es_client, + # ) + + # texts = ["foo", "bar", "baz"] + # store.add_texts(texts) + + # expected_query = { + # "query": { + # "script_score": { + # "query": {"match_all": {}}, + # "script": { + # "source": "cosineSimilarity(params.query_vector, 'vector_field') + 1.0", # noqa: E501 + # "params": { + # "query_vector": [ + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 0.0, + # ] + # }, + # }, + # } + # } + # } + + # def assert_query(query_body: dict, query: Optional[str]) -> dict: + # assert query_body == expected_query + # return query_body + + # output = store.search("foo", k=1, custom_query=assert_query) + # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + # def test_search_script_score_with_filter( + # self, es_client: AsyncElasticsearch, index_name: str + # ) -> None: + # """Test end to end construction and search with metadata.""" + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=DenseVectorScriptScore( + # embedding_service=FakeEmbeddings() + # ), + # es_client=es_client, + # ) + + # texts = ["foo", "bar", "baz"] + # metadatas = [{"page": i} for i in range(len(texts))] + # store.add_texts(texts=texts, metadatas=metadatas) + + # def assert_query(query_body: dict, query: Optional[str]) -> dict: + # expected_query = { + # "query": { + # "script_score": { + # "query": {"bool": {"filter": [{"term": {"metadata.page": 0}}]}}, + # "script": { + # "source": "cosineSimilarity(params.query_vector, 'vector_field') + 1.0", # noqa: E501 + # "params": { + # "query_vector": [ + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 0.0, + # ] + # }, + # }, + # } + # } + # } + # assert query_body == expected_query + # return query_body + + # output = store.search( + # "foo", + # k=1, + # custom_query=assert_query, + # filter=[{"term": {"metadata.page": 0}}], + # ) + # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + # assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] + + # def test_search_script_score_distance_dot_product( + # self, es_client: AsyncElasticsearch, index_name: str + # ) -> None: + # """Test end to end construction and search with metadata.""" + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=DenseVectorScriptScore( + # embedding_service=FakeEmbeddings(), + # distance=DistanceMetric.DOT_PRODUCT, + # ), + # es_client=es_client, + # ) + + # texts = ["foo", "bar", "baz"] + # store.add_texts(texts) + + # def assert_query(query_body: dict, query: Optional[str]) -> dict: + # assert query_body == { + # "query": { + # "script_score": { + # "query": {"match_all": {}}, + # "script": { + # "source": """ + # double value = dotProduct(params.query_vector, 'vector_field'); + # return sigmoid(1, Math.E, -value); + # """, + # "params": { + # "query_vector": [ + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 0.0, + # ] + # }, + # }, + # } + # } + # } + # return query_body + + # output = store.search("foo", k=1, custom_query=assert_query) + # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + # def test_search_knn_with_hybrid_search( + # self, es_client: AsyncElasticsearch, index_name: str + # ) -> None: + # """Test end to end construction and search with metadata.""" + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=DenseVector( + # embedding_service=FakeEmbeddings(), + # hybrid=True, + # ), + # es_client=es_client, + # ) + + # texts = ["foo", "bar", "baz"] + # store.add_texts(texts) + + # def assert_query(query_body: dict, query: Optional[str]) -> dict: + # assert query_body == { + # "knn": { + # "field": "vector_field", + # "filter": [], + # "k": 1, + # "num_candidates": 50, + # "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + # }, + # "query": { + # "bool": { + # "filter": [], + # "must": [{"match": {"text_field": {"query": "foo"}}}], + # } + # }, + # "rank": {"rrf": {}}, + # } + # return query_body + + # output = store.search("foo", k=1, custom_query=assert_query) + # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + # @pytest.mark.asyncio + # async def test_search_knn_with_hybrid_search_rrf( + # self, es_client: AsyncElasticsearch, index_name: str + # ) -> None: + # """Test end to end construction and rrf hybrid search with metadata.""" + # texts = ["foo", "bar", "baz"] + + # def assert_query( + # query_body: dict, + # query: Optional[str], + # expected_rrf: Union[dict, bool], + # ) -> dict: + # cmp_query_body = { + # "knn": { + # "field": "vector_field", + # "filter": [], + # "k": 3, + # "num_candidates": 50, + # "query_vector": [ + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 1.0, + # 0.0, + # ], + # }, + # "query": { + # "bool": { + # "filter": [], + # "must": [{"match": {"text_field": {"query": "foo"}}}], + # } + # }, + # } + + # if isinstance(expected_rrf, dict): + # cmp_query_body["rank"] = {"rrf": expected_rrf} + # elif isinstance(expected_rrf, bool) and expected_rrf is True: + # cmp_query_body["rank"] = {"rrf": {}} + + # assert query_body == cmp_query_body + + # return query_body + + # # 1. check query_body is okay + # rrf_test_cases: List[Union[dict, bool]] = [ + # True, + # False, + # {"rank_constant": 1, "window_size": 5}, + # ] + # for rrf_test_case in rrf_test_cases: + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=DenseVector( + # embedding_service=FakeEmbeddings(), + # hybrid=True, + # rrf=rrf_test_case, + # ), + # es_client=es_client, + # ) + # store.add_texts(texts) + + # ## without fetch_k parameter + # output = store.search( + # "foo", + # k=3, + # custom_query=partial(assert_query, expected_rrf=rrf_test_case), + # ) + + # # 2. check query result is okay + # es_output = await store.es_client.search( + # index=index_name, + # query={ + # "bool": { + # "filter": [], + # "must": [{"match": {"text_field": {"query": "foo"}}}], + # } + # }, + # knn={ + # "field": "vector_field", + # "filter": [], + # "k": 3, + # "num_candidates": 50, + # "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + # }, + # size=3, + # rank={"rrf": {"rank_constant": 1, "window_size": 5}}, + # ) + + # assert [o["_source"]["text_field"] for o in output] == [ + # e["_source"]["text_field"] for e in es_output["hits"]["hits"] + # ] + + # # 3. check rrf default option is okay + # store = VectorStore( + # agent_header="test", + # index_name=f"{index_name}_default", + # retrieval_strategy=DenseVector( + # embedding_service=FakeEmbeddings(), + # hybrid=True, + # ), + # es_client=es_client, + # ) + # store.add_texts(texts) + + # ## with fetch_k parameter + # output = store.search( + # "foo", + # k=3, + # num_candidates=50, + # custom_query=partial(assert_query, expected_rrf={}), + # ) + + # def test_search_knn_with_custom_query_fn( + # self, es_client: AsyncElasticsearch, index_name: str + # ) -> None: + # """test that custom query function is called + # with the query string and query body""" + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + # es_client=es_client, + # ) + + # def my_custom_query(query_body: dict, query: Optional[str]) -> dict: + # assert query == "foo" + # assert query_body == { + # "knn": { + # "field": "vector_field", + # "filter": [], + # "k": 1, + # "num_candidates": 50, + # "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + # } + # } + # return {"query": {"match": {"text_field": {"query": "bar"}}}} + + # """Test end to end construction and search with metadata.""" + # texts = ["foo", "bar", "baz"] + # store.add_texts(texts) + + # output = store.search("foo", k=1, custom_query=my_custom_query) + # assert [doc["_source"]["text_field"] for doc in output] == ["bar"] + + # @pytest.mark.asyncio + # @pytest.mark.skipif( + # not model_is_deployed(create_es_client(), TRANSFORMER_MODEL_ID), + # reason=f"{TRANSFORMER_MODEL_ID} model not deployed in ML Node, " + # "skipping test", + # ) + # async def test_search_with_knn_infer_instack( + # self, es_client: AsyncElasticsearch, index_name: str + # ) -> None: + # """test end to end with knn retrieval strategy and inference in-stack""" + # text_field = "text_field" + + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=Semantic( + # model_id="sentence-transformers__all-minilm-l6-v2", + # text_field=text_field, + # ), + # es_client=es_client, + # ) + + # # setting up the pipeline for inference + # await store.es_client.ingest.put_pipeline( + # id="test_pipeline", + # processors=[ + # { + # "inference": { + # "model_id": TRANSFORMER_MODEL_ID, + # "field_map": {"query_field": text_field}, + # "target_field": "vector_query_field", + # } + # } + # ], + # ) + + # # creating a new index with the pipeline, + # # not relying on langchain to create the index + # await store.es_client.indices.create( + # index=index_name, + # mappings={ + # "properties": { + # text_field: {"type": "text_field"}, + # "vector_query_field": { + # "properties": { + # "predicted_value": { + # "type": "dense_vector", + # "dims": 384, + # "index": True, + # "similarity": "l2_norm", + # } + # } + # }, + # } + # }, + # settings={"index": {"default_pipeline": "test_pipeline"}}, + # ) + + # # adding documents to the index + # texts = ["foo", "bar", "baz"] + + # for i, text in enumerate(texts): + # await store.es_client.create( + # index=index_name, + # id=str(i), + # document={text_field: text, "metadata": {}}, + # ) + + # await store.es_client.indices.refresh(index=index_name) + + # def assert_query(query_body: dict, query: Optional[str]) -> dict: + # assert query_body == { + # "knn": { + # "filter": [], + # "field": "vector_query_field.predicted_value", + # "k": 1, + # "num_candidates": 50, + # "query_vector_builder": { + # "text_embedding": { + # "model_id": TRANSFORMER_MODEL_ID, + # "model_text": "foo", + # } + # }, + # } + # } + # return query_body + + # output = store.search("foo", k=1, custom_query=assert_query) + # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + # output = store.search("bar", k=1) + # assert [doc["_source"]["text_field"] for doc in output] == ["bar"] + + # @pytest.mark.skipif( + # not model_is_deployed(create_es_client(), ELSER_MODEL_ID), + # reason=f"{ELSER_MODEL_ID} model not deployed in ML Node, skipping test", + # ) + # def test_search_with_sparse_infer_instack( + # self, es_client: AsyncElasticsearch, index_name: str + # ) -> None: + # """test end to end with sparse retrieval strategy and inference in-stack""" + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=Semantic(model_id=ELSER_MODEL_ID), + # es_client=es_client, + # ) + + # texts = ["foo", "bar", "baz"] + # store.add_texts(texts) + + # output = store.search("foo", k=1) + # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + # def test_deployed_model_check_fails_semantic( + # self, es_client: AsyncElasticsearch, index_name: str + # ) -> None: + # """test that exceptions are raised if a specified model is not deployed""" + # with pytest.raises(NotFoundError): + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=Semantic(model_id="non-existing model ID"), + # es_client=es_client, + # ) + # store.add_texts(["foo", "bar", "baz"]) + + # def test_search_bm25(self, es_client: AsyncElasticsearch, index_name: str) -> None: + # """Test end to end using the BM25 retrieval strategy.""" + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=BM25(), + # es_client=es_client, + # ) + + # texts = ["foo", "bar", "baz"] + # store.add_texts(texts) + + # def assert_query(query_body: dict, query: Optional[str]) -> dict: + # assert query_body == { + # "query": { + # "bool": { + # "must": [{"match": {"text_field": {"query": "foo"}}}], + # "filter": [], + # } + # } + # } + # return query_body + + # output = store.search("foo", k=1, custom_query=assert_query) + # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + # def test_search_bm25_with_filter( + # self, es_client: AsyncElasticsearch, index_name: str + # ) -> None: + # """Test end to using the BM25 retrieval strategy with metadata.""" + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=BM25(), + # es_client=es_client, + # ) + + # texts = ["foo", "foo", "foo"] + # metadatas = [{"page": i} for i in range(len(texts))] + # store.add_texts(texts=texts, metadatas=metadatas) + + # def assert_query(query_body: dict, query: Optional[str]) -> dict: + # assert query_body == { + # "query": { + # "bool": { + # "must": [{"match": {"text_field": {"query": "foo"}}}], + # "filter": [{"term": {"metadata.page": 1}}], + # } + # } + # } + # return query_body + + # output = store.search( + # "foo", + # k=3, + # custom_query=assert_query, + # filter=[{"term": {"metadata.page": 1}}], + # ) + # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + # assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] + + # def test_delete(self, es_client: AsyncElasticsearch, index_name: str) -> None: + # """Test delete methods from vector store.""" + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + # es_client=es_client, + # ) + + # texts = ["foo", "bar", "baz", "gni"] + # metadatas = [{"page": i} for i in range(len(texts))] + # ids = store.add_texts(texts=texts, metadatas=metadatas) + + # output = store.search("foo", k=10) + # assert len(output) == 4 + + # store.delete(ids[1:3]) + # output = store.search("foo", k=10) + # assert len(output) == 2 + + # store.delete(["not-existing"]) + # output = store.search("foo", k=10) + # assert len(output) == 2 + + # store.delete([ids[0]]) + # output = store.search("foo", k=10) + # assert len(output) == 1 + + # store.delete([ids[3]]) + # output = store.search("gni", k=10) + # assert len(output) == 0 + + # @pytest.mark.asyncio + # async def test_indexing_exception_error( + # self, + # es_client: AsyncElasticsearch, + # index_name: str, + # caplog: pytest.LogCaptureFixture, + # ) -> None: + # """Test bulk exception logging is giving better hints.""" + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=BM25(), + # es_client=es_client, + # ) + + # await store.es_client.indices.create( + # index=index_name, + # mappings={"properties": {}}, + # settings={"index": {"default_pipeline": "not-existing-pipeline"}}, + # ) + + # texts = ["foo"] + + # with pytest.raises(BulkIndexError): + # store.add_texts(texts) + + # error_reason = "pipeline with id [not-existing-pipeline] does not exist" + # log_message = f"First error reason: {error_reason}" + + # assert log_message in caplog.text + + # def test_user_agent( + # self, requests_saving_client: AsyncElasticsearch, index_name: str + # ) -> None: + # """Test to make sure the user-agent is set correctly.""" + # agent_header = "this is THE agent_header!" + # store = VectorStore( + # agent_header=agent_header, + # index_name=index_name, + # retrieval_strategy=BM25(), + # es_client=requests_saving_client, + # ) + + # assert store.es_client._headers["User-Agent"] == agent_header + + # texts = ["foo", "bob", "baz"] + # store.add_texts(texts) + + # transport = cast(RequestSavingTransport, store.es_client.transport) + + # for request in transport.requests: + # assert request["headers"]["User-Agent"] == agent_header + + # def test_bulk_args(self, requests_saving_client: Any, index_name: str) -> None: + # """Test to make sure the bulk arguments work as expected.""" + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=BM25(), + # es_client=requests_saving_client, + # ) + + # texts = ["foo", "bob", "baz"] + # store.add_texts(texts, bulk_kwargs={"chunk_size": 1}) + + # # 1 for index exist, 1 for index create, 3 to index docs + # assert len(store.es_client.transport.requests) == 5 # type: ignore + + # def test_max_marginal_relevance_search( + # self, es_client: AsyncElasticsearch, index_name: str + # ) -> None: + # """Test max marginal relevance search.""" + # texts = ["foo", "bar", "baz"] + # vector_field = "vector_field" + # text_field = "text_field" + # embedding_service = ConsistentFakeEmbeddings() + # store = VectorStore( + # agent_header="test", + # index_name=index_name, + # retrieval_strategy=DenseVectorScriptScore( + # embedding_service=embedding_service + # ), + # vector_field=vector_field, + # text_field=text_field, + # es_client=es_client, + # ) + # store.add_texts(texts) + + # mmr_output = store.max_marginal_relevance_search( + # embedding_service, + # texts[0], + # vector_field=vector_field, + # k=3, + # num_candidates=3, + # ) + # sim_output = store.search(texts[0], k=3) + # assert mmr_output == sim_output + + # mmr_output = store.max_marginal_relevance_search( + # embedding_service, + # texts[0], + # vector_field=vector_field, + # k=2, + # num_candidates=3, + # ) + # assert len(mmr_output) == 2 + # assert mmr_output[0]["_source"][text_field] == texts[0] + # assert mmr_output[1]["_source"][text_field] == texts[1] + + # mmr_output = store.max_marginal_relevance_search( + # embedding_service, + # texts[0], + # vector_field=vector_field, + # k=2, + # num_candidates=3, + # lambda_mult=0.1, # more diversity + # ) + # assert len(mmr_output) == 2 + # assert mmr_output[0]["_source"][text_field] == texts[0] + # assert mmr_output[1]["_source"][text_field] == texts[2] + + # # if fetch_k < k, then the output will be less than k + # mmr_output = store.max_marginal_relevance_search( + # embedding_service, + # texts[0], + # vector_field=vector_field, + # k=3, + # num_candidates=2, + # ) + # assert len(mmr_output) == 2 diff --git a/test_elasticsearch/test_store_integration/test_embedding_service.py b/test_elasticsearch/test_store_integration/test_embedding_service.py deleted file mode 100644 index 61119396c..000000000 --- a/test_elasticsearch/test_store_integration/test_embedding_service.py +++ /dev/null @@ -1,47 +0,0 @@ -import os - -import pytest -from elasticsearch import AsyncElasticsearch - -from elasticsearch.store.embedding_service import ElasticsearchEmbeddings - -from ._test_utilities import model_is_deployed - -# deployed with -# https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html -MODEL_ID = os.getenv("MODEL_ID", "sentence-transformers__msmarco-minilm-l-12-v3") -NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384")) - -ES_URL = os.environ.get("ES_URL", "http://localhost:9200") -ES_CLIENT = AsyncElasticsearch(hosts=[ES_URL]) - - -@pytest.mark.skipif( - not model_is_deployed(ES_CLIENT, MODEL_ID), - reason=f"{MODEL_ID} model is not deployed in ML Node, skipping test", -) -def test_elasticsearch_embedding_documents() -> None: - """Test Elasticsearch embedding documents.""" - documents = ["foo bar", "bar foo", "foo"] - embedding = ElasticsearchEmbeddings( - agent_header="test", model_id=MODEL_ID, es_url=ES_URL - ) - output = embedding.embed_documents(documents) - assert len(output) == 3 - assert len(output[0]) == NUM_DIMENSIONS - assert len(output[1]) == NUM_DIMENSIONS - assert len(output[2]) == NUM_DIMENSIONS - - -@pytest.mark.skipif( - not model_is_deployed(ES_CLIENT, MODEL_ID), - reason=f"{MODEL_ID} model is not deployed in ML Node, skipping test", -) -def test_elasticsearch_embedding_query() -> None: - """Test Elasticsearch embedding query.""" - document = "foo bar" - embedding = ElasticsearchEmbeddings( - agent_header="test", model_id=MODEL_ID, es_url=ES_URL - ) - output = embedding.embed_query(document) - assert len(output) == NUM_DIMENSIONS From 9be44fdc18a226c651e03301518859ad7ca9965f Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Thu, 18 Apr 2024 15:48:05 +0200 Subject: [PATCH 04/36] generate _sync files --- elasticsearch/vectorstore/_async/_utils.py | 8 +- .../vectorstore/_async/strategies.py | 8 +- elasticsearch/vectorstore/_sync/__init__.py | 5 + elasticsearch/vectorstore/_sync/_utils.py | 34 + .../vectorstore/_sync/embedding_service.py | 80 ++ elasticsearch/vectorstore/_sync/strategies.py | 560 +++++++++++ .../vectorstore/_sync/vectorestore.py | 323 ++++++ .../test_vectorstore/_async/_test_utils.py | 5 +- .../_async/test_embedding_service.py | 10 +- .../_async/test_vectorestore.py | 12 +- .../test_vectorstore/_sync/__init__.py | 0 .../test_vectorstore/_sync/_test_utils.py | 138 +++ .../_sync/test_embedding_service.py | 62 ++ .../_sync/test_vectorestore.py | 928 ++++++++++++++++++ utils/run-unasync.py | 83 +- 15 files changed, 2211 insertions(+), 45 deletions(-) create mode 100644 elasticsearch/vectorstore/_sync/__init__.py create mode 100644 elasticsearch/vectorstore/_sync/_utils.py create mode 100644 elasticsearch/vectorstore/_sync/embedding_service.py create mode 100644 elasticsearch/vectorstore/_sync/strategies.py create mode 100644 elasticsearch/vectorstore/_sync/vectorestore.py create mode 100644 test_elasticsearch/test_server/test_vectorstore/_sync/__init__.py create mode 100644 test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py create mode 100644 test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py create mode 100644 test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py diff --git a/elasticsearch/vectorstore/_async/_utils.py b/elasticsearch/vectorstore/_async/_utils.py index 5b5aacdbc..dac58715d 100644 --- a/elasticsearch/vectorstore/_async/_utils.py +++ b/elasticsearch/vectorstore/_async/_utils.py @@ -6,9 +6,7 @@ ) -async def async_model_must_be_deployed( - client: AsyncElasticsearch, model_id: str -) -> None: +async def model_must_be_deployed(client: AsyncElasticsearch, model_id: str) -> None: try: dummy = {"x": "y"} await client.ml.infer_trained_model(model_id=model_id, docs=[dummy]) @@ -28,9 +26,9 @@ async def async_model_must_be_deployed( return None -async def async_model_is_deployed(es_client: AsyncElasticsearch, model_id: str) -> bool: +async def model_is_deployed(es_client: AsyncElasticsearch, model_id: str) -> bool: try: - await async_model_must_be_deployed(es_client, model_id) + await model_must_be_deployed(es_client, model_id) return True except NotFoundError: return False diff --git a/elasticsearch/vectorstore/_async/strategies.py b/elasticsearch/vectorstore/_async/strategies.py index 22d9ab66b..b67b6cffc 100644 --- a/elasticsearch/vectorstore/_async/strategies.py +++ b/elasticsearch/vectorstore/_async/strategies.py @@ -4,7 +4,7 @@ from elasticsearch import AsyncElasticsearch -from elasticsearch.vectorstore._async._utils import async_model_must_be_deployed +from elasticsearch.vectorstore._async._utils import model_must_be_deployed from elasticsearch.vectorstore._async.embedding_service import AsyncEmbeddingService @@ -121,7 +121,7 @@ async def create_index( metadata_mapping: Optional[dict[str, str]], ) -> None: if self.model_id: - await async_model_must_be_deployed(client, self.model_id) + await model_must_be_deployed(client, self.model_id) mappings: dict[str, Any] = { "properties": { @@ -195,7 +195,7 @@ async def create_index( pipeline_name = f"{self.model_id}_sparse_embedding" if self.model_id: - await async_model_must_be_deployed(client, self.model_id) + await model_must_be_deployed(client, self.model_id) # Create a pipeline for the model await client.ingest.put_pipeline( @@ -317,7 +317,7 @@ async def create_index( ) if self.model_id: - await async_model_must_be_deployed(client, self.model_id) + await model_must_be_deployed(client, self.model_id) if self.distance is DistanceMetric.COSINE: similarityAlgo = "cosine" diff --git a/elasticsearch/vectorstore/_sync/__init__.py b/elasticsearch/vectorstore/_sync/__init__.py new file mode 100644 index 000000000..3079492b6 --- /dev/null +++ b/elasticsearch/vectorstore/_sync/__init__.py @@ -0,0 +1,5 @@ +from elasticsearch.vectorstore._sync.vectorestore import VectorStore + +__all__ = [ + "VectorStore", +] diff --git a/elasticsearch/vectorstore/_sync/_utils.py b/elasticsearch/vectorstore/_sync/_utils.py new file mode 100644 index 000000000..e400f9a09 --- /dev/null +++ b/elasticsearch/vectorstore/_sync/_utils.py @@ -0,0 +1,34 @@ +from elasticsearch import ( + Elasticsearch, + BadRequestError, + ConflictError, + NotFoundError, +) + + +def model_must_be_deployed(client: Elasticsearch, model_id: str) -> None: + try: + dummy = {"x": "y"} + client.ml.infer_trained_model(model_id=model_id, docs=[dummy]) + except NotFoundError as err: + raise err + except ConflictError as err: + raise NotFoundError( + f"model '{model_id}' not found, please deploy it first", + meta=err.meta, + body=err.body, + ) from err + except BadRequestError: + # This error is expected because we do not know the expected document + # shape and just use a dummy doc above. + pass + + return None + + +def model_is_deployed(es_client: Elasticsearch, model_id: str) -> bool: + try: + model_must_be_deployed(es_client, model_id) + return True + except NotFoundError: + return False diff --git a/elasticsearch/vectorstore/_sync/embedding_service.py b/elasticsearch/vectorstore/_sync/embedding_service.py new file mode 100644 index 000000000..1c8d39e3a --- /dev/null +++ b/elasticsearch/vectorstore/_sync/embedding_service.py @@ -0,0 +1,80 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +from elasticsearch import Elasticsearch + + +class EmbeddingService(ABC): + @abstractmethod + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Generate embeddings for a list of documents. + + Args: + texts: A list of document strings to generate embeddings for. + + Returns: + A list of embeddings, one for each document in the input. + """ + + @abstractmethod + def embed_query(self, query: str) -> List[float]: + """Generate an embedding for a single query text. + + Args: + text: The query text to generate an embedding for. + + Returns: + The embedding for the input query text. + """ + + +class ElasticsearchEmbeddings(EmbeddingService): + """Elasticsearch as a service for embedding model inference. + + You need to have an embedding model downloaded and deployed in Elasticsearch: + - https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html + - https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html + """ # noqa: E501 + + def __init__( + self, + es_client: Elasticsearch, + user_agent: str, + model_id: str, + input_field: str = "text_field", + num_dimensions: Optional[int] = None, + ): + """ + Args: + agent_header: user agent header specific to the 3rd party integration. + Used for usage tracking in Elastic Cloud. + model_id: The model_id of the model deployed in the Elasticsearch cluster. + input_field: The name of the key for the input text field in the + document. Defaults to 'text_field'. + num_dimensions: The number of embedding dimensions. If None, then dimensions + will be infer from an example inference call. + es_client: Elasticsearch client connection. Alternatively specify the + Elasticsearch connection with the other es_* parameters. + """ + # Add integration-specific usage header for tracking usage in Elastic Cloud. + # client.options preserces existing (non-user-agent) headers. + es_client = es_client.options(headers={"User-Agent": user_agent}) + + self.client = es_client.ml + self.model_id = model_id + self.input_field = input_field + self._num_dimensions = num_dimensions + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + result = self._embedding_func(texts) + return result + + def embed_query(self, text: str) -> List[float]: + result = self._embedding_func([text]) + return result[0] + + def _embedding_func(self, texts: List[str]) -> List[List[float]]: + response = self.client.infer_trained_model( + model_id=self.model_id, docs=[{self.input_field: text} for text in texts] + ) + return [doc["predicted_value"] for doc in response["inference_results"]] diff --git a/elasticsearch/vectorstore/_sync/strategies.py b/elasticsearch/vectorstore/_sync/strategies.py new file mode 100644 index 000000000..3dd182c80 --- /dev/null +++ b/elasticsearch/vectorstore/_sync/strategies.py @@ -0,0 +1,560 @@ +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union, cast + +from elasticsearch import Elasticsearch + +from elasticsearch.vectorstore._sync._utils import model_must_be_deployed +from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService + + +class DistanceMetric(str, Enum): + """Enumerator of all Elasticsearch dense vector distance metrics.""" + + COSINE = "COSINE" + DOT_PRODUCT = "DOT_PRODUCT" + EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE" + MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT" + + +class RetrievalStrategy(ABC): + @abstractmethod + def es_query( + self, + query: Optional[str], + k: int, + num_candidates: int, + filter: List[dict] = [], + query_vector: Optional[List[float]] = None, + ) -> Dict: + """ + Returns the Elasticsearch query body for the given parameters. + The store will execute the query. + + Args: + query: The text query. Can be None if query_vector is given. + k: The total number of results to retrieve. + num_candidates: The number of results to fetch initially in knn search. + filter: List of filter clauses to apply to the query. + query_vector: The query vector. Can be None if a query string is given. + + Returns: + Dict: The Elasticsearch query body. + """ + + @abstractmethod + def create_index( + self, + client: Elasticsearch, + index_name: str, + metadata_mapping: Optional[dict[str, str]], + ) -> None: + """ + Create the required index and do necessary preliminary work, like + creating inference pipelines or checking if a required model was deployed. + + Args: + client: Elasticsearch client connection. + index_name: The name of the Elasticsearch index to create. + metadata_mapping: Flat dictionary with field and field type pairs that + describe the schema of the metadata. + """ + + def embed_for_indexing(self, text: str) -> Dict[str, Any]: + """ + If this strategy creates vector embeddings in Python (not in Elasticsearch), + this method is used to apply the inference. + The output is a dictionary with the vector field and the vector embedding. + It is merged in the ElasticserachStore with the rest of the document (text data, + metadata) before indexing. + + Args: + text: Text input that can be used as input for inference. + + Returns: + Dict: field and value pairs that extend the document to be indexed. + """ + return {} + + +# TODO test when repsective image is released +class Semantic(RetrievalStrategy): + """Dense or sparse retrieval with in-stack inference using semantic_text fields.""" + + def __init__( + self, + model_id: str, + text_field: str = "text_field", + inference_field: str = "text_semantic", + ): + self.model_id = model_id + self.text_field = text_field + self.inference_field = inference_field + + def es_query( + self, + query: Optional[str], + k: int, + num_candidates: int, + filter: List[dict] = [], + query_vector: Optional[List[float]] = None, + ) -> Dict: + if query_vector: + raise ValueError( + "Cannot do sparse retrieval with a query_vector. " + "Inference is currently always applied in-stack." + ) + + return { + "query": { + "semantic": { + self.text_field: query, + }, + }, + "filter": filter, + } + + def create_index( + self, + client: Elasticsearch, + index_name: str, + metadata_mapping: Optional[dict[str, str]], + ) -> None: + if self.model_id: + model_must_be_deployed(client, self.model_id) + + mappings: dict[str, Any] = { + "properties": { + self.inference_field: { + "type": "semantic_text", + "model_id": self.model_id, + } + } + } + if metadata_mapping: + mappings["properties"]["metadata"] = {"properties": metadata_mapping} + + client.indices.create(index=index_name, mappings=mappings) + + +class SparseVector(RetrievalStrategy): + """Sparse retrieval strategy using the `text_expansion` processor.""" + + def __init__( + self, + model_id: str = ".elser_model_2", + text_field: str = "text_field", + vector_field: str = "vector_field", + ): + self.model_id = model_id + self.text_field = text_field + self.vector_field = vector_field + self._tokens_field = "tokens" + + def es_query( + self, + query: Optional[str], + k: int, + num_candidates: int, + filter: List[dict] = [], + query_vector: Optional[List[float]] = None, + ) -> Dict: + if query_vector: + raise ValueError( + "Cannot do sparse retrieval with a query_vector. " + "Inference is currently always applied in Elasticsearch." + ) + if query is None: + raise ValueError("please specify a query string") + + return { + "query": { + "bool": { + "must": [ + { + "text_expansion": { + f"{self.vector_field}.{self._tokens_field}": { + "model_id": self.model_id, + "model_text": query, + } + } + } + ], + "filter": filter, + } + }, + "size": k, + } + + def create_index( + self, + client: Elasticsearch, + index_name: str, + metadata_mapping: Optional[dict[str, str]], + ) -> None: + pipeline_name = f"{self.model_id}_sparse_embedding" + + if self.model_id: + model_must_be_deployed(client, self.model_id) + + # Create a pipeline for the model + client.ingest.put_pipeline( + id=pipeline_name, + description="Embedding pipeline for Python VectorStore", + processors=[ + { + "inference": { + "model_id": self.model_id, + "target_field": self.vector_field, + "field_map": {self.text_field: "text_field"}, + "inference_config": { + "text_expansion": {"results_field": self._tokens_field} + }, + } + } + ], + ) + + mappings = { + "properties": { + self.vector_field: { + "properties": {self._tokens_field: {"type": "rank_features"}} + } + } + } + if metadata_mapping: + mappings["properties"]["metadata"] = {"properties": metadata_mapping} + settings = {"default_pipeline": pipeline_name} + + client.indices.create( + index=index_name, mappings=mappings, settings=settings + ) + + return None + + +class DenseVector(RetrievalStrategy): + """K-nearest-neighbors retrieval.""" + + def __init__( + self, + knn_type: Literal["hnsw", "int8_hnsw", "flat", "int8_flat"] = "hnsw", + vector_field: str = "vector_field", + distance: DistanceMetric = DistanceMetric.COSINE, + embedding_service: Optional[EmbeddingService] = None, + model_id: Optional[str] = None, + num_dimensions: Optional[int] = None, + hybrid: bool = False, + rrf: Union[bool, dict] = True, + text_field: Optional[str] = "text_field", + ): + if embedding_service and model_id: + raise ValueError("either specify embedding_service or model_id, not both") + if model_id and not num_dimensions: + raise ValueError( + "if model_id is specified, num_dimensions must also be specified" + ) + if hybrid and not text_field: + raise ValueError( + "to enable hybrid you have to specify a text_field (for BM25 matching)" + ) + + self.knn_type = knn_type + self.vector_field = vector_field + self.distance = distance + self.embedding_service = embedding_service + self.model_id = model_id + self.num_dimensions = num_dimensions + self.hybrid = hybrid + self.rrf = rrf + self.text_field = text_field + + def es_query( + self, + query: Optional[str], + k: int, + num_candidates: int, + filter: List[dict] = [], + query_vector: Optional[List[float]] = None, + ) -> Dict: + knn = { + "filter": filter, + "field": self.vector_field, + "k": k, + "num_candidates": num_candidates, + } + + if query_vector: + knn["query_vector"] = query_vector + elif self.embedding_service: + knn["query_vector"] = self.embedding_service.embed_query( + cast(str, query) + ) + else: + # Inference in Elasticsearch. When initializing we make sure to always have + # a model_id if don't have an embedding_service. + knn["query_vector_builder"] = { + "text_embedding": { + "model_id": self.model_id, + "model_text": query, + } + } + + if self.hybrid: + return self._hybrid(query=cast(str, query), knn=knn, filter=filter) + + return {"knn": knn} + + def create_index( + self, + client: Elasticsearch, + index_name: str, + metadata_mapping: Optional[dict[str, str]], + ) -> None: + if self.embedding_service and not self.num_dimensions: + self.num_dimensions = len( + self.embedding_service.embed_query("get number of dimensions") + ) + + if self.model_id: + model_must_be_deployed(client, self.model_id) + + if self.distance is DistanceMetric.COSINE: + similarityAlgo = "cosine" + elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE: + similarityAlgo = "l2_norm" + elif self.distance is DistanceMetric.DOT_PRODUCT: + similarityAlgo = "dot_product" + elif self.distance is DistanceMetric.MAX_INNER_PRODUCT: + similarityAlgo = "max_inner_product" + else: + raise ValueError(f"Similarity {self.distance} not supported.") + + mappings: Dict[str, Any] = { + "properties": { + self.vector_field: { + "type": "dense_vector", + "dims": self.num_dimensions, + "index": True, + "similarity": similarityAlgo, + }, + } + } + if metadata_mapping: + mappings["properties"]["metadata"] = {"properties": metadata_mapping} + + r = client.indices.create(index=index_name, mappings=mappings) + print(r) + + def embed_for_indexing(self, text: str) -> Dict[str, Any]: + if self.embedding_service: + vector = self.embedding_service.embed_query(text) + return {self.vector_field: vector} + return {} + + def _hybrid(self, query: str, knn: dict, filter: list): + # Add a query to the knn query. + # RRF is used to even the score from the knn query and text query + # RRF has two optional parameters: {'rank_constant':int, 'window_size':int} + # https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html + query_body = { + "knn": knn, + "query": { + "bool": { + "must": [ + { + "match": { + self.text_field: { + "query": query, + } + } + } + ], + "filter": filter, + } + }, + } + + if isinstance(self.rrf, dict): + query_body["rank"] = {"rrf": self.rrf} + elif isinstance(self.rrf, bool) and self.rrf is True: + query_body["rank"] = {"rrf": {}} + + return query_body + + +class DenseVectorScriptScore(RetrievalStrategy): + """Exact nearest neighbors retrieval using the `script_score` query.""" + + def __init__( + self, + embedding_service: EmbeddingService, + vector_field: str = "vector_field", + distance: DistanceMetric = DistanceMetric.COSINE, + num_dimensions: Optional[int] = None, + ) -> None: + self.vector_field = vector_field + self.distance = distance + self.embedding_service = embedding_service + self.num_dimensions = num_dimensions + + def es_query( + self, + query: Optional[str], + k: int, + num_candidates: int, + filter: List[dict] = [], + query_vector: Optional[List[float]] = None, + ) -> Dict: + if self.distance is DistanceMetric.COSINE: + similarityAlgo = ( + f"cosineSimilarity(params.query_vector, '{self.vector_field}') + 1.0" + ) + elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE: + similarityAlgo = ( + f"1 / (1 + l2norm(params.query_vector, '{self.vector_field}'))" + ) + elif self.distance is DistanceMetric.DOT_PRODUCT: + similarityAlgo = f""" + double value = dotProduct(params.query_vector, '{self.vector_field}'); + return sigmoid(1, Math.E, -value); + """ + elif self.distance is DistanceMetric.MAX_INNER_PRODUCT: + similarityAlgo = f""" + double value = dotProduct(params.query_vector, '{self.vector_field}'); + if (dotProduct < 0) {{ + return 1 / (1 + -1 * dotProduct); + }} + return dotProduct + 1; + """ + else: + raise ValueError(f"Similarity {self.distance} not supported.") + + queryBool: Dict = {"match_all": {}} + if filter: + queryBool = {"bool": {"filter": filter}} + + if not query_vector: + if not self.embedding_service: + raise ValueError( + "if not embedding_service is given, you need to " + "procive a query_vector" + ) + if not query: + raise ValueError("either specify a query string or a query_vector") + query_vector = self.embedding_service.embed_query(query) + + return { + "query": { + "script_score": { + "query": queryBool, + "script": { + "source": similarityAlgo, + "params": {"query_vector": query_vector}, + }, + }, + } + } + + def create_index( + self, + client: Elasticsearch, + index_name: str, + metadata_mapping: Optional[dict[str, str]], + ) -> None: + if not self.num_dimensions: + self.num_dimensions = len( + self.embedding_service.embed_query("get number of dimensions") + ) + + mappings = { + "properties": { + self.vector_field: { + "type": "dense_vector", + "dims": self.num_dimensions, + "index": False, + } + } + } + if metadata_mapping: + mappings["properties"]["metadata"] = {"properties": metadata_mapping} + + client.indices.create(index=index_name, mappings=mappings) + + return None + + def embed_for_indexing(self, text: str) -> Dict[str, Any]: + return {self.vector_field: self.embedding_service.embed_query(text)} + + +class BM25(RetrievalStrategy): + def __init__( + self, + text_field: str = "text_field", + k1: Optional[float] = None, + b: Optional[float] = None, + ): + self.text_field = text_field + self.k1 = k1 + self.b = b + + def es_query( + self, + query: Optional[str], + k: int, + num_candidates: int, + filter: List[dict] = [], + query_vector: Optional[List[float]] = None, + ) -> Dict: + return { + "query": { + "bool": { + "must": [ + { + "match": { + self.text_field: { + "query": query, + } + }, + }, + ], + "filter": filter, + }, + }, + } + + def create_index( + self, + client: Elasticsearch, + index_name: str, + metadata_mapping: Optional[dict[str, str]], + ) -> None: + similarity_name = "custom_bm25" + + mappings: Dict = { + "properties": { + self.text_field: { + "type": "text", + "similarity": similarity_name, + }, + }, + } + if metadata_mapping: + mappings["properties"]["metadata"] = {"properties": metadata_mapping} + + bm25: Dict = { + "type": "BM25", + } + if self.k1 is not None: + bm25["k1"] = self.k1 + if self.b is not None: + bm25["b"] = self.b + settings = { + "similarity": { + similarity_name: bm25, + } + } + + client.indices.create( + index=index_name, mappings=mappings, settings=settings + ) diff --git a/elasticsearch/vectorstore/_sync/vectorestore.py b/elasticsearch/vectorstore/_sync/vectorestore.py new file mode 100644 index 000000000..2d2338ee8 --- /dev/null +++ b/elasticsearch/vectorstore/_sync/vectorestore.py @@ -0,0 +1,323 @@ +import logging +import uuid +from typing import Any, Callable, Dict, List, Optional + +from elasticsearch import Elasticsearch +from elasticsearch.helpers import BulkIndexError, bulk + +from elasticsearch.vectorstore._utils import ( + maximal_marginal_relevance, +) +from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService +from elasticsearch.vectorstore._sync.strategies import RetrievalStrategy + +logger = logging.getLogger(__name__) + + +class VectorStore: + """VectorStore is a higher-level abstraction of indexing and search. + Users can pick from available retrieval strategies. + + Documents are flat text documents. Depending on the strategy, vector embeddings are + - created by the user beforehand + - created by this class in Python + - created in-stack by inference pipelines. + """ + + def __init__( + self, + es_client: Elasticsearch, + user_agent: str, + index_name: str, + retrieval_strategy: RetrievalStrategy, + text_field: str = "text_field", + vector_field: str = "vector_field", + metadata_mapping: Optional[dict[str, str]] = None, + ) -> None: + """ + Args: + user_header: user agent header specific to the 3rd party integration. + Used for usage tracking in Elastic Cloud. + index_name: The name of the index to query. + retrieval_strategy: how to index and search the data. See the strategies + module for availble strategies. + text_field: Name of the field with the textual data. + vector_field: For strategies that perform embedding inference in Python, + the embedding vector goes in this field. + es_client: Elasticsearch client connection. Alternatively specify the + Elasticsearch connection with the other es_* parameters. + """ + + # Add integration-specific usage header for tracking usage in Elastic Cloud. + # client.options preserces existing (non-user-agent) headers. + es_client = es_client.options(headers={"User-Agent": user_agent}) + + if hasattr(retrieval_strategy, "text_field"): + retrieval_strategy.text_field = text_field + if hasattr(retrieval_strategy, "vector_field"): + retrieval_strategy.vector_field = vector_field + + self.es_client = es_client + self.index_name = index_name + self.retrieval_strategy = retrieval_strategy + self.text_field = text_field + self.vector_field = vector_field + self.metadata_mapping = metadata_mapping + + def close(self): + return self.es_client.close() + + def add_texts( + self, + texts: List[str], + metadatas: Optional[List[Dict[str, Any]]] = None, + vectors: Optional[List[List[float]]] = None, + ids: Optional[List[str]] = None, + refresh_indices: bool = True, + create_index_if_not_exists: bool = True, + bulk_kwargs: Optional[Dict[str, Any]] = None, + ) -> List[str]: + """Add documents to the Elasticsearch index. + + Args: + texts: List of text documents. + metadata: Optional list of document metadata. Must be of same length as + texts. + vectors: Optional list of embedding vectors. Must be of same length as + texts. + ids: Optional list of ID strings. Must be of same length as texts. + refresh_indices: Whether to refresh the index after deleting documents. + Defaults to True. + create_index_if_not_exists: Whether to create the index if it does not + exist. Defaults to True. + bulk_kwargs: Arguments to pass to the bulk function when indexing + (for example chunk_size). + + Returns: + List of IDs of the created documents, either echoing the provided one + or returning newly created ones. + """ + bulk_kwargs = bulk_kwargs or {} + ids = ids or [str(uuid.uuid4()) for _ in texts] + requests = [] + + if create_index_if_not_exists: + self._create_index_if_not_exists() + + for i, text in enumerate(texts): + metadata = metadatas[i] if metadatas else {} + + request: Dict[str, Any] = { + "_op_type": "index", + "_index": self.index_name, + self.text_field: text, + "metadata": metadata, + "_id": ids[i], + } + + if vectors: + request[self.vector_field] = vectors[i] + + request.update(self.retrieval_strategy.embed_for_indexing(text)) + requests.append(request) + + if len(requests) > 0: + try: + success, failed = bulk( + self.es_client, + requests, + stats_only=True, + refresh=refresh_indices, + **bulk_kwargs, + ) + logger.debug(f"added texts {ids} to index") + return ids + except BulkIndexError as e: + logger.error(f"Error adding texts: {e}") + firstError = e.errors[0].get("index", {}).get("error", {}) + logger.error(f"First error reason: {firstError.get('reason')}") + raise e + + else: + logger.debug("No texts to add to index") + return [] + + def delete( + self, + ids: Optional[List[str]] = None, + query: Optional[Dict[str, Any]] = None, + refresh_indices: bool = True, + **delete_kwargs, + ) -> bool: + """Delete documents from the Elasticsearch index. + + Args: + ids: List of IDs of documents to delete. + refresh_indices: Whether to refresh the index after deleting documents. + Defaults to True. + """ + if ids is not None and query is not None: + raise ValueError("one of ids or query must be specified") + elif ids is None and query is None: + raise ValueError("either specify ids or query") + + try: + if ids: + body = [ + {"_op_type": "delete", "_index": self.index_name, "_id": _id} + for _id in ids + ] + bulk( + self.es_client, + body, + refresh=refresh_indices, + ignore_status=404, + **delete_kwargs, + ) + logger.debug(f"Deleted {len(body)} texts from index") + + else: + self.es_client.delete_by_query( + index=self.index_name, + query=query, + refresh=refresh_indices, + **delete_kwargs, + ) + + except BulkIndexError as e: + logger.error(f"Error deleting texts: {e}") + firstError = e.errors[0].get("index", {}).get("error", {}) + logger.error(f"First error reason: {firstError.get('reason')}") + raise e + + return True + + def search( + self, + query: Optional[str], + query_vector: Optional[List[float]] = None, + k: int = 4, + num_candidates: int = 50, + fields: Optional[List[str]] = None, + filter: Optional[List[dict]] = None, + custom_query: Optional[Callable[[Dict, Optional[str]], Dict]] = None, + ) -> List[Dict[str, Any]]: + """ + Args: + query: Input query string. + query_vector: Input embedding vector. If given, input query string is + ignored. + k: Number of returned results. + num_candidates: Number of candidates to fetch from data nodes in knn. + fields: List of field names to return. + filter: Elasticsearch filters to apply. + custom_query: Function to modify the Elasticsearch query body before it is + sent to Elasticsearch. + + Returns: + List of document hits. Includes _index, _id, _score and _source. + """ + if fields is None: + fields = [] + if "metadata" not in fields: + fields.append("metadata") + if self.text_field not in fields: + fields.append(self.text_field) + + query_body = self.retrieval_strategy.es_query( + query=query, + k=k, + num_candidates=num_candidates, + filter=filter or [], + query_vector=query_vector, + ) + + if custom_query is not None: + query_body = custom_query(query_body, query) + logger.debug(f"Calling custom_query, Query body now: {query_body}") + + response = self.es_client.search( + index=self.index_name, + **query_body, + size=k, + source=True, + source_includes=fields, + ) + + return response["hits"]["hits"] + + def _create_index_if_not_exists(self) -> None: + exists = self.es_client.indices.exists(index=self.index_name) + if exists.meta.status == 200: + logger.debug(f"Index {self.index_name} already exists. Skipping creation.") + else: + self.retrieval_strategy.create_index( + client=self.es_client, + index_name=self.index_name, + metadata_mapping=self.metadata_mapping, + ) + + def max_marginal_relevance_search( + self, + embedding_service: EmbeddingService, + query: str, + vector_field: str, + k: int = 4, + num_candidates: int = 20, + lambda_mult: float = 0.5, + fields: Optional[List[str]] = None, + custom_query: Optional[Callable[[Dict, Optional[str]], Dict]] = None, + ) -> List[Dict]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + fields: Other fields to get from elasticsearch source. These fields + will be added to the document metadata. + + Returns: + List[Document]: A list of Documents selected by maximal marginal relevance. + """ + remove_vector_query_field_from_metadata = True + if fields is None: + fields = [vector_field] + elif vector_field not in fields: + fields.append(vector_field) + else: + remove_vector_query_field_from_metadata = False + + # Embed the query + query_embedding = embedding_service.embed_query(query) + + # Fetch the initial documents + got_hits = self.search( + query=None, + query_vector=query_embedding, + k=num_candidates, + fields=fields, + custom_query=custom_query, + ) + + # Get the embeddings for the fetched documents + got_embeddings = [hit["_source"][vector_field] for hit in got_hits] + + # Select documents using maximal marginal relevance + selected_indices = maximal_marginal_relevance( + query_embedding, got_embeddings, lambda_mult=lambda_mult, k=k + ) + selected_hits = [got_hits[i] for i in selected_indices] + + if remove_vector_query_field_from_metadata: + for hit in selected_hits: + del hit["_source"][vector_field] + + return selected_hits diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py b/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py index 9a69c8b42..eb6245b20 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py @@ -1,6 +1,5 @@ import os -import pytest_asyncio -from typing import Any, Dict, List, Optional, AsyncGenerator +from typing import Any, Dict, List, Optional, AsyncIterator from elastic_transport import AsyncTransport from elasticsearch import AsyncElasticsearch @@ -94,7 +93,7 @@ def create_requests_saving_client() -> AsyncElasticsearch: return create_es_client(es_kwargs={"transport_class": AsyncRequestSavingTransport}) -async def es_client_fixture() -> AsyncGenerator[AsyncElasticsearch, None]: +async def es_client_fixture() -> AsyncIterator[AsyncElasticsearch]: params = read_env() client = create_es_client(params) diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py index b349b724c..ef41df681 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py @@ -5,9 +5,9 @@ import pytest_asyncio from elasticsearch import AsyncElasticsearch -from typing import AsyncGenerator +from typing import AsyncIterator -from elasticsearch.vectorstore._async._utils import async_model_is_deployed +from elasticsearch.vectorstore._async._utils import model_is_deployed from ._test_utils import ( es_client_fixture, @@ -24,7 +24,7 @@ @pytest_asyncio.fixture(autouse=True) -async def es_client() -> AsyncGenerator[AsyncElasticsearch, None]: +async def es_client() -> AsyncIterator[AsyncElasticsearch]: async for x in es_client_fixture(): yield x @@ -33,7 +33,7 @@ async def es_client() -> AsyncGenerator[AsyncElasticsearch, None]: async def test_elasticsearch_embedding_documents(es_client: AsyncElasticsearch) -> None: """Test Elasticsearch embedding documents.""" - if not await async_model_is_deployed(es_client, MODEL_ID): + if not await model_is_deployed(es_client, MODEL_ID): pytest.skip(f"{MODEL_ID} model is not deployed in ML Node, skipping test") documents = ["foo bar", "bar foo", "foo"] @@ -51,7 +51,7 @@ async def test_elasticsearch_embedding_documents(es_client: AsyncElasticsearch) async def test_elasticsearch_embedding_query(es_client: AsyncElasticsearch) -> None: """Test Elasticsearch embedding query.""" - if not await async_model_is_deployed(es_client, MODEL_ID): + if not await model_is_deployed(es_client, MODEL_ID): pytest.skip(f"{MODEL_ID} model is not deployed in ML Node, skipping test") document = "foo bar" diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py index 7dab019a6..f64c81034 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py @@ -1,6 +1,6 @@ import logging import uuid -from typing import AsyncGenerator +from typing import AsyncIterator from typing import Any, List, Optional, Union, cast from functools import partial @@ -12,7 +12,7 @@ from elasticsearch.helpers import BulkIndexError from elasticsearch.vectorstore._async import AsyncVectorStore -from elasticsearch.vectorstore._async._utils import async_model_is_deployed +from elasticsearch.vectorstore._async._utils import model_is_deployed from elasticsearch.vectorstore._async.strategies import ( BM25, DenseVector, @@ -54,12 +54,12 @@ class TestElasticsearch: @pytest_asyncio.fixture(autouse=True) - async def es_client(self) -> AsyncGenerator[AsyncElasticsearch, None]: + async def es_client(self) -> AsyncIterator[AsyncElasticsearch]: async for x in es_client_fixture(): yield x @pytest_asyncio.fixture(autouse=True) - async def requests_saving_client(self) -> AsyncGenerator[AsyncElasticsearch, None]: + async def requests_saving_client(self) -> AsyncIterator[AsyncElasticsearch]: client = create_requests_saving_client() try: yield client @@ -567,7 +567,7 @@ async def test_search_with_knn_infer_instack( ) -> None: """test end to end with knn retrieval strategy and inference in-stack""" - if not await async_model_is_deployed(es_client, TRANSFORMER_MODEL_ID): + if not await model_is_deployed(es_client, TRANSFORMER_MODEL_ID): pytest.skip( f"{TRANSFORMER_MODEL_ID} model not deployed in ML Node skipping test" ) @@ -661,7 +661,7 @@ async def test_search_with_sparse_infer_instack( ) -> None: """test end to end with sparse retrieval strategy and inference in-stack""" - if not await async_model_is_deployed(es_client, ELSER_MODEL_ID): + if not await model_is_deployed(es_client, ELSER_MODEL_ID): reason = f"{ELSER_MODEL_ID} model not deployed in ML Node, skipping test" pytest.skip(reason) diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/__init__.py b/test_elasticsearch/test_server/test_vectorstore/_sync/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py b/test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py new file mode 100644 index 000000000..774a580f6 --- /dev/null +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py @@ -0,0 +1,138 @@ +import os +from typing import Any, Dict, List, Optional, Iterator + +from elastic_transport import Transport +from elasticsearch import Elasticsearch + +from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService + + +class FakeEmbeddings(EmbeddingService): + """Fake embeddings functionality for testing.""" + + def __init__(self, dimensionality: int = 10) -> None: + self.dimensionality = dimensionality + + def num_dimensions(self) -> int: + return self.dimensionality + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return simple embeddings. Embeddings encode each text as its index.""" + return [ + [float(1.0)] * (self.dimensionality - 1) + [float(i)] + for i in range(len(texts)) + ] + + def embed_query(self, text: str) -> List[float]: + """Return constant query embeddings. + Embeddings are identical to embed_documents(texts)[0]. + Distance to each text will be that text's index, + as it was passed to embed_documents. + """ + return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)] + + +class ConsistentFakeEmbeddings(FakeEmbeddings): + """Fake embeddings which remember all the texts seen so far to return consistent + vectors for the same texts.""" + + def __init__(self, dimensionality: int = 10) -> None: + self.known_texts: List[str] = [] + self.dimensionality = dimensionality + + def num_dimensions(self) -> int: + return self.dimensionality + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return consistent embeddings for each text seen so far.""" + out_vectors = [] + for text in texts: + if text not in self.known_texts: + self.known_texts.append(text) + vector = [float(1.0)] * (self.dimensionality - 1) + [ + float(self.known_texts.index(text)) + ] + out_vectors.append(vector) + return out_vectors + + def embed_query(self, text: str) -> List[float]: + """Return consistent embeddings for the text, if seen before, or a constant + one if the text is unknown.""" + result = self.embed_documents([text]) + return result[0] + + +class RequestSavingTransport(Transport): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.requests: List[Dict] = [] + + def perform_request(self, *args, **kwargs): # type: ignore + self.requests.append(kwargs) + return super().perform_request(*args, **kwargs) + + +def create_es_client( + es_params: Optional[Dict[str, str]] = None, es_kwargs: Dict = {} +) -> Elasticsearch: + if es_params is None: + es_params = read_env() + if not es_kwargs: + es_kwargs = {} + + if "es_cloud_id" in es_params: + return Elasticsearch( + cloud_id=es_params["es_cloud_id"], + api_key=es_params["es_api_key"], + **es_kwargs, + ) + return Elasticsearch(hosts=[es_params["es_url"]], **es_kwargs) + + +def create_requests_saving_client() -> Elasticsearch: + return create_es_client(es_kwargs={"transport_class": RequestSavingTransport}) + + +def es_client_fixture() -> Iterator[Elasticsearch]: + params = read_env() + client = create_es_client(params) + + yield client + + # clear indices + clear_test_indices(client) + + # clear all test pipelines + try: + response = client.ingest.get_pipeline(id="test_*,*_sparse_embedding") + + for pipeline_id, _ in response.items(): + try: + client.ingest.delete_pipeline(id=pipeline_id) + print(f"Deleted pipeline: {pipeline_id}") # noqa: T201 + except Exception as e: + print(f"Pipeline error: {e}") # noqa: T201 + + except Exception: + pass + finally: + client.close() + + +def clear_test_indices(client: Elasticsearch) -> None: + response = client.indices.get(index="_all") + index_names = response.keys() + for index_name in index_names: + if index_name.startswith("test_"): + client.indices.delete(index=index_name) + client.indices.refresh(index="_all") + + +def read_env() -> Dict: + url = os.environ.get("ES_URL", "http://localhost:9200") + cloud_id = os.environ.get("ES_CLOUD_ID") + api_key = os.environ.get("ES_API_KEY") + + if cloud_id: + return {"es_cloud_id": cloud_id, "es_api_key": api_key} + return {"es_url": url} diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py b/test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py new file mode 100644 index 000000000..a68a479d7 --- /dev/null +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py @@ -0,0 +1,62 @@ +import os + +import pytest + +import pytest_asyncio +from elasticsearch import Elasticsearch + +from typing import Iterator + +from elasticsearch.vectorstore._sync._utils import model_is_deployed + +from ._test_utils import ( + es_client_fixture, +) + +from elasticsearch.vectorstore._sync.embedding_service import ( + ElasticsearchEmbeddings, +) + +# deployed with +# https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html +MODEL_ID = os.getenv("MODEL_ID", "sentence-transformers__msmarco-minilm-l-12-v3") +NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384")) + + +@pytest_asyncio.fixture(autouse=True) +def es_client() -> Iterator[Elasticsearch]: + for x in es_client_fixture(): + yield x + + +@pytest.mark.asyncio +def test_elasticsearch_embedding_documents(es_client: Elasticsearch) -> None: + """Test Elasticsearch embedding documents.""" + + if not model_is_deployed(es_client, MODEL_ID): + pytest.skip(f"{MODEL_ID} model is not deployed in ML Node, skipping test") + + documents = ["foo bar", "bar foo", "foo"] + embedding = ElasticsearchEmbeddings( + es_client=es_client, user_agent="test", model_id=MODEL_ID + ) + output = embedding.embed_documents(documents) + assert len(output) == 3 + assert len(output[0]) == NUM_DIMENSIONS + assert len(output[1]) == NUM_DIMENSIONS + assert len(output[2]) == NUM_DIMENSIONS + + +@pytest.mark.asyncio +def test_elasticsearch_embedding_query(es_client: Elasticsearch) -> None: + """Test Elasticsearch embedding query.""" + + if not model_is_deployed(es_client, MODEL_ID): + pytest.skip(f"{MODEL_ID} model is not deployed in ML Node, skipping test") + + document = "foo bar" + embedding = ElasticsearchEmbeddings( + es_client=es_client, user_agent="test", model_id=MODEL_ID + ) + output = embedding.embed_query(document) + assert len(output) == NUM_DIMENSIONS diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py b/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py new file mode 100644 index 000000000..ec01f8147 --- /dev/null +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py @@ -0,0 +1,928 @@ +import logging +import uuid +from typing import Iterator +from typing import Any, List, Optional, Union, cast +from functools import partial + +import pytest +import pytest_asyncio +from elasticsearch import Elasticsearch + +from elasticsearch import NotFoundError +from elasticsearch.helpers import BulkIndexError + +from elasticsearch.vectorstore._sync import VectorStore +from elasticsearch.vectorstore._sync._utils import model_is_deployed +from elasticsearch.vectorstore._sync.strategies import ( + BM25, + DenseVector, + DenseVectorScriptScore, + DistanceMetric, + Semantic, +) + +from ._test_utils import ( + create_requests_saving_client, + es_client_fixture, + ConsistentFakeEmbeddings, + FakeEmbeddings, + RequestSavingTransport, +) + +logging.basicConfig(level=logging.DEBUG) + +""" +docker-compose up elasticsearch + +By default runs against local docker instance of Elasticsearch. +To run against Elastic Cloud, set the following environment variables: +- ES_CLOUD_ID +- ES_API_KEY + +Some of the tests require the following models to be deployed in the ML Node: +- elser (can be downloaded and deployed through Kibana and trained models UI) +- sentence-transformers__all-minilm-l6-v2 (can be deployed through the API, + loaded via eland) + +These tests that require the models to be deployed are skipped by default. +Enable them by adding the model name to the modelsDeployed list below. +""" + +ELSER_MODEL_ID = ".elser_model_2" +TRANSFORMER_MODEL_ID = "sentence-transformers__all-minilm-l6-v2" + + +class TestElasticsearch: + @pytest_asyncio.fixture(autouse=True) + def es_client(self) -> Iterator[Elasticsearch]: + for x in es_client_fixture(): + yield x + + @pytest_asyncio.fixture(autouse=True) + def requests_saving_client(self) -> Iterator[Elasticsearch]: + client = create_requests_saving_client() + try: + yield client + finally: + client.close() + + @pytest.fixture(scope="function") + def index_name(self) -> str: + """Return the index name.""" + return f"test_{uuid.uuid4().hex}" + + @pytest.mark.asyncio + def test_search_without_metadata( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and search without metadata.""" + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "knn": { + "field": "vector_field", + "filter": [], + "k": 1, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + } + } + return query_body + + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + @pytest.mark.asyncio + def test_search_without_metadata_async( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and search without metadata.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + output = store.search("foo", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + @pytest.mark.asyncio + def test_add_vectors( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """ + Test adding pre-built embeddings instead of using inference for the texts. + This allows you to separate the embeddings text and the page_content + for better proximity between user's question and embedded text. + For example, your embedding text can be a question, whereas page_content + is the answer. + """ + embeddings = ConsistentFakeEmbeddings() + texts = ["foo1", "foo2", "foo3"] + metadatas = [{"page": i} for i in range(len(texts))] + + """In real use case, embedding_input can be questions for each text""" + embedding_vectors = embeddings.embed_documents(texts) + + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(embedding_service=embeddings), + es_client=es_client, + ) + + store.add_texts( + texts=texts, vectors=embedding_vectors, metadatas=metadatas + ) + output = store.search("foo1", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["foo1"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] + + @pytest.mark.asyncio + def test_search_with_metadata( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector( + embedding_service=ConsistentFakeEmbeddings() + ), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + store.add_texts(texts=texts, metadatas=metadatas) + + output = store.search("foo", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] + + output = store.search("bar", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["bar"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] + + @pytest.mark.asyncio + def test_search_with_filter( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + es_client=es_client, + ) + + texts = ["foo", "foo", "foo"] + metadatas = [{"page": i} for i in range(len(texts))] + store.add_texts(texts=texts, metadatas=metadatas) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "knn": { + "field": "vector_field", + "filter": [{"term": {"metadata.page": "1"}}], + "k": 3, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + } + } + return query_body + + output = store.search( + query="foo", + k=3, + filter=[{"term": {"metadata.page": "1"}}], + custom_query=assert_query, + ) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] + + @pytest.mark.asyncio + def test_search_script_score( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVectorScriptScore( + embedding_service=FakeEmbeddings() + ), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + expected_query = { + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": "cosineSimilarity(params.query_vector, 'vector_field') + 1.0", # noqa: E501 + "params": { + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ] + }, + }, + } + } + } + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == expected_query + return query_body + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + @pytest.mark.asyncio + def test_search_script_score_with_filter( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVectorScriptScore( + embedding_service=FakeEmbeddings() + ), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + store.add_texts(texts=texts, metadatas=metadatas) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + expected_query = { + "query": { + "script_score": { + "query": {"bool": {"filter": [{"term": {"metadata.page": 0}}]}}, + "script": { + "source": "cosineSimilarity(params.query_vector, 'vector_field') + 1.0", # noqa: E501 + "params": { + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ] + }, + }, + } + } + } + assert query_body == expected_query + return query_body + + output = store.search( + "foo", + k=1, + custom_query=assert_query, + filter=[{"term": {"metadata.page": 0}}], + ) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] + + @pytest.mark.asyncio + def test_search_script_score_distance_dot_product( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVectorScriptScore( + embedding_service=FakeEmbeddings(), + distance=DistanceMetric.DOT_PRODUCT, + ), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": """ + double value = dotProduct(params.query_vector, 'vector_field'); + return sigmoid(1, Math.E, -value); + """, + "params": { + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ] + }, + }, + } + } + } + return query_body + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + @pytest.mark.asyncio + def test_search_knn_with_hybrid_search( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector( + embedding_service=FakeEmbeddings(), + hybrid=True, + ), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "knn": { + "field": "vector_field", + "filter": [], + "k": 1, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + }, + "query": { + "bool": { + "filter": [], + "must": [{"match": {"text_field": {"query": "foo"}}}], + } + }, + "rank": {"rrf": {}}, + } + return query_body + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + @pytest.mark.asyncio + def test_search_knn_with_hybrid_search_rrf( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and rrf hybrid search with metadata.""" + texts = ["foo", "bar", "baz"] + + def assert_query( + query_body: dict, + query: Optional[str], + expected_rrf: Union[dict, bool], + ) -> dict: + cmp_query_body = { + "knn": { + "field": "vector_field", + "filter": [], + "k": 3, + "num_candidates": 50, + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ], + }, + "query": { + "bool": { + "filter": [], + "must": [{"match": {"text_field": {"query": "foo"}}}], + } + }, + } + + if isinstance(expected_rrf, dict): + cmp_query_body["rank"] = {"rrf": expected_rrf} + elif isinstance(expected_rrf, bool) and expected_rrf is True: + cmp_query_body["rank"] = {"rrf": {}} + + assert query_body == cmp_query_body + + return query_body + + # 1. check query_body is okay + rrf_test_cases: List[Union[dict, bool]] = [ + True, + False, + {"rank_constant": 1, "window_size": 5}, + ] + for rrf_test_case in rrf_test_cases: + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector( + embedding_service=FakeEmbeddings(), + hybrid=True, + rrf=rrf_test_case, + ), + es_client=es_client, + ) + store.add_texts(texts) + + ## without fetch_k parameter + output = store.search( + "foo", + k=3, + custom_query=partial(assert_query, expected_rrf=rrf_test_case), + ) + + # 2. check query result is okay + es_output = store.es_client.search( + index=index_name, + query={ + "bool": { + "filter": [], + "must": [{"match": {"text_field": {"query": "foo"}}}], + } + }, + knn={ + "field": "vector_field", + "filter": [], + "k": 3, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + }, + size=3, + rank={"rrf": {"rank_constant": 1, "window_size": 5}}, + ) + + assert [o["_source"]["text_field"] for o in output] == [ + e["_source"]["text_field"] for e in es_output["hits"]["hits"] + ] + + # 3. check rrf default option is okay + store = VectorStore( + user_agent="test", + index_name=f"{index_name}_default", + retrieval_strategy=DenseVector( + embedding_service=FakeEmbeddings(), + hybrid=True, + ), + es_client=es_client, + ) + store.add_texts(texts) + + ## with fetch_k parameter + output = store.search( + "foo", + k=3, + num_candidates=50, + custom_query=partial(assert_query, expected_rrf={}), + ) + + @pytest.mark.asyncio + def test_search_knn_with_custom_query_fn( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """test that custom query function is called + with the query string and query body""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + es_client=es_client, + ) + + def my_custom_query(query_body: dict, query: Optional[str]) -> dict: + assert query == "foo" + assert query_body == { + "knn": { + "field": "vector_field", + "filter": [], + "k": 1, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + } + } + return {"query": {"match": {"text_field": {"query": "bar"}}}} + + """Test end to end construction and search with metadata.""" + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + output = store.search("foo", k=1, custom_query=my_custom_query) + assert [doc["_source"]["text_field"] for doc in output] == ["bar"] + + @pytest.mark.asyncio + def test_search_with_knn_infer_instack( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """test end to end with knn retrieval strategy and inference in-stack""" + + if not model_is_deployed(es_client, TRANSFORMER_MODEL_ID): + pytest.skip( + f"{TRANSFORMER_MODEL_ID} model not deployed in ML Node skipping test" + ) + + text_field = "text_field" + + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=Semantic( + model_id="sentence-transformers__all-minilm-l6-v2", + text_field=text_field, + ), + es_client=es_client, + ) + + # setting up the pipeline for inference + store.es_client.ingest.put_pipeline( + id="test_pipeline", + processors=[ + { + "inference": { + "model_id": TRANSFORMER_MODEL_ID, + "field_map": {"query_field": text_field}, + "target_field": "vector_query_field", + } + } + ], + ) + + # creating a new index with the pipeline, + # not relying on langchain to create the index + store.es_client.indices.create( + index=index_name, + mappings={ + "properties": { + text_field: {"type": "text_field"}, + "vector_query_field": { + "properties": { + "predicted_value": { + "type": "dense_vector", + "dims": 384, + "index": True, + "similarity": "l2_norm", + } + } + }, + } + }, + settings={"index": {"default_pipeline": "test_pipeline"}}, + ) + + # adding documents to the index + texts = ["foo", "bar", "baz"] + + for i, text in enumerate(texts): + store.es_client.create( + index=index_name, + id=str(i), + document={text_field: text, "metadata": {}}, + ) + + store.es_client.indices.refresh(index=index_name) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "knn": { + "filter": [], + "field": "vector_query_field.predicted_value", + "k": 1, + "num_candidates": 50, + "query_vector_builder": { + "text_embedding": { + "model_id": TRANSFORMER_MODEL_ID, + "model_text": "foo", + } + }, + } + } + return query_body + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + output = store.search("bar", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["bar"] + + @pytest.mark.asyncio + def test_search_with_sparse_infer_instack( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """test end to end with sparse retrieval strategy and inference in-stack""" + + if not model_is_deployed(es_client, ELSER_MODEL_ID): + reason = f"{ELSER_MODEL_ID} model not deployed in ML Node, skipping test" + + pytest.skip(reason) + + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=Semantic(model_id=ELSER_MODEL_ID), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + output = store.search("foo", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + @pytest.mark.asyncio + def test_deployed_model_check_fails_semantic( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """test that exceptions are raised if a specified model is not deployed""" + with pytest.raises(NotFoundError): + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=Semantic(model_id="non-existing model ID"), + es_client=es_client, + ) + store.add_texts(["foo", "bar", "baz"]) + + @pytest.mark.asyncio + def test_search_bm25( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end using the BM25 retrieval strategy.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=BM25(), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "query": { + "bool": { + "must": [{"match": {"text_field": {"query": "foo"}}}], + "filter": [], + } + } + } + return query_body + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + @pytest.mark.asyncio + def test_search_bm25_with_filter( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to using the BM25 retrieval strategy with metadata.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=BM25(), + es_client=es_client, + ) + + texts = ["foo", "foo", "foo"] + metadatas = [{"page": i} for i in range(len(texts))] + store.add_texts(texts=texts, metadatas=metadatas) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "query": { + "bool": { + "must": [{"match": {"text_field": {"query": "foo"}}}], + "filter": [{"term": {"metadata.page": 1}}], + } + } + } + return query_body + + output = store.search( + "foo", + k=3, + custom_query=assert_query, + filter=[{"term": {"metadata.page": 1}}], + ) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] + + @pytest.mark.asyncio + def test_delete(self, es_client: Elasticsearch, index_name: str) -> None: + """Test delete methods from vector store.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz", "gni"] + metadatas = [{"page": i} for i in range(len(texts))] + ids = store.add_texts(texts=texts, metadatas=metadatas) + + output = store.search("foo", k=10) + assert len(output) == 4 + + store.delete(ids[1:3]) + output = store.search("foo", k=10) + assert len(output) == 2 + + store.delete(["not-existing"]) + output = store.search("foo", k=10) + assert len(output) == 2 + + store.delete([ids[0]]) + output = store.search("foo", k=10) + assert len(output) == 1 + + store.delete([ids[3]]) + output = store.search("gni", k=10) + assert len(output) == 0 + + @pytest.mark.asyncio + def test_indexing_exception_error( + self, + es_client: Elasticsearch, + index_name: str, + caplog: pytest.LogCaptureFixture, + ) -> None: + """Test bulk exception logging is giving better hints.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=BM25(), + es_client=es_client, + ) + + store.es_client.indices.create( + index=index_name, + mappings={"properties": {}}, + settings={"index": {"default_pipeline": "not-existing-pipeline"}}, + ) + + texts = ["foo"] + + with pytest.raises(BulkIndexError): + store.add_texts(texts) + + error_reason = "pipeline with id [not-existing-pipeline] does not exist" + log_message = f"First error reason: {error_reason}" + + assert log_message in caplog.text + + @pytest.mark.asyncio + def test_user_agent( + self, requests_saving_client: Elasticsearch, index_name: str + ) -> None: + """Test to make sure the user-agent is set correctly.""" + user_agent = "this is THE user_agent!" + store = VectorStore( + user_agent=user_agent, + index_name=index_name, + retrieval_strategy=BM25(), + es_client=requests_saving_client, + ) + + assert store.es_client._headers["User-Agent"] == user_agent + + texts = ["foo", "bob", "baz"] + store.add_texts(texts) + + transport = cast(RequestSavingTransport, store.es_client.transport) + + for request in transport.requests: + assert request["headers"]["User-Agent"] == user_agent + + @pytest.mark.asyncio + def test_bulk_args( + self, requests_saving_client: Any, index_name: str + ) -> None: + """Test to make sure the bulk arguments work as expected.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=BM25(), + es_client=requests_saving_client, + ) + + texts = ["foo", "bob", "baz"] + store.add_texts(texts, bulk_kwargs={"chunk_size": 1}) + + # 1 for index exist, 1 for index create, 3 to index docs + assert len(store.es_client.transport.requests) == 5 # type: ignore + + @pytest.mark.asyncio + def test_max_marginal_relevance_search( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test max marginal relevance search.""" + texts = ["foo", "bar", "baz"] + vector_field = "vector_field" + text_field = "text_field" + embedding_service = ConsistentFakeEmbeddings() + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVectorScriptScore( + embedding_service=embedding_service + ), + vector_field=vector_field, + text_field=text_field, + es_client=es_client, + ) + store.add_texts(texts) + + mmr_output = store.max_marginal_relevance_search( + embedding_service, + texts[0], + vector_field=vector_field, + k=3, + num_candidates=3, + ) + sim_output = store.search(texts[0], k=3) + assert mmr_output == sim_output + + mmr_output = store.max_marginal_relevance_search( + embedding_service, + texts[0], + vector_field=vector_field, + k=2, + num_candidates=3, + ) + assert len(mmr_output) == 2 + assert mmr_output[0]["_source"][text_field] == texts[0] + assert mmr_output[1]["_source"][text_field] == texts[1] + + mmr_output = store.max_marginal_relevance_search( + embedding_service, + texts[0], + vector_field=vector_field, + k=2, + num_candidates=3, + lambda_mult=0.1, # more diversity + ) + assert len(mmr_output) == 2 + assert mmr_output[0]["_source"][text_field] == texts[0] + assert mmr_output[1]["_source"][text_field] == texts[2] + + # if fetch_k < k, then the output will be less than k + mmr_output = store.max_marginal_relevance_search( + embedding_service, + texts[0], + vector_field=vector_field, + k=3, + num_candidates=2, + ) + assert len(mmr_output) == 2 diff --git a/utils/run-unasync.py b/utils/run-unasync.py index 122ba621f..e73d8dc55 100644 --- a/utils/run-unasync.py +++ b/utils/run-unasync.py @@ -21,29 +21,11 @@ import unasync -def main(): - # Unasync all the generated async code - additional_replacements = { - # We want to rewrite to 'Transport' instead of 'SyncTransport', etc - "AsyncTransport": "Transport", - "AsyncElasticsearch": "Elasticsearch", - # We don't want to rewrite this class - "AsyncSearchClient": "AsyncSearchClient", - # Handling typing.Awaitable[...] isn't done yet by unasync. - "_TYPE_ASYNC_SNIFF_CALLBACK": "_TYPE_SYNC_SNIFF_CALLBACK", - } - rules = [ - unasync.Rule( - fromdir="/elasticsearch/_async/client/", - todir="/elasticsearch/_sync/client/", - additional_replacements=additional_replacements, - ), - ] +def run(rule: unasync.Rule): + root = Path(__file__).absolute().parent.parent filepaths = [] - for root, _, filenames in os.walk( - Path(__file__).absolute().parent.parent / "elasticsearch/_async" - ): + for root, _, filenames in os.walk(root / rule.fromdir): for filename in filenames: if filename.rpartition(".")[-1] in ( "py", @@ -51,7 +33,64 @@ def main(): ) and not filename.startswith("utils.py"): filepaths.append(os.path.join(root, filename)) - unasync.unasync_files(filepaths, rules) + unasync.unasync_files(filepaths, [rule]) + + +def main(): + + run( + rule=unasync.Rule( + fromdir="/elasticsearch/_async/client/", + todir="/elasticsearch/_sync/client/", + additional_replacements={ + # We want to rewrite to 'Transport' instead of 'SyncTransport', etc + "AsyncTransport": "Transport", + "AsyncElasticsearch": "Elasticsearch", + # We don't want to rewrite this class + "AsyncSearchClient": "AsyncSearchClient", + # Handling typing.Awaitable[...] isn't done yet by unasync. + "_TYPE_ASYNC_SNIFF_CALLBACK": "_TYPE_SYNC_SNIFF_CALLBACK", + }, + ), + ) + + run( + rule=unasync.Rule( + fromdir="/elasticsearch/vectorstore/_async/", + todir="/elasticsearch/vectorstore/_sync/", + additional_replacements={ + "_async": "_sync", + "async_bulk": "bulk", + "AsyncElasticsearch": "Elasticsearch", + "AsyncElasticsearchEmbeddings": "ElasticsearchEmbeddings", + "AsyncEmbeddingService": "EmbeddingService", + "AsyncTransport": "Transport", + "AsyncVectorStore": "VectorStore", + }, + ), + ) + + run( + rule=unasync.Rule( + fromdir="test_elasticsearch/test_server/test_vectorstore/_async/", + todir="test_elasticsearch/test_server/test_vectorstore/_sync/", + additional_replacements={ + # Main + "_async": "_sync", + "async_bulk": "bulk", + "AsyncElasticsearch": "Elasticsearch", + "AsyncElasticsearchEmbeddings": "ElasticsearchEmbeddings", + "AsyncEmbeddingService": "EmbeddingService", + "AsyncTransport": "Transport", + "AsyncVectorStore": "VectorStore", + # Tests-specific + "AsyncConsistentFakeEmbeddings": "ConsistentFakeEmbeddings", + "AsyncFakeEmbeddings": "FakeEmbeddings", + "AsyncGenerator": "Generator", + "AsyncRequestSavingTransport": "RequestSavingTransport", + }, + ), + ) if __name__ == "__main__": From 7ee38460e49d63738c98c939d1daf10a47b9ae60 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Thu, 18 Apr 2024 16:27:23 +0200 Subject: [PATCH 05/36] add cleanup step for _sync generation --- elasticsearch/vectorstore/_sync/_utils.py | 7 +-- elasticsearch/vectorstore/_sync/strategies.py | 13 ++--- .../vectorstore/_sync/vectorestore.py | 5 +- .../_async/test_embedding_service.py | 2 +- .../_async/test_vectorestore.py | 4 +- .../test_vectorstore/_sync/_test_utils.py | 4 +- .../_sync/test_embedding_service.py | 18 ++----- .../_sync/test_vectorestore.py | 53 ++++--------------- utils/run-unasync.py | 37 +++++++++++-- 9 files changed, 57 insertions(+), 86 deletions(-) diff --git a/elasticsearch/vectorstore/_sync/_utils.py b/elasticsearch/vectorstore/_sync/_utils.py index e400f9a09..5a85fcaac 100644 --- a/elasticsearch/vectorstore/_sync/_utils.py +++ b/elasticsearch/vectorstore/_sync/_utils.py @@ -1,9 +1,4 @@ -from elasticsearch import ( - Elasticsearch, - BadRequestError, - ConflictError, - NotFoundError, -) +from elasticsearch import BadRequestError, ConflictError, Elasticsearch, NotFoundError def model_must_be_deployed(client: Elasticsearch, model_id: str) -> None: diff --git a/elasticsearch/vectorstore/_sync/strategies.py b/elasticsearch/vectorstore/_sync/strategies.py index 3dd182c80..1e54767b2 100644 --- a/elasticsearch/vectorstore/_sync/strategies.py +++ b/elasticsearch/vectorstore/_sync/strategies.py @@ -3,7 +3,6 @@ from typing import Any, Dict, List, Literal, Optional, Union, cast from elasticsearch import Elasticsearch - from elasticsearch.vectorstore._sync._utils import model_must_be_deployed from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService @@ -226,9 +225,7 @@ def create_index( mappings["properties"]["metadata"] = {"properties": metadata_mapping} settings = {"default_pipeline": pipeline_name} - client.indices.create( - index=index_name, mappings=mappings, settings=settings - ) + client.indices.create(index=index_name, mappings=mappings, settings=settings) return None @@ -287,9 +284,7 @@ def es_query( if query_vector: knn["query_vector"] = query_vector elif self.embedding_service: - knn["query_vector"] = self.embedding_service.embed_query( - cast(str, query) - ) + knn["query_vector"] = self.embedding_service.embed_query(cast(str, query)) else: # Inference in Elasticsearch. When initializing we make sure to always have # a model_id if don't have an embedding_service. @@ -555,6 +550,4 @@ def create_index( } } - client.indices.create( - index=index_name, mappings=mappings, settings=settings - ) + client.indices.create(index=index_name, mappings=mappings, settings=settings) diff --git a/elasticsearch/vectorstore/_sync/vectorestore.py b/elasticsearch/vectorstore/_sync/vectorestore.py index 2d2338ee8..ba1fabdb3 100644 --- a/elasticsearch/vectorstore/_sync/vectorestore.py +++ b/elasticsearch/vectorstore/_sync/vectorestore.py @@ -4,12 +4,9 @@ from elasticsearch import Elasticsearch from elasticsearch.helpers import BulkIndexError, bulk - -from elasticsearch.vectorstore._utils import ( - maximal_marginal_relevance, -) from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService from elasticsearch.vectorstore._sync.strategies import RetrievalStrategy +from elasticsearch.vectorstore._utils import maximal_marginal_relevance logger = logging.getLogger(__name__) diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py index ef41df681..ea04f08df 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py @@ -23,7 +23,7 @@ NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384")) -@pytest_asyncio.fixture(autouse=True) +@pytest_asyncio.fixture async def es_client() -> AsyncIterator[AsyncElasticsearch]: async for x in es_client_fixture(): yield x diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py index f64c81034..c548b1afc 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py @@ -53,12 +53,12 @@ class TestElasticsearch: - @pytest_asyncio.fixture(autouse=True) + @pytest_asyncio.fixture async def es_client(self) -> AsyncIterator[AsyncElasticsearch]: async for x in es_client_fixture(): yield x - @pytest_asyncio.fixture(autouse=True) + @pytest_asyncio.fixture async def requests_saving_client(self) -> AsyncIterator[AsyncElasticsearch]: client = create_requests_saving_client() try: diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py b/test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py index 774a580f6..ead24f154 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py @@ -1,9 +1,9 @@ import os -from typing import Any, Dict, List, Optional, Iterator +from typing import Any, Dict, Iterator, List, Optional from elastic_transport import Transport -from elasticsearch import Elasticsearch +from elasticsearch import Elasticsearch from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py b/test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py index a68a479d7..002ae57b1 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py @@ -1,21 +1,13 @@ import os +from typing import Iterator import pytest -import pytest_asyncio from elasticsearch import Elasticsearch - -from typing import Iterator - from elasticsearch.vectorstore._sync._utils import model_is_deployed +from elasticsearch.vectorstore._sync.embedding_service import ElasticsearchEmbeddings -from ._test_utils import ( - es_client_fixture, -) - -from elasticsearch.vectorstore._sync.embedding_service import ( - ElasticsearchEmbeddings, -) +from ._test_utils import es_client_fixture # deployed with # https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html @@ -23,13 +15,12 @@ NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384")) -@pytest_asyncio.fixture(autouse=True) +@pytest.fixture def es_client() -> Iterator[Elasticsearch]: for x in es_client_fixture(): yield x -@pytest.mark.asyncio def test_elasticsearch_embedding_documents(es_client: Elasticsearch) -> None: """Test Elasticsearch embedding documents.""" @@ -47,7 +38,6 @@ def test_elasticsearch_embedding_documents(es_client: Elasticsearch) -> None: assert len(output[2]) == NUM_DIMENSIONS -@pytest.mark.asyncio def test_elasticsearch_embedding_query(es_client: Elasticsearch) -> None: """Test Elasticsearch embedding query.""" diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py b/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py index ec01f8147..b496b409d 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py @@ -1,16 +1,12 @@ import logging import uuid -from typing import Iterator -from typing import Any, List, Optional, Union, cast from functools import partial +from typing import Any, Iterator, List, Optional, Union, cast import pytest -import pytest_asyncio -from elasticsearch import Elasticsearch -from elasticsearch import NotFoundError +from elasticsearch import Elasticsearch, NotFoundError from elasticsearch.helpers import BulkIndexError - from elasticsearch.vectorstore._sync import VectorStore from elasticsearch.vectorstore._sync._utils import model_is_deployed from elasticsearch.vectorstore._sync.strategies import ( @@ -22,11 +18,11 @@ ) from ._test_utils import ( - create_requests_saving_client, - es_client_fixture, ConsistentFakeEmbeddings, FakeEmbeddings, RequestSavingTransport, + create_requests_saving_client, + es_client_fixture, ) logging.basicConfig(level=logging.DEBUG) @@ -53,12 +49,12 @@ class TestElasticsearch: - @pytest_asyncio.fixture(autouse=True) + @pytest.fixture def es_client(self) -> Iterator[Elasticsearch]: for x in es_client_fixture(): yield x - @pytest_asyncio.fixture(autouse=True) + @pytest.fixture def requests_saving_client(self) -> Iterator[Elasticsearch]: client = create_requests_saving_client() try: @@ -71,7 +67,6 @@ def index_name(self) -> str: """Return the index name.""" return f"test_{uuid.uuid4().hex}" - @pytest.mark.asyncio def test_search_without_metadata( self, es_client: Elasticsearch, index_name: str ) -> None: @@ -102,7 +97,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: output = store.search("foo", k=1, custom_query=assert_query) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - @pytest.mark.asyncio def test_search_without_metadata_async( self, es_client: Elasticsearch, index_name: str ) -> None: @@ -120,10 +114,7 @@ def test_search_without_metadata_async( output = store.search("foo", k=1) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - @pytest.mark.asyncio - def test_add_vectors( - self, es_client: Elasticsearch, index_name: str - ) -> None: + def test_add_vectors(self, es_client: Elasticsearch, index_name: str) -> None: """ Test adding pre-built embeddings instead of using inference for the texts. This allows you to separate the embeddings text and the page_content @@ -145,14 +136,11 @@ def test_add_vectors( es_client=es_client, ) - store.add_texts( - texts=texts, vectors=embedding_vectors, metadatas=metadatas - ) + store.add_texts(texts=texts, vectors=embedding_vectors, metadatas=metadatas) output = store.search("foo1", k=1) assert [doc["_source"]["text_field"] for doc in output] == ["foo1"] assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] - @pytest.mark.asyncio def test_search_with_metadata( self, es_client: Elasticsearch, index_name: str ) -> None: @@ -178,7 +166,6 @@ def test_search_with_metadata( assert [doc["_source"]["text_field"] for doc in output] == ["bar"] assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - @pytest.mark.asyncio def test_search_with_filter( self, es_client: Elasticsearch, index_name: str ) -> None: @@ -215,7 +202,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["text_field"] for doc in output] == ["foo"] assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - @pytest.mark.asyncio def test_search_script_score( self, es_client: Elasticsearch, index_name: str ) -> None: @@ -264,7 +250,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: output = store.search("foo", k=1, custom_query=assert_query) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - @pytest.mark.asyncio def test_search_script_score_with_filter( self, es_client: Elasticsearch, index_name: str ) -> None: @@ -319,7 +304,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["text_field"] for doc in output] == ["foo"] assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] - @pytest.mark.asyncio def test_search_script_score_distance_dot_product( self, es_client: Elasticsearch, index_name: str ) -> None: @@ -370,7 +354,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: output = store.search("foo", k=1, custom_query=assert_query) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - @pytest.mark.asyncio def test_search_knn_with_hybrid_search( self, es_client: Elasticsearch, index_name: str ) -> None: @@ -410,7 +393,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: output = store.search("foo", k=1, custom_query=assert_query) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - @pytest.mark.asyncio def test_search_knn_with_hybrid_search_rrf( self, es_client: Elasticsearch, index_name: str ) -> None: @@ -528,7 +510,6 @@ def assert_query( custom_query=partial(assert_query, expected_rrf={}), ) - @pytest.mark.asyncio def test_search_knn_with_custom_query_fn( self, es_client: Elasticsearch, index_name: str ) -> None: @@ -561,7 +542,6 @@ def my_custom_query(query_body: dict, query: Optional[str]) -> dict: output = store.search("foo", k=1, custom_query=my_custom_query) assert [doc["_source"]["text_field"] for doc in output] == ["bar"] - @pytest.mark.asyncio def test_search_with_knn_infer_instack( self, es_client: Elasticsearch, index_name: str ) -> None: @@ -655,7 +635,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: output = store.search("bar", k=1) assert [doc["_source"]["text_field"] for doc in output] == ["bar"] - @pytest.mark.asyncio def test_search_with_sparse_infer_instack( self, es_client: Elasticsearch, index_name: str ) -> None: @@ -679,7 +658,6 @@ def test_search_with_sparse_infer_instack( output = store.search("foo", k=1) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - @pytest.mark.asyncio def test_deployed_model_check_fails_semantic( self, es_client: Elasticsearch, index_name: str ) -> None: @@ -693,10 +671,7 @@ def test_deployed_model_check_fails_semantic( ) store.add_texts(["foo", "bar", "baz"]) - @pytest.mark.asyncio - def test_search_bm25( - self, es_client: Elasticsearch, index_name: str - ) -> None: + def test_search_bm25(self, es_client: Elasticsearch, index_name: str) -> None: """Test end to end using the BM25 retrieval strategy.""" store = VectorStore( user_agent="test", @@ -722,7 +697,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: output = store.search("foo", k=1, custom_query=assert_query) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - @pytest.mark.asyncio def test_search_bm25_with_filter( self, es_client: Elasticsearch, index_name: str ) -> None: @@ -758,7 +732,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["text_field"] for doc in output] == ["foo"] assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - @pytest.mark.asyncio def test_delete(self, es_client: Elasticsearch, index_name: str) -> None: """Test delete methods from vector store.""" store = VectorStore( @@ -791,7 +764,6 @@ def test_delete(self, es_client: Elasticsearch, index_name: str) -> None: output = store.search("gni", k=10) assert len(output) == 0 - @pytest.mark.asyncio def test_indexing_exception_error( self, es_client: Elasticsearch, @@ -822,7 +794,6 @@ def test_indexing_exception_error( assert log_message in caplog.text - @pytest.mark.asyncio def test_user_agent( self, requests_saving_client: Elasticsearch, index_name: str ) -> None: @@ -845,10 +816,7 @@ def test_user_agent( for request in transport.requests: assert request["headers"]["User-Agent"] == user_agent - @pytest.mark.asyncio - def test_bulk_args( - self, requests_saving_client: Any, index_name: str - ) -> None: + def test_bulk_args(self, requests_saving_client: Any, index_name: str) -> None: """Test to make sure the bulk arguments work as expected.""" store = VectorStore( user_agent="test", @@ -863,7 +831,6 @@ def test_bulk_args( # 1 for index exist, 1 for index create, 3 to index docs assert len(store.es_client.transport.requests) == 5 # type: ignore - @pytest.mark.asyncio def test_max_marginal_relevance_search( self, es_client: Elasticsearch, index_name: str ) -> None: diff --git a/utils/run-unasync.py b/utils/run-unasync.py index e73d8dc55..8e4cd513c 100644 --- a/utils/run-unasync.py +++ b/utils/run-unasync.py @@ -16,16 +16,31 @@ # under the License. import os +import subprocess from pathlib import Path +from glob import glob import unasync -def run(rule: unasync.Rule): +def cleanup(source_dir: Path, output_dir: Path, patterns: list[str]): + subprocess.check_call(["black", "--target-version=py38", output_dir]) + subprocess.check_call(["isort", output_dir]) + + for file in glob("*.py", root_dir=source_dir): + path = Path(output_dir) / file + for pattern in patterns: + subprocess.check_call(["sed", "-i.bak", pattern, str(path)]) + subprocess.check_call(["rm", f"{path}.bak"]) + + +def run(rule: unasync.Rule, cleanup_patterns: list[str] = []): root = Path(__file__).absolute().parent.parent + source_dir = root / rule.fromdir.lstrip("/") + output_dir = root / rule.todir.lstrip("/") filepaths = [] - for root, _, filenames in os.walk(root / rule.fromdir): + for root, _, filenames in os.walk(source_dir): for filename in filenames: if filename.rpartition(".")[-1] in ( "py", @@ -35,6 +50,9 @@ def run(rule: unasync.Rule): unasync.unasync_files(filepaths, [rule]) + if cleanup_patterns: + cleanup(source_dir, output_dir, cleanup_patterns) + def main(): @@ -56,8 +74,8 @@ def main(): run( rule=unasync.Rule( - fromdir="/elasticsearch/vectorstore/_async/", - todir="/elasticsearch/vectorstore/_sync/", + fromdir="elasticsearch/vectorstore/_async/", + todir="elasticsearch/vectorstore/_sync/", additional_replacements={ "_async": "_sync", "async_bulk": "bulk", @@ -68,6 +86,11 @@ def main(): "AsyncVectorStore": "VectorStore", }, ), + cleanup_patterns=[ + "/^import asyncio$/d", + "/^import pytest_asyncio*/d", + "/ *@pytest.mark.asyncio$/d", + ], ) run( @@ -88,8 +111,14 @@ def main(): "AsyncFakeEmbeddings": "FakeEmbeddings", "AsyncGenerator": "Generator", "AsyncRequestSavingTransport": "RequestSavingTransport", + "pytest_asyncio": "pytest", }, ), + cleanup_patterns=[ + "/^import asyncio$/d", + "/^import pytest_asyncio*/d", + "/ *@pytest.mark.asyncio$/d", + ], ) From 2fd89bd834e9b9b692bee9d368a2ac443f50ca0c Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Thu, 18 Apr 2024 16:35:46 +0200 Subject: [PATCH 06/36] fix formatting --- elasticsearch/vectorstore/_async/strategies.py | 1 - elasticsearch/vectorstore/_async/vectorestore.py | 5 +---- examples/bulk-ingest/bulk-ingest.py | 9 ++++++--- examples/fastapi-apm/app.py | 11 +++++++---- examples/fastapi-apm/ping.py | 2 +- .../test_vectorstore/_async/_test_utils.py | 4 ++-- .../_async/test_embedding_service.py | 13 ++++--------- .../test_vectorstore/_async/test_vectorestore.py | 13 +++++-------- utils/run-unasync.py | 4 +--- 9 files changed, 27 insertions(+), 35 deletions(-) diff --git a/elasticsearch/vectorstore/_async/strategies.py b/elasticsearch/vectorstore/_async/strategies.py index b67b6cffc..3a5ba63d7 100644 --- a/elasticsearch/vectorstore/_async/strategies.py +++ b/elasticsearch/vectorstore/_async/strategies.py @@ -3,7 +3,6 @@ from typing import Any, Dict, List, Literal, Optional, Union, cast from elasticsearch import AsyncElasticsearch - from elasticsearch.vectorstore._async._utils import model_must_be_deployed from elasticsearch.vectorstore._async.embedding_service import AsyncEmbeddingService diff --git a/elasticsearch/vectorstore/_async/vectorestore.py b/elasticsearch/vectorstore/_async/vectorestore.py index ba9243e85..b1f2438c8 100644 --- a/elasticsearch/vectorstore/_async/vectorestore.py +++ b/elasticsearch/vectorstore/_async/vectorestore.py @@ -4,12 +4,9 @@ from elasticsearch import AsyncElasticsearch from elasticsearch.helpers import BulkIndexError, async_bulk - -from elasticsearch.vectorstore._utils import ( - maximal_marginal_relevance, -) from elasticsearch.vectorstore._async.embedding_service import AsyncEmbeddingService from elasticsearch.vectorstore._async.strategies import RetrievalStrategy +from elasticsearch.vectorstore._utils import maximal_marginal_relevance logger = logging.getLogger(__name__) diff --git a/examples/bulk-ingest/bulk-ingest.py b/examples/bulk-ingest/bulk-ingest.py index 4c5c34c86..0481a70fa 100644 --- a/examples/bulk-ingest/bulk-ingest.py +++ b/examples/bulk-ingest/bulk-ingest.py @@ -6,13 +6,14 @@ """Script that downloads a public dataset and streams it to an Elasticsearch cluster""" import csv -from os.path import abspath, join, dirname, exists +from os.path import abspath, dirname, exists, join + import tqdm import urllib3 + from elasticsearch import Elasticsearch from elasticsearch.helpers import streaming_bulk - NYC_RESTAURANTS = ( "https://data.cityofnewyork.us/api/views/43nn-pn8j/rows.csv?accessType=DOWNLOAD" ) @@ -99,7 +100,9 @@ def main(): progress = tqdm.tqdm(unit="docs", total=number_of_docs) successes = 0 for ok, action in streaming_bulk( - client=client, index="nyc-restaurants", actions=generate_actions(), + client=client, + index="nyc-restaurants", + actions=generate_actions(), ): progress.update(1) successes += ok diff --git a/examples/fastapi-apm/app.py b/examples/fastapi-apm/app.py index 60a482448..d96c1cd33 100644 --- a/examples/fastapi-apm/app.py +++ b/examples/fastapi-apm/app.py @@ -2,15 +2,16 @@ # Elasticsearch B.V licenses this file to you under the Apache 2.0 License. # See the LICENSE file in the project root for more information -import aiohttp import datetime import os + +import aiohttp +from elasticapm.contrib.starlette import ElasticAPM, make_apm_client from fastapi import FastAPI from fastapi.encoders import jsonable_encoder + from elasticsearch import AsyncElasticsearch, NotFoundError from elasticsearch.helpers import async_streaming_bulk -from elasticapm.contrib.starlette import ElasticAPM, make_apm_client - apm = make_apm_client( {"SERVICE_NAME": "fastapi-app", "SERVER_URL": "http://apm-server:8200"} @@ -60,7 +61,9 @@ async def search(query): @app.get("/delete") async def delete(): - return await client.delete_by_query(index="games", body={"query": {"match_all": {}}}) + return await client.delete_by_query( + index="games", body={"query": {"match_all": {}}} + ) @app.get("/delete/{id}") diff --git a/examples/fastapi-apm/ping.py b/examples/fastapi-apm/ping.py index 94a364f93..83d67b6d5 100644 --- a/examples/fastapi-apm/ping.py +++ b/examples/fastapi-apm/ping.py @@ -3,9 +3,9 @@ # See the LICENSE file in the project root for more information import random -import urllib3 import time +import urllib3 endpoints = [ "http://app:9292/", diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py b/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py index eb6245b20..6f51e1a48 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py @@ -1,9 +1,9 @@ import os -from typing import Any, Dict, List, Optional, AsyncIterator +from typing import Any, AsyncIterator, Dict, List, Optional from elastic_transport import AsyncTransport -from elasticsearch import AsyncElasticsearch +from elasticsearch import AsyncElasticsearch from elasticsearch.vectorstore._async.embedding_service import AsyncEmbeddingService diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py index ea04f08df..9602737ab 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py @@ -1,22 +1,17 @@ import os +from typing import AsyncIterator import pytest - import pytest_asyncio -from elasticsearch import AsyncElasticsearch - -from typing import AsyncIterator +from elasticsearch import AsyncElasticsearch from elasticsearch.vectorstore._async._utils import model_is_deployed - -from ._test_utils import ( - es_client_fixture, -) - from elasticsearch.vectorstore._async.embedding_service import ( AsyncElasticsearchEmbeddings, ) +from ._test_utils import es_client_fixture + # deployed with # https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html MODEL_ID = os.getenv("MODEL_ID", "sentence-transformers__msmarco-minilm-l-12-v3") diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py index c548b1afc..e2c58a415 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py @@ -1,16 +1,13 @@ import logging import uuid -from typing import AsyncIterator -from typing import Any, List, Optional, Union, cast from functools import partial +from typing import Any, AsyncIterator, List, Optional, Union, cast import pytest import pytest_asyncio -from elasticsearch import AsyncElasticsearch -from elasticsearch import NotFoundError +from elasticsearch import AsyncElasticsearch, NotFoundError from elasticsearch.helpers import BulkIndexError - from elasticsearch.vectorstore._async import AsyncVectorStore from elasticsearch.vectorstore._async._utils import model_is_deployed from elasticsearch.vectorstore._async.strategies import ( @@ -22,11 +19,11 @@ ) from ._test_utils import ( - create_requests_saving_client, - es_client_fixture, AsyncConsistentFakeEmbeddings, AsyncFakeEmbeddings, AsyncRequestSavingTransport, + create_requests_saving_client, + es_client_fixture, ) logging.basicConfig(level=logging.DEBUG) @@ -44,7 +41,7 @@ - sentence-transformers__all-minilm-l6-v2 (can be deployed through the API, loaded via eland) -These tests that require the models to be deployed are skipped by default. +These tests that require the models to be deployed are skipped by default. Enable them by adding the model name to the modelsDeployed list below. """ diff --git a/utils/run-unasync.py b/utils/run-unasync.py index 8e4cd513c..fd2c0babb 100644 --- a/utils/run-unasync.py +++ b/utils/run-unasync.py @@ -17,8 +17,8 @@ import os import subprocess -from pathlib import Path from glob import glob +from pathlib import Path import unasync @@ -55,7 +55,6 @@ def run(rule: unasync.Rule, cleanup_patterns: list[str] = []): def main(): - run( rule=unasync.Rule( fromdir="/elasticsearch/_async/client/", @@ -98,7 +97,6 @@ def main(): fromdir="test_elasticsearch/test_server/test_vectorstore/_async/", todir="test_elasticsearch/test_server/test_vectorstore/_sync/", additional_replacements={ - # Main "_async": "_sync", "async_bulk": "bulk", "AsyncElasticsearch": "Elasticsearch", From 9387b74ea9329048ec0072eb69656cf213dadb07 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Thu, 18 Apr 2024 16:43:35 +0200 Subject: [PATCH 07/36] more linting fixes --- elasticsearch/vectorstore/__init__.py | 16 +++++ elasticsearch/vectorstore/_async/__init__.py | 17 +++++ elasticsearch/vectorstore/_async/_utils.py | 17 +++++ .../vectorstore/_async/embedding_service.py | 17 +++++ .../vectorstore/_async/strategies.py | 71 ++++++++++++------- .../vectorstore/_async/vectorestore.py | 33 +++++++-- elasticsearch/vectorstore/_sync/__init__.py | 17 +++++ elasticsearch/vectorstore/_sync/_utils.py | 17 +++++ .../vectorstore/_sync/embedding_service.py | 17 +++++ elasticsearch/vectorstore/_sync/strategies.py | 71 ++++++++++++------- .../vectorstore/_sync/vectorestore.py | 33 +++++++-- elasticsearch/vectorstore/_utils.py | 23 +++++- .../test_vectorstore/_async/__init__.py | 16 +++++ .../test_vectorstore/_async/_test_utils.py | 17 +++++ .../_async/test_embedding_service.py | 17 +++++ .../_async/test_vectorestore.py | 17 +++++ .../test_vectorstore/_sync/__init__.py | 16 +++++ .../test_vectorstore/_sync/_test_utils.py | 17 +++++ .../_sync/test_embedding_service.py | 17 +++++ .../_sync/test_vectorestore.py | 19 ++++- utils/run-unasync.py | 32 ++++++--- 21 files changed, 439 insertions(+), 78 deletions(-) diff --git a/elasticsearch/vectorstore/__init__.py b/elasticsearch/vectorstore/__init__.py index e69de29bb..2a87d183f 100644 --- a/elasticsearch/vectorstore/__init__.py +++ b/elasticsearch/vectorstore/__init__.py @@ -0,0 +1,16 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/elasticsearch/vectorstore/_async/__init__.py b/elasticsearch/vectorstore/_async/__init__.py index 135dccf64..0d0e0ac9f 100644 --- a/elasticsearch/vectorstore/_async/__init__.py +++ b/elasticsearch/vectorstore/_async/__init__.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from elasticsearch.vectorstore._async.vectorestore import AsyncVectorStore __all__ = [ diff --git a/elasticsearch/vectorstore/_async/_utils.py b/elasticsearch/vectorstore/_async/_utils.py index dac58715d..ad8794def 100644 --- a/elasticsearch/vectorstore/_async/_utils.py +++ b/elasticsearch/vectorstore/_async/_utils.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from elasticsearch import ( AsyncElasticsearch, BadRequestError, diff --git a/elasticsearch/vectorstore/_async/embedding_service.py b/elasticsearch/vectorstore/_async/embedding_service.py index 00611e9cd..1027d4f12 100644 --- a/elasticsearch/vectorstore/_async/embedding_service.py +++ b/elasticsearch/vectorstore/_async/embedding_service.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from abc import ABC, abstractmethod from typing import List, Optional diff --git a/elasticsearch/vectorstore/_async/strategies.py b/elasticsearch/vectorstore/_async/strategies.py index 3a5ba63d7..83e582726 100644 --- a/elasticsearch/vectorstore/_async/strategies.py +++ b/elasticsearch/vectorstore/_async/strategies.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from abc import ABC, abstractmethod from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union, cast @@ -23,9 +40,9 @@ async def es_query( query: Optional[str], k: int, num_candidates: int, - filter: List[dict] = [], + filter: List[Dict[str, Any]] = [], query_vector: Optional[List[float]] = None, - ) -> Dict: + ) -> Dict[str, Any]: """ Returns the Elasticsearch query body for the given parameters. The store will execute the query. @@ -46,7 +63,7 @@ async def create_index( self, client: AsyncElasticsearch, index_name: str, - metadata_mapping: Optional[dict[str, str]], + metadata_mapping: Optional[Dict[str, str]], ) -> None: """ Create the required index and do necessary preliminary work, like @@ -95,9 +112,9 @@ async def es_query( query: Optional[str], k: int, num_candidates: int, - filter: List[dict] = [], + filter: List[Dict[str, Any]] = [], query_vector: Optional[List[float]] = None, - ) -> Dict: + ) -> Dict[str, Any]: if query_vector: raise ValueError( "Cannot do sparse retrieval with a query_vector. " @@ -117,12 +134,12 @@ async def create_index( self, client: AsyncElasticsearch, index_name: str, - metadata_mapping: Optional[dict[str, str]], + metadata_mapping: Optional[Dict[str, str]], ) -> None: if self.model_id: await model_must_be_deployed(client, self.model_id) - mappings: dict[str, Any] = { + mappings: Dict[str, Any] = { "properties": { self.inference_field: { "type": "semantic_text", @@ -155,9 +172,9 @@ async def es_query( query: Optional[str], k: int, num_candidates: int, - filter: List[dict] = [], + filter: List[Dict[str, Any]] = [], query_vector: Optional[List[float]] = None, - ) -> Dict: + ) -> Dict[str, Any]: if query_vector: raise ValueError( "Cannot do sparse retrieval with a query_vector. " @@ -189,7 +206,7 @@ async def create_index( self, client: AsyncElasticsearch, index_name: str, - metadata_mapping: Optional[dict[str, str]], + metadata_mapping: Optional[Dict[str, str]], ) -> None: pipeline_name = f"{self.model_id}_sparse_embedding" @@ -214,7 +231,7 @@ async def create_index( ], ) - mappings = { + mappings: Dict[str, Any] = { "properties": { self.vector_field: { "properties": {self._tokens_field: {"type": "rank_features"}} @@ -244,7 +261,7 @@ def __init__( model_id: Optional[str] = None, num_dimensions: Optional[int] = None, hybrid: bool = False, - rrf: Union[bool, dict] = True, + rrf: Union[bool, Dict[str, Any]] = True, text_field: Optional[str] = "text_field", ): if embedding_service and model_id: @@ -273,9 +290,9 @@ async def es_query( query: Optional[str], k: int, num_candidates: int, - filter: List[dict] = [], + filter: List[Dict[str, Any]] = [], query_vector: Optional[List[float]] = None, - ) -> Dict: + ) -> Dict[str, Any]: knn = { "filter": filter, "field": self.vector_field, @@ -308,7 +325,7 @@ async def create_index( self, client: AsyncElasticsearch, index_name: str, - metadata_mapping: Optional[dict[str, str]], + metadata_mapping: Optional[Dict[str, str]], ) -> None: if self.embedding_service and not self.num_dimensions: self.num_dimensions = len( @@ -351,7 +368,9 @@ async def embed_for_indexing(self, text: str) -> Dict[str, Any]: return {self.vector_field: vector} return {} - def _hybrid(self, query: str, knn: dict, filter: list): + def _hybrid( + self, query: str, knn: Dict[str, Any], filter: List[Dict[str, Any]] + ) -> Dict[str, Any]: # Add a query to the knn query. # RRF is used to even the score from the knn query and text query # RRF has two optional parameters: {'rank_constant':int, 'window_size':int} @@ -374,7 +393,7 @@ def _hybrid(self, query: str, knn: dict, filter: list): }, } - if isinstance(self.rrf, dict): + if isinstance(self.rrf, Dict[str, Any]): query_body["rank"] = {"rrf": self.rrf} elif isinstance(self.rrf, bool) and self.rrf is True: query_body["rank"] = {"rrf": {}} @@ -402,9 +421,9 @@ async def es_query( query: Optional[str], k: int, num_candidates: int, - filter: List[dict] = [], + filter: List[Dict[str, Any]] = [], query_vector: Optional[List[float]] = None, - ) -> Dict: + ) -> Dict[str, Any]: if self.distance is DistanceMetric.COSINE: similarityAlgo = ( f"cosineSimilarity(params.query_vector, '{self.vector_field}') + 1.0" @@ -429,7 +448,7 @@ async def es_query( else: raise ValueError(f"Similarity {self.distance} not supported.") - queryBool: Dict = {"match_all": {}} + queryBool: Dict[str, Any] = {"match_all": {}} if filter: queryBool = {"bool": {"filter": filter}} @@ -459,7 +478,7 @@ async def create_index( self, client: AsyncElasticsearch, index_name: str, - metadata_mapping: Optional[dict[str, str]], + metadata_mapping: Optional[Dict[str, str]], ) -> None: if not self.num_dimensions: self.num_dimensions = len( @@ -502,9 +521,9 @@ async def es_query( query: Optional[str], k: int, num_candidates: int, - filter: List[dict] = [], + filter: List[Dict[str, Any]] = [], query_vector: Optional[List[float]] = None, - ) -> Dict: + ) -> Dict[str, Any]: return { "query": { "bool": { @@ -526,11 +545,11 @@ async def create_index( self, client: AsyncElasticsearch, index_name: str, - metadata_mapping: Optional[dict[str, str]], + metadata_mapping: Optional[Dict[str, str]], ) -> None: similarity_name = "custom_bm25" - mappings: Dict = { + mappings: Dict[str, Any] = { "properties": { self.text_field: { "type": "text", @@ -541,7 +560,7 @@ async def create_index( if metadata_mapping: mappings["properties"]["metadata"] = {"properties": metadata_mapping} - bm25: Dict = { + bm25: Dict[str, Any] = { "type": "BM25", } if self.k1 is not None: diff --git a/elasticsearch/vectorstore/_async/vectorestore.py b/elasticsearch/vectorstore/_async/vectorestore.py index b1f2438c8..670c1edfa 100644 --- a/elasticsearch/vectorstore/_async/vectorestore.py +++ b/elasticsearch/vectorstore/_async/vectorestore.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import logging import uuid from typing import Any, Callable, Dict, List, Optional @@ -29,7 +46,7 @@ def __init__( retrieval_strategy: RetrievalStrategy, text_field: str = "text_field", vector_field: str = "vector_field", - metadata_mapping: Optional[dict[str, str]] = None, + metadata_mapping: Optional[Dict[str, str]] = None, ) -> None: """ Args: @@ -61,7 +78,7 @@ def __init__( self.vector_field = vector_field self.metadata_mapping = metadata_mapping - async def close(self): + async def close(self) -> None: return await self.es_client.close() async def add_texts( @@ -196,8 +213,10 @@ async def search( k: int = 4, num_candidates: int = 50, fields: Optional[List[str]] = None, - filter: Optional[List[dict]] = None, - custom_query: Optional[Callable[[Dict, Optional[str]], Dict]] = None, + filter: Optional[List[Dict[str, Any]]] = None, + custom_query: Optional[ + Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] + ] = None, ) -> List[Dict[str, Any]]: """ Args: @@ -263,8 +282,10 @@ async def max_marginal_relevance_search( num_candidates: int = 20, lambda_mult: float = 0.5, fields: Optional[List[str]] = None, - custom_query: Optional[Callable[[Dict, Optional[str]], Dict]] = None, - ) -> List[Dict]: + custom_query: Optional[ + Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] + ] = None, + ) -> List[Dict[str, Any]]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity diff --git a/elasticsearch/vectorstore/_sync/__init__.py b/elasticsearch/vectorstore/_sync/__init__.py index 3079492b6..903dc00ed 100644 --- a/elasticsearch/vectorstore/_sync/__init__.py +++ b/elasticsearch/vectorstore/_sync/__init__.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from elasticsearch.vectorstore._sync.vectorestore import VectorStore __all__ = [ diff --git a/elasticsearch/vectorstore/_sync/_utils.py b/elasticsearch/vectorstore/_sync/_utils.py index 5a85fcaac..ad77be5aa 100644 --- a/elasticsearch/vectorstore/_sync/_utils.py +++ b/elasticsearch/vectorstore/_sync/_utils.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from elasticsearch import BadRequestError, ConflictError, Elasticsearch, NotFoundError diff --git a/elasticsearch/vectorstore/_sync/embedding_service.py b/elasticsearch/vectorstore/_sync/embedding_service.py index 1c8d39e3a..272c34214 100644 --- a/elasticsearch/vectorstore/_sync/embedding_service.py +++ b/elasticsearch/vectorstore/_sync/embedding_service.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from abc import ABC, abstractmethod from typing import List, Optional diff --git a/elasticsearch/vectorstore/_sync/strategies.py b/elasticsearch/vectorstore/_sync/strategies.py index 1e54767b2..ea155628d 100644 --- a/elasticsearch/vectorstore/_sync/strategies.py +++ b/elasticsearch/vectorstore/_sync/strategies.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from abc import ABC, abstractmethod from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union, cast @@ -23,9 +40,9 @@ def es_query( query: Optional[str], k: int, num_candidates: int, - filter: List[dict] = [], + filter: List[Dict[str, Any]] = [], query_vector: Optional[List[float]] = None, - ) -> Dict: + ) -> Dict[str, Any]: """ Returns the Elasticsearch query body for the given parameters. The store will execute the query. @@ -46,7 +63,7 @@ def create_index( self, client: Elasticsearch, index_name: str, - metadata_mapping: Optional[dict[str, str]], + metadata_mapping: Optional[Dict[str, str]], ) -> None: """ Create the required index and do necessary preliminary work, like @@ -95,9 +112,9 @@ def es_query( query: Optional[str], k: int, num_candidates: int, - filter: List[dict] = [], + filter: List[Dict[str, Any]] = [], query_vector: Optional[List[float]] = None, - ) -> Dict: + ) -> Dict[str, Any]: if query_vector: raise ValueError( "Cannot do sparse retrieval with a query_vector. " @@ -117,12 +134,12 @@ def create_index( self, client: Elasticsearch, index_name: str, - metadata_mapping: Optional[dict[str, str]], + metadata_mapping: Optional[Dict[str, str]], ) -> None: if self.model_id: model_must_be_deployed(client, self.model_id) - mappings: dict[str, Any] = { + mappings: Dict[str, Any] = { "properties": { self.inference_field: { "type": "semantic_text", @@ -155,9 +172,9 @@ def es_query( query: Optional[str], k: int, num_candidates: int, - filter: List[dict] = [], + filter: List[Dict[str, Any]] = [], query_vector: Optional[List[float]] = None, - ) -> Dict: + ) -> Dict[str, Any]: if query_vector: raise ValueError( "Cannot do sparse retrieval with a query_vector. " @@ -189,7 +206,7 @@ def create_index( self, client: Elasticsearch, index_name: str, - metadata_mapping: Optional[dict[str, str]], + metadata_mapping: Optional[Dict[str, str]], ) -> None: pipeline_name = f"{self.model_id}_sparse_embedding" @@ -214,7 +231,7 @@ def create_index( ], ) - mappings = { + mappings: Dict[str, Any] = { "properties": { self.vector_field: { "properties": {self._tokens_field: {"type": "rank_features"}} @@ -242,7 +259,7 @@ def __init__( model_id: Optional[str] = None, num_dimensions: Optional[int] = None, hybrid: bool = False, - rrf: Union[bool, dict] = True, + rrf: Union[bool, Dict[str, Any]] = True, text_field: Optional[str] = "text_field", ): if embedding_service and model_id: @@ -271,9 +288,9 @@ def es_query( query: Optional[str], k: int, num_candidates: int, - filter: List[dict] = [], + filter: List[Dict[str, Any]] = [], query_vector: Optional[List[float]] = None, - ) -> Dict: + ) -> Dict[str, Any]: knn = { "filter": filter, "field": self.vector_field, @@ -304,7 +321,7 @@ def create_index( self, client: Elasticsearch, index_name: str, - metadata_mapping: Optional[dict[str, str]], + metadata_mapping: Optional[Dict[str, str]], ) -> None: if self.embedding_service and not self.num_dimensions: self.num_dimensions = len( @@ -347,7 +364,9 @@ def embed_for_indexing(self, text: str) -> Dict[str, Any]: return {self.vector_field: vector} return {} - def _hybrid(self, query: str, knn: dict, filter: list): + def _hybrid( + self, query: str, knn: Dict[str, Any], filter: List[Dict[str, Any]] + ) -> Dict[str, Any]: # Add a query to the knn query. # RRF is used to even the score from the knn query and text query # RRF has two optional parameters: {'rank_constant':int, 'window_size':int} @@ -370,7 +389,7 @@ def _hybrid(self, query: str, knn: dict, filter: list): }, } - if isinstance(self.rrf, dict): + if isinstance(self.rrf, Dict[str, Any]): query_body["rank"] = {"rrf": self.rrf} elif isinstance(self.rrf, bool) and self.rrf is True: query_body["rank"] = {"rrf": {}} @@ -398,9 +417,9 @@ def es_query( query: Optional[str], k: int, num_candidates: int, - filter: List[dict] = [], + filter: List[Dict[str, Any]] = [], query_vector: Optional[List[float]] = None, - ) -> Dict: + ) -> Dict[str, Any]: if self.distance is DistanceMetric.COSINE: similarityAlgo = ( f"cosineSimilarity(params.query_vector, '{self.vector_field}') + 1.0" @@ -425,7 +444,7 @@ def es_query( else: raise ValueError(f"Similarity {self.distance} not supported.") - queryBool: Dict = {"match_all": {}} + queryBool: Dict[str, Any] = {"match_all": {}} if filter: queryBool = {"bool": {"filter": filter}} @@ -455,7 +474,7 @@ def create_index( self, client: Elasticsearch, index_name: str, - metadata_mapping: Optional[dict[str, str]], + metadata_mapping: Optional[Dict[str, str]], ) -> None: if not self.num_dimensions: self.num_dimensions = len( @@ -498,9 +517,9 @@ def es_query( query: Optional[str], k: int, num_candidates: int, - filter: List[dict] = [], + filter: List[Dict[str, Any]] = [], query_vector: Optional[List[float]] = None, - ) -> Dict: + ) -> Dict[str, Any]: return { "query": { "bool": { @@ -522,11 +541,11 @@ def create_index( self, client: Elasticsearch, index_name: str, - metadata_mapping: Optional[dict[str, str]], + metadata_mapping: Optional[Dict[str, str]], ) -> None: similarity_name = "custom_bm25" - mappings: Dict = { + mappings: Dict[str, Any] = { "properties": { self.text_field: { "type": "text", @@ -537,7 +556,7 @@ def create_index( if metadata_mapping: mappings["properties"]["metadata"] = {"properties": metadata_mapping} - bm25: Dict = { + bm25: Dict[str, Any] = { "type": "BM25", } if self.k1 is not None: diff --git a/elasticsearch/vectorstore/_sync/vectorestore.py b/elasticsearch/vectorstore/_sync/vectorestore.py index ba1fabdb3..fd465111f 100644 --- a/elasticsearch/vectorstore/_sync/vectorestore.py +++ b/elasticsearch/vectorstore/_sync/vectorestore.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import logging import uuid from typing import Any, Callable, Dict, List, Optional @@ -29,7 +46,7 @@ def __init__( retrieval_strategy: RetrievalStrategy, text_field: str = "text_field", vector_field: str = "vector_field", - metadata_mapping: Optional[dict[str, str]] = None, + metadata_mapping: Optional[Dict[str, str]] = None, ) -> None: """ Args: @@ -61,7 +78,7 @@ def __init__( self.vector_field = vector_field self.metadata_mapping = metadata_mapping - def close(self): + def close(self) -> None: return self.es_client.close() def add_texts( @@ -196,8 +213,10 @@ def search( k: int = 4, num_candidates: int = 50, fields: Optional[List[str]] = None, - filter: Optional[List[dict]] = None, - custom_query: Optional[Callable[[Dict, Optional[str]], Dict]] = None, + filter: Optional[List[Dict[str, Any]]] = None, + custom_query: Optional[ + Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] + ] = None, ) -> List[Dict[str, Any]]: """ Args: @@ -263,8 +282,10 @@ def max_marginal_relevance_search( num_candidates: int = 20, lambda_mult: float = 0.5, fields: Optional[List[str]] = None, - custom_query: Optional[Callable[[Dict, Optional[str]], Dict]] = None, - ) -> List[Dict]: + custom_query: Optional[ + Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] + ] = None, + ) -> List[Dict[str, Any]]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity diff --git a/elasticsearch/vectorstore/_utils.py b/elasticsearch/vectorstore/_utils.py index b0e4f1372..1eb7e9026 100644 --- a/elasticsearch/vectorstore/_utils.py +++ b/elasticsearch/vectorstore/_utils.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from typing import List, Union import numpy as np @@ -6,8 +23,8 @@ def maximal_marginal_relevance( - query_embedding: list, - embedding_list: list, + query_embedding: list[float], + embedding_list: list[list[float]], lambda_mult: float = 0.5, k: int = 4, ) -> List[int]: @@ -54,7 +71,7 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: f"and Y has shape {Y.shape}." ) try: - import simsimd as simd # type: ignore + import simsimd as simd X = np.array(X, dtype=np.float32) Y = np.array(Y, dtype=np.float32) diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/__init__.py b/test_elasticsearch/test_server/test_vectorstore/_async/__init__.py index e69de29bb..2a87d183f 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/__init__.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/__init__.py @@ -0,0 +1,16 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py b/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py index 6f51e1a48..87b8b7a96 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import os from typing import Any, AsyncIterator, Dict, List, Optional diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py index 9602737ab..924339683 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import os from typing import AsyncIterator diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py index e2c58a415..dbf3d2578 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import logging import uuid from functools import partial diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/__init__.py b/test_elasticsearch/test_server/test_vectorstore/_sync/__init__.py index e69de29bb..2a87d183f 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_sync/__init__.py +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/__init__.py @@ -0,0 +1,16 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py b/test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py index ead24f154..e43dbeffd 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import os from typing import Any, Dict, Iterator, List, Optional diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py b/test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py index 002ae57b1..979096d39 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import os from typing import Iterator diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py b/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py index b496b409d..5564f8b6f 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py @@ -1,3 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import logging import uuid from functools import partial @@ -40,7 +57,7 @@ - sentence-transformers__all-minilm-l6-v2 (can be deployed through the API, loaded via eland) -These tests that require the models to be deployed are skipped by default. +These tests that require the models to be deployed are skipped by default. Enable them by adding the model name to the modelsDeployed list below. """ diff --git a/utils/run-unasync.py b/utils/run-unasync.py index fd2c0babb..2d8156721 100644 --- a/utils/run-unasync.py +++ b/utils/run-unasync.py @@ -24,17 +24,24 @@ def cleanup(source_dir: Path, output_dir: Path, patterns: list[str]): - subprocess.check_call(["black", "--target-version=py38", output_dir]) - subprocess.check_call(["isort", output_dir]) + if patterns: + for file in glob("*.py", root_dir=source_dir): + path = Path(output_dir) / file + for pattern in patterns: + subprocess.check_call(["sed", "-i.bak", pattern, str(path)]) + subprocess.check_call(["rm", f"{path}.bak"]) - for file in glob("*.py", root_dir=source_dir): - path = Path(output_dir) / file - for pattern in patterns: - subprocess.check_call(["sed", "-i.bak", pattern, str(path)]) - subprocess.check_call(["rm", f"{path}.bak"]) +def format_dir(dir: Path): + subprocess.check_call(["isort", "--profile=black", dir]) + subprocess.check_call(["black", dir]) -def run(rule: unasync.Rule, cleanup_patterns: list[str] = []): + +def run( + rule: unasync.Rule, + cleanup_patterns: list[str] = [], + format: bool = False, +): root = Path(__file__).absolute().parent.parent source_dir = root / rule.fromdir.lstrip("/") output_dir = root / rule.todir.lstrip("/") @@ -50,8 +57,11 @@ def run(rule: unasync.Rule, cleanup_patterns: list[str] = []): unasync.unasync_files(filepaths, [rule]) - if cleanup_patterns: - cleanup(source_dir, output_dir, cleanup_patterns) + cleanup(source_dir, output_dir, cleanup_patterns) + + if format: + format_dir(source_dir) + format_dir(output_dir) def main(): @@ -90,6 +100,7 @@ def main(): "/^import pytest_asyncio*/d", "/ *@pytest.mark.asyncio$/d", ], + format=True, ) run( @@ -117,6 +128,7 @@ def main(): "/^import pytest_asyncio*/d", "/ *@pytest.mark.asyncio$/d", ], + format=True, ) From b18d63dfa24ec5dc9bdf38d86453eaec5d704cb6 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Thu, 18 Apr 2024 18:30:31 +0200 Subject: [PATCH 08/36] batch embedding call; infer num_dimensions --- .../vectorstore/_async/strategies.py | 87 +++++-------------- .../vectorstore/_async/vectorestore.py | 27 +++++- .../_async/test_vectorestore.py | 59 ++++++------- 3 files changed, 73 insertions(+), 100 deletions(-) diff --git a/elasticsearch/vectorstore/_async/strategies.py b/elasticsearch/vectorstore/_async/strategies.py index 83e582726..7b67b71fa 100644 --- a/elasticsearch/vectorstore/_async/strategies.py +++ b/elasticsearch/vectorstore/_async/strategies.py @@ -21,7 +21,6 @@ from elasticsearch import AsyncElasticsearch from elasticsearch.vectorstore._async._utils import model_must_be_deployed -from elasticsearch.vectorstore._async.embedding_service import AsyncEmbeddingService class DistanceMetric(str, Enum): @@ -63,7 +62,8 @@ async def create_index( self, client: AsyncElasticsearch, index_name: str, - metadata_mapping: Optional[Dict[str, str]], + num_dimensions: Optional[int] = None, + metadata_mapping: Optional[Dict[str, str]] = None, ) -> None: """ Create the required index and do necessary preliminary work, like @@ -76,21 +76,11 @@ async def create_index( describe the schema of the metadata. """ - async def embed_for_indexing(self, text: str) -> Dict[str, Any]: + def needs_inference(self) -> bool: """ - If this strategy creates vector embeddings in Python (not in Elasticsearch), - this method is used to apply the inference. - The output is a dictionary with the vector field and the vector embedding. - It is merged in the ElasticserachStore with the rest of the document (text data, - metadata) before indexing. - - Args: - text: Text input that can be used as input for inference. - - Returns: - Dict: field and value pairs that extend the document to be indexed. + TODO """ - return {} + return False # TODO test when repsective image is released @@ -134,6 +124,7 @@ async def create_index( self, client: AsyncElasticsearch, index_name: str, + num_dimensions: int, metadata_mapping: Optional[Dict[str, str]], ) -> None: if self.model_id: @@ -206,6 +197,7 @@ async def create_index( self, client: AsyncElasticsearch, index_name: str, + num_dimensions: int, metadata_mapping: Optional[Dict[str, str]], ) -> None: pipeline_name = f"{self.model_id}_sparse_embedding" @@ -257,19 +249,11 @@ def __init__( knn_type: Literal["hnsw", "int8_hnsw", "flat", "int8_flat"] = "hnsw", vector_field: str = "vector_field", distance: DistanceMetric = DistanceMetric.COSINE, - embedding_service: Optional[AsyncEmbeddingService] = None, model_id: Optional[str] = None, - num_dimensions: Optional[int] = None, hybrid: bool = False, rrf: Union[bool, Dict[str, Any]] = True, text_field: Optional[str] = "text_field", ): - if embedding_service and model_id: - raise ValueError("either specify embedding_service or model_id, not both") - if model_id and not num_dimensions: - raise ValueError( - "if model_id is specified, num_dimensions must also be specified" - ) if hybrid and not text_field: raise ValueError( "to enable hybrid you have to specify a text_field (for BM25 matching)" @@ -278,9 +262,7 @@ def __init__( self.knn_type = knn_type self.vector_field = vector_field self.distance = distance - self.embedding_service = embedding_service self.model_id = model_id - self.num_dimensions = num_dimensions self.hybrid = hybrid self.rrf = rrf self.text_field = text_field @@ -302,10 +284,6 @@ async def es_query( if query_vector: knn["query_vector"] = query_vector - elif self.embedding_service: - knn["query_vector"] = await self.embedding_service.embed_query( - cast(str, query) - ) else: # Inference in Elasticsearch. When initializing we make sure to always have # a model_id if don't have an embedding_service. @@ -325,13 +303,9 @@ async def create_index( self, client: AsyncElasticsearch, index_name: str, + num_dimensions: int, metadata_mapping: Optional[Dict[str, str]], ) -> None: - if self.embedding_service and not self.num_dimensions: - self.num_dimensions = len( - await self.embedding_service.embed_query("get number of dimensions") - ) - if self.model_id: await model_must_be_deployed(client, self.model_id) @@ -350,7 +324,7 @@ async def create_index( "properties": { self.vector_field: { "type": "dense_vector", - "dims": self.num_dimensions, + "dims": num_dimensions, "index": True, "similarity": similarityAlgo, }, @@ -362,12 +336,6 @@ async def create_index( r = await client.indices.create(index=index_name, mappings=mappings) print(r) - async def embed_for_indexing(self, text: str) -> Dict[str, Any]: - if self.embedding_service: - vector = await self.embedding_service.embed_query(text) - return {self.vector_field: vector} - return {} - def _hybrid( self, query: str, knn: Dict[str, Any], filter: List[Dict[str, Any]] ) -> Dict[str, Any]: @@ -393,28 +361,27 @@ def _hybrid( }, } - if isinstance(self.rrf, Dict[str, Any]): + if isinstance(self.rrf, Dict): query_body["rank"] = {"rrf": self.rrf} elif isinstance(self.rrf, bool) and self.rrf is True: query_body["rank"] = {"rrf": {}} return query_body + def needs_inference(self) -> bool: + return not self.model_id + class DenseVectorScriptScore(RetrievalStrategy): """Exact nearest neighbors retrieval using the `script_score` query.""" def __init__( self, - embedding_service: AsyncEmbeddingService, vector_field: str = "vector_field", distance: DistanceMetric = DistanceMetric.COSINE, - num_dimensions: Optional[int] = None, ) -> None: self.vector_field = vector_field self.distance = distance - self.embedding_service = embedding_service - self.num_dimensions = num_dimensions async def es_query( self, @@ -424,6 +391,9 @@ async def es_query( filter: List[Dict[str, Any]] = [], query_vector: Optional[List[float]] = None, ) -> Dict[str, Any]: + if not query_vector: + raise ValueError("specify a query_vector") + if self.distance is DistanceMetric.COSINE: similarityAlgo = ( f"cosineSimilarity(params.query_vector, '{self.vector_field}') + 1.0" @@ -452,16 +422,6 @@ async def es_query( if filter: queryBool = {"bool": {"filter": filter}} - if not query_vector: - if not self.embedding_service: - raise ValueError( - "if not embedding_service is given, you need to " - "procive a query_vector" - ) - if not query: - raise ValueError("either specify a query string or a query_vector") - query_vector = await self.embedding_service.embed_query(query) - return { "query": { "script_score": { @@ -478,18 +438,14 @@ async def create_index( self, client: AsyncElasticsearch, index_name: str, + num_dimensions: int, metadata_mapping: Optional[Dict[str, str]], ) -> None: - if not self.num_dimensions: - self.num_dimensions = len( - await self.embedding_service.embed_query("get number of dimensions") - ) - mappings = { "properties": { self.vector_field: { "type": "dense_vector", - "dims": self.num_dimensions, + "dims": num_dimensions, "index": False, } } @@ -499,10 +455,8 @@ async def create_index( await client.indices.create(index=index_name, mappings=mappings) - return None - - async def embed_for_indexing(self, text: str) -> Dict[str, Any]: - return {self.vector_field: await self.embedding_service.embed_query(text)} + def needs_inference(self) -> bool: + return True class BM25(RetrievalStrategy): @@ -545,6 +499,7 @@ async def create_index( self, client: AsyncElasticsearch, index_name: str, + num_dimensions: int, metadata_mapping: Optional[Dict[str, str]], ) -> None: similarity_name = "custom_bm25" diff --git a/elasticsearch/vectorstore/_async/vectorestore.py b/elasticsearch/vectorstore/_async/vectorestore.py index 670c1edfa..6bb62418f 100644 --- a/elasticsearch/vectorstore/_async/vectorestore.py +++ b/elasticsearch/vectorstore/_async/vectorestore.py @@ -44,6 +44,8 @@ def __init__( user_agent: str, index_name: str, retrieval_strategy: RetrievalStrategy, + embedding_service: Optional[AsyncEmbeddingService] = None, + num_dimensions: Optional[int] = None, text_field: str = "text_field", vector_field: str = "vector_field", metadata_mapping: Optional[Dict[str, str]] = None, @@ -61,7 +63,6 @@ def __init__( es_client: Elasticsearch client connection. Alternatively specify the Elasticsearch connection with the other es_* parameters. """ - # Add integration-specific usage header for tracking usage in Elastic Cloud. # client.options preserces existing (non-user-agent) headers. es_client = es_client.options(headers={"User-Agent": user_agent}) @@ -74,6 +75,8 @@ def __init__( self.es_client = es_client self.index_name = index_name self.retrieval_strategy = retrieval_strategy + self.embedding_service = embedding_service + self.num_dimensions = num_dimensions self.text_field = text_field self.vector_field = vector_field self.metadata_mapping = metadata_mapping @@ -118,6 +121,9 @@ async def add_texts( if create_index_if_not_exists: await self._create_index_if_not_exists() + if self.embedding_service and not vectors: + vectors = await self.embedding_service.embed_documents(texts) + for i, text in enumerate(texts): metadata = metadatas[i] if metadatas else {} @@ -132,7 +138,6 @@ async def add_texts( if vectors: request[self.vector_field] = vectors[i] - request.update(await self.retrieval_strategy.embed_for_indexing(text)) requests.append(request) if len(requests) > 0: @@ -240,6 +245,11 @@ async def search( if self.text_field not in fields: fields.append(self.text_field) + if self.embedding_service and not query_vector: + if not query: + raise ValueError("specify a query or a query_vector to search") + query_vector = await self.embedding_service.embed_query(query) + query_body = await self.retrieval_strategy.es_query( query=query, k=k, @@ -267,9 +277,22 @@ async def _create_index_if_not_exists(self) -> None: if exists.meta.status == 200: logger.debug(f"Index {self.index_name} already exists. Skipping creation.") else: + if self.retrieval_strategy.needs_inference(): + if not self.num_dimensions and not self.embedding_service: + raise ValueError( + "retrieval strategy requires embeddings; either embedding_service " + "or num_dimensions need to be specified" + ) + if not self.num_dimensions and self.embedding_service: + vector = await self.embedding_service.embed_query( + "get num dimensions" + ) + self.num_dimensions = len(vector) + await self.retrieval_strategy.create_index( client=self.es_client, index_name=self.index_name, + num_dimensions=self.num_dimensions, metadata_mapping=self.metadata_mapping, ) diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py index dbf3d2578..29779c9b1 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py @@ -106,7 +106,8 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(embedding_service=AsyncFakeEmbeddings()), + retrieval_strategy=DenseVector(), + embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -124,7 +125,8 @@ async def test_search_without_metadata_async( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(embedding_service=AsyncFakeEmbeddings()), + retrieval_strategy=DenseVector(), + embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -155,7 +157,8 @@ async def test_add_vectors( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(embedding_service=embeddings), + retrieval_strategy=DenseVector(), + embedding_service=embeddings, es_client=es_client, ) @@ -174,9 +177,8 @@ async def test_search_with_metadata( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector( - embedding_service=AsyncConsistentFakeEmbeddings() - ), + retrieval_strategy=DenseVector(), + embedding_service=AsyncConsistentFakeEmbeddings(), es_client=es_client, ) @@ -200,7 +202,8 @@ async def test_search_with_filter( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(embedding_service=AsyncFakeEmbeddings()), + retrieval_strategy=DenseVector(), + embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -237,9 +240,8 @@ async def test_search_script_score( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVectorScriptScore( - embedding_service=AsyncFakeEmbeddings() - ), + retrieval_strategy=DenseVectorScriptScore(), + embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -286,9 +288,8 @@ async def test_search_script_score_with_filter( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVectorScriptScore( - embedding_service=AsyncFakeEmbeddings() - ), + retrieval_strategy=DenseVectorScriptScore(), + embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -342,9 +343,9 @@ async def test_search_script_score_distance_dot_product( user_agent="test", index_name=index_name, retrieval_strategy=DenseVectorScriptScore( - embedding_service=AsyncFakeEmbeddings(), distance=DistanceMetric.DOT_PRODUCT, ), + embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -392,10 +393,8 @@ async def test_search_knn_with_hybrid_search( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector( - embedding_service=AsyncFakeEmbeddings(), - hybrid=True, - ), + retrieval_strategy=DenseVector(hybrid=True), + embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -482,11 +481,8 @@ def assert_query( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector( - embedding_service=AsyncFakeEmbeddings(), - hybrid=True, - rrf=rrf_test_case, - ), + retrieval_strategy=DenseVector(hybrid=True, rrf=rrf_test_case), + embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) await store.add_texts(texts) @@ -526,10 +522,8 @@ def assert_query( store = AsyncVectorStore( user_agent="test", index_name=f"{index_name}_default", - retrieval_strategy=DenseVector( - embedding_service=AsyncFakeEmbeddings(), - hybrid=True, - ), + retrieval_strategy=DenseVector(hybrid=True), + embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) await store.add_texts(texts) @@ -551,7 +545,8 @@ async def test_search_knn_with_custom_query_fn( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(embedding_service=AsyncFakeEmbeddings()), + retrieval_strategy=DenseVector(), + embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -778,7 +773,8 @@ async def test_delete(self, es_client: AsyncElasticsearch, index_name: str) -> N store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(embedding_service=AsyncFakeEmbeddings()), + retrieval_strategy=DenseVector(), + embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -889,9 +885,8 @@ async def test_max_marginal_relevance_search( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVectorScriptScore( - embedding_service=embedding_service - ), + retrieval_strategy=DenseVectorScriptScore(), + embedding_service=embedding_service, vector_field=vector_field, text_field=text_field, es_client=es_client, From 9f83408b7babb21ca6c8e359e1dba63cb4e9f1aa Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Mon, 22 Apr 2024 12:04:52 +0200 Subject: [PATCH 09/36] revert accidental changes --- examples/bulk-ingest/bulk-ingest.py | 9 +- examples/fastapi-apm/app.py | 11 +- examples/fastapi-apm/ping.py | 2 +- .../test_vectorestore copy.py1 | 986 ------------------ 4 files changed, 8 insertions(+), 1000 deletions(-) delete mode 100644 test_elasticsearch/test_server/test_vectorstore/test_vectorestore copy.py1 diff --git a/examples/bulk-ingest/bulk-ingest.py b/examples/bulk-ingest/bulk-ingest.py index 0481a70fa..4c5c34c86 100644 --- a/examples/bulk-ingest/bulk-ingest.py +++ b/examples/bulk-ingest/bulk-ingest.py @@ -6,14 +6,13 @@ """Script that downloads a public dataset and streams it to an Elasticsearch cluster""" import csv -from os.path import abspath, dirname, exists, join - +from os.path import abspath, join, dirname, exists import tqdm import urllib3 - from elasticsearch import Elasticsearch from elasticsearch.helpers import streaming_bulk + NYC_RESTAURANTS = ( "https://data.cityofnewyork.us/api/views/43nn-pn8j/rows.csv?accessType=DOWNLOAD" ) @@ -100,9 +99,7 @@ def main(): progress = tqdm.tqdm(unit="docs", total=number_of_docs) successes = 0 for ok, action in streaming_bulk( - client=client, - index="nyc-restaurants", - actions=generate_actions(), + client=client, index="nyc-restaurants", actions=generate_actions(), ): progress.update(1) successes += ok diff --git a/examples/fastapi-apm/app.py b/examples/fastapi-apm/app.py index d96c1cd33..60a482448 100644 --- a/examples/fastapi-apm/app.py +++ b/examples/fastapi-apm/app.py @@ -2,16 +2,15 @@ # Elasticsearch B.V licenses this file to you under the Apache 2.0 License. # See the LICENSE file in the project root for more information +import aiohttp import datetime import os - -import aiohttp -from elasticapm.contrib.starlette import ElasticAPM, make_apm_client from fastapi import FastAPI from fastapi.encoders import jsonable_encoder - from elasticsearch import AsyncElasticsearch, NotFoundError from elasticsearch.helpers import async_streaming_bulk +from elasticapm.contrib.starlette import ElasticAPM, make_apm_client + apm = make_apm_client( {"SERVICE_NAME": "fastapi-app", "SERVER_URL": "http://apm-server:8200"} @@ -61,9 +60,7 @@ async def search(query): @app.get("/delete") async def delete(): - return await client.delete_by_query( - index="games", body={"query": {"match_all": {}}} - ) + return await client.delete_by_query(index="games", body={"query": {"match_all": {}}}) @app.get("/delete/{id}") diff --git a/examples/fastapi-apm/ping.py b/examples/fastapi-apm/ping.py index 83d67b6d5..94a364f93 100644 --- a/examples/fastapi-apm/ping.py +++ b/examples/fastapi-apm/ping.py @@ -3,9 +3,9 @@ # See the LICENSE file in the project root for more information import random +import urllib3 import time -import urllib3 endpoints = [ "http://app:9292/", diff --git a/test_elasticsearch/test_server/test_vectorstore/test_vectorestore copy.py1 b/test_elasticsearch/test_server/test_vectorstore/test_vectorestore copy.py1 deleted file mode 100644 index c432582d5..000000000 --- a/test_elasticsearch/test_server/test_vectorstore/test_vectorestore copy.py1 +++ /dev/null @@ -1,986 +0,0 @@ -import logging -import uuid -from typing import AsyncGenerator - -import pytest -import pytest_asyncio -from elasticsearch_serverless import AsyncElasticsearch -from elasticsearch_serverless.helpers import async_bulk - - -from ._test_utilities import ( - create_es_client, - create_requests_saving_client, - clear_test_indices, - read_env, -) - -logging.basicConfig(level=logging.DEBUG) - -""" -docker-compose up elasticsearch - -By default runs against local docker instance of Elasticsearch. -To run against Elastic Cloud, set the following environment variables: -- ES_CLOUD_ID -- ES_API_KEY - -Some of the tests require the following models to be deployed in the ML Node: -- elser (can be downloaded and deployed through Kibana and trained models UI) -- sentence-transformers__all-minilm-l6-v2 (can be deployed through the API, - loaded via eland) - -These tests that require the models to be deployed are skipped by default. -Enable them by adding the model name to the modelsDeployed list below. -""" - -ELSER_MODEL_ID = ".elser_model_2" -TRANSFORMER_MODEL_ID = "sentence-transformers__all-minilm-l6-v2" - - -class TestElasticsearch: - @pytest_asyncio.fixture(autouse=True) - async def es_client(self) -> AsyncGenerator[AsyncElasticsearch, None]: - params = read_env() - client = create_es_client(params) - - yield client - - # clear indices - if False: - await clear_test_indices(client) - - # clear all test pipelines - try: - response = await client.ingest.get_pipeline(id="test_*,*_sparse_embedding") - - for pipeline_id, _ in response.items(): - try: - await client.ingest.delete_pipeline(id=pipeline_id) - print(f"Deleted pipeline: {pipeline_id}") # noqa: T201 - except Exception as e: - print(f"Pipeline error: {e}") # noqa: T201 - - except Exception: - pass - finally: - await client.close() - - @pytest_asyncio.fixture(autouse=True) - async def requests_saving_client(self) -> AsyncGenerator[AsyncElasticsearch, None]: - client = create_requests_saving_client() - try: - yield client - finally: - await client.close() - - @pytest.fixture(scope="function") - def index_name(self) -> str: - """Return the index name.""" - return f"test_{uuid.uuid4().hex}" - - # def test_initialize_from_params(self, index_name: str) -> None: - # params = read_env() - # agent_header = "test initialize from params" - # store = VectorStore( - # agent_header=agent_header, - # index_name=index_name, - # retrieval_strategy=BM25(), - # **params, - # ) - - # assert store.es_client._headers["User-Agent"] == agent_header - - # texts = ["foo", "bar", "baz"] - # store.add_texts(texts) - - # output = store.search("foo", k=1) - # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - # @pytest.mark.asyncio - # async def test_search_without_metadata( - # self, es_client: AsyncElasticsearch, index_name: str - # ) -> None: - # """Test end to end construction and search without metadata.""" - - # def assert_query(query_body: dict, query: Optional[str]) -> dict: - # assert query_body == { - # "knn": { - # "field": "vector_field", - # "filter": [], - # "k": 1, - # "num_candidates": 50, - # "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - # } - # } - # return query_body - - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), - # es_client=es_client, - # ) - - # texts = ["foo", "bar", "baz"] - # store.add_texts(texts) - - # output = store.search("foo", k=1, custom_query=assert_query) - # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - @pytest.mark.asyncio - async def test_add_vectors( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """ - Test adding pre-built embeddings instead of using inference for the texts. - This allows you to separate the embeddings text and the page_content - for better proximity between user's question and embedded text. - For example, your embedding text can be a question, whereas page_content - is the answer. - """ - docs = [ - ("foo1", [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0]), - ("foo2", [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]), - ("foo3", [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0]), - ] - - texts = [t for t, _ in docs] - # embeddings = ConsistentFakeEmbeddings() - # texts = ["foo1", "foo2", "foo3"] - metadatas = [{"page": i} for i in range(len(texts))] - - """In real use case, embedding_input can be questions for each text""" - # embedding_vectors = embeddings.embed_documents(texts) - - embedding_vectors = [e for _, e in docs] - - index_name = "test_2" - - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=DenseVector(num_dimensions=10), - # es_client=es_client, - # ) - - # store.add_texts(texts=texts, vectors=embedding_vectors, metadatas=metadatas) - - #################### - - mappings = { - "properties": { - "vector_field": { - "type": "dense_vector", - "dims": len(docs[0][1]), - "index": True, - "similarity": "cosine", - }, - } - } - await es_client.indices.create(index=index_name, mappings=mappings) - - indexing_requests = [ - { - "_op_type": "index", - "_index": index_name, - "_id": str(uuid.uuid4()), - "text_field": doc_id, - "metadata": metadatas[i], - "vector_field": vector, - } - for i, (doc_id, vector) in enumerate(docs) - ] - await async_bulk(es_client, indexing_requests, refresh=True) - - #################### - - # query_vector = embedding_vectors[0] - query_vector = docs[0][1] - assert query_vector == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0] - query_body = { - "knn": { - "filter": [], - "field": "vector_field", - "k": 1, - "num_candidates": 50, - "query_vector": query_vector, - } - } - print(query_body) - - # if custom_query is not None: - # query_body = custom_query(query_body, query) - # logger.debug(f"Calling custom_query, Query body now: {query_body}") - - output = await es_client.search( - index=index_name, - **query_body, - size=1, - source=True, - ) - output = output["hits"]["hits"] - - # output = store.search(query=None, query_vector=query_vector, k=1) - - # print("\n".join([str(v) for v in embedding_vectors])) - print(query_vector) - print(output) - print([doc["_score"] for doc in output]) - - assert [doc["_source"]["text_field"] for doc in output] == ["foo1"] - assert [doc["_score"] for doc in output] == [1.0] - # assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] - - # def test_search_with_metadata( - # self, es_client: AsyncElasticsearch, index_name: str - # ) -> None: - # """Test end to end construction and search with metadata.""" - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=DenseVector( - # embedding_service=ConsistentFakeEmbeddings() - # ), - # es_client=es_client, - # ) - - # texts = ["foo", "bar", "baz"] - # metadatas = [{"page": i} for i in range(len(texts))] - # store.add_texts(texts=texts, metadatas=metadatas) - - # output = store.search("foo", k=1) - # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - # assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] - - # output = store.search("bar", k=1) - # assert [doc["_source"]["text_field"] for doc in output] == ["bar"] - # assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - - # def test_search_with_filter( - # self, es_client: AsyncElasticsearch, index_name: str - # ) -> None: - # """Test end to end construction and search with metadata.""" - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), - # es_client=es_client, - # ) - - # texts = ["foo", "foo", "foo"] - # metadatas = [{"page": i} for i in range(len(texts))] - # store.add_texts(texts=texts, metadatas=metadatas) - - # def assert_query(query_body: dict, query: Optional[str]) -> dict: - # assert query_body == { - # "knn": { - # "field": "vector_field", - # "filter": [{"term": {"metadata.page": "1"}}], - # "k": 3, - # "num_candidates": 50, - # "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - # } - # } - # return query_body - - # output = store.search( - # query="foo", - # k=3, - # filter=[{"term": {"metadata.page": "1"}}], - # custom_query=assert_query, - # ) - # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - # assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - - # def test_search_script_score( - # self, es_client: AsyncElasticsearch, index_name: str - # ) -> None: - # """Test end to end construction and search with metadata.""" - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=DenseVectorScriptScore( - # embedding_service=FakeEmbeddings() - # ), - # es_client=es_client, - # ) - - # texts = ["foo", "bar", "baz"] - # store.add_texts(texts) - - # expected_query = { - # "query": { - # "script_score": { - # "query": {"match_all": {}}, - # "script": { - # "source": "cosineSimilarity(params.query_vector, 'vector_field') + 1.0", # noqa: E501 - # "params": { - # "query_vector": [ - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 0.0, - # ] - # }, - # }, - # } - # } - # } - - # def assert_query(query_body: dict, query: Optional[str]) -> dict: - # assert query_body == expected_query - # return query_body - - # output = store.search("foo", k=1, custom_query=assert_query) - # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - # def test_search_script_score_with_filter( - # self, es_client: AsyncElasticsearch, index_name: str - # ) -> None: - # """Test end to end construction and search with metadata.""" - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=DenseVectorScriptScore( - # embedding_service=FakeEmbeddings() - # ), - # es_client=es_client, - # ) - - # texts = ["foo", "bar", "baz"] - # metadatas = [{"page": i} for i in range(len(texts))] - # store.add_texts(texts=texts, metadatas=metadatas) - - # def assert_query(query_body: dict, query: Optional[str]) -> dict: - # expected_query = { - # "query": { - # "script_score": { - # "query": {"bool": {"filter": [{"term": {"metadata.page": 0}}]}}, - # "script": { - # "source": "cosineSimilarity(params.query_vector, 'vector_field') + 1.0", # noqa: E501 - # "params": { - # "query_vector": [ - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 0.0, - # ] - # }, - # }, - # } - # } - # } - # assert query_body == expected_query - # return query_body - - # output = store.search( - # "foo", - # k=1, - # custom_query=assert_query, - # filter=[{"term": {"metadata.page": 0}}], - # ) - # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - # assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] - - # def test_search_script_score_distance_dot_product( - # self, es_client: AsyncElasticsearch, index_name: str - # ) -> None: - # """Test end to end construction and search with metadata.""" - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=DenseVectorScriptScore( - # embedding_service=FakeEmbeddings(), - # distance=DistanceMetric.DOT_PRODUCT, - # ), - # es_client=es_client, - # ) - - # texts = ["foo", "bar", "baz"] - # store.add_texts(texts) - - # def assert_query(query_body: dict, query: Optional[str]) -> dict: - # assert query_body == { - # "query": { - # "script_score": { - # "query": {"match_all": {}}, - # "script": { - # "source": """ - # double value = dotProduct(params.query_vector, 'vector_field'); - # return sigmoid(1, Math.E, -value); - # """, - # "params": { - # "query_vector": [ - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 0.0, - # ] - # }, - # }, - # } - # } - # } - # return query_body - - # output = store.search("foo", k=1, custom_query=assert_query) - # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - # def test_search_knn_with_hybrid_search( - # self, es_client: AsyncElasticsearch, index_name: str - # ) -> None: - # """Test end to end construction and search with metadata.""" - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=DenseVector( - # embedding_service=FakeEmbeddings(), - # hybrid=True, - # ), - # es_client=es_client, - # ) - - # texts = ["foo", "bar", "baz"] - # store.add_texts(texts) - - # def assert_query(query_body: dict, query: Optional[str]) -> dict: - # assert query_body == { - # "knn": { - # "field": "vector_field", - # "filter": [], - # "k": 1, - # "num_candidates": 50, - # "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - # }, - # "query": { - # "bool": { - # "filter": [], - # "must": [{"match": {"text_field": {"query": "foo"}}}], - # } - # }, - # "rank": {"rrf": {}}, - # } - # return query_body - - # output = store.search("foo", k=1, custom_query=assert_query) - # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - # @pytest.mark.asyncio - # async def test_search_knn_with_hybrid_search_rrf( - # self, es_client: AsyncElasticsearch, index_name: str - # ) -> None: - # """Test end to end construction and rrf hybrid search with metadata.""" - # texts = ["foo", "bar", "baz"] - - # def assert_query( - # query_body: dict, - # query: Optional[str], - # expected_rrf: Union[dict, bool], - # ) -> dict: - # cmp_query_body = { - # "knn": { - # "field": "vector_field", - # "filter": [], - # "k": 3, - # "num_candidates": 50, - # "query_vector": [ - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 1.0, - # 0.0, - # ], - # }, - # "query": { - # "bool": { - # "filter": [], - # "must": [{"match": {"text_field": {"query": "foo"}}}], - # } - # }, - # } - - # if isinstance(expected_rrf, dict): - # cmp_query_body["rank"] = {"rrf": expected_rrf} - # elif isinstance(expected_rrf, bool) and expected_rrf is True: - # cmp_query_body["rank"] = {"rrf": {}} - - # assert query_body == cmp_query_body - - # return query_body - - # # 1. check query_body is okay - # rrf_test_cases: List[Union[dict, bool]] = [ - # True, - # False, - # {"rank_constant": 1, "window_size": 5}, - # ] - # for rrf_test_case in rrf_test_cases: - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=DenseVector( - # embedding_service=FakeEmbeddings(), - # hybrid=True, - # rrf=rrf_test_case, - # ), - # es_client=es_client, - # ) - # store.add_texts(texts) - - # ## without fetch_k parameter - # output = store.search( - # "foo", - # k=3, - # custom_query=partial(assert_query, expected_rrf=rrf_test_case), - # ) - - # # 2. check query result is okay - # es_output = await store.es_client.search( - # index=index_name, - # query={ - # "bool": { - # "filter": [], - # "must": [{"match": {"text_field": {"query": "foo"}}}], - # } - # }, - # knn={ - # "field": "vector_field", - # "filter": [], - # "k": 3, - # "num_candidates": 50, - # "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - # }, - # size=3, - # rank={"rrf": {"rank_constant": 1, "window_size": 5}}, - # ) - - # assert [o["_source"]["text_field"] for o in output] == [ - # e["_source"]["text_field"] for e in es_output["hits"]["hits"] - # ] - - # # 3. check rrf default option is okay - # store = VectorStore( - # agent_header="test", - # index_name=f"{index_name}_default", - # retrieval_strategy=DenseVector( - # embedding_service=FakeEmbeddings(), - # hybrid=True, - # ), - # es_client=es_client, - # ) - # store.add_texts(texts) - - # ## with fetch_k parameter - # output = store.search( - # "foo", - # k=3, - # num_candidates=50, - # custom_query=partial(assert_query, expected_rrf={}), - # ) - - # def test_search_knn_with_custom_query_fn( - # self, es_client: AsyncElasticsearch, index_name: str - # ) -> None: - # """test that custom query function is called - # with the query string and query body""" - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), - # es_client=es_client, - # ) - - # def my_custom_query(query_body: dict, query: Optional[str]) -> dict: - # assert query == "foo" - # assert query_body == { - # "knn": { - # "field": "vector_field", - # "filter": [], - # "k": 1, - # "num_candidates": 50, - # "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - # } - # } - # return {"query": {"match": {"text_field": {"query": "bar"}}}} - - # """Test end to end construction and search with metadata.""" - # texts = ["foo", "bar", "baz"] - # store.add_texts(texts) - - # output = store.search("foo", k=1, custom_query=my_custom_query) - # assert [doc["_source"]["text_field"] for doc in output] == ["bar"] - - # @pytest.mark.asyncio - # @pytest.mark.skipif( - # not model_is_deployed(create_es_client(), TRANSFORMER_MODEL_ID), - # reason=f"{TRANSFORMER_MODEL_ID} model not deployed in ML Node, " - # "skipping test", - # ) - # async def test_search_with_knn_infer_instack( - # self, es_client: AsyncElasticsearch, index_name: str - # ) -> None: - # """test end to end with knn retrieval strategy and inference in-stack""" - # text_field = "text_field" - - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=Semantic( - # model_id="sentence-transformers__all-minilm-l6-v2", - # text_field=text_field, - # ), - # es_client=es_client, - # ) - - # # setting up the pipeline for inference - # await store.es_client.ingest.put_pipeline( - # id="test_pipeline", - # processors=[ - # { - # "inference": { - # "model_id": TRANSFORMER_MODEL_ID, - # "field_map": {"query_field": text_field}, - # "target_field": "vector_query_field", - # } - # } - # ], - # ) - - # # creating a new index with the pipeline, - # # not relying on langchain to create the index - # await store.es_client.indices.create( - # index=index_name, - # mappings={ - # "properties": { - # text_field: {"type": "text_field"}, - # "vector_query_field": { - # "properties": { - # "predicted_value": { - # "type": "dense_vector", - # "dims": 384, - # "index": True, - # "similarity": "l2_norm", - # } - # } - # }, - # } - # }, - # settings={"index": {"default_pipeline": "test_pipeline"}}, - # ) - - # # adding documents to the index - # texts = ["foo", "bar", "baz"] - - # for i, text in enumerate(texts): - # await store.es_client.create( - # index=index_name, - # id=str(i), - # document={text_field: text, "metadata": {}}, - # ) - - # await store.es_client.indices.refresh(index=index_name) - - # def assert_query(query_body: dict, query: Optional[str]) -> dict: - # assert query_body == { - # "knn": { - # "filter": [], - # "field": "vector_query_field.predicted_value", - # "k": 1, - # "num_candidates": 50, - # "query_vector_builder": { - # "text_embedding": { - # "model_id": TRANSFORMER_MODEL_ID, - # "model_text": "foo", - # } - # }, - # } - # } - # return query_body - - # output = store.search("foo", k=1, custom_query=assert_query) - # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - # output = store.search("bar", k=1) - # assert [doc["_source"]["text_field"] for doc in output] == ["bar"] - - # @pytest.mark.skipif( - # not model_is_deployed(create_es_client(), ELSER_MODEL_ID), - # reason=f"{ELSER_MODEL_ID} model not deployed in ML Node, skipping test", - # ) - # def test_search_with_sparse_infer_instack( - # self, es_client: AsyncElasticsearch, index_name: str - # ) -> None: - # """test end to end with sparse retrieval strategy and inference in-stack""" - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=Semantic(model_id=ELSER_MODEL_ID), - # es_client=es_client, - # ) - - # texts = ["foo", "bar", "baz"] - # store.add_texts(texts) - - # output = store.search("foo", k=1) - # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - # def test_deployed_model_check_fails_semantic( - # self, es_client: AsyncElasticsearch, index_name: str - # ) -> None: - # """test that exceptions are raised if a specified model is not deployed""" - # with pytest.raises(NotFoundError): - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=Semantic(model_id="non-existing model ID"), - # es_client=es_client, - # ) - # store.add_texts(["foo", "bar", "baz"]) - - # def test_search_bm25(self, es_client: AsyncElasticsearch, index_name: str) -> None: - # """Test end to end using the BM25 retrieval strategy.""" - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=BM25(), - # es_client=es_client, - # ) - - # texts = ["foo", "bar", "baz"] - # store.add_texts(texts) - - # def assert_query(query_body: dict, query: Optional[str]) -> dict: - # assert query_body == { - # "query": { - # "bool": { - # "must": [{"match": {"text_field": {"query": "foo"}}}], - # "filter": [], - # } - # } - # } - # return query_body - - # output = store.search("foo", k=1, custom_query=assert_query) - # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - # def test_search_bm25_with_filter( - # self, es_client: AsyncElasticsearch, index_name: str - # ) -> None: - # """Test end to using the BM25 retrieval strategy with metadata.""" - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=BM25(), - # es_client=es_client, - # ) - - # texts = ["foo", "foo", "foo"] - # metadatas = [{"page": i} for i in range(len(texts))] - # store.add_texts(texts=texts, metadatas=metadatas) - - # def assert_query(query_body: dict, query: Optional[str]) -> dict: - # assert query_body == { - # "query": { - # "bool": { - # "must": [{"match": {"text_field": {"query": "foo"}}}], - # "filter": [{"term": {"metadata.page": 1}}], - # } - # } - # } - # return query_body - - # output = store.search( - # "foo", - # k=3, - # custom_query=assert_query, - # filter=[{"term": {"metadata.page": 1}}], - # ) - # assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - # assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - - # def test_delete(self, es_client: AsyncElasticsearch, index_name: str) -> None: - # """Test delete methods from vector store.""" - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), - # es_client=es_client, - # ) - - # texts = ["foo", "bar", "baz", "gni"] - # metadatas = [{"page": i} for i in range(len(texts))] - # ids = store.add_texts(texts=texts, metadatas=metadatas) - - # output = store.search("foo", k=10) - # assert len(output) == 4 - - # store.delete(ids[1:3]) - # output = store.search("foo", k=10) - # assert len(output) == 2 - - # store.delete(["not-existing"]) - # output = store.search("foo", k=10) - # assert len(output) == 2 - - # store.delete([ids[0]]) - # output = store.search("foo", k=10) - # assert len(output) == 1 - - # store.delete([ids[3]]) - # output = store.search("gni", k=10) - # assert len(output) == 0 - - # @pytest.mark.asyncio - # async def test_indexing_exception_error( - # self, - # es_client: AsyncElasticsearch, - # index_name: str, - # caplog: pytest.LogCaptureFixture, - # ) -> None: - # """Test bulk exception logging is giving better hints.""" - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=BM25(), - # es_client=es_client, - # ) - - # await store.es_client.indices.create( - # index=index_name, - # mappings={"properties": {}}, - # settings={"index": {"default_pipeline": "not-existing-pipeline"}}, - # ) - - # texts = ["foo"] - - # with pytest.raises(BulkIndexError): - # store.add_texts(texts) - - # error_reason = "pipeline with id [not-existing-pipeline] does not exist" - # log_message = f"First error reason: {error_reason}" - - # assert log_message in caplog.text - - # def test_user_agent( - # self, requests_saving_client: AsyncElasticsearch, index_name: str - # ) -> None: - # """Test to make sure the user-agent is set correctly.""" - # agent_header = "this is THE agent_header!" - # store = VectorStore( - # agent_header=agent_header, - # index_name=index_name, - # retrieval_strategy=BM25(), - # es_client=requests_saving_client, - # ) - - # assert store.es_client._headers["User-Agent"] == agent_header - - # texts = ["foo", "bob", "baz"] - # store.add_texts(texts) - - # transport = cast(RequestSavingTransport, store.es_client.transport) - - # for request in transport.requests: - # assert request["headers"]["User-Agent"] == agent_header - - # def test_bulk_args(self, requests_saving_client: Any, index_name: str) -> None: - # """Test to make sure the bulk arguments work as expected.""" - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=BM25(), - # es_client=requests_saving_client, - # ) - - # texts = ["foo", "bob", "baz"] - # store.add_texts(texts, bulk_kwargs={"chunk_size": 1}) - - # # 1 for index exist, 1 for index create, 3 to index docs - # assert len(store.es_client.transport.requests) == 5 # type: ignore - - # def test_max_marginal_relevance_search( - # self, es_client: AsyncElasticsearch, index_name: str - # ) -> None: - # """Test max marginal relevance search.""" - # texts = ["foo", "bar", "baz"] - # vector_field = "vector_field" - # text_field = "text_field" - # embedding_service = ConsistentFakeEmbeddings() - # store = VectorStore( - # agent_header="test", - # index_name=index_name, - # retrieval_strategy=DenseVectorScriptScore( - # embedding_service=embedding_service - # ), - # vector_field=vector_field, - # text_field=text_field, - # es_client=es_client, - # ) - # store.add_texts(texts) - - # mmr_output = store.max_marginal_relevance_search( - # embedding_service, - # texts[0], - # vector_field=vector_field, - # k=3, - # num_candidates=3, - # ) - # sim_output = store.search(texts[0], k=3) - # assert mmr_output == sim_output - - # mmr_output = store.max_marginal_relevance_search( - # embedding_service, - # texts[0], - # vector_field=vector_field, - # k=2, - # num_candidates=3, - # ) - # assert len(mmr_output) == 2 - # assert mmr_output[0]["_source"][text_field] == texts[0] - # assert mmr_output[1]["_source"][text_field] == texts[1] - - # mmr_output = store.max_marginal_relevance_search( - # embedding_service, - # texts[0], - # vector_field=vector_field, - # k=2, - # num_candidates=3, - # lambda_mult=0.1, # more diversity - # ) - # assert len(mmr_output) == 2 - # assert mmr_output[0]["_source"][text_field] == texts[0] - # assert mmr_output[1]["_source"][text_field] == texts[2] - - # # if fetch_k < k, then the output will be less than k - # mmr_output = store.max_marginal_relevance_search( - # embedding_service, - # texts[0], - # vector_field=vector_field, - # k=3, - # num_candidates=2, - # ) - # assert len(mmr_output) == 2 From 9803414a3b484c016ed6d4a0c26f34428aafe29f Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Mon, 22 Apr 2024 14:21:37 +0200 Subject: [PATCH 10/36] keep field names only in store; apply metadata mappings in store --- .../vectorstore/_async/strategies.py | 231 +++++++------- .../vectorstore/_async/vectorestore.py | 65 ++-- elasticsearch/vectorstore/_sync/strategies.py | 290 ++++++++---------- .../vectorstore/_sync/vectorestore.py | 66 +++- .../_async/test_vectorestore.py | 29 ++ .../_sync/test_vectorestore.py | 85 +++-- 6 files changed, 408 insertions(+), 358 deletions(-) diff --git a/elasticsearch/vectorstore/_async/strategies.py b/elasticsearch/vectorstore/_async/strategies.py index 7b67b71fa..6c5b0faf7 100644 --- a/elasticsearch/vectorstore/_async/strategies.py +++ b/elasticsearch/vectorstore/_async/strategies.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Union, cast +from typing import Any, Dict, List, Literal, Optional, Tuple, Union, cast from elasticsearch import AsyncElasticsearch from elasticsearch.vectorstore._async._utils import model_must_be_deployed @@ -37,10 +37,12 @@ class RetrievalStrategy(ABC): async def es_query( self, query: Optional[str], + query_vector: Optional[List[float]], + text_field: str, + vector_field: str, k: int, num_candidates: int, filter: List[Dict[str, Any]] = [], - query_vector: Optional[List[float]] = None, ) -> Dict[str, Any]: """ Returns the Elasticsearch query body for the given parameters. @@ -58,13 +60,12 @@ async def es_query( """ @abstractmethod - async def create_index( + def es_mappings_settings( self, - client: AsyncElasticsearch, - index_name: str, - num_dimensions: Optional[int] = None, - metadata_mapping: Optional[Dict[str, str]] = None, - ) -> None: + text_field: str, + vector_field: str, + num_dimensions: Optional[int], + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Create the required index and do necessary preliminary work, like creating inference pipelines or checking if a required model was deployed. @@ -76,6 +77,20 @@ async def create_index( describe the schema of the metadata. """ + async def before_index_creation( + self, client: AsyncElasticsearch, text_field: str, vector_field: str + ) -> None: + """ + Executes before the index is created. Used for setting up + any required Elasticsearch resources like a pipeline. + + Args: + client: The Elasticsearch client. + text_field: The field containing the text data in the index. + vector_field: The field containing the vector representations in the index. + """ + pass + def needs_inference(self) -> bool: """ TODO @@ -91,19 +106,19 @@ def __init__( self, model_id: str, text_field: str = "text_field", - inference_field: str = "text_semantic", ): self.model_id = model_id self.text_field = text_field - self.inference_field = inference_field async def es_query( self, query: Optional[str], + query_vector: Optional[List[float]], + text_field: str, + vector_field: str, k: int, num_candidates: int, filter: List[Dict[str, Any]] = [], - query_vector: Optional[List[float]] = None, ) -> Dict[str, Any]: if query_vector: raise ValueError( @@ -114,57 +129,53 @@ async def es_query( return { "query": { "semantic": { - self.text_field: query, + text_field: query, }, }, "filter": filter, } - async def create_index( + def es_mappings_settings( self, - client: AsyncElasticsearch, - index_name: str, - num_dimensions: int, - metadata_mapping: Optional[Dict[str, str]], - ) -> None: - if self.model_id: - await model_must_be_deployed(client, self.model_id) - - mappings: Dict[str, Any] = { + text_field: str, + vector_field: str, + num_dimensions: Optional[int] = None, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + mappings = { "properties": { - self.inference_field: { + vector_field: { "type": "semantic_text", "model_id": self.model_id, } } } - if metadata_mapping: - mappings["properties"]["metadata"] = {"properties": metadata_mapping} - await client.indices.create(index=index_name, mappings=mappings) + return mappings, {} + + async def before_index_creation( + self, client: AsyncElasticsearch, text_field: str, vector_field: str + ) -> None: + if self.model_id: + await model_must_be_deployed(client, self.model_id) class SparseVector(RetrievalStrategy): """Sparse retrieval strategy using the `text_expansion` processor.""" - def __init__( - self, - model_id: str = ".elser_model_2", - text_field: str = "text_field", - vector_field: str = "vector_field", - ): + def __init__(self, model_id: str = ".elser_model_2"): self.model_id = model_id - self.text_field = text_field - self.vector_field = vector_field self._tokens_field = "tokens" + self._pipeline_name = f"{self.model_id}_sparse_embedding" async def es_query( self, query: Optional[str], + query_vector: Optional[List[float]], + text_field: str, + vector_field: str, k: int, num_candidates: int, filter: List[Dict[str, Any]] = [], - query_vector: Optional[List[float]] = None, ) -> Dict[str, Any]: if query_vector: raise ValueError( @@ -180,7 +191,7 @@ async def es_query( "must": [ { "text_expansion": { - f"{self.vector_field}.{self._tokens_field}": { + f"{vector_field}.{self._tokens_field}": { "model_id": self.model_id, "model_text": query, } @@ -193,28 +204,39 @@ async def es_query( "size": k, } - async def create_index( + def es_mappings_settings( self, - client: AsyncElasticsearch, - index_name: str, - num_dimensions: int, - metadata_mapping: Optional[Dict[str, str]], - ) -> None: - pipeline_name = f"{self.model_id}_sparse_embedding" + text_field: str, + vector_field: str, + num_dimensions: Optional[int], + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + mappings: Dict[str, Any] = { + "properties": { + vector_field: { + "properties": {self._tokens_field: {"type": "rank_features"}} + } + } + } + settings = {"default_pipeline": self._pipeline_name} + return mappings, settings + + async def before_index_creation( + self, client: AsyncElasticsearch, text_field: str, vector_field: str + ) -> None: if self.model_id: await model_must_be_deployed(client, self.model_id) # Create a pipeline for the model await client.ingest.put_pipeline( - id=pipeline_name, + id=self._pipeline_name, description="Embedding pipeline for Python VectorStore", processors=[ { "inference": { "model_id": self.model_id, - "target_field": self.vector_field, - "field_map": {self.text_field: "text_field"}, + "target_field": vector_field, + "field_map": {text_field: "text_field"}, "inference_config": { "text_expansion": {"results_field": self._tokens_field} }, @@ -223,23 +245,6 @@ async def create_index( ], ) - mappings: Dict[str, Any] = { - "properties": { - self.vector_field: { - "properties": {self._tokens_field: {"type": "rank_features"}} - } - } - } - if metadata_mapping: - mappings["properties"]["metadata"] = {"properties": metadata_mapping} - settings = {"default_pipeline": pipeline_name} - - await client.indices.create( - index=index_name, mappings=mappings, settings=settings - ) - - return None - class DenseVector(RetrievalStrategy): """K-nearest-neighbors retrieval.""" @@ -247,7 +252,6 @@ class DenseVector(RetrievalStrategy): def __init__( self, knn_type: Literal["hnsw", "int8_hnsw", "flat", "int8_flat"] = "hnsw", - vector_field: str = "vector_field", distance: DistanceMetric = DistanceMetric.COSINE, model_id: Optional[str] = None, hybrid: bool = False, @@ -260,7 +264,6 @@ def __init__( ) self.knn_type = knn_type - self.vector_field = vector_field self.distance = distance self.model_id = model_id self.hybrid = hybrid @@ -270,14 +273,16 @@ def __init__( async def es_query( self, query: Optional[str], + query_vector: Optional[List[float]], + text_field: str, + vector_field: str, k: int, num_candidates: int, filter: List[Dict[str, Any]] = [], - query_vector: Optional[List[float]] = None, ) -> Dict[str, Any]: knn = { "filter": filter, - "field": self.vector_field, + "field": vector_field, "k": k, "num_candidates": num_candidates, } @@ -299,16 +304,12 @@ async def es_query( return {"knn": knn} - async def create_index( + def es_mappings_settings( self, - client: AsyncElasticsearch, - index_name: str, - num_dimensions: int, - metadata_mapping: Optional[Dict[str, str]], - ) -> None: - if self.model_id: - await model_must_be_deployed(client, self.model_id) - + text_field: str, + vector_field: str, + num_dimensions: Optional[int], + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: if self.distance is DistanceMetric.COSINE: similarityAlgo = "cosine" elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE: @@ -322,7 +323,7 @@ async def create_index( mappings: Dict[str, Any] = { "properties": { - self.vector_field: { + vector_field: { "type": "dense_vector", "dims": num_dimensions, "index": True, @@ -330,11 +331,14 @@ async def create_index( }, } } - if metadata_mapping: - mappings["properties"]["metadata"] = {"properties": metadata_mapping} - r = await client.indices.create(index=index_name, mappings=mappings) - print(r) + return mappings, {} + + async def before_index_creation( + self, client: AsyncElasticsearch, text_field: str, vector_field: str + ) -> None: + if self.model_id: + await model_must_be_deployed(client, self.model_id) def _hybrid( self, query: str, knn: Dict[str, Any], filter: List[Dict[str, Any]] @@ -375,41 +379,36 @@ def needs_inference(self) -> bool: class DenseVectorScriptScore(RetrievalStrategy): """Exact nearest neighbors retrieval using the `script_score` query.""" - def __init__( - self, - vector_field: str = "vector_field", - distance: DistanceMetric = DistanceMetric.COSINE, - ) -> None: - self.vector_field = vector_field + def __init__(self, distance: DistanceMetric = DistanceMetric.COSINE) -> None: self.distance = distance async def es_query( self, query: Optional[str], + query_vector: Optional[List[float]], + text_field: str, + vector_field: str, k: int, num_candidates: int, filter: List[Dict[str, Any]] = [], - query_vector: Optional[List[float]] = None, ) -> Dict[str, Any]: if not query_vector: raise ValueError("specify a query_vector") if self.distance is DistanceMetric.COSINE: similarityAlgo = ( - f"cosineSimilarity(params.query_vector, '{self.vector_field}') + 1.0" + f"cosineSimilarity(params.query_vector, '{vector_field}') + 1.0" ) elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE: - similarityAlgo = ( - f"1 / (1 + l2norm(params.query_vector, '{self.vector_field}'))" - ) + similarityAlgo = f"1 / (1 + l2norm(params.query_vector, '{vector_field}'))" elif self.distance is DistanceMetric.DOT_PRODUCT: similarityAlgo = f""" - double value = dotProduct(params.query_vector, '{self.vector_field}'); + double value = dotProduct(params.query_vector, '{vector_field}'); return sigmoid(1, Math.E, -value); """ elif self.distance is DistanceMetric.MAX_INNER_PRODUCT: similarityAlgo = f""" - double value = dotProduct(params.query_vector, '{self.vector_field}'); + double value = dotProduct(params.query_vector, '{vector_field}'); if (dotProduct < 0) {{ return 1 / (1 + -1 * dotProduct); }} @@ -434,26 +433,23 @@ async def es_query( } } - async def create_index( + def es_mappings_settings( self, - client: AsyncElasticsearch, - index_name: str, - num_dimensions: int, - metadata_mapping: Optional[Dict[str, str]], - ) -> None: + text_field: str, + vector_field: str, + num_dimensions: Optional[int], + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: mappings = { "properties": { - self.vector_field: { + vector_field: { "type": "dense_vector", "dims": num_dimensions, "index": False, } } } - if metadata_mapping: - mappings["properties"]["metadata"] = {"properties": metadata_mapping} - await client.indices.create(index=index_name, mappings=mappings) + return mappings, {} def needs_inference(self) -> bool: return True @@ -462,21 +458,21 @@ def needs_inference(self) -> bool: class BM25(RetrievalStrategy): def __init__( self, - text_field: str = "text_field", k1: Optional[float] = None, b: Optional[float] = None, ): - self.text_field = text_field self.k1 = k1 self.b = b async def es_query( self, query: Optional[str], + query_vector: Optional[List[float]], + text_field: str, + vector_field: str, k: int, num_candidates: int, filter: List[Dict[str, Any]] = [], - query_vector: Optional[List[float]] = None, ) -> Dict[str, Any]: return { "query": { @@ -484,7 +480,7 @@ async def es_query( "must": [ { "match": { - self.text_field: { + text_field: { "query": query, } }, @@ -495,25 +491,22 @@ async def es_query( }, } - async def create_index( + def es_mappings_settings( self, - client: AsyncElasticsearch, - index_name: str, - num_dimensions: int, - metadata_mapping: Optional[Dict[str, str]], - ) -> None: + text_field: str, + vector_field: str, + num_dimensions: Optional[int], + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: similarity_name = "custom_bm25" mappings: Dict[str, Any] = { "properties": { - self.text_field: { + text_field: { "type": "text", "similarity": similarity_name, }, }, } - if metadata_mapping: - mappings["properties"]["metadata"] = {"properties": metadata_mapping} bm25: Dict[str, Any] = { "type": "BM25", @@ -528,6 +521,4 @@ async def create_index( } } - await client.indices.create( - index=index_name, mappings=mappings, settings=settings - ) + return mappings, settings diff --git a/elasticsearch/vectorstore/_async/vectorestore.py b/elasticsearch/vectorstore/_async/vectorestore.py index 6bb62418f..f9e88b43a 100644 --- a/elasticsearch/vectorstore/_async/vectorestore.py +++ b/elasticsearch/vectorstore/_async/vectorestore.py @@ -48,7 +48,7 @@ def __init__( num_dimensions: Optional[int] = None, text_field: str = "text_field", vector_field: str = "vector_field", - metadata_mapping: Optional[Dict[str, str]] = None, + metadata_mappings: Optional[Dict[str, Any]] = None, ) -> None: """ Args: @@ -79,7 +79,7 @@ def __init__( self.num_dimensions = num_dimensions self.text_field = text_field self.vector_field = vector_field - self.metadata_mapping = metadata_mapping + self.metadata_mappings = metadata_mappings async def close(self) -> None: return await self.es_client.close() @@ -161,7 +161,7 @@ async def add_texts( logger.debug("No texts to add to index") return [] - async def delete( + async def delete( # type: ignore[no-untyped-def] self, ids: Optional[List[str]] = None, query: Optional[Dict[str, Any]] = None, @@ -252,10 +252,12 @@ async def search( query_body = await self.retrieval_strategy.es_query( query=query, + query_vector=query_vector, + text_field=self.text_field, + vector_field=self.vector_field, k=k, num_candidates=num_candidates, filter=filter or [], - query_vector=query_vector, ) if custom_query is not None: @@ -269,32 +271,47 @@ async def search( source=True, source_includes=fields, ) + hits: List[Dict[str, Any]] = response["hits"]["hits"] - return response["hits"]["hits"] + return hits async def _create_index_if_not_exists(self) -> None: exists = await self.es_client.indices.exists(index=self.index_name) if exists.meta.status == 200: logger.debug(f"Index {self.index_name} already exists. Skipping creation.") - else: - if self.retrieval_strategy.needs_inference(): - if not self.num_dimensions and not self.embedding_service: - raise ValueError( - "retrieval strategy requires embeddings; either embedding_service " - "or num_dimensions need to be specified" - ) - if not self.num_dimensions and self.embedding_service: - vector = await self.embedding_service.embed_query( - "get num dimensions" - ) - self.num_dimensions = len(vector) - - await self.retrieval_strategy.create_index( - client=self.es_client, - index_name=self.index_name, - num_dimensions=self.num_dimensions, - metadata_mapping=self.metadata_mapping, - ) + return + + if self.retrieval_strategy.needs_inference(): + if not self.num_dimensions and not self.embedding_service: + raise ValueError( + "retrieval strategy requires embeddings; either embedding_service " + "or num_dimensions need to be specified" + ) + if not self.num_dimensions and self.embedding_service: + vector = await self.embedding_service.embed_query("get num dimensions") + self.num_dimensions = len(vector) + + mappings, settings = self.retrieval_strategy.es_mappings_settings( + text_field=self.text_field, + vector_field=self.vector_field, + num_dimensions=self.num_dimensions, + ) + + if self.metadata_mappings: + metadata = mappings["properties"].get("metadata", {"properties": {}}) + for key in self.metadata_mappings.keys(): + if key in metadata: + raise ValueError(f"metadata key {key} already exists in mappings") + + metadata = dict(**metadata["properties"], **self.metadata_mappings) + mappings["properties"] = {"metadata": {"properties": metadata}} + + await self.retrieval_strategy.before_index_creation( + self.es_client, self.text_field, self.vector_field + ) + await self.es_client.indices.create( + index=self.index_name, mappings=mappings, settings=settings + ) async def max_marginal_relevance_search( self, diff --git a/elasticsearch/vectorstore/_sync/strategies.py b/elasticsearch/vectorstore/_sync/strategies.py index ea155628d..f31fc1113 100644 --- a/elasticsearch/vectorstore/_sync/strategies.py +++ b/elasticsearch/vectorstore/_sync/strategies.py @@ -17,11 +17,10 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Union, cast +from typing import Any, Dict, List, Literal, Optional, Tuple, Union, cast from elasticsearch import Elasticsearch from elasticsearch.vectorstore._sync._utils import model_must_be_deployed -from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService class DistanceMetric(str, Enum): @@ -38,10 +37,12 @@ class RetrievalStrategy(ABC): def es_query( self, query: Optional[str], + query_vector: Optional[List[float]], + text_field: str, + vector_field: str, k: int, num_candidates: int, filter: List[Dict[str, Any]] = [], - query_vector: Optional[List[float]] = None, ) -> Dict[str, Any]: """ Returns the Elasticsearch query body for the given parameters. @@ -59,12 +60,12 @@ def es_query( """ @abstractmethod - def create_index( + def es_mappings_settings( self, - client: Elasticsearch, - index_name: str, - metadata_mapping: Optional[Dict[str, str]], - ) -> None: + text_field: str, + vector_field: str, + num_dimensions: Optional[int], + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Create the required index and do necessary preliminary work, like creating inference pipelines or checking if a required model was deployed. @@ -76,21 +77,25 @@ def create_index( describe the schema of the metadata. """ - def embed_for_indexing(self, text: str) -> Dict[str, Any]: + def before_index_creation( + self, client: Elasticsearch, text_field: str, vector_field: str + ) -> None: """ - If this strategy creates vector embeddings in Python (not in Elasticsearch), - this method is used to apply the inference. - The output is a dictionary with the vector field and the vector embedding. - It is merged in the ElasticserachStore with the rest of the document (text data, - metadata) before indexing. + Executes before the index is created. Used for setting up + any required Elasticsearch resources like a pipeline. Args: - text: Text input that can be used as input for inference. + client: The Elasticsearch client. + text_field: The field containing the text data in the index. + vector_field: The field containing the vector representations in the index. + """ + pass - Returns: - Dict: field and value pairs that extend the document to be indexed. + def needs_inference(self) -> bool: + """ + TODO """ - return {} + return False # TODO test when repsective image is released @@ -101,19 +106,19 @@ def __init__( self, model_id: str, text_field: str = "text_field", - inference_field: str = "text_semantic", ): self.model_id = model_id self.text_field = text_field - self.inference_field = inference_field def es_query( self, query: Optional[str], + query_vector: Optional[List[float]], + text_field: str, + vector_field: str, k: int, num_candidates: int, filter: List[Dict[str, Any]] = [], - query_vector: Optional[List[float]] = None, ) -> Dict[str, Any]: if query_vector: raise ValueError( @@ -124,56 +129,53 @@ def es_query( return { "query": { "semantic": { - self.text_field: query, + text_field: query, }, }, "filter": filter, } - def create_index( + def es_mappings_settings( self, - client: Elasticsearch, - index_name: str, - metadata_mapping: Optional[Dict[str, str]], - ) -> None: - if self.model_id: - model_must_be_deployed(client, self.model_id) - - mappings: Dict[str, Any] = { + text_field: str, + vector_field: str, + num_dimensions: Optional[int] = None, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + mappings = { "properties": { - self.inference_field: { + vector_field: { "type": "semantic_text", "model_id": self.model_id, } } } - if metadata_mapping: - mappings["properties"]["metadata"] = {"properties": metadata_mapping} - client.indices.create(index=index_name, mappings=mappings) + return mappings, {} + + def before_index_creation( + self, client: Elasticsearch, text_field: str, vector_field: str + ) -> None: + if self.model_id: + model_must_be_deployed(client, self.model_id) class SparseVector(RetrievalStrategy): """Sparse retrieval strategy using the `text_expansion` processor.""" - def __init__( - self, - model_id: str = ".elser_model_2", - text_field: str = "text_field", - vector_field: str = "vector_field", - ): + def __init__(self, model_id: str = ".elser_model_2"): self.model_id = model_id - self.text_field = text_field - self.vector_field = vector_field self._tokens_field = "tokens" + self._pipeline_name = f"{self.model_id}_sparse_embedding" def es_query( self, query: Optional[str], + query_vector: Optional[List[float]], + text_field: str, + vector_field: str, k: int, num_candidates: int, filter: List[Dict[str, Any]] = [], - query_vector: Optional[List[float]] = None, ) -> Dict[str, Any]: if query_vector: raise ValueError( @@ -189,7 +191,7 @@ def es_query( "must": [ { "text_expansion": { - f"{self.vector_field}.{self._tokens_field}": { + f"{vector_field}.{self._tokens_field}": { "model_id": self.model_id, "model_text": query, } @@ -202,27 +204,39 @@ def es_query( "size": k, } - def create_index( + def es_mappings_settings( self, - client: Elasticsearch, - index_name: str, - metadata_mapping: Optional[Dict[str, str]], - ) -> None: - pipeline_name = f"{self.model_id}_sparse_embedding" + text_field: str, + vector_field: str, + num_dimensions: Optional[int], + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + mappings: Dict[str, Any] = { + "properties": { + vector_field: { + "properties": {self._tokens_field: {"type": "rank_features"}} + } + } + } + settings = {"default_pipeline": self._pipeline_name} + + return mappings, settings + def before_index_creation( + self, client: Elasticsearch, text_field: str, vector_field: str + ) -> None: if self.model_id: model_must_be_deployed(client, self.model_id) # Create a pipeline for the model client.ingest.put_pipeline( - id=pipeline_name, + id=self._pipeline_name, description="Embedding pipeline for Python VectorStore", processors=[ { "inference": { "model_id": self.model_id, - "target_field": self.vector_field, - "field_map": {self.text_field: "text_field"}, + "target_field": vector_field, + "field_map": {text_field: "text_field"}, "inference_config": { "text_expansion": {"results_field": self._tokens_field} }, @@ -231,21 +245,6 @@ def create_index( ], ) - mappings: Dict[str, Any] = { - "properties": { - self.vector_field: { - "properties": {self._tokens_field: {"type": "rank_features"}} - } - } - } - if metadata_mapping: - mappings["properties"]["metadata"] = {"properties": metadata_mapping} - settings = {"default_pipeline": pipeline_name} - - client.indices.create(index=index_name, mappings=mappings, settings=settings) - - return None - class DenseVector(RetrievalStrategy): """K-nearest-neighbors retrieval.""" @@ -253,32 +252,20 @@ class DenseVector(RetrievalStrategy): def __init__( self, knn_type: Literal["hnsw", "int8_hnsw", "flat", "int8_flat"] = "hnsw", - vector_field: str = "vector_field", distance: DistanceMetric = DistanceMetric.COSINE, - embedding_service: Optional[EmbeddingService] = None, model_id: Optional[str] = None, - num_dimensions: Optional[int] = None, hybrid: bool = False, rrf: Union[bool, Dict[str, Any]] = True, text_field: Optional[str] = "text_field", ): - if embedding_service and model_id: - raise ValueError("either specify embedding_service or model_id, not both") - if model_id and not num_dimensions: - raise ValueError( - "if model_id is specified, num_dimensions must also be specified" - ) if hybrid and not text_field: raise ValueError( "to enable hybrid you have to specify a text_field (for BM25 matching)" ) self.knn_type = knn_type - self.vector_field = vector_field self.distance = distance - self.embedding_service = embedding_service self.model_id = model_id - self.num_dimensions = num_dimensions self.hybrid = hybrid self.rrf = rrf self.text_field = text_field @@ -286,22 +273,22 @@ def __init__( def es_query( self, query: Optional[str], + query_vector: Optional[List[float]], + text_field: str, + vector_field: str, k: int, num_candidates: int, filter: List[Dict[str, Any]] = [], - query_vector: Optional[List[float]] = None, ) -> Dict[str, Any]: knn = { "filter": filter, - "field": self.vector_field, + "field": vector_field, "k": k, "num_candidates": num_candidates, } if query_vector: knn["query_vector"] = query_vector - elif self.embedding_service: - knn["query_vector"] = self.embedding_service.embed_query(cast(str, query)) else: # Inference in Elasticsearch. When initializing we make sure to always have # a model_id if don't have an embedding_service. @@ -317,20 +304,12 @@ def es_query( return {"knn": knn} - def create_index( + def es_mappings_settings( self, - client: Elasticsearch, - index_name: str, - metadata_mapping: Optional[Dict[str, str]], - ) -> None: - if self.embedding_service and not self.num_dimensions: - self.num_dimensions = len( - self.embedding_service.embed_query("get number of dimensions") - ) - - if self.model_id: - model_must_be_deployed(client, self.model_id) - + text_field: str, + vector_field: str, + num_dimensions: Optional[int], + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: if self.distance is DistanceMetric.COSINE: similarityAlgo = "cosine" elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE: @@ -344,25 +323,22 @@ def create_index( mappings: Dict[str, Any] = { "properties": { - self.vector_field: { + vector_field: { "type": "dense_vector", - "dims": self.num_dimensions, + "dims": num_dimensions, "index": True, "similarity": similarityAlgo, }, } } - if metadata_mapping: - mappings["properties"]["metadata"] = {"properties": metadata_mapping} - r = client.indices.create(index=index_name, mappings=mappings) - print(r) + return mappings, {} - def embed_for_indexing(self, text: str) -> Dict[str, Any]: - if self.embedding_service: - vector = self.embedding_service.embed_query(text) - return {self.vector_field: vector} - return {} + def before_index_creation( + self, client: Elasticsearch, text_field: str, vector_field: str + ) -> None: + if self.model_id: + model_must_be_deployed(client, self.model_id) def _hybrid( self, query: str, knn: Dict[str, Any], filter: List[Dict[str, Any]] @@ -389,53 +365,50 @@ def _hybrid( }, } - if isinstance(self.rrf, Dict[str, Any]): + if isinstance(self.rrf, Dict): query_body["rank"] = {"rrf": self.rrf} elif isinstance(self.rrf, bool) and self.rrf is True: query_body["rank"] = {"rrf": {}} return query_body + def needs_inference(self) -> bool: + return not self.model_id + class DenseVectorScriptScore(RetrievalStrategy): """Exact nearest neighbors retrieval using the `script_score` query.""" - def __init__( - self, - embedding_service: EmbeddingService, - vector_field: str = "vector_field", - distance: DistanceMetric = DistanceMetric.COSINE, - num_dimensions: Optional[int] = None, - ) -> None: - self.vector_field = vector_field + def __init__(self, distance: DistanceMetric = DistanceMetric.COSINE) -> None: self.distance = distance - self.embedding_service = embedding_service - self.num_dimensions = num_dimensions def es_query( self, query: Optional[str], + query_vector: Optional[List[float]], + text_field: str, + vector_field: str, k: int, num_candidates: int, filter: List[Dict[str, Any]] = [], - query_vector: Optional[List[float]] = None, ) -> Dict[str, Any]: + if not query_vector: + raise ValueError("specify a query_vector") + if self.distance is DistanceMetric.COSINE: similarityAlgo = ( - f"cosineSimilarity(params.query_vector, '{self.vector_field}') + 1.0" + f"cosineSimilarity(params.query_vector, '{vector_field}') + 1.0" ) elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE: - similarityAlgo = ( - f"1 / (1 + l2norm(params.query_vector, '{self.vector_field}'))" - ) + similarityAlgo = f"1 / (1 + l2norm(params.query_vector, '{vector_field}'))" elif self.distance is DistanceMetric.DOT_PRODUCT: similarityAlgo = f""" - double value = dotProduct(params.query_vector, '{self.vector_field}'); + double value = dotProduct(params.query_vector, '{vector_field}'); return sigmoid(1, Math.E, -value); """ elif self.distance is DistanceMetric.MAX_INNER_PRODUCT: similarityAlgo = f""" - double value = dotProduct(params.query_vector, '{self.vector_field}'); + double value = dotProduct(params.query_vector, '{vector_field}'); if (dotProduct < 0) {{ return 1 / (1 + -1 * dotProduct); }} @@ -448,16 +421,6 @@ def es_query( if filter: queryBool = {"bool": {"filter": filter}} - if not query_vector: - if not self.embedding_service: - raise ValueError( - "if not embedding_service is given, you need to " - "procive a query_vector" - ) - if not query: - raise ValueError("either specify a query string or a query_vector") - query_vector = self.embedding_service.embed_query(query) - return { "query": { "script_score": { @@ -470,55 +433,46 @@ def es_query( } } - def create_index( + def es_mappings_settings( self, - client: Elasticsearch, - index_name: str, - metadata_mapping: Optional[Dict[str, str]], - ) -> None: - if not self.num_dimensions: - self.num_dimensions = len( - self.embedding_service.embed_query("get number of dimensions") - ) - + text_field: str, + vector_field: str, + num_dimensions: Optional[int], + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: mappings = { "properties": { - self.vector_field: { + vector_field: { "type": "dense_vector", - "dims": self.num_dimensions, + "dims": num_dimensions, "index": False, } } } - if metadata_mapping: - mappings["properties"]["metadata"] = {"properties": metadata_mapping} - - client.indices.create(index=index_name, mappings=mappings) - return None + return mappings, {} - def embed_for_indexing(self, text: str) -> Dict[str, Any]: - return {self.vector_field: self.embedding_service.embed_query(text)} + def needs_inference(self) -> bool: + return True class BM25(RetrievalStrategy): def __init__( self, - text_field: str = "text_field", k1: Optional[float] = None, b: Optional[float] = None, ): - self.text_field = text_field self.k1 = k1 self.b = b def es_query( self, query: Optional[str], + query_vector: Optional[List[float]], + text_field: str, + vector_field: str, k: int, num_candidates: int, filter: List[Dict[str, Any]] = [], - query_vector: Optional[List[float]] = None, ) -> Dict[str, Any]: return { "query": { @@ -526,7 +480,7 @@ def es_query( "must": [ { "match": { - self.text_field: { + text_field: { "query": query, } }, @@ -537,24 +491,22 @@ def es_query( }, } - def create_index( + def es_mappings_settings( self, - client: Elasticsearch, - index_name: str, - metadata_mapping: Optional[Dict[str, str]], - ) -> None: + text_field: str, + vector_field: str, + num_dimensions: Optional[int], + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: similarity_name = "custom_bm25" mappings: Dict[str, Any] = { "properties": { - self.text_field: { + text_field: { "type": "text", "similarity": similarity_name, }, }, } - if metadata_mapping: - mappings["properties"]["metadata"] = {"properties": metadata_mapping} bm25: Dict[str, Any] = { "type": "BM25", @@ -569,4 +521,4 @@ def create_index( } } - client.indices.create(index=index_name, mappings=mappings, settings=settings) + return mappings, settings diff --git a/elasticsearch/vectorstore/_sync/vectorestore.py b/elasticsearch/vectorstore/_sync/vectorestore.py index fd465111f..862932eef 100644 --- a/elasticsearch/vectorstore/_sync/vectorestore.py +++ b/elasticsearch/vectorstore/_sync/vectorestore.py @@ -44,9 +44,11 @@ def __init__( user_agent: str, index_name: str, retrieval_strategy: RetrievalStrategy, + embedding_service: Optional[EmbeddingService] = None, + num_dimensions: Optional[int] = None, text_field: str = "text_field", vector_field: str = "vector_field", - metadata_mapping: Optional[Dict[str, str]] = None, + metadata_mappings: Optional[Dict[str, Any]] = None, ) -> None: """ Args: @@ -61,7 +63,6 @@ def __init__( es_client: Elasticsearch client connection. Alternatively specify the Elasticsearch connection with the other es_* parameters. """ - # Add integration-specific usage header for tracking usage in Elastic Cloud. # client.options preserces existing (non-user-agent) headers. es_client = es_client.options(headers={"User-Agent": user_agent}) @@ -74,9 +75,11 @@ def __init__( self.es_client = es_client self.index_name = index_name self.retrieval_strategy = retrieval_strategy + self.embedding_service = embedding_service + self.num_dimensions = num_dimensions self.text_field = text_field self.vector_field = vector_field - self.metadata_mapping = metadata_mapping + self.metadata_mappings = metadata_mappings def close(self) -> None: return self.es_client.close() @@ -118,6 +121,9 @@ def add_texts( if create_index_if_not_exists: self._create_index_if_not_exists() + if self.embedding_service and not vectors: + vectors = self.embedding_service.embed_documents(texts) + for i, text in enumerate(texts): metadata = metadatas[i] if metadatas else {} @@ -132,7 +138,6 @@ def add_texts( if vectors: request[self.vector_field] = vectors[i] - request.update(self.retrieval_strategy.embed_for_indexing(text)) requests.append(request) if len(requests) > 0: @@ -156,7 +161,7 @@ def add_texts( logger.debug("No texts to add to index") return [] - def delete( + def delete( # type: ignore[no-untyped-def] self, ids: Optional[List[str]] = None, query: Optional[Dict[str, Any]] = None, @@ -240,12 +245,19 @@ def search( if self.text_field not in fields: fields.append(self.text_field) + if self.embedding_service and not query_vector: + if not query: + raise ValueError("specify a query or a query_vector to search") + query_vector = self.embedding_service.embed_query(query) + query_body = self.retrieval_strategy.es_query( query=query, + query_vector=query_vector, + text_field=self.text_field, + vector_field=self.vector_field, k=k, num_candidates=num_candidates, filter=filter or [], - query_vector=query_vector, ) if custom_query is not None: @@ -259,19 +271,47 @@ def search( source=True, source_includes=fields, ) + hits: List[Dict[str, Any]] = response["hits"]["hits"] - return response["hits"]["hits"] + return hits def _create_index_if_not_exists(self) -> None: exists = self.es_client.indices.exists(index=self.index_name) if exists.meta.status == 200: logger.debug(f"Index {self.index_name} already exists. Skipping creation.") - else: - self.retrieval_strategy.create_index( - client=self.es_client, - index_name=self.index_name, - metadata_mapping=self.metadata_mapping, - ) + return + + if self.retrieval_strategy.needs_inference(): + if not self.num_dimensions and not self.embedding_service: + raise ValueError( + "retrieval strategy requires embeddings; either embedding_service " + "or num_dimensions need to be specified" + ) + if not self.num_dimensions and self.embedding_service: + vector = self.embedding_service.embed_query("get num dimensions") + self.num_dimensions = len(vector) + + mappings, settings = self.retrieval_strategy.es_mappings_settings( + text_field=self.text_field, + vector_field=self.vector_field, + num_dimensions=self.num_dimensions, + ) + + if self.metadata_mappings: + metadata = mappings["properties"].get("metadata", {"properties": {}}) + for key in self.metadata_mappings.keys(): + if key in metadata: + raise ValueError(f"metadata key {key} already exists in mappings") + + metadata = dict(**metadata["properties"], **self.metadata_mappings) + mappings["properties"] = {"metadata": {"properties": metadata}} + + self.retrieval_strategy.before_index_creation( + self.es_client, self.text_field, self.vector_field + ) + self.es_client.indices.create( + index=self.index_name, mappings=mappings, settings=settings + ) def max_marginal_relevance_search( self, diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py index 29779c9b1..4f5736e78 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py @@ -935,3 +935,32 @@ async def test_max_marginal_relevance_search( num_candidates=2, ) assert len(mmr_output) == 2 + + @pytest.mark.asyncio + async def test_metadata_mapping( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """Test that the metadata mapping is applied.""" + test_mappings = { + "my_field": {"type": "keyword"}, + "another_field": {"type": "text"}, + } + store = AsyncVectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(), + embedding_service=AsyncFakeEmbeddings(), + es_client=es_client, + metadata_mappings=test_mappings, + ) + + texts = ["foo", "foo", "foo"] + metadatas = [{"page": i} for i in range(len(texts))] + await store.add_texts(texts=texts, metadatas=metadatas) + + mapping_response = await es_client.indices.get_mapping(index=index_name) + mapping_properties = mapping_response[index_name]["mappings"]["properties"] + print(mapping_response) + assert "metadata" in mapping_properties + for key, val in test_mappings.items(): + assert mapping_properties["metadata"]["properties"][key] == val diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py b/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py index 5564f8b6f..0ce312ca7 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py @@ -104,7 +104,8 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: store = VectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + retrieval_strategy=DenseVector(), + embedding_service=FakeEmbeddings(), es_client=es_client, ) @@ -121,7 +122,8 @@ def test_search_without_metadata_async( store = VectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + retrieval_strategy=DenseVector(), + embedding_service=FakeEmbeddings(), es_client=es_client, ) @@ -149,7 +151,8 @@ def test_add_vectors(self, es_client: Elasticsearch, index_name: str) -> None: store = VectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(embedding_service=embeddings), + retrieval_strategy=DenseVector(), + embedding_service=embeddings, es_client=es_client, ) @@ -165,9 +168,8 @@ def test_search_with_metadata( store = VectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector( - embedding_service=ConsistentFakeEmbeddings() - ), + retrieval_strategy=DenseVector(), + embedding_service=ConsistentFakeEmbeddings(), es_client=es_client, ) @@ -190,7 +192,8 @@ def test_search_with_filter( store = VectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + retrieval_strategy=DenseVector(), + embedding_service=FakeEmbeddings(), es_client=es_client, ) @@ -226,9 +229,8 @@ def test_search_script_score( store = VectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVectorScriptScore( - embedding_service=FakeEmbeddings() - ), + retrieval_strategy=DenseVectorScriptScore(), + embedding_service=FakeEmbeddings(), es_client=es_client, ) @@ -274,9 +276,8 @@ def test_search_script_score_with_filter( store = VectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVectorScriptScore( - embedding_service=FakeEmbeddings() - ), + retrieval_strategy=DenseVectorScriptScore(), + embedding_service=FakeEmbeddings(), es_client=es_client, ) @@ -329,9 +330,9 @@ def test_search_script_score_distance_dot_product( user_agent="test", index_name=index_name, retrieval_strategy=DenseVectorScriptScore( - embedding_service=FakeEmbeddings(), distance=DistanceMetric.DOT_PRODUCT, ), + embedding_service=FakeEmbeddings(), es_client=es_client, ) @@ -378,10 +379,8 @@ def test_search_knn_with_hybrid_search( store = VectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector( - embedding_service=FakeEmbeddings(), - hybrid=True, - ), + retrieval_strategy=DenseVector(hybrid=True), + embedding_service=FakeEmbeddings(), es_client=es_client, ) @@ -467,11 +466,8 @@ def assert_query( store = VectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector( - embedding_service=FakeEmbeddings(), - hybrid=True, - rrf=rrf_test_case, - ), + retrieval_strategy=DenseVector(hybrid=True, rrf=rrf_test_case), + embedding_service=FakeEmbeddings(), es_client=es_client, ) store.add_texts(texts) @@ -511,10 +507,8 @@ def assert_query( store = VectorStore( user_agent="test", index_name=f"{index_name}_default", - retrieval_strategy=DenseVector( - embedding_service=FakeEmbeddings(), - hybrid=True, - ), + retrieval_strategy=DenseVector(hybrid=True), + embedding_service=FakeEmbeddings(), es_client=es_client, ) store.add_texts(texts) @@ -535,7 +529,8 @@ def test_search_knn_with_custom_query_fn( store = VectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + retrieval_strategy=DenseVector(), + embedding_service=FakeEmbeddings(), es_client=es_client, ) @@ -754,7 +749,8 @@ def test_delete(self, es_client: Elasticsearch, index_name: str) -> None: store = VectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(embedding_service=FakeEmbeddings()), + retrieval_strategy=DenseVector(), + embedding_service=FakeEmbeddings(), es_client=es_client, ) @@ -859,9 +855,8 @@ def test_max_marginal_relevance_search( store = VectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVectorScriptScore( - embedding_service=embedding_service - ), + retrieval_strategy=DenseVectorScriptScore(), + embedding_service=embedding_service, vector_field=vector_field, text_field=text_field, es_client=es_client, @@ -910,3 +905,29 @@ def test_max_marginal_relevance_search( num_candidates=2, ) assert len(mmr_output) == 2 + + def test_metadata_mapping(self, es_client: Elasticsearch, index_name: str) -> None: + """Test that the metadata mapping is applied.""" + test_mappings = { + "my_field": {"type": "keyword"}, + "another_field": {"type": "text"}, + } + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(), + embedding_service=FakeEmbeddings(), + es_client=es_client, + metadata_mappings=test_mappings, + ) + + texts = ["foo", "foo", "foo"] + metadatas = [{"page": i} for i in range(len(texts))] + store.add_texts(texts=texts, metadatas=metadatas) + + mapping_response = es_client.indices.get_mapping(index=index_name) + mapping_properties = mapping_response[index_name]["mappings"]["properties"] + print(mapping_response) + assert "metadata" in mapping_properties + for key, val in test_mappings.items(): + assert mapping_properties["metadata"]["properties"][key] == val From 76479614f86033edcd8cfcb4bdc23bdb4b6bde41 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Mon, 22 Apr 2024 14:39:07 +0200 Subject: [PATCH 11/36] fix typos in file names --- elasticsearch/vectorstore/_async/__init__.py | 2 +- .../vectorstore/_async/strategies.py | 1 - .../{vectorestore.py => vectorstore.py} | 0 elasticsearch/vectorstore/_sync/__init__.py | 2 +- elasticsearch/vectorstore/_sync/strategies.py | 1 - .../vectorstore/_sync/vectorstore.py | 381 +++++++ setup.py | 74 ++ ...st_vectorestore.py => test_vectorstore.py} | 2 +- .../_sync/test_vectorstore.py | 933 ++++++++++++++++++ 9 files changed, 1391 insertions(+), 5 deletions(-) rename elasticsearch/vectorstore/_async/{vectorestore.py => vectorstore.py} (100%) create mode 100644 elasticsearch/vectorstore/_sync/vectorstore.py rename test_elasticsearch/test_server/test_vectorstore/_async/{test_vectorestore.py => test_vectorstore.py} (99%) create mode 100644 test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py diff --git a/elasticsearch/vectorstore/_async/__init__.py b/elasticsearch/vectorstore/_async/__init__.py index 0d0e0ac9f..a16a14343 100644 --- a/elasticsearch/vectorstore/_async/__init__.py +++ b/elasticsearch/vectorstore/_async/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from elasticsearch.vectorstore._async.vectorestore import AsyncVectorStore +from elasticsearch.vectorstore._async.vectorstore import AsyncVectorStore __all__ = [ "AsyncVectorStore", diff --git a/elasticsearch/vectorstore/_async/strategies.py b/elasticsearch/vectorstore/_async/strategies.py index 6c5b0faf7..ab0f9c35b 100644 --- a/elasticsearch/vectorstore/_async/strategies.py +++ b/elasticsearch/vectorstore/_async/strategies.py @@ -98,7 +98,6 @@ def needs_inference(self) -> bool: return False -# TODO test when repsective image is released class Semantic(RetrievalStrategy): """Dense or sparse retrieval with in-stack inference using semantic_text fields.""" diff --git a/elasticsearch/vectorstore/_async/vectorestore.py b/elasticsearch/vectorstore/_async/vectorstore.py similarity index 100% rename from elasticsearch/vectorstore/_async/vectorestore.py rename to elasticsearch/vectorstore/_async/vectorstore.py diff --git a/elasticsearch/vectorstore/_sync/__init__.py b/elasticsearch/vectorstore/_sync/__init__.py index 903dc00ed..fadcc6b9f 100644 --- a/elasticsearch/vectorstore/_sync/__init__.py +++ b/elasticsearch/vectorstore/_sync/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from elasticsearch.vectorstore._sync.vectorestore import VectorStore +from elasticsearch.vectorstore._sync.vectorstore import VectorStore __all__ = [ "VectorStore", diff --git a/elasticsearch/vectorstore/_sync/strategies.py b/elasticsearch/vectorstore/_sync/strategies.py index f31fc1113..cb0f75e83 100644 --- a/elasticsearch/vectorstore/_sync/strategies.py +++ b/elasticsearch/vectorstore/_sync/strategies.py @@ -98,7 +98,6 @@ def needs_inference(self) -> bool: return False -# TODO test when repsective image is released class Semantic(RetrievalStrategy): """Dense or sparse retrieval with in-stack inference using semantic_text fields.""" diff --git a/elasticsearch/vectorstore/_sync/vectorstore.py b/elasticsearch/vectorstore/_sync/vectorstore.py new file mode 100644 index 000000000..862932eef --- /dev/null +++ b/elasticsearch/vectorstore/_sync/vectorstore.py @@ -0,0 +1,381 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +import uuid +from typing import Any, Callable, Dict, List, Optional + +from elasticsearch import Elasticsearch +from elasticsearch.helpers import BulkIndexError, bulk +from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService +from elasticsearch.vectorstore._sync.strategies import RetrievalStrategy +from elasticsearch.vectorstore._utils import maximal_marginal_relevance + +logger = logging.getLogger(__name__) + + +class VectorStore: + """VectorStore is a higher-level abstraction of indexing and search. + Users can pick from available retrieval strategies. + + Documents are flat text documents. Depending on the strategy, vector embeddings are + - created by the user beforehand + - created by this class in Python + - created in-stack by inference pipelines. + """ + + def __init__( + self, + es_client: Elasticsearch, + user_agent: str, + index_name: str, + retrieval_strategy: RetrievalStrategy, + embedding_service: Optional[EmbeddingService] = None, + num_dimensions: Optional[int] = None, + text_field: str = "text_field", + vector_field: str = "vector_field", + metadata_mappings: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Args: + user_header: user agent header specific to the 3rd party integration. + Used for usage tracking in Elastic Cloud. + index_name: The name of the index to query. + retrieval_strategy: how to index and search the data. See the strategies + module for availble strategies. + text_field: Name of the field with the textual data. + vector_field: For strategies that perform embedding inference in Python, + the embedding vector goes in this field. + es_client: Elasticsearch client connection. Alternatively specify the + Elasticsearch connection with the other es_* parameters. + """ + # Add integration-specific usage header for tracking usage in Elastic Cloud. + # client.options preserces existing (non-user-agent) headers. + es_client = es_client.options(headers={"User-Agent": user_agent}) + + if hasattr(retrieval_strategy, "text_field"): + retrieval_strategy.text_field = text_field + if hasattr(retrieval_strategy, "vector_field"): + retrieval_strategy.vector_field = vector_field + + self.es_client = es_client + self.index_name = index_name + self.retrieval_strategy = retrieval_strategy + self.embedding_service = embedding_service + self.num_dimensions = num_dimensions + self.text_field = text_field + self.vector_field = vector_field + self.metadata_mappings = metadata_mappings + + def close(self) -> None: + return self.es_client.close() + + def add_texts( + self, + texts: List[str], + metadatas: Optional[List[Dict[str, Any]]] = None, + vectors: Optional[List[List[float]]] = None, + ids: Optional[List[str]] = None, + refresh_indices: bool = True, + create_index_if_not_exists: bool = True, + bulk_kwargs: Optional[Dict[str, Any]] = None, + ) -> List[str]: + """Add documents to the Elasticsearch index. + + Args: + texts: List of text documents. + metadata: Optional list of document metadata. Must be of same length as + texts. + vectors: Optional list of embedding vectors. Must be of same length as + texts. + ids: Optional list of ID strings. Must be of same length as texts. + refresh_indices: Whether to refresh the index after deleting documents. + Defaults to True. + create_index_if_not_exists: Whether to create the index if it does not + exist. Defaults to True. + bulk_kwargs: Arguments to pass to the bulk function when indexing + (for example chunk_size). + + Returns: + List of IDs of the created documents, either echoing the provided one + or returning newly created ones. + """ + bulk_kwargs = bulk_kwargs or {} + ids = ids or [str(uuid.uuid4()) for _ in texts] + requests = [] + + if create_index_if_not_exists: + self._create_index_if_not_exists() + + if self.embedding_service and not vectors: + vectors = self.embedding_service.embed_documents(texts) + + for i, text in enumerate(texts): + metadata = metadatas[i] if metadatas else {} + + request: Dict[str, Any] = { + "_op_type": "index", + "_index": self.index_name, + self.text_field: text, + "metadata": metadata, + "_id": ids[i], + } + + if vectors: + request[self.vector_field] = vectors[i] + + requests.append(request) + + if len(requests) > 0: + try: + success, failed = bulk( + self.es_client, + requests, + stats_only=True, + refresh=refresh_indices, + **bulk_kwargs, + ) + logger.debug(f"added texts {ids} to index") + return ids + except BulkIndexError as e: + logger.error(f"Error adding texts: {e}") + firstError = e.errors[0].get("index", {}).get("error", {}) + logger.error(f"First error reason: {firstError.get('reason')}") + raise e + + else: + logger.debug("No texts to add to index") + return [] + + def delete( # type: ignore[no-untyped-def] + self, + ids: Optional[List[str]] = None, + query: Optional[Dict[str, Any]] = None, + refresh_indices: bool = True, + **delete_kwargs, + ) -> bool: + """Delete documents from the Elasticsearch index. + + Args: + ids: List of IDs of documents to delete. + refresh_indices: Whether to refresh the index after deleting documents. + Defaults to True. + """ + if ids is not None and query is not None: + raise ValueError("one of ids or query must be specified") + elif ids is None and query is None: + raise ValueError("either specify ids or query") + + try: + if ids: + body = [ + {"_op_type": "delete", "_index": self.index_name, "_id": _id} + for _id in ids + ] + bulk( + self.es_client, + body, + refresh=refresh_indices, + ignore_status=404, + **delete_kwargs, + ) + logger.debug(f"Deleted {len(body)} texts from index") + + else: + self.es_client.delete_by_query( + index=self.index_name, + query=query, + refresh=refresh_indices, + **delete_kwargs, + ) + + except BulkIndexError as e: + logger.error(f"Error deleting texts: {e}") + firstError = e.errors[0].get("index", {}).get("error", {}) + logger.error(f"First error reason: {firstError.get('reason')}") + raise e + + return True + + def search( + self, + query: Optional[str], + query_vector: Optional[List[float]] = None, + k: int = 4, + num_candidates: int = 50, + fields: Optional[List[str]] = None, + filter: Optional[List[Dict[str, Any]]] = None, + custom_query: Optional[ + Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] + ] = None, + ) -> List[Dict[str, Any]]: + """ + Args: + query: Input query string. + query_vector: Input embedding vector. If given, input query string is + ignored. + k: Number of returned results. + num_candidates: Number of candidates to fetch from data nodes in knn. + fields: List of field names to return. + filter: Elasticsearch filters to apply. + custom_query: Function to modify the Elasticsearch query body before it is + sent to Elasticsearch. + + Returns: + List of document hits. Includes _index, _id, _score and _source. + """ + if fields is None: + fields = [] + if "metadata" not in fields: + fields.append("metadata") + if self.text_field not in fields: + fields.append(self.text_field) + + if self.embedding_service and not query_vector: + if not query: + raise ValueError("specify a query or a query_vector to search") + query_vector = self.embedding_service.embed_query(query) + + query_body = self.retrieval_strategy.es_query( + query=query, + query_vector=query_vector, + text_field=self.text_field, + vector_field=self.vector_field, + k=k, + num_candidates=num_candidates, + filter=filter or [], + ) + + if custom_query is not None: + query_body = custom_query(query_body, query) + logger.debug(f"Calling custom_query, Query body now: {query_body}") + + response = self.es_client.search( + index=self.index_name, + **query_body, + size=k, + source=True, + source_includes=fields, + ) + hits: List[Dict[str, Any]] = response["hits"]["hits"] + + return hits + + def _create_index_if_not_exists(self) -> None: + exists = self.es_client.indices.exists(index=self.index_name) + if exists.meta.status == 200: + logger.debug(f"Index {self.index_name} already exists. Skipping creation.") + return + + if self.retrieval_strategy.needs_inference(): + if not self.num_dimensions and not self.embedding_service: + raise ValueError( + "retrieval strategy requires embeddings; either embedding_service " + "or num_dimensions need to be specified" + ) + if not self.num_dimensions and self.embedding_service: + vector = self.embedding_service.embed_query("get num dimensions") + self.num_dimensions = len(vector) + + mappings, settings = self.retrieval_strategy.es_mappings_settings( + text_field=self.text_field, + vector_field=self.vector_field, + num_dimensions=self.num_dimensions, + ) + + if self.metadata_mappings: + metadata = mappings["properties"].get("metadata", {"properties": {}}) + for key in self.metadata_mappings.keys(): + if key in metadata: + raise ValueError(f"metadata key {key} already exists in mappings") + + metadata = dict(**metadata["properties"], **self.metadata_mappings) + mappings["properties"] = {"metadata": {"properties": metadata}} + + self.retrieval_strategy.before_index_creation( + self.es_client, self.text_field, self.vector_field + ) + self.es_client.indices.create( + index=self.index_name, mappings=mappings, settings=settings + ) + + def max_marginal_relevance_search( + self, + embedding_service: EmbeddingService, + query: str, + vector_field: str, + k: int = 4, + num_candidates: int = 20, + lambda_mult: float = 0.5, + fields: Optional[List[str]] = None, + custom_query: Optional[ + Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] + ] = None, + ) -> List[Dict[str, Any]]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + fields: Other fields to get from elasticsearch source. These fields + will be added to the document metadata. + + Returns: + List[Document]: A list of Documents selected by maximal marginal relevance. + """ + remove_vector_query_field_from_metadata = True + if fields is None: + fields = [vector_field] + elif vector_field not in fields: + fields.append(vector_field) + else: + remove_vector_query_field_from_metadata = False + + # Embed the query + query_embedding = embedding_service.embed_query(query) + + # Fetch the initial documents + got_hits = self.search( + query=None, + query_vector=query_embedding, + k=num_candidates, + fields=fields, + custom_query=custom_query, + ) + + # Get the embeddings for the fetched documents + got_embeddings = [hit["_source"][vector_field] for hit in got_hits] + + # Select documents using maximal marginal relevance + selected_indices = maximal_marginal_relevance( + query_embedding, got_embeddings, lambda_mult=lambda_mult, k=k + ) + selected_hits = [got_hits[i] for i in selected_indices] + + if remove_vector_query_field_from_metadata: + for hit in selected_hits: + del hit["_source"][vector_field] + + return selected_hits diff --git a/setup.py b/setup.py index dc592dcc4..39ffc7191 100644 --- a/setup.py +++ b/setup.py @@ -94,3 +94,77 @@ "orjson": ["orjson>=3"], }, ) + +vectorstore_package_name = "elasticsearch[vectorstore]" +base_dir = abspath(dirname(__file__)) + +with open(join(base_dir, package_name, "_version.py")) as f: + package_version = re.search( + r"__versionstr__\s+=\s+[\"\']([^\"\']+)[\"\']", f.read() + ).group(1) + +with open(join(base_dir, "README.rst")) as f: + # Remove reST raw directive from README as they're not allowed on PyPI + # Those blocks start with a newline and continue until the next newline + mode = None + lines = [] + for line in f: + if line.startswith(".. raw::"): + mode = "ignore_nl" + elif line == "\n": + mode = "wait_nl" if mode == "ignore_nl" else None + if mode is None: + lines.append(line) + + long_description = "".join(lines) + + +packages = [ + package + for package in find_packages(where=".", exclude=("test_elasticsearch*",)) + if package == package_name or package.startswith(package_name + ".") +] + +setup( + name=package_name, + description="Python client for Elasticsearch", + license="Apache-2.0", + url="https://github.com/elastic/elasticsearch-py", + long_description=long_description, + long_description_content_type="text/x-rst", + version=package_version, + author="Elastic Client Library Maintainers", + author_email="client-libs@elastic.co", + project_urls={ + "Documentation": "https://elasticsearch-py.readthedocs.io", + "Source Code": "https://github.com/elastic/elasticsearch-py", + "Issue Tracker": "https://github.com/elastic/elasticsearch-py/issues", + }, + packages=packages, + package_data={"elasticsearch": ["py.typed", "*.pyi"]}, + include_package_data=True, + zip_safe=False, + classifiers=[ + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: Apache Software License", + "Intended Audience :: Developers", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + ], + python_requires=">=3.7", + install_requires=["elastic-transport>=8.13,<9"], + extras_require={ + "requests": ["requests>=2.4.0, <3.0.0"], + "async": ["aiohttp>=3,<4"], + "orjson": ["orjson>=3"], + }, +) diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py similarity index 99% rename from test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py rename to test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py index 4f5736e78..ec5e7cb4e 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py @@ -66,7 +66,7 @@ TRANSFORMER_MODEL_ID = "sentence-transformers__all-minilm-l6-v2" -class TestElasticsearch: +class TestVectorStore: @pytest_asyncio.fixture async def es_client(self) -> AsyncIterator[AsyncElasticsearch]: async for x in es_client_fixture(): diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py new file mode 100644 index 000000000..60d47bb4f --- /dev/null +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py @@ -0,0 +1,933 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +import uuid +from functools import partial +from typing import Any, Iterator, List, Optional, Union, cast + +import pytest + +from elasticsearch import Elasticsearch, NotFoundError +from elasticsearch.helpers import BulkIndexError +from elasticsearch.vectorstore._sync import VectorStore +from elasticsearch.vectorstore._sync._utils import model_is_deployed +from elasticsearch.vectorstore._sync.strategies import ( + BM25, + DenseVector, + DenseVectorScriptScore, + DistanceMetric, + Semantic, +) + +from ._test_utils import ( + ConsistentFakeEmbeddings, + FakeEmbeddings, + RequestSavingTransport, + create_requests_saving_client, + es_client_fixture, +) + +logging.basicConfig(level=logging.DEBUG) + +""" +docker-compose up elasticsearch + +By default runs against local docker instance of Elasticsearch. +To run against Elastic Cloud, set the following environment variables: +- ES_CLOUD_ID +- ES_API_KEY + +Some of the tests require the following models to be deployed in the ML Node: +- elser (can be downloaded and deployed through Kibana and trained models UI) +- sentence-transformers__all-minilm-l6-v2 (can be deployed through the API, + loaded via eland) + +These tests that require the models to be deployed are skipped by default. +Enable them by adding the model name to the modelsDeployed list below. +""" + +ELSER_MODEL_ID = ".elser_model_2" +TRANSFORMER_MODEL_ID = "sentence-transformers__all-minilm-l6-v2" + + +class TestVectorStore: + @pytest.fixture + def es_client(self) -> Iterator[Elasticsearch]: + for x in es_client_fixture(): + yield x + + @pytest.fixture + def requests_saving_client(self) -> Iterator[Elasticsearch]: + client = create_requests_saving_client() + try: + yield client + finally: + client.close() + + @pytest.fixture(scope="function") + def index_name(self) -> str: + """Return the index name.""" + return f"test_{uuid.uuid4().hex}" + + def test_search_without_metadata( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and search without metadata.""" + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "knn": { + "field": "vector_field", + "filter": [], + "k": 1, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + } + } + return query_body + + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(), + embedding_service=FakeEmbeddings(), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + def test_search_without_metadata_async( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and search without metadata.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(), + embedding_service=FakeEmbeddings(), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + output = store.search("foo", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + def test_add_vectors(self, es_client: Elasticsearch, index_name: str) -> None: + """ + Test adding pre-built embeddings instead of using inference for the texts. + This allows you to separate the embeddings text and the page_content + for better proximity between user's question and embedded text. + For example, your embedding text can be a question, whereas page_content + is the answer. + """ + embeddings = ConsistentFakeEmbeddings() + texts = ["foo1", "foo2", "foo3"] + metadatas = [{"page": i} for i in range(len(texts))] + + """In real use case, embedding_input can be questions for each text""" + embedding_vectors = embeddings.embed_documents(texts) + + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(), + embedding_service=embeddings, + es_client=es_client, + ) + + store.add_texts(texts=texts, vectors=embedding_vectors, metadatas=metadatas) + output = store.search("foo1", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["foo1"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] + + def test_search_with_metadata( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(), + embedding_service=ConsistentFakeEmbeddings(), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + store.add_texts(texts=texts, metadatas=metadatas) + + output = store.search("foo", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] + + output = store.search("bar", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["bar"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] + + def test_search_with_filter( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(), + embedding_service=FakeEmbeddings(), + es_client=es_client, + ) + + texts = ["foo", "foo", "foo"] + metadatas = [{"page": i} for i in range(len(texts))] + store.add_texts(texts=texts, metadatas=metadatas) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "knn": { + "field": "vector_field", + "filter": [{"term": {"metadata.page": "1"}}], + "k": 3, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + } + } + return query_body + + output = store.search( + query="foo", + k=3, + filter=[{"term": {"metadata.page": "1"}}], + custom_query=assert_query, + ) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] + + def test_search_script_score( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVectorScriptScore(), + embedding_service=FakeEmbeddings(), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + expected_query = { + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": "cosineSimilarity(params.query_vector, 'vector_field') + 1.0", # noqa: E501 + "params": { + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ] + }, + }, + } + } + } + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == expected_query + return query_body + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + def test_search_script_score_with_filter( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVectorScriptScore(), + embedding_service=FakeEmbeddings(), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + store.add_texts(texts=texts, metadatas=metadatas) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + expected_query = { + "query": { + "script_score": { + "query": {"bool": {"filter": [{"term": {"metadata.page": 0}}]}}, + "script": { + "source": "cosineSimilarity(params.query_vector, 'vector_field') + 1.0", # noqa: E501 + "params": { + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ] + }, + }, + } + } + } + assert query_body == expected_query + return query_body + + output = store.search( + "foo", + k=1, + custom_query=assert_query, + filter=[{"term": {"metadata.page": 0}}], + ) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] + + def test_search_script_score_distance_dot_product( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVectorScriptScore( + distance=DistanceMetric.DOT_PRODUCT, + ), + embedding_service=FakeEmbeddings(), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": """ + double value = dotProduct(params.query_vector, 'vector_field'); + return sigmoid(1, Math.E, -value); + """, + "params": { + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ] + }, + }, + } + } + } + return query_body + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + def test_search_knn_with_hybrid_search( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(hybrid=True), + embedding_service=FakeEmbeddings(), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "knn": { + "field": "vector_field", + "filter": [], + "k": 1, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + }, + "query": { + "bool": { + "filter": [], + "must": [{"match": {"text_field": {"query": "foo"}}}], + } + }, + "rank": {"rrf": {}}, + } + return query_body + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + def test_search_knn_with_hybrid_search_rrf( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to end construction and rrf hybrid search with metadata.""" + texts = ["foo", "bar", "baz"] + + def assert_query( + query_body: dict, + query: Optional[str], + expected_rrf: Union[dict, bool], + ) -> dict: + cmp_query_body = { + "knn": { + "field": "vector_field", + "filter": [], + "k": 3, + "num_candidates": 50, + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ], + }, + "query": { + "bool": { + "filter": [], + "must": [{"match": {"text_field": {"query": "foo"}}}], + } + }, + } + + if isinstance(expected_rrf, dict): + cmp_query_body["rank"] = {"rrf": expected_rrf} + elif isinstance(expected_rrf, bool) and expected_rrf is True: + cmp_query_body["rank"] = {"rrf": {}} + + assert query_body == cmp_query_body + + return query_body + + # 1. check query_body is okay + rrf_test_cases: List[Union[dict, bool]] = [ + True, + False, + {"rank_constant": 1, "window_size": 5}, + ] + for rrf_test_case in rrf_test_cases: + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(hybrid=True, rrf=rrf_test_case), + embedding_service=FakeEmbeddings(), + es_client=es_client, + ) + store.add_texts(texts) + + ## without fetch_k parameter + output = store.search( + "foo", + k=3, + custom_query=partial(assert_query, expected_rrf=rrf_test_case), + ) + + # 2. check query result is okay + es_output = store.es_client.search( + index=index_name, + query={ + "bool": { + "filter": [], + "must": [{"match": {"text_field": {"query": "foo"}}}], + } + }, + knn={ + "field": "vector_field", + "filter": [], + "k": 3, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + }, + size=3, + rank={"rrf": {"rank_constant": 1, "window_size": 5}}, + ) + + assert [o["_source"]["text_field"] for o in output] == [ + e["_source"]["text_field"] for e in es_output["hits"]["hits"] + ] + + # 3. check rrf default option is okay + store = VectorStore( + user_agent="test", + index_name=f"{index_name}_default", + retrieval_strategy=DenseVector(hybrid=True), + embedding_service=FakeEmbeddings(), + es_client=es_client, + ) + store.add_texts(texts) + + ## with fetch_k parameter + output = store.search( + "foo", + k=3, + num_candidates=50, + custom_query=partial(assert_query, expected_rrf={}), + ) + + def test_search_knn_with_custom_query_fn( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """test that custom query function is called + with the query string and query body""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(), + embedding_service=FakeEmbeddings(), + es_client=es_client, + ) + + def my_custom_query(query_body: dict, query: Optional[str]) -> dict: + assert query == "foo" + assert query_body == { + "knn": { + "field": "vector_field", + "filter": [], + "k": 1, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + } + } + return {"query": {"match": {"text_field": {"query": "bar"}}}} + + """Test end to end construction and search with metadata.""" + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + output = store.search("foo", k=1, custom_query=my_custom_query) + assert [doc["_source"]["text_field"] for doc in output] == ["bar"] + + def test_search_with_knn_infer_instack( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """test end to end with knn retrieval strategy and inference in-stack""" + + if not model_is_deployed(es_client, TRANSFORMER_MODEL_ID): + pytest.skip( + f"{TRANSFORMER_MODEL_ID} model not deployed in ML Node skipping test" + ) + + text_field = "text_field" + + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=Semantic( + model_id="sentence-transformers__all-minilm-l6-v2", + text_field=text_field, + ), + es_client=es_client, + ) + + # setting up the pipeline for inference + store.es_client.ingest.put_pipeline( + id="test_pipeline", + processors=[ + { + "inference": { + "model_id": TRANSFORMER_MODEL_ID, + "field_map": {"query_field": text_field}, + "target_field": "vector_query_field", + } + } + ], + ) + + # creating a new index with the pipeline, + # not relying on langchain to create the index + store.es_client.indices.create( + index=index_name, + mappings={ + "properties": { + text_field: {"type": "text_field"}, + "vector_query_field": { + "properties": { + "predicted_value": { + "type": "dense_vector", + "dims": 384, + "index": True, + "similarity": "l2_norm", + } + } + }, + } + }, + settings={"index": {"default_pipeline": "test_pipeline"}}, + ) + + # adding documents to the index + texts = ["foo", "bar", "baz"] + + for i, text in enumerate(texts): + store.es_client.create( + index=index_name, + id=str(i), + document={text_field: text, "metadata": {}}, + ) + + store.es_client.indices.refresh(index=index_name) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "knn": { + "filter": [], + "field": "vector_query_field.predicted_value", + "k": 1, + "num_candidates": 50, + "query_vector_builder": { + "text_embedding": { + "model_id": TRANSFORMER_MODEL_ID, + "model_text": "foo", + } + }, + } + } + return query_body + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + output = store.search("bar", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["bar"] + + def test_search_with_sparse_infer_instack( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """test end to end with sparse retrieval strategy and inference in-stack""" + + if not model_is_deployed(es_client, ELSER_MODEL_ID): + reason = f"{ELSER_MODEL_ID} model not deployed in ML Node, skipping test" + + pytest.skip(reason) + + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=Semantic(model_id=ELSER_MODEL_ID), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + output = store.search("foo", k=1) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + def test_deployed_model_check_fails_semantic( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """test that exceptions are raised if a specified model is not deployed""" + with pytest.raises(NotFoundError): + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=Semantic(model_id="non-existing model ID"), + es_client=es_client, + ) + store.add_texts(["foo", "bar", "baz"]) + + def test_search_bm25(self, es_client: Elasticsearch, index_name: str) -> None: + """Test end to end using the BM25 retrieval strategy.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=BM25(), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz"] + store.add_texts(texts) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "query": { + "bool": { + "must": [{"match": {"text_field": {"query": "foo"}}}], + "filter": [], + } + } + } + return query_body + + output = store.search("foo", k=1, custom_query=assert_query) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + + def test_search_bm25_with_filter( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test end to using the BM25 retrieval strategy with metadata.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=BM25(), + es_client=es_client, + ) + + texts = ["foo", "foo", "foo"] + metadatas = [{"page": i} for i in range(len(texts))] + store.add_texts(texts=texts, metadatas=metadatas) + + def assert_query(query_body: dict, query: Optional[str]) -> dict: + assert query_body == { + "query": { + "bool": { + "must": [{"match": {"text_field": {"query": "foo"}}}], + "filter": [{"term": {"metadata.page": 1}}], + } + } + } + return query_body + + output = store.search( + "foo", + k=3, + custom_query=assert_query, + filter=[{"term": {"metadata.page": 1}}], + ) + assert [doc["_source"]["text_field"] for doc in output] == ["foo"] + assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] + + def test_delete(self, es_client: Elasticsearch, index_name: str) -> None: + """Test delete methods from vector store.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(), + embedding_service=FakeEmbeddings(), + es_client=es_client, + ) + + texts = ["foo", "bar", "baz", "gni"] + metadatas = [{"page": i} for i in range(len(texts))] + ids = store.add_texts(texts=texts, metadatas=metadatas) + + output = store.search("foo", k=10) + assert len(output) == 4 + + store.delete(ids[1:3]) + output = store.search("foo", k=10) + assert len(output) == 2 + + store.delete(["not-existing"]) + output = store.search("foo", k=10) + assert len(output) == 2 + + store.delete([ids[0]]) + output = store.search("foo", k=10) + assert len(output) == 1 + + store.delete([ids[3]]) + output = store.search("gni", k=10) + assert len(output) == 0 + + def test_indexing_exception_error( + self, + es_client: Elasticsearch, + index_name: str, + caplog: pytest.LogCaptureFixture, + ) -> None: + """Test bulk exception logging is giving better hints.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=BM25(), + es_client=es_client, + ) + + store.es_client.indices.create( + index=index_name, + mappings={"properties": {}}, + settings={"index": {"default_pipeline": "not-existing-pipeline"}}, + ) + + texts = ["foo"] + + with pytest.raises(BulkIndexError): + store.add_texts(texts) + + error_reason = "pipeline with id [not-existing-pipeline] does not exist" + log_message = f"First error reason: {error_reason}" + + assert log_message in caplog.text + + def test_user_agent( + self, requests_saving_client: Elasticsearch, index_name: str + ) -> None: + """Test to make sure the user-agent is set correctly.""" + user_agent = "this is THE user_agent!" + store = VectorStore( + user_agent=user_agent, + index_name=index_name, + retrieval_strategy=BM25(), + es_client=requests_saving_client, + ) + + assert store.es_client._headers["User-Agent"] == user_agent + + texts = ["foo", "bob", "baz"] + store.add_texts(texts) + + transport = cast(RequestSavingTransport, store.es_client.transport) + + for request in transport.requests: + assert request["headers"]["User-Agent"] == user_agent + + def test_bulk_args(self, requests_saving_client: Any, index_name: str) -> None: + """Test to make sure the bulk arguments work as expected.""" + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=BM25(), + es_client=requests_saving_client, + ) + + texts = ["foo", "bob", "baz"] + store.add_texts(texts, bulk_kwargs={"chunk_size": 1}) + + # 1 for index exist, 1 for index create, 3 to index docs + assert len(store.es_client.transport.requests) == 5 # type: ignore + + def test_max_marginal_relevance_search( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test max marginal relevance search.""" + texts = ["foo", "bar", "baz"] + vector_field = "vector_field" + text_field = "text_field" + embedding_service = ConsistentFakeEmbeddings() + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVectorScriptScore(), + embedding_service=embedding_service, + vector_field=vector_field, + text_field=text_field, + es_client=es_client, + ) + store.add_texts(texts) + + mmr_output = store.max_marginal_relevance_search( + embedding_service, + texts[0], + vector_field=vector_field, + k=3, + num_candidates=3, + ) + sim_output = store.search(texts[0], k=3) + assert mmr_output == sim_output + + mmr_output = store.max_marginal_relevance_search( + embedding_service, + texts[0], + vector_field=vector_field, + k=2, + num_candidates=3, + ) + assert len(mmr_output) == 2 + assert mmr_output[0]["_source"][text_field] == texts[0] + assert mmr_output[1]["_source"][text_field] == texts[1] + + mmr_output = store.max_marginal_relevance_search( + embedding_service, + texts[0], + vector_field=vector_field, + k=2, + num_candidates=3, + lambda_mult=0.1, # more diversity + ) + assert len(mmr_output) == 2 + assert mmr_output[0]["_source"][text_field] == texts[0] + assert mmr_output[1]["_source"][text_field] == texts[2] + + # if fetch_k < k, then the output will be less than k + mmr_output = store.max_marginal_relevance_search( + embedding_service, + texts[0], + vector_field=vector_field, + k=3, + num_candidates=2, + ) + assert len(mmr_output) == 2 + + def test_metadata_mapping(self, es_client: Elasticsearch, index_name: str) -> None: + """Test that the metadata mapping is applied.""" + test_mappings = { + "my_field": {"type": "keyword"}, + "another_field": {"type": "text"}, + } + store = VectorStore( + user_agent="test", + index_name=index_name, + retrieval_strategy=DenseVector(), + embedding_service=FakeEmbeddings(), + es_client=es_client, + metadata_mappings=test_mappings, + ) + + texts = ["foo", "foo", "foo"] + metadatas = [{"page": i} for i in range(len(texts))] + store.add_texts(texts=texts, metadatas=metadatas) + + mapping_response = es_client.indices.get_mapping(index=index_name) + mapping_properties = mapping_response[index_name]["mappings"]["properties"] + print(mapping_response) + assert "metadata" in mapping_properties + for key, val in test_mappings.items(): + assert mapping_properties["metadata"]["properties"][key] == val From d3979828bd9e47d2343d27f0b2535c95887117f7 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Mon, 22 Apr 2024 16:22:49 +0200 Subject: [PATCH 12/36] use `elasticsearch_url` fixture; create conftest.py --- .../vectorstore/_async/strategies.py | 4 +- elasticsearch/vectorstore/_sync/strategies.py | 4 +- elasticsearch/vectorstore/_utils.py | 4 +- setup.py | 74 -- .../test_vectorstore/_async/_test_utils.py | 70 +- .../test_vectorstore/_async/conftest.py | 103 ++ .../_async/test_embedding_service.py | 10 - .../_async/test_vectorstore.py | 24 +- .../test_vectorstore/_sync/_test_utils.py | 70 +- .../test_vectorstore/_sync/conftest.py | 100 ++ .../_sync/test_embedding_service.py | 9 - .../_sync/test_vectorestore.py | 933 ------------------ .../_sync/test_vectorstore.py | 23 +- 13 files changed, 211 insertions(+), 1217 deletions(-) create mode 100644 test_elasticsearch/test_server/test_vectorstore/_async/conftest.py create mode 100644 test_elasticsearch/test_server/test_vectorstore/_sync/conftest.py delete mode 100644 test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py diff --git a/elasticsearch/vectorstore/_async/strategies.py b/elasticsearch/vectorstore/_async/strategies.py index ab0f9c35b..238855247 100644 --- a/elasticsearch/vectorstore/_async/strategies.py +++ b/elasticsearch/vectorstore/_async/strategies.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Union, cast from elasticsearch import AsyncElasticsearch from elasticsearch.vectorstore._async._utils import model_must_be_deployed @@ -250,7 +250,6 @@ class DenseVector(RetrievalStrategy): def __init__( self, - knn_type: Literal["hnsw", "int8_hnsw", "flat", "int8_flat"] = "hnsw", distance: DistanceMetric = DistanceMetric.COSINE, model_id: Optional[str] = None, hybrid: bool = False, @@ -262,7 +261,6 @@ def __init__( "to enable hybrid you have to specify a text_field (for BM25 matching)" ) - self.knn_type = knn_type self.distance = distance self.model_id = model_id self.hybrid = hybrid diff --git a/elasticsearch/vectorstore/_sync/strategies.py b/elasticsearch/vectorstore/_sync/strategies.py index cb0f75e83..a10222f12 100644 --- a/elasticsearch/vectorstore/_sync/strategies.py +++ b/elasticsearch/vectorstore/_sync/strategies.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Union, cast from elasticsearch import Elasticsearch from elasticsearch.vectorstore._sync._utils import model_must_be_deployed @@ -250,7 +250,6 @@ class DenseVector(RetrievalStrategy): def __init__( self, - knn_type: Literal["hnsw", "int8_hnsw", "flat", "int8_flat"] = "hnsw", distance: DistanceMetric = DistanceMetric.COSINE, model_id: Optional[str] = None, hybrid: bool = False, @@ -262,7 +261,6 @@ def __init__( "to enable hybrid you have to specify a text_field (for BM25 matching)" ) - self.knn_type = knn_type self.distance = distance self.model_id = model_id self.hybrid = hybrid diff --git a/elasticsearch/vectorstore/_utils.py b/elasticsearch/vectorstore/_utils.py index 1eb7e9026..5c72962c5 100644 --- a/elasticsearch/vectorstore/_utils.py +++ b/elasticsearch/vectorstore/_utils.py @@ -23,8 +23,8 @@ def maximal_marginal_relevance( - query_embedding: list[float], - embedding_list: list[list[float]], + query_embedding: List[float], + embedding_list: List[List[float]], lambda_mult: float = 0.5, k: int = 4, ) -> List[int]: diff --git a/setup.py b/setup.py index 39ffc7191..dc592dcc4 100644 --- a/setup.py +++ b/setup.py @@ -94,77 +94,3 @@ "orjson": ["orjson>=3"], }, ) - -vectorstore_package_name = "elasticsearch[vectorstore]" -base_dir = abspath(dirname(__file__)) - -with open(join(base_dir, package_name, "_version.py")) as f: - package_version = re.search( - r"__versionstr__\s+=\s+[\"\']([^\"\']+)[\"\']", f.read() - ).group(1) - -with open(join(base_dir, "README.rst")) as f: - # Remove reST raw directive from README as they're not allowed on PyPI - # Those blocks start with a newline and continue until the next newline - mode = None - lines = [] - for line in f: - if line.startswith(".. raw::"): - mode = "ignore_nl" - elif line == "\n": - mode = "wait_nl" if mode == "ignore_nl" else None - if mode is None: - lines.append(line) - - long_description = "".join(lines) - - -packages = [ - package - for package in find_packages(where=".", exclude=("test_elasticsearch*",)) - if package == package_name or package.startswith(package_name + ".") -] - -setup( - name=package_name, - description="Python client for Elasticsearch", - license="Apache-2.0", - url="https://github.com/elastic/elasticsearch-py", - long_description=long_description, - long_description_content_type="text/x-rst", - version=package_version, - author="Elastic Client Library Maintainers", - author_email="client-libs@elastic.co", - project_urls={ - "Documentation": "https://elasticsearch-py.readthedocs.io", - "Source Code": "https://github.com/elastic/elasticsearch-py", - "Issue Tracker": "https://github.com/elastic/elasticsearch-py/issues", - }, - packages=packages, - package_data={"elasticsearch": ["py.typed", "*.pyi"]}, - include_package_data=True, - zip_safe=False, - classifiers=[ - "Development Status :: 5 - Production/Stable", - "License :: OSI Approved :: Apache Software License", - "Intended Audience :: Developers", - "Operating System :: OS Independent", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", - ], - python_requires=">=3.7", - install_requires=["elastic-transport>=8.13,<9"], - extras_require={ - "requests": ["requests>=2.4.0, <3.0.0"], - "async": ["aiohttp>=3,<4"], - "orjson": ["orjson>=3"], - }, -) diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py b/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py index 87b8b7a96..9f6522811 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py @@ -15,12 +15,10 @@ # specific language governing permissions and limitations # under the License. -import os -from typing import Any, AsyncIterator, Dict, List, Optional +from typing import Any, Dict, List from elastic_transport import AsyncTransport -from elasticsearch import AsyncElasticsearch from elasticsearch.vectorstore._async.embedding_service import AsyncEmbeddingService @@ -87,69 +85,3 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: async def perform_request(self, *args, **kwargs): # type: ignore self.requests.append(kwargs) return await super().perform_request(*args, **kwargs) - - -def create_es_client( - es_params: Optional[Dict[str, str]] = None, es_kwargs: Dict = {} -) -> AsyncElasticsearch: - if es_params is None: - es_params = read_env() - if not es_kwargs: - es_kwargs = {} - - if "es_cloud_id" in es_params: - return AsyncElasticsearch( - cloud_id=es_params["es_cloud_id"], - api_key=es_params["es_api_key"], - **es_kwargs, - ) - return AsyncElasticsearch(hosts=[es_params["es_url"]], **es_kwargs) - - -def create_requests_saving_client() -> AsyncElasticsearch: - return create_es_client(es_kwargs={"transport_class": AsyncRequestSavingTransport}) - - -async def es_client_fixture() -> AsyncIterator[AsyncElasticsearch]: - params = read_env() - client = create_es_client(params) - - yield client - - # clear indices - await clear_test_indices(client) - - # clear all test pipelines - try: - response = await client.ingest.get_pipeline(id="test_*,*_sparse_embedding") - - for pipeline_id, _ in response.items(): - try: - await client.ingest.delete_pipeline(id=pipeline_id) - print(f"Deleted pipeline: {pipeline_id}") # noqa: T201 - except Exception as e: - print(f"Pipeline error: {e}") # noqa: T201 - - except Exception: - pass - finally: - await client.close() - - -async def clear_test_indices(client: AsyncElasticsearch) -> None: - response = await client.indices.get(index="_all") - index_names = response.keys() - for index_name in index_names: - if index_name.startswith("test_"): - await client.indices.delete(index=index_name) - await client.indices.refresh(index="_all") - - -def read_env() -> Dict: - url = os.environ.get("ES_URL", "http://localhost:9200") - cloud_id = os.environ.get("ES_CLOUD_ID") - api_key = os.environ.get("ES_API_KEY") - - if cloud_id: - return {"es_cloud_id": cloud_id, "es_api_key": api_key} - return {"es_url": url} diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/conftest.py b/test_elasticsearch/test_server/test_vectorstore/_async/conftest.py new file mode 100644 index 000000000..32b36a329 --- /dev/null +++ b/test_elasticsearch/test_server/test_vectorstore/_async/conftest.py @@ -0,0 +1,103 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import uuid +from typing import AsyncIterator, Dict + +import pytest +import pytest_asyncio + +from elasticsearch import AsyncElasticsearch + +from ._test_utils import AsyncRequestSavingTransport + + +@pytest_asyncio.fixture +async def es_client(elasticsearch_url: str) -> AsyncIterator[AsyncElasticsearch]: + client = _create_es_client(elasticsearch_url) + + yield client + + # clear indices + await _clear_test_indices(client) + + # clear all test pipelines + try: + response = await client.ingest.get_pipeline(id="test_*,*_sparse_embedding") + + for pipeline_id, _ in response.items(): + try: + await client.ingest.delete_pipeline(id=pipeline_id) + print(f"Deleted pipeline: {pipeline_id}") # noqa: T201 + except Exception as e: + print(f"Pipeline error: {e}") # noqa: T201 + + except Exception: + pass + finally: + await client.close() + + +@pytest_asyncio.fixture +async def requests_saving_client( + elasticsearch_url: str, +) -> AsyncIterator[AsyncElasticsearch]: + client = _create_es_client( + elasticsearch_url, es_kwargs={"transport_class": AsyncRequestSavingTransport} + ) + + try: + yield client + finally: + await client.close() + + +@pytest.fixture(scope="function") +def index_name() -> str: + return f"test_{uuid.uuid4().hex}" + + +async def _clear_test_indices(client: AsyncElasticsearch) -> None: + response = await client.indices.get(index="_all") + index_names = response.keys() + for index_name in index_names: + if index_name.startswith("test_"): + await client.indices.delete(index=index_name) + await client.indices.refresh(index="_all") + + +def _create_es_client( + elasticsearch_url: str, es_kwargs: Dict = {} +) -> AsyncElasticsearch: + if not elasticsearch_url: + elasticsearch_url = os.environ.get("ES_URL", "http://localhost:9200") + cloud_id = os.environ.get("ES_CLOUD_ID") + api_key = os.environ.get("ES_API_KEY") + + if cloud_id: + es_params = {"es_cloud_id": cloud_id, "es_api_key": api_key} + else: + es_params = {"es_url": elasticsearch_url} + + if "es_cloud_id" in es_params: + return AsyncElasticsearch( + cloud_id=es_params["es_cloud_id"], + api_key=es_params["es_api_key"], + **es_kwargs, + ) + return AsyncElasticsearch(hosts=[es_params["es_url"]], **es_kwargs) diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py index 924339683..d18b5ddb2 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py @@ -16,10 +16,8 @@ # under the License. import os -from typing import AsyncIterator import pytest -import pytest_asyncio from elasticsearch import AsyncElasticsearch from elasticsearch.vectorstore._async._utils import model_is_deployed @@ -27,20 +25,12 @@ AsyncElasticsearchEmbeddings, ) -from ._test_utils import es_client_fixture - # deployed with # https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html MODEL_ID = os.getenv("MODEL_ID", "sentence-transformers__msmarco-minilm-l-12-v3") NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384")) -@pytest_asyncio.fixture -async def es_client() -> AsyncIterator[AsyncElasticsearch]: - async for x in es_client_fixture(): - yield x - - @pytest.mark.asyncio async def test_elasticsearch_embedding_documents(es_client: AsyncElasticsearch) -> None: """Test Elasticsearch embedding documents.""" diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py index ec5e7cb4e..089ae8e63 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py @@ -16,12 +16,10 @@ # under the License. import logging -import uuid from functools import partial -from typing import Any, AsyncIterator, List, Optional, Union, cast +from typing import Any, List, Optional, Union, cast import pytest -import pytest_asyncio from elasticsearch import AsyncElasticsearch, NotFoundError from elasticsearch.helpers import BulkIndexError @@ -39,8 +37,6 @@ AsyncConsistentFakeEmbeddings, AsyncFakeEmbeddings, AsyncRequestSavingTransport, - create_requests_saving_client, - es_client_fixture, ) logging.basicConfig(level=logging.DEBUG) @@ -67,24 +63,6 @@ class TestVectorStore: - @pytest_asyncio.fixture - async def es_client(self) -> AsyncIterator[AsyncElasticsearch]: - async for x in es_client_fixture(): - yield x - - @pytest_asyncio.fixture - async def requests_saving_client(self) -> AsyncIterator[AsyncElasticsearch]: - client = create_requests_saving_client() - try: - yield client - finally: - await client.close() - - @pytest.fixture(scope="function") - def index_name(self) -> str: - """Return the index name.""" - return f"test_{uuid.uuid4().hex}" - @pytest.mark.asyncio async def test_search_without_metadata( self, es_client: AsyncElasticsearch, index_name: str diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py b/test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py index e43dbeffd..68fa9dc91 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py @@ -15,12 +15,10 @@ # specific language governing permissions and limitations # under the License. -import os -from typing import Any, Dict, Iterator, List, Optional +from typing import Any, Dict, List from elastic_transport import Transport -from elasticsearch import Elasticsearch from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService @@ -87,69 +85,3 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def perform_request(self, *args, **kwargs): # type: ignore self.requests.append(kwargs) return super().perform_request(*args, **kwargs) - - -def create_es_client( - es_params: Optional[Dict[str, str]] = None, es_kwargs: Dict = {} -) -> Elasticsearch: - if es_params is None: - es_params = read_env() - if not es_kwargs: - es_kwargs = {} - - if "es_cloud_id" in es_params: - return Elasticsearch( - cloud_id=es_params["es_cloud_id"], - api_key=es_params["es_api_key"], - **es_kwargs, - ) - return Elasticsearch(hosts=[es_params["es_url"]], **es_kwargs) - - -def create_requests_saving_client() -> Elasticsearch: - return create_es_client(es_kwargs={"transport_class": RequestSavingTransport}) - - -def es_client_fixture() -> Iterator[Elasticsearch]: - params = read_env() - client = create_es_client(params) - - yield client - - # clear indices - clear_test_indices(client) - - # clear all test pipelines - try: - response = client.ingest.get_pipeline(id="test_*,*_sparse_embedding") - - for pipeline_id, _ in response.items(): - try: - client.ingest.delete_pipeline(id=pipeline_id) - print(f"Deleted pipeline: {pipeline_id}") # noqa: T201 - except Exception as e: - print(f"Pipeline error: {e}") # noqa: T201 - - except Exception: - pass - finally: - client.close() - - -def clear_test_indices(client: Elasticsearch) -> None: - response = client.indices.get(index="_all") - index_names = response.keys() - for index_name in index_names: - if index_name.startswith("test_"): - client.indices.delete(index=index_name) - client.indices.refresh(index="_all") - - -def read_env() -> Dict: - url = os.environ.get("ES_URL", "http://localhost:9200") - cloud_id = os.environ.get("ES_CLOUD_ID") - api_key = os.environ.get("ES_API_KEY") - - if cloud_id: - return {"es_cloud_id": cloud_id, "es_api_key": api_key} - return {"es_url": url} diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/conftest.py b/test_elasticsearch/test_server/test_vectorstore/_sync/conftest.py new file mode 100644 index 000000000..be11547f5 --- /dev/null +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/conftest.py @@ -0,0 +1,100 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import uuid +from typing import Dict, Iterator + +import pytest + +from elasticsearch import Elasticsearch + +from ._test_utils import RequestSavingTransport + + +@pytest.fixture +def es_client(elasticsearch_url: str) -> Iterator[Elasticsearch]: + client = _create_es_client(elasticsearch_url) + + yield client + + # clear indices + _clear_test_indices(client) + + # clear all test pipelines + try: + response = client.ingest.get_pipeline(id="test_*,*_sparse_embedding") + + for pipeline_id, _ in response.items(): + try: + client.ingest.delete_pipeline(id=pipeline_id) + print(f"Deleted pipeline: {pipeline_id}") # noqa: T201 + except Exception as e: + print(f"Pipeline error: {e}") # noqa: T201 + + except Exception: + pass + finally: + client.close() + + +@pytest.fixture +def requests_saving_client( + elasticsearch_url: str, +) -> Iterator[Elasticsearch]: + client = _create_es_client( + elasticsearch_url, es_kwargs={"transport_class": RequestSavingTransport} + ) + + try: + yield client + finally: + client.close() + + +@pytest.fixture(scope="function") +def index_name() -> str: + return f"test_{uuid.uuid4().hex}" + + +def _clear_test_indices(client: Elasticsearch) -> None: + response = client.indices.get(index="_all") + index_names = response.keys() + for index_name in index_names: + if index_name.startswith("test_"): + client.indices.delete(index=index_name) + client.indices.refresh(index="_all") + + +def _create_es_client(elasticsearch_url: str, es_kwargs: Dict = {}) -> Elasticsearch: + if not elasticsearch_url: + elasticsearch_url = os.environ.get("ES_URL", "http://localhost:9200") + cloud_id = os.environ.get("ES_CLOUD_ID") + api_key = os.environ.get("ES_API_KEY") + + if cloud_id: + es_params = {"es_cloud_id": cloud_id, "es_api_key": api_key} + else: + es_params = {"es_url": elasticsearch_url} + + if "es_cloud_id" in es_params: + return Elasticsearch( + cloud_id=es_params["es_cloud_id"], + api_key=es_params["es_api_key"], + **es_kwargs, + ) + return Elasticsearch(hosts=[es_params["es_url"]], **es_kwargs) diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py b/test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py index 979096d39..f9677fd04 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py @@ -16,7 +16,6 @@ # under the License. import os -from typing import Iterator import pytest @@ -24,20 +23,12 @@ from elasticsearch.vectorstore._sync._utils import model_is_deployed from elasticsearch.vectorstore._sync.embedding_service import ElasticsearchEmbeddings -from ._test_utils import es_client_fixture - # deployed with # https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html MODEL_ID = os.getenv("MODEL_ID", "sentence-transformers__msmarco-minilm-l-12-v3") NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384")) -@pytest.fixture -def es_client() -> Iterator[Elasticsearch]: - for x in es_client_fixture(): - yield x - - def test_elasticsearch_embedding_documents(es_client: Elasticsearch) -> None: """Test Elasticsearch embedding documents.""" diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py b/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py deleted file mode 100644 index 0ce312ca7..000000000 --- a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py +++ /dev/null @@ -1,933 +0,0 @@ -# Licensed to Elasticsearch B.V. under one or more contributor -# license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright -# ownership. Elasticsearch B.V. licenses this file to you under -# the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import logging -import uuid -from functools import partial -from typing import Any, Iterator, List, Optional, Union, cast - -import pytest - -from elasticsearch import Elasticsearch, NotFoundError -from elasticsearch.helpers import BulkIndexError -from elasticsearch.vectorstore._sync import VectorStore -from elasticsearch.vectorstore._sync._utils import model_is_deployed -from elasticsearch.vectorstore._sync.strategies import ( - BM25, - DenseVector, - DenseVectorScriptScore, - DistanceMetric, - Semantic, -) - -from ._test_utils import ( - ConsistentFakeEmbeddings, - FakeEmbeddings, - RequestSavingTransport, - create_requests_saving_client, - es_client_fixture, -) - -logging.basicConfig(level=logging.DEBUG) - -""" -docker-compose up elasticsearch - -By default runs against local docker instance of Elasticsearch. -To run against Elastic Cloud, set the following environment variables: -- ES_CLOUD_ID -- ES_API_KEY - -Some of the tests require the following models to be deployed in the ML Node: -- elser (can be downloaded and deployed through Kibana and trained models UI) -- sentence-transformers__all-minilm-l6-v2 (can be deployed through the API, - loaded via eland) - -These tests that require the models to be deployed are skipped by default. -Enable them by adding the model name to the modelsDeployed list below. -""" - -ELSER_MODEL_ID = ".elser_model_2" -TRANSFORMER_MODEL_ID = "sentence-transformers__all-minilm-l6-v2" - - -class TestElasticsearch: - @pytest.fixture - def es_client(self) -> Iterator[Elasticsearch]: - for x in es_client_fixture(): - yield x - - @pytest.fixture - def requests_saving_client(self) -> Iterator[Elasticsearch]: - client = create_requests_saving_client() - try: - yield client - finally: - client.close() - - @pytest.fixture(scope="function") - def index_name(self) -> str: - """Return the index name.""" - return f"test_{uuid.uuid4().hex}" - - def test_search_without_metadata( - self, es_client: Elasticsearch, index_name: str - ) -> None: - """Test end to end construction and search without metadata.""" - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - assert query_body == { - "knn": { - "field": "vector_field", - "filter": [], - "k": 1, - "num_candidates": 50, - "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - } - } - return query_body - - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=DenseVector(), - embedding_service=FakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - store.add_texts(texts) - - output = store.search("foo", k=1, custom_query=assert_query) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - def test_search_without_metadata_async( - self, es_client: Elasticsearch, index_name: str - ) -> None: - """Test end to end construction and search without metadata.""" - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=DenseVector(), - embedding_service=FakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - store.add_texts(texts) - - output = store.search("foo", k=1) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - def test_add_vectors(self, es_client: Elasticsearch, index_name: str) -> None: - """ - Test adding pre-built embeddings instead of using inference for the texts. - This allows you to separate the embeddings text and the page_content - for better proximity between user's question and embedded text. - For example, your embedding text can be a question, whereas page_content - is the answer. - """ - embeddings = ConsistentFakeEmbeddings() - texts = ["foo1", "foo2", "foo3"] - metadatas = [{"page": i} for i in range(len(texts))] - - """In real use case, embedding_input can be questions for each text""" - embedding_vectors = embeddings.embed_documents(texts) - - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=DenseVector(), - embedding_service=embeddings, - es_client=es_client, - ) - - store.add_texts(texts=texts, vectors=embedding_vectors, metadatas=metadatas) - output = store.search("foo1", k=1) - assert [doc["_source"]["text_field"] for doc in output] == ["foo1"] - assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] - - def test_search_with_metadata( - self, es_client: Elasticsearch, index_name: str - ) -> None: - """Test end to end construction and search with metadata.""" - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=DenseVector(), - embedding_service=ConsistentFakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - store.add_texts(texts=texts, metadatas=metadatas) - - output = store.search("foo", k=1) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] - - output = store.search("bar", k=1) - assert [doc["_source"]["text_field"] for doc in output] == ["bar"] - assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - - def test_search_with_filter( - self, es_client: Elasticsearch, index_name: str - ) -> None: - """Test end to end construction and search with metadata.""" - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=DenseVector(), - embedding_service=FakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "foo", "foo"] - metadatas = [{"page": i} for i in range(len(texts))] - store.add_texts(texts=texts, metadatas=metadatas) - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - assert query_body == { - "knn": { - "field": "vector_field", - "filter": [{"term": {"metadata.page": "1"}}], - "k": 3, - "num_candidates": 50, - "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - } - } - return query_body - - output = store.search( - query="foo", - k=3, - filter=[{"term": {"metadata.page": "1"}}], - custom_query=assert_query, - ) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - - def test_search_script_score( - self, es_client: Elasticsearch, index_name: str - ) -> None: - """Test end to end construction and search with metadata.""" - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=DenseVectorScriptScore(), - embedding_service=FakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - store.add_texts(texts) - - expected_query = { - "query": { - "script_score": { - "query": {"match_all": {}}, - "script": { - "source": "cosineSimilarity(params.query_vector, 'vector_field') + 1.0", # noqa: E501 - "params": { - "query_vector": [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 0.0, - ] - }, - }, - } - } - } - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - assert query_body == expected_query - return query_body - - output = store.search("foo", k=1, custom_query=assert_query) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - def test_search_script_score_with_filter( - self, es_client: Elasticsearch, index_name: str - ) -> None: - """Test end to end construction and search with metadata.""" - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=DenseVectorScriptScore(), - embedding_service=FakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - store.add_texts(texts=texts, metadatas=metadatas) - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - expected_query = { - "query": { - "script_score": { - "query": {"bool": {"filter": [{"term": {"metadata.page": 0}}]}}, - "script": { - "source": "cosineSimilarity(params.query_vector, 'vector_field') + 1.0", # noqa: E501 - "params": { - "query_vector": [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 0.0, - ] - }, - }, - } - } - } - assert query_body == expected_query - return query_body - - output = store.search( - "foo", - k=1, - custom_query=assert_query, - filter=[{"term": {"metadata.page": 0}}], - ) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] - - def test_search_script_score_distance_dot_product( - self, es_client: Elasticsearch, index_name: str - ) -> None: - """Test end to end construction and search with metadata.""" - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=DenseVectorScriptScore( - distance=DistanceMetric.DOT_PRODUCT, - ), - embedding_service=FakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - store.add_texts(texts) - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - assert query_body == { - "query": { - "script_score": { - "query": {"match_all": {}}, - "script": { - "source": """ - double value = dotProduct(params.query_vector, 'vector_field'); - return sigmoid(1, Math.E, -value); - """, - "params": { - "query_vector": [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 0.0, - ] - }, - }, - } - } - } - return query_body - - output = store.search("foo", k=1, custom_query=assert_query) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - def test_search_knn_with_hybrid_search( - self, es_client: Elasticsearch, index_name: str - ) -> None: - """Test end to end construction and search with metadata.""" - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=DenseVector(hybrid=True), - embedding_service=FakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - store.add_texts(texts) - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - assert query_body == { - "knn": { - "field": "vector_field", - "filter": [], - "k": 1, - "num_candidates": 50, - "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - }, - "query": { - "bool": { - "filter": [], - "must": [{"match": {"text_field": {"query": "foo"}}}], - } - }, - "rank": {"rrf": {}}, - } - return query_body - - output = store.search("foo", k=1, custom_query=assert_query) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - def test_search_knn_with_hybrid_search_rrf( - self, es_client: Elasticsearch, index_name: str - ) -> None: - """Test end to end construction and rrf hybrid search with metadata.""" - texts = ["foo", "bar", "baz"] - - def assert_query( - query_body: dict, - query: Optional[str], - expected_rrf: Union[dict, bool], - ) -> dict: - cmp_query_body = { - "knn": { - "field": "vector_field", - "filter": [], - "k": 3, - "num_candidates": 50, - "query_vector": [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 0.0, - ], - }, - "query": { - "bool": { - "filter": [], - "must": [{"match": {"text_field": {"query": "foo"}}}], - } - }, - } - - if isinstance(expected_rrf, dict): - cmp_query_body["rank"] = {"rrf": expected_rrf} - elif isinstance(expected_rrf, bool) and expected_rrf is True: - cmp_query_body["rank"] = {"rrf": {}} - - assert query_body == cmp_query_body - - return query_body - - # 1. check query_body is okay - rrf_test_cases: List[Union[dict, bool]] = [ - True, - False, - {"rank_constant": 1, "window_size": 5}, - ] - for rrf_test_case in rrf_test_cases: - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=DenseVector(hybrid=True, rrf=rrf_test_case), - embedding_service=FakeEmbeddings(), - es_client=es_client, - ) - store.add_texts(texts) - - ## without fetch_k parameter - output = store.search( - "foo", - k=3, - custom_query=partial(assert_query, expected_rrf=rrf_test_case), - ) - - # 2. check query result is okay - es_output = store.es_client.search( - index=index_name, - query={ - "bool": { - "filter": [], - "must": [{"match": {"text_field": {"query": "foo"}}}], - } - }, - knn={ - "field": "vector_field", - "filter": [], - "k": 3, - "num_candidates": 50, - "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - }, - size=3, - rank={"rrf": {"rank_constant": 1, "window_size": 5}}, - ) - - assert [o["_source"]["text_field"] for o in output] == [ - e["_source"]["text_field"] for e in es_output["hits"]["hits"] - ] - - # 3. check rrf default option is okay - store = VectorStore( - user_agent="test", - index_name=f"{index_name}_default", - retrieval_strategy=DenseVector(hybrid=True), - embedding_service=FakeEmbeddings(), - es_client=es_client, - ) - store.add_texts(texts) - - ## with fetch_k parameter - output = store.search( - "foo", - k=3, - num_candidates=50, - custom_query=partial(assert_query, expected_rrf={}), - ) - - def test_search_knn_with_custom_query_fn( - self, es_client: Elasticsearch, index_name: str - ) -> None: - """test that custom query function is called - with the query string and query body""" - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=DenseVector(), - embedding_service=FakeEmbeddings(), - es_client=es_client, - ) - - def my_custom_query(query_body: dict, query: Optional[str]) -> dict: - assert query == "foo" - assert query_body == { - "knn": { - "field": "vector_field", - "filter": [], - "k": 1, - "num_candidates": 50, - "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - } - } - return {"query": {"match": {"text_field": {"query": "bar"}}}} - - """Test end to end construction and search with metadata.""" - texts = ["foo", "bar", "baz"] - store.add_texts(texts) - - output = store.search("foo", k=1, custom_query=my_custom_query) - assert [doc["_source"]["text_field"] for doc in output] == ["bar"] - - def test_search_with_knn_infer_instack( - self, es_client: Elasticsearch, index_name: str - ) -> None: - """test end to end with knn retrieval strategy and inference in-stack""" - - if not model_is_deployed(es_client, TRANSFORMER_MODEL_ID): - pytest.skip( - f"{TRANSFORMER_MODEL_ID} model not deployed in ML Node skipping test" - ) - - text_field = "text_field" - - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=Semantic( - model_id="sentence-transformers__all-minilm-l6-v2", - text_field=text_field, - ), - es_client=es_client, - ) - - # setting up the pipeline for inference - store.es_client.ingest.put_pipeline( - id="test_pipeline", - processors=[ - { - "inference": { - "model_id": TRANSFORMER_MODEL_ID, - "field_map": {"query_field": text_field}, - "target_field": "vector_query_field", - } - } - ], - ) - - # creating a new index with the pipeline, - # not relying on langchain to create the index - store.es_client.indices.create( - index=index_name, - mappings={ - "properties": { - text_field: {"type": "text_field"}, - "vector_query_field": { - "properties": { - "predicted_value": { - "type": "dense_vector", - "dims": 384, - "index": True, - "similarity": "l2_norm", - } - } - }, - } - }, - settings={"index": {"default_pipeline": "test_pipeline"}}, - ) - - # adding documents to the index - texts = ["foo", "bar", "baz"] - - for i, text in enumerate(texts): - store.es_client.create( - index=index_name, - id=str(i), - document={text_field: text, "metadata": {}}, - ) - - store.es_client.indices.refresh(index=index_name) - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - assert query_body == { - "knn": { - "filter": [], - "field": "vector_query_field.predicted_value", - "k": 1, - "num_candidates": 50, - "query_vector_builder": { - "text_embedding": { - "model_id": TRANSFORMER_MODEL_ID, - "model_text": "foo", - } - }, - } - } - return query_body - - output = store.search("foo", k=1, custom_query=assert_query) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - output = store.search("bar", k=1) - assert [doc["_source"]["text_field"] for doc in output] == ["bar"] - - def test_search_with_sparse_infer_instack( - self, es_client: Elasticsearch, index_name: str - ) -> None: - """test end to end with sparse retrieval strategy and inference in-stack""" - - if not model_is_deployed(es_client, ELSER_MODEL_ID): - reason = f"{ELSER_MODEL_ID} model not deployed in ML Node, skipping test" - - pytest.skip(reason) - - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=Semantic(model_id=ELSER_MODEL_ID), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - store.add_texts(texts) - - output = store.search("foo", k=1) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - def test_deployed_model_check_fails_semantic( - self, es_client: Elasticsearch, index_name: str - ) -> None: - """test that exceptions are raised if a specified model is not deployed""" - with pytest.raises(NotFoundError): - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=Semantic(model_id="non-existing model ID"), - es_client=es_client, - ) - store.add_texts(["foo", "bar", "baz"]) - - def test_search_bm25(self, es_client: Elasticsearch, index_name: str) -> None: - """Test end to end using the BM25 retrieval strategy.""" - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=BM25(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - store.add_texts(texts) - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - assert query_body == { - "query": { - "bool": { - "must": [{"match": {"text_field": {"query": "foo"}}}], - "filter": [], - } - } - } - return query_body - - output = store.search("foo", k=1, custom_query=assert_query) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - def test_search_bm25_with_filter( - self, es_client: Elasticsearch, index_name: str - ) -> None: - """Test end to using the BM25 retrieval strategy with metadata.""" - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=BM25(), - es_client=es_client, - ) - - texts = ["foo", "foo", "foo"] - metadatas = [{"page": i} for i in range(len(texts))] - store.add_texts(texts=texts, metadatas=metadatas) - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - assert query_body == { - "query": { - "bool": { - "must": [{"match": {"text_field": {"query": "foo"}}}], - "filter": [{"term": {"metadata.page": 1}}], - } - } - } - return query_body - - output = store.search( - "foo", - k=3, - custom_query=assert_query, - filter=[{"term": {"metadata.page": 1}}], - ) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - - def test_delete(self, es_client: Elasticsearch, index_name: str) -> None: - """Test delete methods from vector store.""" - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=DenseVector(), - embedding_service=FakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz", "gni"] - metadatas = [{"page": i} for i in range(len(texts))] - ids = store.add_texts(texts=texts, metadatas=metadatas) - - output = store.search("foo", k=10) - assert len(output) == 4 - - store.delete(ids[1:3]) - output = store.search("foo", k=10) - assert len(output) == 2 - - store.delete(["not-existing"]) - output = store.search("foo", k=10) - assert len(output) == 2 - - store.delete([ids[0]]) - output = store.search("foo", k=10) - assert len(output) == 1 - - store.delete([ids[3]]) - output = store.search("gni", k=10) - assert len(output) == 0 - - def test_indexing_exception_error( - self, - es_client: Elasticsearch, - index_name: str, - caplog: pytest.LogCaptureFixture, - ) -> None: - """Test bulk exception logging is giving better hints.""" - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=BM25(), - es_client=es_client, - ) - - store.es_client.indices.create( - index=index_name, - mappings={"properties": {}}, - settings={"index": {"default_pipeline": "not-existing-pipeline"}}, - ) - - texts = ["foo"] - - with pytest.raises(BulkIndexError): - store.add_texts(texts) - - error_reason = "pipeline with id [not-existing-pipeline] does not exist" - log_message = f"First error reason: {error_reason}" - - assert log_message in caplog.text - - def test_user_agent( - self, requests_saving_client: Elasticsearch, index_name: str - ) -> None: - """Test to make sure the user-agent is set correctly.""" - user_agent = "this is THE user_agent!" - store = VectorStore( - user_agent=user_agent, - index_name=index_name, - retrieval_strategy=BM25(), - es_client=requests_saving_client, - ) - - assert store.es_client._headers["User-Agent"] == user_agent - - texts = ["foo", "bob", "baz"] - store.add_texts(texts) - - transport = cast(RequestSavingTransport, store.es_client.transport) - - for request in transport.requests: - assert request["headers"]["User-Agent"] == user_agent - - def test_bulk_args(self, requests_saving_client: Any, index_name: str) -> None: - """Test to make sure the bulk arguments work as expected.""" - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=BM25(), - es_client=requests_saving_client, - ) - - texts = ["foo", "bob", "baz"] - store.add_texts(texts, bulk_kwargs={"chunk_size": 1}) - - # 1 for index exist, 1 for index create, 3 to index docs - assert len(store.es_client.transport.requests) == 5 # type: ignore - - def test_max_marginal_relevance_search( - self, es_client: Elasticsearch, index_name: str - ) -> None: - """Test max marginal relevance search.""" - texts = ["foo", "bar", "baz"] - vector_field = "vector_field" - text_field = "text_field" - embedding_service = ConsistentFakeEmbeddings() - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=DenseVectorScriptScore(), - embedding_service=embedding_service, - vector_field=vector_field, - text_field=text_field, - es_client=es_client, - ) - store.add_texts(texts) - - mmr_output = store.max_marginal_relevance_search( - embedding_service, - texts[0], - vector_field=vector_field, - k=3, - num_candidates=3, - ) - sim_output = store.search(texts[0], k=3) - assert mmr_output == sim_output - - mmr_output = store.max_marginal_relevance_search( - embedding_service, - texts[0], - vector_field=vector_field, - k=2, - num_candidates=3, - ) - assert len(mmr_output) == 2 - assert mmr_output[0]["_source"][text_field] == texts[0] - assert mmr_output[1]["_source"][text_field] == texts[1] - - mmr_output = store.max_marginal_relevance_search( - embedding_service, - texts[0], - vector_field=vector_field, - k=2, - num_candidates=3, - lambda_mult=0.1, # more diversity - ) - assert len(mmr_output) == 2 - assert mmr_output[0]["_source"][text_field] == texts[0] - assert mmr_output[1]["_source"][text_field] == texts[2] - - # if fetch_k < k, then the output will be less than k - mmr_output = store.max_marginal_relevance_search( - embedding_service, - texts[0], - vector_field=vector_field, - k=3, - num_candidates=2, - ) - assert len(mmr_output) == 2 - - def test_metadata_mapping(self, es_client: Elasticsearch, index_name: str) -> None: - """Test that the metadata mapping is applied.""" - test_mappings = { - "my_field": {"type": "keyword"}, - "another_field": {"type": "text"}, - } - store = VectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=DenseVector(), - embedding_service=FakeEmbeddings(), - es_client=es_client, - metadata_mappings=test_mappings, - ) - - texts = ["foo", "foo", "foo"] - metadatas = [{"page": i} for i in range(len(texts))] - store.add_texts(texts=texts, metadatas=metadatas) - - mapping_response = es_client.indices.get_mapping(index=index_name) - mapping_properties = mapping_response[index_name]["mappings"]["properties"] - print(mapping_response) - assert "metadata" in mapping_properties - for key, val in test_mappings.items(): - assert mapping_properties["metadata"]["properties"][key] == val diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py index 60d47bb4f..7251cd399 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py @@ -16,9 +16,8 @@ # under the License. import logging -import uuid from functools import partial -from typing import Any, Iterator, List, Optional, Union, cast +from typing import Any, List, Optional, Union, cast import pytest @@ -38,8 +37,6 @@ ConsistentFakeEmbeddings, FakeEmbeddings, RequestSavingTransport, - create_requests_saving_client, - es_client_fixture, ) logging.basicConfig(level=logging.DEBUG) @@ -66,24 +63,6 @@ class TestVectorStore: - @pytest.fixture - def es_client(self) -> Iterator[Elasticsearch]: - for x in es_client_fixture(): - yield x - - @pytest.fixture - def requests_saving_client(self) -> Iterator[Elasticsearch]: - client = create_requests_saving_client() - try: - yield client - finally: - client.close() - - @pytest.fixture(scope="function") - def index_name(self) -> str: - """Return the index name.""" - return f"test_{uuid.uuid4().hex}" - def test_search_without_metadata( self, es_client: Elasticsearch, index_name: str ) -> None: From 2f1fcb073022d2c59c6cfc09fcbec2652faa1fe9 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Tue, 23 Apr 2024 15:33:42 +0200 Subject: [PATCH 13/36] export relevant classes --- elasticsearch/vectorstore/_async/__init__.py | 20 ++++++++++++++++++++ elasticsearch/vectorstore/_sync/__init__.py | 20 ++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/elasticsearch/vectorstore/_async/__init__.py b/elasticsearch/vectorstore/_async/__init__.py index a16a14343..94071c177 100644 --- a/elasticsearch/vectorstore/_async/__init__.py +++ b/elasticsearch/vectorstore/_async/__init__.py @@ -15,8 +15,28 @@ # specific language governing permissions and limitations # under the License. +from elasticsearch.vectorstore._async.embedding_service import ( + AsyncElasticsearchEmbeddings, + AsyncEmbeddingService, +) +from elasticsearch.vectorstore._async.strategies import ( + BM25, + DenseVector, + DenseVectorScriptScore, + DistanceMetric, + RetrievalStrategy, + SparseVector, +) from elasticsearch.vectorstore._async.vectorstore import AsyncVectorStore __all__ = [ + "AsyncEmbeddingService", + "AsyncElasticsearchEmbeddings", "AsyncVectorStore", + "BM25", + "DenseVector", + "DenseVectorScriptScore", + "DistanceMetric", + "RetrievalStrategy", + "SparseVector", ] diff --git a/elasticsearch/vectorstore/_sync/__init__.py b/elasticsearch/vectorstore/_sync/__init__.py index fadcc6b9f..fa7981c82 100644 --- a/elasticsearch/vectorstore/_sync/__init__.py +++ b/elasticsearch/vectorstore/_sync/__init__.py @@ -15,8 +15,28 @@ # specific language governing permissions and limitations # under the License. +from elasticsearch.vectorstore._sync.embedding_service import ( + ElasticsearchEmbeddings, + EmbeddingService, +) +from elasticsearch.vectorstore._sync.strategies import ( + BM25, + DenseVector, + DenseVectorScriptScore, + DistanceMetric, + RetrievalStrategy, + SparseVector, +) from elasticsearch.vectorstore._sync.vectorstore import VectorStore __all__ = [ + "EmbeddingService", + "ElasticsearchEmbeddings", "VectorStore", + "BM25", + "DenseVector", + "DenseVectorScriptScore", + "DistanceMetric", + "RetrievalStrategy", + "SparseVector", ] From b19de27395da7698744889345aebded2910047e9 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Tue, 23 Apr 2024 15:35:04 +0200 Subject: [PATCH 14/36] remove Semantic strategy wait for `semantic_text` to land --- .../vectorstore/_async/strategies.py | 60 ------------------- elasticsearch/vectorstore/_sync/strategies.py | 60 ------------------- .../_async/test_vectorstore.py | 11 ++-- .../_sync/test_vectorstore.py | 11 ++-- 4 files changed, 10 insertions(+), 132 deletions(-) diff --git a/elasticsearch/vectorstore/_async/strategies.py b/elasticsearch/vectorstore/_async/strategies.py index 238855247..54cbe106b 100644 --- a/elasticsearch/vectorstore/_async/strategies.py +++ b/elasticsearch/vectorstore/_async/strategies.py @@ -98,66 +98,6 @@ def needs_inference(self) -> bool: return False -class Semantic(RetrievalStrategy): - """Dense or sparse retrieval with in-stack inference using semantic_text fields.""" - - def __init__( - self, - model_id: str, - text_field: str = "text_field", - ): - self.model_id = model_id - self.text_field = text_field - - async def es_query( - self, - query: Optional[str], - query_vector: Optional[List[float]], - text_field: str, - vector_field: str, - k: int, - num_candidates: int, - filter: List[Dict[str, Any]] = [], - ) -> Dict[str, Any]: - if query_vector: - raise ValueError( - "Cannot do sparse retrieval with a query_vector. " - "Inference is currently always applied in-stack." - ) - - return { - "query": { - "semantic": { - text_field: query, - }, - }, - "filter": filter, - } - - def es_mappings_settings( - self, - text_field: str, - vector_field: str, - num_dimensions: Optional[int] = None, - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - mappings = { - "properties": { - vector_field: { - "type": "semantic_text", - "model_id": self.model_id, - } - } - } - - return mappings, {} - - async def before_index_creation( - self, client: AsyncElasticsearch, text_field: str, vector_field: str - ) -> None: - if self.model_id: - await model_must_be_deployed(client, self.model_id) - - class SparseVector(RetrievalStrategy): """Sparse retrieval strategy using the `text_expansion` processor.""" diff --git a/elasticsearch/vectorstore/_sync/strategies.py b/elasticsearch/vectorstore/_sync/strategies.py index a10222f12..20558cf0a 100644 --- a/elasticsearch/vectorstore/_sync/strategies.py +++ b/elasticsearch/vectorstore/_sync/strategies.py @@ -98,66 +98,6 @@ def needs_inference(self) -> bool: return False -class Semantic(RetrievalStrategy): - """Dense or sparse retrieval with in-stack inference using semantic_text fields.""" - - def __init__( - self, - model_id: str, - text_field: str = "text_field", - ): - self.model_id = model_id - self.text_field = text_field - - def es_query( - self, - query: Optional[str], - query_vector: Optional[List[float]], - text_field: str, - vector_field: str, - k: int, - num_candidates: int, - filter: List[Dict[str, Any]] = [], - ) -> Dict[str, Any]: - if query_vector: - raise ValueError( - "Cannot do sparse retrieval with a query_vector. " - "Inference is currently always applied in-stack." - ) - - return { - "query": { - "semantic": { - text_field: query, - }, - }, - "filter": filter, - } - - def es_mappings_settings( - self, - text_field: str, - vector_field: str, - num_dimensions: Optional[int] = None, - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - mappings = { - "properties": { - vector_field: { - "type": "semantic_text", - "model_id": self.model_id, - } - } - } - - return mappings, {} - - def before_index_creation( - self, client: Elasticsearch, text_field: str, vector_field: str - ) -> None: - if self.model_id: - model_must_be_deployed(client, self.model_id) - - class SparseVector(RetrievalStrategy): """Sparse retrieval strategy using the `text_expansion` processor.""" diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py index 089ae8e63..68f4c2e9e 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py @@ -30,7 +30,7 @@ DenseVector, DenseVectorScriptScore, DistanceMetric, - Semantic, + SparseVector, ) from ._test_utils import ( @@ -564,9 +564,8 @@ async def test_search_with_knn_infer_instack( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=Semantic( - model_id="sentence-transformers__all-minilm-l6-v2", - text_field=text_field, + retrieval_strategy=DenseVector( + model_id="sentence-transformers__all-minilm-l6-v2" ), es_client=es_client, ) @@ -656,7 +655,7 @@ async def test_search_with_sparse_infer_instack( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=Semantic(model_id=ELSER_MODEL_ID), + retrieval_strategy=SparseVector(model_id=ELSER_MODEL_ID), es_client=es_client, ) @@ -675,7 +674,7 @@ async def test_deployed_model_check_fails_semantic( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=Semantic(model_id="non-existing model ID"), + retrieval_strategy=DenseVector(model_id="non-existing model ID"), es_client=es_client, ) await store.add_texts(["foo", "bar", "baz"]) diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py index 7251cd399..12282ac41 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py @@ -30,7 +30,7 @@ DenseVector, DenseVectorScriptScore, DistanceMetric, - Semantic, + SparseVector, ) from ._test_utils import ( @@ -548,9 +548,8 @@ def test_search_with_knn_infer_instack( store = VectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=Semantic( - model_id="sentence-transformers__all-minilm-l6-v2", - text_field=text_field, + retrieval_strategy=DenseVector( + model_id="sentence-transformers__all-minilm-l6-v2" ), es_client=es_client, ) @@ -639,7 +638,7 @@ def test_search_with_sparse_infer_instack( store = VectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=Semantic(model_id=ELSER_MODEL_ID), + retrieval_strategy=SparseVector(model_id=ELSER_MODEL_ID), es_client=es_client, ) @@ -657,7 +656,7 @@ def test_deployed_model_check_fails_semantic( store = VectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=Semantic(model_id="non-existing model ID"), + retrieval_strategy=DenseVector(model_id="non-existing model ID"), es_client=es_client, ) store.add_texts(["foo", "bar", "baz"]) From 274911adaf79bca65feb1814ebe9ab336a09b69e Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Tue, 23 Apr 2024 15:53:03 +0200 Subject: [PATCH 15/36] es_query is sync --- elasticsearch/vectorstore/_async/strategies.py | 10 +++++----- elasticsearch/vectorstore/_async/vectorstore.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/elasticsearch/vectorstore/_async/strategies.py b/elasticsearch/vectorstore/_async/strategies.py index 54cbe106b..2d54720b8 100644 --- a/elasticsearch/vectorstore/_async/strategies.py +++ b/elasticsearch/vectorstore/_async/strategies.py @@ -34,7 +34,7 @@ class DistanceMetric(str, Enum): class RetrievalStrategy(ABC): @abstractmethod - async def es_query( + def es_query( self, query: Optional[str], query_vector: Optional[List[float]], @@ -106,7 +106,7 @@ def __init__(self, model_id: str = ".elser_model_2"): self._tokens_field = "tokens" self._pipeline_name = f"{self.model_id}_sparse_embedding" - async def es_query( + def es_query( self, query: Optional[str], query_vector: Optional[List[float]], @@ -207,7 +207,7 @@ def __init__( self.rrf = rrf self.text_field = text_field - async def es_query( + def es_query( self, query: Optional[str], query_vector: Optional[List[float]], @@ -319,7 +319,7 @@ class DenseVectorScriptScore(RetrievalStrategy): def __init__(self, distance: DistanceMetric = DistanceMetric.COSINE) -> None: self.distance = distance - async def es_query( + def es_query( self, query: Optional[str], query_vector: Optional[List[float]], @@ -401,7 +401,7 @@ def __init__( self.k1 = k1 self.b = b - async def es_query( + def es_query( self, query: Optional[str], query_vector: Optional[List[float]], diff --git a/elasticsearch/vectorstore/_async/vectorstore.py b/elasticsearch/vectorstore/_async/vectorstore.py index f9e88b43a..0d5dc65b7 100644 --- a/elasticsearch/vectorstore/_async/vectorstore.py +++ b/elasticsearch/vectorstore/_async/vectorstore.py @@ -250,7 +250,7 @@ async def search( raise ValueError("specify a query or a query_vector to search") query_vector = await self.embedding_service.embed_query(query) - query_body = await self.retrieval_strategy.es_query( + query_body = self.retrieval_strategy.es_query( query=query, query_vector=query_vector, text_field=self.text_field, From 8cec9cc9645cd37508ba3a4bddbd3766fe797d52 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Tue, 23 Apr 2024 16:12:51 +0200 Subject: [PATCH 16/36] async strategies --- elasticsearch/vectorstore/__init__.py | 46 ++++++++++++++ elasticsearch/vectorstore/_async/__init__.py | 26 -------- .../vectorstore/_async/strategies.py | 21 ++----- .../vectorstore/_async/vectorstore.py | 4 +- elasticsearch/vectorstore/_sync/__init__.py | 26 -------- elasticsearch/vectorstore/_sync/strategies.py | 11 +--- elasticsearch/vectorstore/_utils.py | 10 ++++ .../_async/test_vectorstore.py | 60 +++++++++---------- .../_sync/test_vectorstore.py | 6 +- utils/run-unasync.py | 18 ++++-- 10 files changed, 112 insertions(+), 116 deletions(-) diff --git a/elasticsearch/vectorstore/__init__.py b/elasticsearch/vectorstore/__init__.py index 2a87d183f..a53bde47f 100644 --- a/elasticsearch/vectorstore/__init__.py +++ b/elasticsearch/vectorstore/__init__.py @@ -14,3 +14,49 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from elasticsearch.vectorstore._async.embedding_service import ( + AsyncElasticsearchEmbeddings, + AsyncEmbeddingService, +) +from elasticsearch.vectorstore._async.strategies import ( + AsyncBM25, + AsyncDenseVector, + AsyncDenseVectorScriptScore, + AsyncRetrievalStrategy, + AsyncSparseVector, +) +from elasticsearch.vectorstore._async.vectorstore import AsyncVectorStore +from elasticsearch.vectorstore._sync.embedding_service import ( + ElasticsearchEmbeddings, + EmbeddingService, +) +from elasticsearch.vectorstore._sync.strategies import ( + BM25, + DenseVector, + DenseVectorScriptScore, + RetrievalStrategy, + SparseVector, +) +from elasticsearch.vectorstore._sync.vectorstore import VectorStore +from elasticsearch.vectorstore._utils import DistanceMetric + +__all__ = [ + "BM25", + "DenseVector", + "DenseVectorScriptScore", + "ElasticsearchEmbeddings", + "EmbeddingService", + "RetrievalStrategy", + "SparseVector", + "VectorStore", + "AsyncBM25", + "AsyncDenseVector", + "AsyncDenseVectorScriptScore", + "AsyncElasticsearchEmbeddings", + "AsyncEmbeddingService", + "AsyncRetrievalStrategy", + "AsyncSparseVector", + "AsyncVectorStore", + "DistanceMetric", +] diff --git a/elasticsearch/vectorstore/_async/__init__.py b/elasticsearch/vectorstore/_async/__init__.py index 94071c177..2a87d183f 100644 --- a/elasticsearch/vectorstore/_async/__init__.py +++ b/elasticsearch/vectorstore/_async/__init__.py @@ -14,29 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -from elasticsearch.vectorstore._async.embedding_service import ( - AsyncElasticsearchEmbeddings, - AsyncEmbeddingService, -) -from elasticsearch.vectorstore._async.strategies import ( - BM25, - DenseVector, - DenseVectorScriptScore, - DistanceMetric, - RetrievalStrategy, - SparseVector, -) -from elasticsearch.vectorstore._async.vectorstore import AsyncVectorStore - -__all__ = [ - "AsyncEmbeddingService", - "AsyncElasticsearchEmbeddings", - "AsyncVectorStore", - "BM25", - "DenseVector", - "DenseVectorScriptScore", - "DistanceMetric", - "RetrievalStrategy", - "SparseVector", -] diff --git a/elasticsearch/vectorstore/_async/strategies.py b/elasticsearch/vectorstore/_async/strategies.py index 2d54720b8..f10d664c0 100644 --- a/elasticsearch/vectorstore/_async/strategies.py +++ b/elasticsearch/vectorstore/_async/strategies.py @@ -16,23 +16,14 @@ # under the License. from abc import ABC, abstractmethod -from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Union, cast from elasticsearch import AsyncElasticsearch from elasticsearch.vectorstore._async._utils import model_must_be_deployed +from elasticsearch.vectorstore._utils import DistanceMetric -class DistanceMetric(str, Enum): - """Enumerator of all Elasticsearch dense vector distance metrics.""" - - COSINE = "COSINE" - DOT_PRODUCT = "DOT_PRODUCT" - EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE" - MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT" - - -class RetrievalStrategy(ABC): +class AsyncRetrievalStrategy(ABC): @abstractmethod def es_query( self, @@ -98,7 +89,7 @@ def needs_inference(self) -> bool: return False -class SparseVector(RetrievalStrategy): +class AsyncSparseVector(AsyncRetrievalStrategy): """Sparse retrieval strategy using the `text_expansion` processor.""" def __init__(self, model_id: str = ".elser_model_2"): @@ -185,7 +176,7 @@ async def before_index_creation( ) -class DenseVector(RetrievalStrategy): +class AsyncDenseVector(AsyncRetrievalStrategy): """K-nearest-neighbors retrieval.""" def __init__( @@ -313,7 +304,7 @@ def needs_inference(self) -> bool: return not self.model_id -class DenseVectorScriptScore(RetrievalStrategy): +class AsyncDenseVectorScriptScore(AsyncRetrievalStrategy): """Exact nearest neighbors retrieval using the `script_score` query.""" def __init__(self, distance: DistanceMetric = DistanceMetric.COSINE) -> None: @@ -392,7 +383,7 @@ def needs_inference(self) -> bool: return True -class BM25(RetrievalStrategy): +class AsyncBM25(AsyncRetrievalStrategy): def __init__( self, k1: Optional[float] = None, diff --git a/elasticsearch/vectorstore/_async/vectorstore.py b/elasticsearch/vectorstore/_async/vectorstore.py index 0d5dc65b7..87e05b10a 100644 --- a/elasticsearch/vectorstore/_async/vectorstore.py +++ b/elasticsearch/vectorstore/_async/vectorstore.py @@ -22,7 +22,7 @@ from elasticsearch import AsyncElasticsearch from elasticsearch.helpers import BulkIndexError, async_bulk from elasticsearch.vectorstore._async.embedding_service import AsyncEmbeddingService -from elasticsearch.vectorstore._async.strategies import RetrievalStrategy +from elasticsearch.vectorstore._async.strategies import AsyncRetrievalStrategy from elasticsearch.vectorstore._utils import maximal_marginal_relevance logger = logging.getLogger(__name__) @@ -43,7 +43,7 @@ def __init__( es_client: AsyncElasticsearch, user_agent: str, index_name: str, - retrieval_strategy: RetrievalStrategy, + retrieval_strategy: AsyncRetrievalStrategy, embedding_service: Optional[AsyncEmbeddingService] = None, num_dimensions: Optional[int] = None, text_field: str = "text_field", diff --git a/elasticsearch/vectorstore/_sync/__init__.py b/elasticsearch/vectorstore/_sync/__init__.py index fa7981c82..2a87d183f 100644 --- a/elasticsearch/vectorstore/_sync/__init__.py +++ b/elasticsearch/vectorstore/_sync/__init__.py @@ -14,29 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -from elasticsearch.vectorstore._sync.embedding_service import ( - ElasticsearchEmbeddings, - EmbeddingService, -) -from elasticsearch.vectorstore._sync.strategies import ( - BM25, - DenseVector, - DenseVectorScriptScore, - DistanceMetric, - RetrievalStrategy, - SparseVector, -) -from elasticsearch.vectorstore._sync.vectorstore import VectorStore - -__all__ = [ - "EmbeddingService", - "ElasticsearchEmbeddings", - "VectorStore", - "BM25", - "DenseVector", - "DenseVectorScriptScore", - "DistanceMetric", - "RetrievalStrategy", - "SparseVector", -] diff --git a/elasticsearch/vectorstore/_sync/strategies.py b/elasticsearch/vectorstore/_sync/strategies.py index 20558cf0a..a9b0df01e 100644 --- a/elasticsearch/vectorstore/_sync/strategies.py +++ b/elasticsearch/vectorstore/_sync/strategies.py @@ -16,20 +16,11 @@ # under the License. from abc import ABC, abstractmethod -from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Union, cast from elasticsearch import Elasticsearch from elasticsearch.vectorstore._sync._utils import model_must_be_deployed - - -class DistanceMetric(str, Enum): - """Enumerator of all Elasticsearch dense vector distance metrics.""" - - COSINE = "COSINE" - DOT_PRODUCT = "DOT_PRODUCT" - EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE" - MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT" +from elasticsearch.vectorstore._utils import DistanceMetric class RetrievalStrategy(ABC): diff --git a/elasticsearch/vectorstore/_utils.py b/elasticsearch/vectorstore/_utils.py index 5c72962c5..342411fb5 100644 --- a/elasticsearch/vectorstore/_utils.py +++ b/elasticsearch/vectorstore/_utils.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +from enum import Enum from typing import List, Union import numpy as np @@ -22,6 +23,15 @@ Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] +class DistanceMetric(str, Enum): + """Enumerator of all Elasticsearch dense vector distance metrics.""" + + COSINE = "COSINE" + DOT_PRODUCT = "DOT_PRODUCT" + EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE" + MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT" + + def maximal_marginal_relevance( query_embedding: List[float], embedding_list: List[List[float]], diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py index 68f4c2e9e..f53d3336f 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py @@ -23,15 +23,15 @@ from elasticsearch import AsyncElasticsearch, NotFoundError from elasticsearch.helpers import BulkIndexError -from elasticsearch.vectorstore._async import AsyncVectorStore -from elasticsearch.vectorstore._async._utils import model_is_deployed -from elasticsearch.vectorstore._async.strategies import ( - BM25, - DenseVector, - DenseVectorScriptScore, +from elasticsearch.vectorstore import ( + AsyncBM25, + AsyncDenseVector, + AsyncDenseVectorScriptScore, + AsyncSparseVector, + AsyncVectorStore, DistanceMetric, - SparseVector, ) +from elasticsearch.vectorstore._async._utils import model_is_deployed from ._test_utils import ( AsyncConsistentFakeEmbeddings, @@ -84,7 +84,7 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(), + retrieval_strategy=AsyncDenseVector(), embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -103,7 +103,7 @@ async def test_search_without_metadata_async( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(), + retrieval_strategy=AsyncDenseVector(), embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -135,7 +135,7 @@ async def test_add_vectors( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(), + retrieval_strategy=AsyncDenseVector(), embedding_service=embeddings, es_client=es_client, ) @@ -155,7 +155,7 @@ async def test_search_with_metadata( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(), + retrieval_strategy=AsyncDenseVector(), embedding_service=AsyncConsistentFakeEmbeddings(), es_client=es_client, ) @@ -180,7 +180,7 @@ async def test_search_with_filter( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(), + retrieval_strategy=AsyncDenseVector(), embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -218,7 +218,7 @@ async def test_search_script_score( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVectorScriptScore(), + retrieval_strategy=AsyncDenseVectorScriptScore(), embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -266,7 +266,7 @@ async def test_search_script_score_with_filter( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVectorScriptScore(), + retrieval_strategy=AsyncDenseVectorScriptScore(), embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -320,7 +320,7 @@ async def test_search_script_score_distance_dot_product( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVectorScriptScore( + retrieval_strategy=AsyncDenseVectorScriptScore( distance=DistanceMetric.DOT_PRODUCT, ), embedding_service=AsyncFakeEmbeddings(), @@ -371,7 +371,7 @@ async def test_search_knn_with_hybrid_search( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(hybrid=True), + retrieval_strategy=AsyncDenseVector(hybrid=True), embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -459,7 +459,7 @@ def assert_query( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(hybrid=True, rrf=rrf_test_case), + retrieval_strategy=AsyncDenseVector(hybrid=True, rrf=rrf_test_case), embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -500,7 +500,7 @@ def assert_query( store = AsyncVectorStore( user_agent="test", index_name=f"{index_name}_default", - retrieval_strategy=DenseVector(hybrid=True), + retrieval_strategy=AsyncDenseVector(hybrid=True), embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -523,7 +523,7 @@ async def test_search_knn_with_custom_query_fn( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(), + retrieval_strategy=AsyncDenseVector(), embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -564,7 +564,7 @@ async def test_search_with_knn_infer_instack( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector( + retrieval_strategy=AsyncDenseVector( model_id="sentence-transformers__all-minilm-l6-v2" ), es_client=es_client, @@ -655,7 +655,7 @@ async def test_search_with_sparse_infer_instack( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=SparseVector(model_id=ELSER_MODEL_ID), + retrieval_strategy=AsyncSparseVector(model_id=ELSER_MODEL_ID), es_client=es_client, ) @@ -674,7 +674,7 @@ async def test_deployed_model_check_fails_semantic( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(model_id="non-existing model ID"), + retrieval_strategy=AsyncDenseVector(model_id="non-existing model ID"), es_client=es_client, ) await store.add_texts(["foo", "bar", "baz"]) @@ -687,7 +687,7 @@ async def test_search_bm25( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=BM25(), + retrieval_strategy=AsyncBM25(), es_client=es_client, ) @@ -716,7 +716,7 @@ async def test_search_bm25_with_filter( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=BM25(), + retrieval_strategy=AsyncBM25(), es_client=es_client, ) @@ -750,7 +750,7 @@ async def test_delete(self, es_client: AsyncElasticsearch, index_name: str) -> N store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(), + retrieval_strategy=AsyncDenseVector(), embedding_service=AsyncFakeEmbeddings(), es_client=es_client, ) @@ -789,7 +789,7 @@ async def test_indexing_exception_error( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=BM25(), + retrieval_strategy=AsyncBM25(), es_client=es_client, ) @@ -818,7 +818,7 @@ async def test_user_agent( store = AsyncVectorStore( user_agent=user_agent, index_name=index_name, - retrieval_strategy=BM25(), + retrieval_strategy=AsyncBM25(), es_client=requests_saving_client, ) @@ -840,7 +840,7 @@ async def test_bulk_args( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=BM25(), + retrieval_strategy=AsyncBM25(), es_client=requests_saving_client, ) @@ -862,7 +862,7 @@ async def test_max_marginal_relevance_search( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVectorScriptScore(), + retrieval_strategy=AsyncDenseVectorScriptScore(), embedding_service=embedding_service, vector_field=vector_field, text_field=text_field, @@ -925,7 +925,7 @@ async def test_metadata_mapping( store = AsyncVectorStore( user_agent="test", index_name=index_name, - retrieval_strategy=DenseVector(), + retrieval_strategy=AsyncDenseVector(), embedding_service=AsyncFakeEmbeddings(), es_client=es_client, metadata_mappings=test_mappings, diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py index 12282ac41..a27afd2c4 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py @@ -23,15 +23,15 @@ from elasticsearch import Elasticsearch, NotFoundError from elasticsearch.helpers import BulkIndexError -from elasticsearch.vectorstore._sync import VectorStore -from elasticsearch.vectorstore._sync._utils import model_is_deployed -from elasticsearch.vectorstore._sync.strategies import ( +from elasticsearch.vectorstore import ( BM25, DenseVector, DenseVectorScriptScore, DistanceMetric, SparseVector, + VectorStore, ) +from elasticsearch.vectorstore._sync._utils import model_is_deployed from ._test_utils import ( ConsistentFakeEmbeddings, diff --git a/utils/run-unasync.py b/utils/run-unasync.py index 2d8156721..990a78517 100644 --- a/utils/run-unasync.py +++ b/utils/run-unasync.py @@ -86,13 +86,18 @@ def main(): fromdir="elasticsearch/vectorstore/_async/", todir="elasticsearch/vectorstore/_sync/", additional_replacements={ - "_async": "_sync", - "async_bulk": "bulk", + "AsyncBM25": "BM25", + "AsyncDenseVector": "DenseVector", + "AsyncDenseVectorScriptScore": "DenseVectorScriptScore", "AsyncElasticsearch": "Elasticsearch", "AsyncElasticsearchEmbeddings": "ElasticsearchEmbeddings", "AsyncEmbeddingService": "EmbeddingService", + "AsyncRetrievalStrategy": "RetrievalStrategy", + "AsyncSparseVector": "SparseVector", "AsyncTransport": "Transport", "AsyncVectorStore": "VectorStore", + "async_bulk": "bulk", + "_async": "_sync", }, ), cleanup_patterns=[ @@ -108,13 +113,18 @@ def main(): fromdir="test_elasticsearch/test_server/test_vectorstore/_async/", todir="test_elasticsearch/test_server/test_vectorstore/_sync/", additional_replacements={ - "_async": "_sync", - "async_bulk": "bulk", + "AsyncBM25": "BM25", + "AsyncDenseVector": "DenseVector", + "AsyncDenseVectorScriptScore": "DenseVectorScriptScore", "AsyncElasticsearch": "Elasticsearch", "AsyncElasticsearchEmbeddings": "ElasticsearchEmbeddings", "AsyncEmbeddingService": "EmbeddingService", + "AsyncRetrievalStrategy": "RetrievalStrategy", + "AsyncSparseVector": "SparseVector", "AsyncTransport": "Transport", "AsyncVectorStore": "VectorStore", + "async_bulk": "bulk", + "_async": "_sync", # Tests-specific "AsyncConsistentFakeEmbeddings": "ConsistentFakeEmbeddings", "AsyncFakeEmbeddings": "FakeEmbeddings", From bbf2be94a2b65e00fac1e826ce8567e46da5731d Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Wed, 24 Apr 2024 10:14:45 +0200 Subject: [PATCH 17/36] cleanup old file --- .../vectorstore/_sync/vectorestore.py | 381 ------------------ 1 file changed, 381 deletions(-) delete mode 100644 elasticsearch/vectorstore/_sync/vectorestore.py diff --git a/elasticsearch/vectorstore/_sync/vectorestore.py b/elasticsearch/vectorstore/_sync/vectorestore.py deleted file mode 100644 index 862932eef..000000000 --- a/elasticsearch/vectorstore/_sync/vectorestore.py +++ /dev/null @@ -1,381 +0,0 @@ -# Licensed to Elasticsearch B.V. under one or more contributor -# license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright -# ownership. Elasticsearch B.V. licenses this file to you under -# the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import logging -import uuid -from typing import Any, Callable, Dict, List, Optional - -from elasticsearch import Elasticsearch -from elasticsearch.helpers import BulkIndexError, bulk -from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService -from elasticsearch.vectorstore._sync.strategies import RetrievalStrategy -from elasticsearch.vectorstore._utils import maximal_marginal_relevance - -logger = logging.getLogger(__name__) - - -class VectorStore: - """VectorStore is a higher-level abstraction of indexing and search. - Users can pick from available retrieval strategies. - - Documents are flat text documents. Depending on the strategy, vector embeddings are - - created by the user beforehand - - created by this class in Python - - created in-stack by inference pipelines. - """ - - def __init__( - self, - es_client: Elasticsearch, - user_agent: str, - index_name: str, - retrieval_strategy: RetrievalStrategy, - embedding_service: Optional[EmbeddingService] = None, - num_dimensions: Optional[int] = None, - text_field: str = "text_field", - vector_field: str = "vector_field", - metadata_mappings: Optional[Dict[str, Any]] = None, - ) -> None: - """ - Args: - user_header: user agent header specific to the 3rd party integration. - Used for usage tracking in Elastic Cloud. - index_name: The name of the index to query. - retrieval_strategy: how to index and search the data. See the strategies - module for availble strategies. - text_field: Name of the field with the textual data. - vector_field: For strategies that perform embedding inference in Python, - the embedding vector goes in this field. - es_client: Elasticsearch client connection. Alternatively specify the - Elasticsearch connection with the other es_* parameters. - """ - # Add integration-specific usage header for tracking usage in Elastic Cloud. - # client.options preserces existing (non-user-agent) headers. - es_client = es_client.options(headers={"User-Agent": user_agent}) - - if hasattr(retrieval_strategy, "text_field"): - retrieval_strategy.text_field = text_field - if hasattr(retrieval_strategy, "vector_field"): - retrieval_strategy.vector_field = vector_field - - self.es_client = es_client - self.index_name = index_name - self.retrieval_strategy = retrieval_strategy - self.embedding_service = embedding_service - self.num_dimensions = num_dimensions - self.text_field = text_field - self.vector_field = vector_field - self.metadata_mappings = metadata_mappings - - def close(self) -> None: - return self.es_client.close() - - def add_texts( - self, - texts: List[str], - metadatas: Optional[List[Dict[str, Any]]] = None, - vectors: Optional[List[List[float]]] = None, - ids: Optional[List[str]] = None, - refresh_indices: bool = True, - create_index_if_not_exists: bool = True, - bulk_kwargs: Optional[Dict[str, Any]] = None, - ) -> List[str]: - """Add documents to the Elasticsearch index. - - Args: - texts: List of text documents. - metadata: Optional list of document metadata. Must be of same length as - texts. - vectors: Optional list of embedding vectors. Must be of same length as - texts. - ids: Optional list of ID strings. Must be of same length as texts. - refresh_indices: Whether to refresh the index after deleting documents. - Defaults to True. - create_index_if_not_exists: Whether to create the index if it does not - exist. Defaults to True. - bulk_kwargs: Arguments to pass to the bulk function when indexing - (for example chunk_size). - - Returns: - List of IDs of the created documents, either echoing the provided one - or returning newly created ones. - """ - bulk_kwargs = bulk_kwargs or {} - ids = ids or [str(uuid.uuid4()) for _ in texts] - requests = [] - - if create_index_if_not_exists: - self._create_index_if_not_exists() - - if self.embedding_service and not vectors: - vectors = self.embedding_service.embed_documents(texts) - - for i, text in enumerate(texts): - metadata = metadatas[i] if metadatas else {} - - request: Dict[str, Any] = { - "_op_type": "index", - "_index": self.index_name, - self.text_field: text, - "metadata": metadata, - "_id": ids[i], - } - - if vectors: - request[self.vector_field] = vectors[i] - - requests.append(request) - - if len(requests) > 0: - try: - success, failed = bulk( - self.es_client, - requests, - stats_only=True, - refresh=refresh_indices, - **bulk_kwargs, - ) - logger.debug(f"added texts {ids} to index") - return ids - except BulkIndexError as e: - logger.error(f"Error adding texts: {e}") - firstError = e.errors[0].get("index", {}).get("error", {}) - logger.error(f"First error reason: {firstError.get('reason')}") - raise e - - else: - logger.debug("No texts to add to index") - return [] - - def delete( # type: ignore[no-untyped-def] - self, - ids: Optional[List[str]] = None, - query: Optional[Dict[str, Any]] = None, - refresh_indices: bool = True, - **delete_kwargs, - ) -> bool: - """Delete documents from the Elasticsearch index. - - Args: - ids: List of IDs of documents to delete. - refresh_indices: Whether to refresh the index after deleting documents. - Defaults to True. - """ - if ids is not None and query is not None: - raise ValueError("one of ids or query must be specified") - elif ids is None and query is None: - raise ValueError("either specify ids or query") - - try: - if ids: - body = [ - {"_op_type": "delete", "_index": self.index_name, "_id": _id} - for _id in ids - ] - bulk( - self.es_client, - body, - refresh=refresh_indices, - ignore_status=404, - **delete_kwargs, - ) - logger.debug(f"Deleted {len(body)} texts from index") - - else: - self.es_client.delete_by_query( - index=self.index_name, - query=query, - refresh=refresh_indices, - **delete_kwargs, - ) - - except BulkIndexError as e: - logger.error(f"Error deleting texts: {e}") - firstError = e.errors[0].get("index", {}).get("error", {}) - logger.error(f"First error reason: {firstError.get('reason')}") - raise e - - return True - - def search( - self, - query: Optional[str], - query_vector: Optional[List[float]] = None, - k: int = 4, - num_candidates: int = 50, - fields: Optional[List[str]] = None, - filter: Optional[List[Dict[str, Any]]] = None, - custom_query: Optional[ - Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] - ] = None, - ) -> List[Dict[str, Any]]: - """ - Args: - query: Input query string. - query_vector: Input embedding vector. If given, input query string is - ignored. - k: Number of returned results. - num_candidates: Number of candidates to fetch from data nodes in knn. - fields: List of field names to return. - filter: Elasticsearch filters to apply. - custom_query: Function to modify the Elasticsearch query body before it is - sent to Elasticsearch. - - Returns: - List of document hits. Includes _index, _id, _score and _source. - """ - if fields is None: - fields = [] - if "metadata" not in fields: - fields.append("metadata") - if self.text_field not in fields: - fields.append(self.text_field) - - if self.embedding_service and not query_vector: - if not query: - raise ValueError("specify a query or a query_vector to search") - query_vector = self.embedding_service.embed_query(query) - - query_body = self.retrieval_strategy.es_query( - query=query, - query_vector=query_vector, - text_field=self.text_field, - vector_field=self.vector_field, - k=k, - num_candidates=num_candidates, - filter=filter or [], - ) - - if custom_query is not None: - query_body = custom_query(query_body, query) - logger.debug(f"Calling custom_query, Query body now: {query_body}") - - response = self.es_client.search( - index=self.index_name, - **query_body, - size=k, - source=True, - source_includes=fields, - ) - hits: List[Dict[str, Any]] = response["hits"]["hits"] - - return hits - - def _create_index_if_not_exists(self) -> None: - exists = self.es_client.indices.exists(index=self.index_name) - if exists.meta.status == 200: - logger.debug(f"Index {self.index_name} already exists. Skipping creation.") - return - - if self.retrieval_strategy.needs_inference(): - if not self.num_dimensions and not self.embedding_service: - raise ValueError( - "retrieval strategy requires embeddings; either embedding_service " - "or num_dimensions need to be specified" - ) - if not self.num_dimensions and self.embedding_service: - vector = self.embedding_service.embed_query("get num dimensions") - self.num_dimensions = len(vector) - - mappings, settings = self.retrieval_strategy.es_mappings_settings( - text_field=self.text_field, - vector_field=self.vector_field, - num_dimensions=self.num_dimensions, - ) - - if self.metadata_mappings: - metadata = mappings["properties"].get("metadata", {"properties": {}}) - for key in self.metadata_mappings.keys(): - if key in metadata: - raise ValueError(f"metadata key {key} already exists in mappings") - - metadata = dict(**metadata["properties"], **self.metadata_mappings) - mappings["properties"] = {"metadata": {"properties": metadata}} - - self.retrieval_strategy.before_index_creation( - self.es_client, self.text_field, self.vector_field - ) - self.es_client.indices.create( - index=self.index_name, mappings=mappings, settings=settings - ) - - def max_marginal_relevance_search( - self, - embedding_service: EmbeddingService, - query: str, - vector_field: str, - k: int = 4, - num_candidates: int = 20, - lambda_mult: float = 0.5, - fields: Optional[List[str]] = None, - custom_query: Optional[ - Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] - ] = None, - ) -> List[Dict[str, Any]]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query (str): Text to look up documents similar to. - k (int): Number of Documents to return. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - lambda_mult (float): Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - fields: Other fields to get from elasticsearch source. These fields - will be added to the document metadata. - - Returns: - List[Document]: A list of Documents selected by maximal marginal relevance. - """ - remove_vector_query_field_from_metadata = True - if fields is None: - fields = [vector_field] - elif vector_field not in fields: - fields.append(vector_field) - else: - remove_vector_query_field_from_metadata = False - - # Embed the query - query_embedding = embedding_service.embed_query(query) - - # Fetch the initial documents - got_hits = self.search( - query=None, - query_vector=query_embedding, - k=num_candidates, - fields=fields, - custom_query=custom_query, - ) - - # Get the embeddings for the fetched documents - got_embeddings = [hit["_source"][vector_field] for hit in got_hits] - - # Select documents using maximal marginal relevance - selected_indices = maximal_marginal_relevance( - query_embedding, got_embeddings, lambda_mult=lambda_mult, k=k - ) - selected_hits = [got_hits[i] for i in selected_indices] - - if remove_vector_query_field_from_metadata: - for hit in selected_hits: - del hit["_source"][vector_field] - - return selected_hits From 299cd94caf3efede67d32615f5911dae344cb337 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Wed, 24 Apr 2024 10:17:16 +0200 Subject: [PATCH 18/36] add docker-compose service with model deployment --- .../test_server/test_vectorstore/docker-compose.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test_elasticsearch/test_server/test_vectorstore/docker-compose.yml b/test_elasticsearch/test_server/test_vectorstore/docker-compose.yml index b0e832e37..d598fe235 100644 --- a/test_elasticsearch/test_server/test_vectorstore/docker-compose.yml +++ b/test_elasticsearch/test_server/test_vectorstore/docker-compose.yml @@ -18,6 +18,14 @@ services: interval: 10s retries: 60 + # Currently fails on Mac: https://github.com/elastic/elasticsearch/issues/106206 + elasticsearch-with-models: + image: docker.elastic.co/eland/eland + depends_on: + - elasticsearch + restart: no + command: sh -c "sleep 10 && eland_import_hub_model --url http://elasticsearch:9200 --hub-model-id sentence-transformers/msmarco-minilm-l-12-v3 --start" + kibana: image: kibana:8.13.0 environment: From 5f0d98d0166d5ff31550cc60f576fc2272e5ffce Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Wed, 24 Apr 2024 11:07:54 +0200 Subject: [PATCH 19/36] optional dependencies for MMR --- dev-requirements.txt | 5 ++- elasticsearch/vectorstore/_utils.py | 54 ++++++++++++++++++----------- setup.py | 1 + 3 files changed, 39 insertions(+), 21 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index c42af4eab..330cb2701 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -15,10 +15,13 @@ twine build nox -numpy pandas orjson +# mmr for vectorstore +numpy +simsimd + # Testing the 'search_mvt' API response mapbox-vector-tile # Python 3.7 gets an old version of mapbox-vector-tile, requiring an diff --git a/elasticsearch/vectorstore/_utils.py b/elasticsearch/vectorstore/_utils.py index 342411fb5..a9b29623a 100644 --- a/elasticsearch/vectorstore/_utils.py +++ b/elasticsearch/vectorstore/_utils.py @@ -16,11 +16,12 @@ # under the License. from enum import Enum -from typing import List, Union +from typing import TYPE_CHECKING, List, Union -import numpy as np +if TYPE_CHECKING: + import numpy as np -Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] +Matrix = Union[List[List[float]], List["np.ndarray"], "np.ndarray"] class DistanceMetric(str, Enum): @@ -39,6 +40,12 @@ def maximal_marginal_relevance( k: int = 4, ) -> List[int]: """Calculate maximal marginal relevance.""" + + try: + import numpy as np + except ModuleNotFoundError as e: + _raise_missing_mmr_deps_error(e) + query_embedding_arr = np.array(query_embedding) if min(k, len(embedding_list)) <= 0: @@ -68,8 +75,15 @@ def maximal_marginal_relevance( return idxs -def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: +def _cosine_similarity(X: Matrix, Y: Matrix) -> "np.ndarray": """Row-wise cosine similarity between two equal-width matrices.""" + + try: + import numpy as np + import simsimd as simd + except ModuleNotFoundError as e: + _raise_missing_mmr_deps_error(e) + if len(X) == 0 or len(Y) == 0: return np.array([]) @@ -80,20 +94,20 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: f"Number of columns in X and Y must be the same. X has shape {X.shape} " f"and Y has shape {Y.shape}." ) - try: - import simsimd as simd - X = np.array(X, dtype=np.float32) - Y = np.array(Y, dtype=np.float32) - Z = 1 - simd.cdist(X, Y, metric="cosine") - if isinstance(Z, float): - return np.array([Z]) - return np.array(Z) - except ImportError: - X_norm = np.linalg.norm(X, axis=1) - Y_norm = np.linalg.norm(Y, axis=1) - # Ignore divide by zero errors run time warnings as those are handled below. - with np.errstate(divide="ignore", invalid="ignore"): - similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) - similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 - return similarity + X = np.array(X, dtype=np.float32) + Y = np.array(Y, dtype=np.float32) + Z = 1 - np.array(simd.cdist(X, Y, metric="cosine")) + if isinstance(Z, float): + return np.array([Z]) + return np.array(Z) + + +def _raise_missing_mmr_deps_error(parent_error: ModuleNotFoundError) -> None: + import sys + + raise ModuleNotFoundError( + f"Failed to compute maximal marginal relevance because the required " + f"module '{parent_error.name}' is missing. You can install it by running: " + f"'{sys.executable} -m pip install elasticsearch[mmr]'" + ) from parent_error diff --git a/setup.py b/setup.py index dc592dcc4..775a7b319 100644 --- a/setup.py +++ b/setup.py @@ -92,5 +92,6 @@ "requests": ["requests>=2.4.0, <3.0.0"], "async": ["aiohttp>=3,<4"], "orjson": ["orjson>=3"], + "mmr": ["numpy>=1", "simsimd>=3"], }, ) From 58c8b7df6d20246ef8c144307e581ff930047042 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Wed, 24 Apr 2024 12:06:33 +0200 Subject: [PATCH 20/36] only test sync parts --- test_elasticsearch/test_server/conftest.py | 6 +- .../test_vectorstore/{_async => }/__init__.py | 0 .../test_vectorstore/_async/_test_utils.py | 87 -- .../test_vectorstore/_async/conftest.py | 103 -- .../_async/test_embedding_service.py | 64 -- .../_async/test_vectorstore.py | 943 ------------------ .../test_vectorstore/_sync/__init__.py | 16 - .../{_sync => }/_test_utils.py | 0 .../test_vectorstore/{_sync => }/conftest.py | 0 .../{_sync => }/test_embedding_service.py | 0 .../{_sync => }/test_vectorstore.py | 91 +- 11 files changed, 51 insertions(+), 1259 deletions(-) rename test_elasticsearch/test_server/test_vectorstore/{_async => }/__init__.py (100%) delete mode 100644 test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py delete mode 100644 test_elasticsearch/test_server/test_vectorstore/_async/conftest.py delete mode 100644 test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py delete mode 100644 test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py delete mode 100644 test_elasticsearch/test_server/test_vectorstore/_sync/__init__.py rename test_elasticsearch/test_server/test_vectorstore/{_sync => }/_test_utils.py (100%) rename test_elasticsearch/test_server/test_vectorstore/{_sync => }/conftest.py (100%) rename test_elasticsearch/test_server/test_vectorstore/{_sync => }/test_embedding_service.py (100%) rename test_elasticsearch/test_server/test_vectorstore/{_sync => }/test_vectorstore.py (93%) diff --git a/test_elasticsearch/test_server/conftest.py b/test_elasticsearch/test_server/conftest.py index 558d0b013..474440afa 100644 --- a/test_elasticsearch/test_server/conftest.py +++ b/test_elasticsearch/test_server/conftest.py @@ -36,10 +36,14 @@ def sync_client_factory(elasticsearch_url): try: # Configure the client with certificates and optionally # an HTTP conn class depending on 'PYTHON_CONNECTION_CLASS' envvar - kw = {"ca_certs": CA_CERTS} + kw = {} if "PYTHON_CONNECTION_CLASS" in os.environ: kw["node_class"] = os.environ["PYTHON_CONNECTION_CLASS"] + # Add certificate verification if we're using HTTPS. + if elasticsearch_url.startswith("https"): + kw["ca_certs"] = CA_CERTS + # We do this little dance with the URL to force # Requests to respect 'headers: None' within rest API spec tests. client = elasticsearch.Elasticsearch(elasticsearch_url, **kw) diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/__init__.py b/test_elasticsearch/test_server/test_vectorstore/__init__.py similarity index 100% rename from test_elasticsearch/test_server/test_vectorstore/_async/__init__.py rename to test_elasticsearch/test_server/test_vectorstore/__init__.py diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py b/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py deleted file mode 100644 index 9f6522811..000000000 --- a/test_elasticsearch/test_server/test_vectorstore/_async/_test_utils.py +++ /dev/null @@ -1,87 +0,0 @@ -# Licensed to Elasticsearch B.V. under one or more contributor -# license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright -# ownership. Elasticsearch B.V. licenses this file to you under -# the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import Any, Dict, List - -from elastic_transport import AsyncTransport - -from elasticsearch.vectorstore._async.embedding_service import AsyncEmbeddingService - - -class AsyncFakeEmbeddings(AsyncEmbeddingService): - """Fake embeddings functionality for testing.""" - - def __init__(self, dimensionality: int = 10) -> None: - self.dimensionality = dimensionality - - def num_dimensions(self) -> int: - return self.dimensionality - - async def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Return simple embeddings. Embeddings encode each text as its index.""" - return [ - [float(1.0)] * (self.dimensionality - 1) + [float(i)] - for i in range(len(texts)) - ] - - async def embed_query(self, text: str) -> List[float]: - """Return constant query embeddings. - Embeddings are identical to embed_documents(texts)[0]. - Distance to each text will be that text's index, - as it was passed to embed_documents. - """ - return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)] - - -class AsyncConsistentFakeEmbeddings(AsyncFakeEmbeddings): - """Fake embeddings which remember all the texts seen so far to return consistent - vectors for the same texts.""" - - def __init__(self, dimensionality: int = 10) -> None: - self.known_texts: List[str] = [] - self.dimensionality = dimensionality - - def num_dimensions(self) -> int: - return self.dimensionality - - async def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Return consistent embeddings for each text seen so far.""" - out_vectors = [] - for text in texts: - if text not in self.known_texts: - self.known_texts.append(text) - vector = [float(1.0)] * (self.dimensionality - 1) + [ - float(self.known_texts.index(text)) - ] - out_vectors.append(vector) - return out_vectors - - async def embed_query(self, text: str) -> List[float]: - """Return consistent embeddings for the text, if seen before, or a constant - one if the text is unknown.""" - result = await self.embed_documents([text]) - return result[0] - - -class AsyncRequestSavingTransport(AsyncTransport): - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.requests: List[Dict] = [] - - async def perform_request(self, *args, **kwargs): # type: ignore - self.requests.append(kwargs) - return await super().perform_request(*args, **kwargs) diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/conftest.py b/test_elasticsearch/test_server/test_vectorstore/_async/conftest.py deleted file mode 100644 index 32b36a329..000000000 --- a/test_elasticsearch/test_server/test_vectorstore/_async/conftest.py +++ /dev/null @@ -1,103 +0,0 @@ -# Licensed to Elasticsearch B.V. under one or more contributor -# license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright -# ownership. Elasticsearch B.V. licenses this file to you under -# the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -import uuid -from typing import AsyncIterator, Dict - -import pytest -import pytest_asyncio - -from elasticsearch import AsyncElasticsearch - -from ._test_utils import AsyncRequestSavingTransport - - -@pytest_asyncio.fixture -async def es_client(elasticsearch_url: str) -> AsyncIterator[AsyncElasticsearch]: - client = _create_es_client(elasticsearch_url) - - yield client - - # clear indices - await _clear_test_indices(client) - - # clear all test pipelines - try: - response = await client.ingest.get_pipeline(id="test_*,*_sparse_embedding") - - for pipeline_id, _ in response.items(): - try: - await client.ingest.delete_pipeline(id=pipeline_id) - print(f"Deleted pipeline: {pipeline_id}") # noqa: T201 - except Exception as e: - print(f"Pipeline error: {e}") # noqa: T201 - - except Exception: - pass - finally: - await client.close() - - -@pytest_asyncio.fixture -async def requests_saving_client( - elasticsearch_url: str, -) -> AsyncIterator[AsyncElasticsearch]: - client = _create_es_client( - elasticsearch_url, es_kwargs={"transport_class": AsyncRequestSavingTransport} - ) - - try: - yield client - finally: - await client.close() - - -@pytest.fixture(scope="function") -def index_name() -> str: - return f"test_{uuid.uuid4().hex}" - - -async def _clear_test_indices(client: AsyncElasticsearch) -> None: - response = await client.indices.get(index="_all") - index_names = response.keys() - for index_name in index_names: - if index_name.startswith("test_"): - await client.indices.delete(index=index_name) - await client.indices.refresh(index="_all") - - -def _create_es_client( - elasticsearch_url: str, es_kwargs: Dict = {} -) -> AsyncElasticsearch: - if not elasticsearch_url: - elasticsearch_url = os.environ.get("ES_URL", "http://localhost:9200") - cloud_id = os.environ.get("ES_CLOUD_ID") - api_key = os.environ.get("ES_API_KEY") - - if cloud_id: - es_params = {"es_cloud_id": cloud_id, "es_api_key": api_key} - else: - es_params = {"es_url": elasticsearch_url} - - if "es_cloud_id" in es_params: - return AsyncElasticsearch( - cloud_id=es_params["es_cloud_id"], - api_key=es_params["es_api_key"], - **es_kwargs, - ) - return AsyncElasticsearch(hosts=[es_params["es_url"]], **es_kwargs) diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py deleted file mode 100644 index d18b5ddb2..000000000 --- a/test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py +++ /dev/null @@ -1,64 +0,0 @@ -# Licensed to Elasticsearch B.V. under one or more contributor -# license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright -# ownership. Elasticsearch B.V. licenses this file to you under -# the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os - -import pytest - -from elasticsearch import AsyncElasticsearch -from elasticsearch.vectorstore._async._utils import model_is_deployed -from elasticsearch.vectorstore._async.embedding_service import ( - AsyncElasticsearchEmbeddings, -) - -# deployed with -# https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html -MODEL_ID = os.getenv("MODEL_ID", "sentence-transformers__msmarco-minilm-l-12-v3") -NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384")) - - -@pytest.mark.asyncio -async def test_elasticsearch_embedding_documents(es_client: AsyncElasticsearch) -> None: - """Test Elasticsearch embedding documents.""" - - if not await model_is_deployed(es_client, MODEL_ID): - pytest.skip(f"{MODEL_ID} model is not deployed in ML Node, skipping test") - - documents = ["foo bar", "bar foo", "foo"] - embedding = AsyncElasticsearchEmbeddings( - es_client=es_client, user_agent="test", model_id=MODEL_ID - ) - output = await embedding.embed_documents(documents) - assert len(output) == 3 - assert len(output[0]) == NUM_DIMENSIONS - assert len(output[1]) == NUM_DIMENSIONS - assert len(output[2]) == NUM_DIMENSIONS - - -@pytest.mark.asyncio -async def test_elasticsearch_embedding_query(es_client: AsyncElasticsearch) -> None: - """Test Elasticsearch embedding query.""" - - if not await model_is_deployed(es_client, MODEL_ID): - pytest.skip(f"{MODEL_ID} model is not deployed in ML Node, skipping test") - - document = "foo bar" - embedding = AsyncElasticsearchEmbeddings( - es_client=es_client, user_agent="test", model_id=MODEL_ID - ) - output = await embedding.embed_query(document) - assert len(output) == NUM_DIMENSIONS diff --git a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py deleted file mode 100644 index f53d3336f..000000000 --- a/test_elasticsearch/test_server/test_vectorstore/_async/test_vectorstore.py +++ /dev/null @@ -1,943 +0,0 @@ -# Licensed to Elasticsearch B.V. under one or more contributor -# license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright -# ownership. Elasticsearch B.V. licenses this file to you under -# the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import logging -from functools import partial -from typing import Any, List, Optional, Union, cast - -import pytest - -from elasticsearch import AsyncElasticsearch, NotFoundError -from elasticsearch.helpers import BulkIndexError -from elasticsearch.vectorstore import ( - AsyncBM25, - AsyncDenseVector, - AsyncDenseVectorScriptScore, - AsyncSparseVector, - AsyncVectorStore, - DistanceMetric, -) -from elasticsearch.vectorstore._async._utils import model_is_deployed - -from ._test_utils import ( - AsyncConsistentFakeEmbeddings, - AsyncFakeEmbeddings, - AsyncRequestSavingTransport, -) - -logging.basicConfig(level=logging.DEBUG) - -""" -docker-compose up elasticsearch - -By default runs against local docker instance of Elasticsearch. -To run against Elastic Cloud, set the following environment variables: -- ES_CLOUD_ID -- ES_API_KEY - -Some of the tests require the following models to be deployed in the ML Node: -- elser (can be downloaded and deployed through Kibana and trained models UI) -- sentence-transformers__all-minilm-l6-v2 (can be deployed through the API, - loaded via eland) - -These tests that require the models to be deployed are skipped by default. -Enable them by adding the model name to the modelsDeployed list below. -""" - -ELSER_MODEL_ID = ".elser_model_2" -TRANSFORMER_MODEL_ID = "sentence-transformers__all-minilm-l6-v2" - - -class TestVectorStore: - @pytest.mark.asyncio - async def test_search_without_metadata( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """Test end to end construction and search without metadata.""" - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - assert query_body == { - "knn": { - "field": "vector_field", - "filter": [], - "k": 1, - "num_candidates": 50, - "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - } - } - return query_body - - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncDenseVector(), - embedding_service=AsyncFakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - await store.add_texts(texts) - - output = await store.search("foo", k=1, custom_query=assert_query) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - @pytest.mark.asyncio - async def test_search_without_metadata_async( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """Test end to end construction and search without metadata.""" - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncDenseVector(), - embedding_service=AsyncFakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - await store.add_texts(texts) - - output = await store.search("foo", k=1) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - @pytest.mark.asyncio - async def test_add_vectors( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """ - Test adding pre-built embeddings instead of using inference for the texts. - This allows you to separate the embeddings text and the page_content - for better proximity between user's question and embedded text. - For example, your embedding text can be a question, whereas page_content - is the answer. - """ - embeddings = AsyncConsistentFakeEmbeddings() - texts = ["foo1", "foo2", "foo3"] - metadatas = [{"page": i} for i in range(len(texts))] - - """In real use case, embedding_input can be questions for each text""" - embedding_vectors = await embeddings.embed_documents(texts) - - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncDenseVector(), - embedding_service=embeddings, - es_client=es_client, - ) - - await store.add_texts( - texts=texts, vectors=embedding_vectors, metadatas=metadatas - ) - output = await store.search("foo1", k=1) - assert [doc["_source"]["text_field"] for doc in output] == ["foo1"] - assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] - - @pytest.mark.asyncio - async def test_search_with_metadata( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """Test end to end construction and search with metadata.""" - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncDenseVector(), - embedding_service=AsyncConsistentFakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - await store.add_texts(texts=texts, metadatas=metadatas) - - output = await store.search("foo", k=1) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] - - output = await store.search("bar", k=1) - assert [doc["_source"]["text_field"] for doc in output] == ["bar"] - assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - - @pytest.mark.asyncio - async def test_search_with_filter( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """Test end to end construction and search with metadata.""" - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncDenseVector(), - embedding_service=AsyncFakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "foo", "foo"] - metadatas = [{"page": i} for i in range(len(texts))] - await store.add_texts(texts=texts, metadatas=metadatas) - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - assert query_body == { - "knn": { - "field": "vector_field", - "filter": [{"term": {"metadata.page": "1"}}], - "k": 3, - "num_candidates": 50, - "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - } - } - return query_body - - output = await store.search( - query="foo", - k=3, - filter=[{"term": {"metadata.page": "1"}}], - custom_query=assert_query, - ) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - - @pytest.mark.asyncio - async def test_search_script_score( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """Test end to end construction and search with metadata.""" - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncDenseVectorScriptScore(), - embedding_service=AsyncFakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - await store.add_texts(texts) - - expected_query = { - "query": { - "script_score": { - "query": {"match_all": {}}, - "script": { - "source": "cosineSimilarity(params.query_vector, 'vector_field') + 1.0", # noqa: E501 - "params": { - "query_vector": [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 0.0, - ] - }, - }, - } - } - } - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - assert query_body == expected_query - return query_body - - output = await store.search("foo", k=1, custom_query=assert_query) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - @pytest.mark.asyncio - async def test_search_script_score_with_filter( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """Test end to end construction and search with metadata.""" - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncDenseVectorScriptScore(), - embedding_service=AsyncFakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - metadatas = [{"page": i} for i in range(len(texts))] - await store.add_texts(texts=texts, metadatas=metadatas) - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - expected_query = { - "query": { - "script_score": { - "query": {"bool": {"filter": [{"term": {"metadata.page": 0}}]}}, - "script": { - "source": "cosineSimilarity(params.query_vector, 'vector_field') + 1.0", # noqa: E501 - "params": { - "query_vector": [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 0.0, - ] - }, - }, - } - } - } - assert query_body == expected_query - return query_body - - output = await store.search( - "foo", - k=1, - custom_query=assert_query, - filter=[{"term": {"metadata.page": 0}}], - ) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] - - @pytest.mark.asyncio - async def test_search_script_score_distance_dot_product( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """Test end to end construction and search with metadata.""" - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncDenseVectorScriptScore( - distance=DistanceMetric.DOT_PRODUCT, - ), - embedding_service=AsyncFakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - await store.add_texts(texts) - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - assert query_body == { - "query": { - "script_score": { - "query": {"match_all": {}}, - "script": { - "source": """ - double value = dotProduct(params.query_vector, 'vector_field'); - return sigmoid(1, Math.E, -value); - """, - "params": { - "query_vector": [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 0.0, - ] - }, - }, - } - } - } - return query_body - - output = await store.search("foo", k=1, custom_query=assert_query) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - @pytest.mark.asyncio - async def test_search_knn_with_hybrid_search( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """Test end to end construction and search with metadata.""" - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncDenseVector(hybrid=True), - embedding_service=AsyncFakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - await store.add_texts(texts) - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - assert query_body == { - "knn": { - "field": "vector_field", - "filter": [], - "k": 1, - "num_candidates": 50, - "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - }, - "query": { - "bool": { - "filter": [], - "must": [{"match": {"text_field": {"query": "foo"}}}], - } - }, - "rank": {"rrf": {}}, - } - return query_body - - output = await store.search("foo", k=1, custom_query=assert_query) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - @pytest.mark.asyncio - async def test_search_knn_with_hybrid_search_rrf( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """Test end to end construction and rrf hybrid search with metadata.""" - texts = ["foo", "bar", "baz"] - - def assert_query( - query_body: dict, - query: Optional[str], - expected_rrf: Union[dict, bool], - ) -> dict: - cmp_query_body = { - "knn": { - "field": "vector_field", - "filter": [], - "k": 3, - "num_candidates": 50, - "query_vector": [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 0.0, - ], - }, - "query": { - "bool": { - "filter": [], - "must": [{"match": {"text_field": {"query": "foo"}}}], - } - }, - } - - if isinstance(expected_rrf, dict): - cmp_query_body["rank"] = {"rrf": expected_rrf} - elif isinstance(expected_rrf, bool) and expected_rrf is True: - cmp_query_body["rank"] = {"rrf": {}} - - assert query_body == cmp_query_body - - return query_body - - # 1. check query_body is okay - rrf_test_cases: List[Union[dict, bool]] = [ - True, - False, - {"rank_constant": 1, "window_size": 5}, - ] - for rrf_test_case in rrf_test_cases: - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncDenseVector(hybrid=True, rrf=rrf_test_case), - embedding_service=AsyncFakeEmbeddings(), - es_client=es_client, - ) - await store.add_texts(texts) - - ## without fetch_k parameter - output = await store.search( - "foo", - k=3, - custom_query=partial(assert_query, expected_rrf=rrf_test_case), - ) - - # 2. check query result is okay - es_output = await store.es_client.search( - index=index_name, - query={ - "bool": { - "filter": [], - "must": [{"match": {"text_field": {"query": "foo"}}}], - } - }, - knn={ - "field": "vector_field", - "filter": [], - "k": 3, - "num_candidates": 50, - "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - }, - size=3, - rank={"rrf": {"rank_constant": 1, "window_size": 5}}, - ) - - assert [o["_source"]["text_field"] for o in output] == [ - e["_source"]["text_field"] for e in es_output["hits"]["hits"] - ] - - # 3. check rrf default option is okay - store = AsyncVectorStore( - user_agent="test", - index_name=f"{index_name}_default", - retrieval_strategy=AsyncDenseVector(hybrid=True), - embedding_service=AsyncFakeEmbeddings(), - es_client=es_client, - ) - await store.add_texts(texts) - - ## with fetch_k parameter - output = await store.search( - "foo", - k=3, - num_candidates=50, - custom_query=partial(assert_query, expected_rrf={}), - ) - - @pytest.mark.asyncio - async def test_search_knn_with_custom_query_fn( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """test that custom query function is called - with the query string and query body""" - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncDenseVector(), - embedding_service=AsyncFakeEmbeddings(), - es_client=es_client, - ) - - def my_custom_query(query_body: dict, query: Optional[str]) -> dict: - assert query == "foo" - assert query_body == { - "knn": { - "field": "vector_field", - "filter": [], - "k": 1, - "num_candidates": 50, - "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - } - } - return {"query": {"match": {"text_field": {"query": "bar"}}}} - - """Test end to end construction and search with metadata.""" - texts = ["foo", "bar", "baz"] - await store.add_texts(texts) - - output = await store.search("foo", k=1, custom_query=my_custom_query) - assert [doc["_source"]["text_field"] for doc in output] == ["bar"] - - @pytest.mark.asyncio - async def test_search_with_knn_infer_instack( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """test end to end with knn retrieval strategy and inference in-stack""" - - if not await model_is_deployed(es_client, TRANSFORMER_MODEL_ID): - pytest.skip( - f"{TRANSFORMER_MODEL_ID} model not deployed in ML Node skipping test" - ) - - text_field = "text_field" - - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncDenseVector( - model_id="sentence-transformers__all-minilm-l6-v2" - ), - es_client=es_client, - ) - - # setting up the pipeline for inference - await store.es_client.ingest.put_pipeline( - id="test_pipeline", - processors=[ - { - "inference": { - "model_id": TRANSFORMER_MODEL_ID, - "field_map": {"query_field": text_field}, - "target_field": "vector_query_field", - } - } - ], - ) - - # creating a new index with the pipeline, - # not relying on langchain to create the index - await store.es_client.indices.create( - index=index_name, - mappings={ - "properties": { - text_field: {"type": "text_field"}, - "vector_query_field": { - "properties": { - "predicted_value": { - "type": "dense_vector", - "dims": 384, - "index": True, - "similarity": "l2_norm", - } - } - }, - } - }, - settings={"index": {"default_pipeline": "test_pipeline"}}, - ) - - # adding documents to the index - texts = ["foo", "bar", "baz"] - - for i, text in enumerate(texts): - await store.es_client.create( - index=index_name, - id=str(i), - document={text_field: text, "metadata": {}}, - ) - - await store.es_client.indices.refresh(index=index_name) - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - assert query_body == { - "knn": { - "filter": [], - "field": "vector_query_field.predicted_value", - "k": 1, - "num_candidates": 50, - "query_vector_builder": { - "text_embedding": { - "model_id": TRANSFORMER_MODEL_ID, - "model_text": "foo", - } - }, - } - } - return query_body - - output = await store.search("foo", k=1, custom_query=assert_query) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - output = await store.search("bar", k=1) - assert [doc["_source"]["text_field"] for doc in output] == ["bar"] - - @pytest.mark.asyncio - async def test_search_with_sparse_infer_instack( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """test end to end with sparse retrieval strategy and inference in-stack""" - - if not await model_is_deployed(es_client, ELSER_MODEL_ID): - reason = f"{ELSER_MODEL_ID} model not deployed in ML Node, skipping test" - - pytest.skip(reason) - - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncSparseVector(model_id=ELSER_MODEL_ID), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - await store.add_texts(texts) - - output = await store.search("foo", k=1) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - @pytest.mark.asyncio - async def test_deployed_model_check_fails_semantic( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """test that exceptions are raised if a specified model is not deployed""" - with pytest.raises(NotFoundError): - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncDenseVector(model_id="non-existing model ID"), - es_client=es_client, - ) - await store.add_texts(["foo", "bar", "baz"]) - - @pytest.mark.asyncio - async def test_search_bm25( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """Test end to end using the BM25 retrieval strategy.""" - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncBM25(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz"] - await store.add_texts(texts) - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - assert query_body == { - "query": { - "bool": { - "must": [{"match": {"text_field": {"query": "foo"}}}], - "filter": [], - } - } - } - return query_body - - output = await store.search("foo", k=1, custom_query=assert_query) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - - @pytest.mark.asyncio - async def test_search_bm25_with_filter( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """Test end to using the BM25 retrieval strategy with metadata.""" - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncBM25(), - es_client=es_client, - ) - - texts = ["foo", "foo", "foo"] - metadatas = [{"page": i} for i in range(len(texts))] - await store.add_texts(texts=texts, metadatas=metadatas) - - def assert_query(query_body: dict, query: Optional[str]) -> dict: - assert query_body == { - "query": { - "bool": { - "must": [{"match": {"text_field": {"query": "foo"}}}], - "filter": [{"term": {"metadata.page": 1}}], - } - } - } - return query_body - - output = await store.search( - "foo", - k=3, - custom_query=assert_query, - filter=[{"term": {"metadata.page": 1}}], - ) - assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - - @pytest.mark.asyncio - async def test_delete(self, es_client: AsyncElasticsearch, index_name: str) -> None: - """Test delete methods from vector store.""" - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncDenseVector(), - embedding_service=AsyncFakeEmbeddings(), - es_client=es_client, - ) - - texts = ["foo", "bar", "baz", "gni"] - metadatas = [{"page": i} for i in range(len(texts))] - ids = await store.add_texts(texts=texts, metadatas=metadatas) - - output = await store.search("foo", k=10) - assert len(output) == 4 - - await store.delete(ids[1:3]) - output = await store.search("foo", k=10) - assert len(output) == 2 - - await store.delete(["not-existing"]) - output = await store.search("foo", k=10) - assert len(output) == 2 - - await store.delete([ids[0]]) - output = await store.search("foo", k=10) - assert len(output) == 1 - - await store.delete([ids[3]]) - output = await store.search("gni", k=10) - assert len(output) == 0 - - @pytest.mark.asyncio - async def test_indexing_exception_error( - self, - es_client: AsyncElasticsearch, - index_name: str, - caplog: pytest.LogCaptureFixture, - ) -> None: - """Test bulk exception logging is giving better hints.""" - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncBM25(), - es_client=es_client, - ) - - await store.es_client.indices.create( - index=index_name, - mappings={"properties": {}}, - settings={"index": {"default_pipeline": "not-existing-pipeline"}}, - ) - - texts = ["foo"] - - with pytest.raises(BulkIndexError): - await store.add_texts(texts) - - error_reason = "pipeline with id [not-existing-pipeline] does not exist" - log_message = f"First error reason: {error_reason}" - - assert log_message in caplog.text - - @pytest.mark.asyncio - async def test_user_agent( - self, requests_saving_client: AsyncElasticsearch, index_name: str - ) -> None: - """Test to make sure the user-agent is set correctly.""" - user_agent = "this is THE user_agent!" - store = AsyncVectorStore( - user_agent=user_agent, - index_name=index_name, - retrieval_strategy=AsyncBM25(), - es_client=requests_saving_client, - ) - - assert store.es_client._headers["User-Agent"] == user_agent - - texts = ["foo", "bob", "baz"] - await store.add_texts(texts) - - transport = cast(AsyncRequestSavingTransport, store.es_client.transport) - - for request in transport.requests: - assert request["headers"]["User-Agent"] == user_agent - - @pytest.mark.asyncio - async def test_bulk_args( - self, requests_saving_client: Any, index_name: str - ) -> None: - """Test to make sure the bulk arguments work as expected.""" - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncBM25(), - es_client=requests_saving_client, - ) - - texts = ["foo", "bob", "baz"] - await store.add_texts(texts, bulk_kwargs={"chunk_size": 1}) - - # 1 for index exist, 1 for index create, 3 to index docs - assert len(store.es_client.transport.requests) == 5 # type: ignore - - @pytest.mark.asyncio - async def test_max_marginal_relevance_search( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """Test max marginal relevance search.""" - texts = ["foo", "bar", "baz"] - vector_field = "vector_field" - text_field = "text_field" - embedding_service = AsyncConsistentFakeEmbeddings() - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncDenseVectorScriptScore(), - embedding_service=embedding_service, - vector_field=vector_field, - text_field=text_field, - es_client=es_client, - ) - await store.add_texts(texts) - - mmr_output = await store.max_marginal_relevance_search( - embedding_service, - texts[0], - vector_field=vector_field, - k=3, - num_candidates=3, - ) - sim_output = await store.search(texts[0], k=3) - assert mmr_output == sim_output - - mmr_output = await store.max_marginal_relevance_search( - embedding_service, - texts[0], - vector_field=vector_field, - k=2, - num_candidates=3, - ) - assert len(mmr_output) == 2 - assert mmr_output[0]["_source"][text_field] == texts[0] - assert mmr_output[1]["_source"][text_field] == texts[1] - - mmr_output = await store.max_marginal_relevance_search( - embedding_service, - texts[0], - vector_field=vector_field, - k=2, - num_candidates=3, - lambda_mult=0.1, # more diversity - ) - assert len(mmr_output) == 2 - assert mmr_output[0]["_source"][text_field] == texts[0] - assert mmr_output[1]["_source"][text_field] == texts[2] - - # if fetch_k < k, then the output will be less than k - mmr_output = await store.max_marginal_relevance_search( - embedding_service, - texts[0], - vector_field=vector_field, - k=3, - num_candidates=2, - ) - assert len(mmr_output) == 2 - - @pytest.mark.asyncio - async def test_metadata_mapping( - self, es_client: AsyncElasticsearch, index_name: str - ) -> None: - """Test that the metadata mapping is applied.""" - test_mappings = { - "my_field": {"type": "keyword"}, - "another_field": {"type": "text"}, - } - store = AsyncVectorStore( - user_agent="test", - index_name=index_name, - retrieval_strategy=AsyncDenseVector(), - embedding_service=AsyncFakeEmbeddings(), - es_client=es_client, - metadata_mappings=test_mappings, - ) - - texts = ["foo", "foo", "foo"] - metadatas = [{"page": i} for i in range(len(texts))] - await store.add_texts(texts=texts, metadatas=metadatas) - - mapping_response = await es_client.indices.get_mapping(index=index_name) - mapping_properties = mapping_response[index_name]["mappings"]["properties"] - print(mapping_response) - assert "metadata" in mapping_properties - for key, val in test_mappings.items(): - assert mapping_properties["metadata"]["properties"][key] == val diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/__init__.py b/test_elasticsearch/test_server/test_vectorstore/_sync/__init__.py deleted file mode 100644 index 2a87d183f..000000000 --- a/test_elasticsearch/test_server/test_vectorstore/_sync/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to Elasticsearch B.V. under one or more contributor -# license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright -# ownership. Elasticsearch B.V. licenses this file to you under -# the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py b/test_elasticsearch/test_server/test_vectorstore/_test_utils.py similarity index 100% rename from test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py rename to test_elasticsearch/test_server/test_vectorstore/_test_utils.py diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/conftest.py b/test_elasticsearch/test_server/test_vectorstore/conftest.py similarity index 100% rename from test_elasticsearch/test_server/test_vectorstore/_sync/conftest.py rename to test_elasticsearch/test_server/test_vectorstore/conftest.py diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py b/test_elasticsearch/test_server/test_vectorstore/test_embedding_service.py similarity index 100% rename from test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py rename to test_elasticsearch/test_server/test_vectorstore/test_embedding_service.py diff --git a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py similarity index 93% rename from test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py rename to test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py index a27afd2c4..f7eab24dc 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py @@ -64,7 +64,7 @@ class TestVectorStore: def test_search_without_metadata( - self, es_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index_name: str ) -> None: """Test end to end construction and search without metadata.""" @@ -85,7 +85,7 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: index_name=index_name, retrieval_strategy=DenseVector(), embedding_service=FakeEmbeddings(), - es_client=es_client, + es_client=sync_client, ) texts = ["foo", "bar", "baz"] @@ -95,7 +95,7 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["text_field"] for doc in output] == ["foo"] def test_search_without_metadata_async( - self, es_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index_name: str ) -> None: """Test end to end construction and search without metadata.""" store = VectorStore( @@ -103,7 +103,7 @@ def test_search_without_metadata_async( index_name=index_name, retrieval_strategy=DenseVector(), embedding_service=FakeEmbeddings(), - es_client=es_client, + es_client=sync_client, ) texts = ["foo", "bar", "baz"] @@ -112,7 +112,7 @@ def test_search_without_metadata_async( output = store.search("foo", k=1) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - def test_add_vectors(self, es_client: Elasticsearch, index_name: str) -> None: + def test_add_vectors(self, sync_client: Elasticsearch, index_name: str) -> None: """ Test adding pre-built embeddings instead of using inference for the texts. This allows you to separate the embeddings text and the page_content @@ -132,7 +132,7 @@ def test_add_vectors(self, es_client: Elasticsearch, index_name: str) -> None: index_name=index_name, retrieval_strategy=DenseVector(), embedding_service=embeddings, - es_client=es_client, + es_client=sync_client, ) store.add_texts(texts=texts, vectors=embedding_vectors, metadatas=metadatas) @@ -141,7 +141,7 @@ def test_add_vectors(self, es_client: Elasticsearch, index_name: str) -> None: assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] def test_search_with_metadata( - self, es_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index_name: str ) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( @@ -149,7 +149,7 @@ def test_search_with_metadata( index_name=index_name, retrieval_strategy=DenseVector(), embedding_service=ConsistentFakeEmbeddings(), - es_client=es_client, + es_client=sync_client, ) texts = ["foo", "bar", "baz"] @@ -165,7 +165,7 @@ def test_search_with_metadata( assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] def test_search_with_filter( - self, es_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index_name: str ) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( @@ -173,7 +173,7 @@ def test_search_with_filter( index_name=index_name, retrieval_strategy=DenseVector(), embedding_service=FakeEmbeddings(), - es_client=es_client, + es_client=sync_client, ) texts = ["foo", "foo", "foo"] @@ -202,7 +202,7 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] def test_search_script_score( - self, es_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index_name: str ) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( @@ -210,7 +210,7 @@ def test_search_script_score( index_name=index_name, retrieval_strategy=DenseVectorScriptScore(), embedding_service=FakeEmbeddings(), - es_client=es_client, + es_client=sync_client, ) texts = ["foo", "bar", "baz"] @@ -249,7 +249,7 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["text_field"] for doc in output] == ["foo"] def test_search_script_score_with_filter( - self, es_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index_name: str ) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( @@ -257,7 +257,7 @@ def test_search_script_score_with_filter( index_name=index_name, retrieval_strategy=DenseVectorScriptScore(), embedding_service=FakeEmbeddings(), - es_client=es_client, + es_client=sync_client, ) texts = ["foo", "bar", "baz"] @@ -302,7 +302,7 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] def test_search_script_score_distance_dot_product( - self, es_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index_name: str ) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( @@ -312,7 +312,7 @@ def test_search_script_score_distance_dot_product( distance=DistanceMetric.DOT_PRODUCT, ), embedding_service=FakeEmbeddings(), - es_client=es_client, + es_client=sync_client, ) texts = ["foo", "bar", "baz"] @@ -352,7 +352,7 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["text_field"] for doc in output] == ["foo"] def test_search_knn_with_hybrid_search( - self, es_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index_name: str ) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( @@ -360,7 +360,7 @@ def test_search_knn_with_hybrid_search( index_name=index_name, retrieval_strategy=DenseVector(hybrid=True), embedding_service=FakeEmbeddings(), - es_client=es_client, + es_client=sync_client, ) texts = ["foo", "bar", "baz"] @@ -389,7 +389,7 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["text_field"] for doc in output] == ["foo"] def test_search_knn_with_hybrid_search_rrf( - self, es_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index_name: str ) -> None: """Test end to end construction and rrf hybrid search with metadata.""" texts = ["foo", "bar", "baz"] @@ -447,7 +447,7 @@ def assert_query( index_name=index_name, retrieval_strategy=DenseVector(hybrid=True, rrf=rrf_test_case), embedding_service=FakeEmbeddings(), - es_client=es_client, + es_client=sync_client, ) store.add_texts(texts) @@ -488,7 +488,7 @@ def assert_query( index_name=f"{index_name}_default", retrieval_strategy=DenseVector(hybrid=True), embedding_service=FakeEmbeddings(), - es_client=es_client, + es_client=sync_client, ) store.add_texts(texts) @@ -501,7 +501,7 @@ def assert_query( ) def test_search_knn_with_custom_query_fn( - self, es_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index_name: str ) -> None: """test that custom query function is called with the query string and query body""" @@ -510,7 +510,7 @@ def test_search_knn_with_custom_query_fn( index_name=index_name, retrieval_strategy=DenseVector(), embedding_service=FakeEmbeddings(), - es_client=es_client, + es_client=sync_client, ) def my_custom_query(query_body: dict, query: Optional[str]) -> dict: @@ -534,11 +534,11 @@ def my_custom_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["text_field"] for doc in output] == ["bar"] def test_search_with_knn_infer_instack( - self, es_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index_name: str ) -> None: """test end to end with knn retrieval strategy and inference in-stack""" - if not model_is_deployed(es_client, TRANSFORMER_MODEL_ID): + if not model_is_deployed(sync_client, TRANSFORMER_MODEL_ID): pytest.skip( f"{TRANSFORMER_MODEL_ID} model not deployed in ML Node skipping test" ) @@ -551,7 +551,7 @@ def test_search_with_knn_infer_instack( retrieval_strategy=DenseVector( model_id="sentence-transformers__all-minilm-l6-v2" ), - es_client=es_client, + es_client=sync_client, ) # setting up the pipeline for inference @@ -626,11 +626,11 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["text_field"] for doc in output] == ["bar"] def test_search_with_sparse_infer_instack( - self, es_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index_name: str ) -> None: """test end to end with sparse retrieval strategy and inference in-stack""" - if not model_is_deployed(es_client, ELSER_MODEL_ID): + if not model_is_deployed(sync_client, ELSER_MODEL_ID): reason = f"{ELSER_MODEL_ID} model not deployed in ML Node, skipping test" pytest.skip(reason) @@ -639,7 +639,7 @@ def test_search_with_sparse_infer_instack( user_agent="test", index_name=index_name, retrieval_strategy=SparseVector(model_id=ELSER_MODEL_ID), - es_client=es_client, + es_client=sync_client, ) texts = ["foo", "bar", "baz"] @@ -649,7 +649,7 @@ def test_search_with_sparse_infer_instack( assert [doc["_source"]["text_field"] for doc in output] == ["foo"] def test_deployed_model_check_fails_semantic( - self, es_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index_name: str ) -> None: """test that exceptions are raised if a specified model is not deployed""" with pytest.raises(NotFoundError): @@ -657,17 +657,17 @@ def test_deployed_model_check_fails_semantic( user_agent="test", index_name=index_name, retrieval_strategy=DenseVector(model_id="non-existing model ID"), - es_client=es_client, + es_client=sync_client, ) store.add_texts(["foo", "bar", "baz"]) - def test_search_bm25(self, es_client: Elasticsearch, index_name: str) -> None: + def test_search_bm25(self, sync_client: Elasticsearch, index_name: str) -> None: """Test end to end using the BM25 retrieval strategy.""" store = VectorStore( user_agent="test", index_name=index_name, retrieval_strategy=BM25(), - es_client=es_client, + es_client=sync_client, ) texts = ["foo", "bar", "baz"] @@ -688,14 +688,14 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["text_field"] for doc in output] == ["foo"] def test_search_bm25_with_filter( - self, es_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index_name: str ) -> None: """Test end to using the BM25 retrieval strategy with metadata.""" store = VectorStore( user_agent="test", index_name=index_name, retrieval_strategy=BM25(), - es_client=es_client, + es_client=sync_client, ) texts = ["foo", "foo", "foo"] @@ -722,14 +722,14 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["text_field"] for doc in output] == ["foo"] assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - def test_delete(self, es_client: Elasticsearch, index_name: str) -> None: + def test_delete(self, sync_client: Elasticsearch, index_name: str) -> None: """Test delete methods from vector store.""" store = VectorStore( user_agent="test", index_name=index_name, retrieval_strategy=DenseVector(), embedding_service=FakeEmbeddings(), - es_client=es_client, + es_client=sync_client, ) texts = ["foo", "bar", "baz", "gni"] @@ -757,7 +757,7 @@ def test_delete(self, es_client: Elasticsearch, index_name: str) -> None: def test_indexing_exception_error( self, - es_client: Elasticsearch, + sync_client: Elasticsearch, index_name: str, caplog: pytest.LogCaptureFixture, ) -> None: @@ -766,7 +766,7 @@ def test_indexing_exception_error( user_agent="test", index_name=index_name, retrieval_strategy=BM25(), - es_client=es_client, + es_client=sync_client, ) store.es_client.indices.create( @@ -823,7 +823,7 @@ def test_bulk_args(self, requests_saving_client: Any, index_name: str) -> None: assert len(store.es_client.transport.requests) == 5 # type: ignore def test_max_marginal_relevance_search( - self, es_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index_name: str ) -> None: """Test max marginal relevance search.""" texts = ["foo", "bar", "baz"] @@ -837,7 +837,7 @@ def test_max_marginal_relevance_search( embedding_service=embedding_service, vector_field=vector_field, text_field=text_field, - es_client=es_client, + es_client=sync_client, ) store.add_texts(texts) @@ -884,7 +884,9 @@ def test_max_marginal_relevance_search( ) assert len(mmr_output) == 2 - def test_metadata_mapping(self, es_client: Elasticsearch, index_name: str) -> None: + def test_metadata_mapping( + self, sync_client: Elasticsearch, index_name: str + ) -> None: """Test that the metadata mapping is applied.""" test_mappings = { "my_field": {"type": "keyword"}, @@ -895,7 +897,7 @@ def test_metadata_mapping(self, es_client: Elasticsearch, index_name: str) -> No index_name=index_name, retrieval_strategy=DenseVector(), embedding_service=FakeEmbeddings(), - es_client=es_client, + es_client=sync_client, metadata_mappings=test_mappings, ) @@ -903,9 +905,8 @@ def test_metadata_mapping(self, es_client: Elasticsearch, index_name: str) -> No metadatas = [{"page": i} for i in range(len(texts))] store.add_texts(texts=texts, metadatas=metadatas) - mapping_response = es_client.indices.get_mapping(index=index_name) + mapping_response = sync_client.indices.get_mapping(index=index_name) mapping_properties = mapping_response[index_name]["mappings"]["properties"] - print(mapping_response) assert "metadata" in mapping_properties for key, val in test_mappings.items(): assert mapping_properties["metadata"]["properties"][key] == val From 994b412b5739eb45a2a3f2952362f93aef8fc4cc Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Wed, 24 Apr 2024 12:08:36 +0200 Subject: [PATCH 21/36] cleanup unasync script --- utils/run-unasync.py | 33 --------------------------------- 1 file changed, 33 deletions(-) diff --git a/utils/run-unasync.py b/utils/run-unasync.py index 990a78517..80d12b427 100644 --- a/utils/run-unasync.py +++ b/utils/run-unasync.py @@ -108,39 +108,6 @@ def main(): format=True, ) - run( - rule=unasync.Rule( - fromdir="test_elasticsearch/test_server/test_vectorstore/_async/", - todir="test_elasticsearch/test_server/test_vectorstore/_sync/", - additional_replacements={ - "AsyncBM25": "BM25", - "AsyncDenseVector": "DenseVector", - "AsyncDenseVectorScriptScore": "DenseVectorScriptScore", - "AsyncElasticsearch": "Elasticsearch", - "AsyncElasticsearchEmbeddings": "ElasticsearchEmbeddings", - "AsyncEmbeddingService": "EmbeddingService", - "AsyncRetrievalStrategy": "RetrievalStrategy", - "AsyncSparseVector": "SparseVector", - "AsyncTransport": "Transport", - "AsyncVectorStore": "VectorStore", - "async_bulk": "bulk", - "_async": "_sync", - # Tests-specific - "AsyncConsistentFakeEmbeddings": "ConsistentFakeEmbeddings", - "AsyncFakeEmbeddings": "FakeEmbeddings", - "AsyncGenerator": "Generator", - "AsyncRequestSavingTransport": "RequestSavingTransport", - "pytest_asyncio": "pytest", - }, - ), - cleanup_patterns=[ - "/^import asyncio$/d", - "/^import pytest_asyncio*/d", - "/ *@pytest.mark.asyncio$/d", - ], - format=True, - ) - if __name__ == "__main__": main() From 5073af1aad858579f781018d50a482030fc7bbc2 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Wed, 24 Apr 2024 12:49:03 +0200 Subject: [PATCH 22/36] nox: install optional deps --- noxfile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/noxfile.py b/noxfile.py index c303fe26c..275deceed 100644 --- a/noxfile.py +++ b/noxfile.py @@ -48,7 +48,7 @@ def pytest_argv(): @nox.session(python=["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]) def test(session): - session.install(".[async,requests,orjson]", env=INSTALL_ENV, silent=False) + session.install(".[async,requests,orjson,mmr]", env=INSTALL_ENV, silent=False) session.install("-r", "dev-requirements.txt", silent=False) session.run(*pytest_argv()) @@ -95,7 +95,7 @@ def lint(session): session.run("flake8", *SOURCE_FILES) session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) - session.install(".[async,requests,orjson]", env=INSTALL_ENV) + session.install(".[async,requests,orjson,mmr]", env=INSTALL_ENV) # Run mypy on the package and then the type examples separately for # the two different mypy use-cases, ourselves and our users. From 9c50c6d71764e833c4206572a149ee94a9df8155 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Wed, 24 Apr 2024 12:49:41 +0200 Subject: [PATCH 23/36] fix tests with requests remembering Transport --- test_elasticsearch/test_server/conftest.py | 38 ++++---- .../test_vectorstore/_test_utils.py | 19 +--- .../test_server/test_vectorstore/conftest.py | 91 +++++-------------- .../test_embedding_service.py | 12 +-- .../test_vectorstore/test_vectorstore.py | 16 ++-- 5 files changed, 59 insertions(+), 117 deletions(-) diff --git a/test_elasticsearch/test_server/conftest.py b/test_elasticsearch/test_server/conftest.py index 474440afa..1fa1f2c74 100644 --- a/test_elasticsearch/test_server/conftest.py +++ b/test_elasticsearch/test_server/conftest.py @@ -30,28 +30,30 @@ ELASTICSEARCH_REST_API_TESTS = [] -@pytest.fixture(scope="session") -def sync_client_factory(elasticsearch_url): - client = None - try: - # Configure the client with certificates and optionally - # an HTTP conn class depending on 'PYTHON_CONNECTION_CLASS' envvar - kw = {} - if "PYTHON_CONNECTION_CLASS" in os.environ: - kw["node_class"] = os.environ["PYTHON_CONNECTION_CLASS"] +def _create(elasticsearch_url, transport=None): + # Configure the client with certificates + kw = {} + if elasticsearch_url.startswith("https://"): + kw["ca_certs"] = CA_CERTS + + # Optionally configure an HTTP conn class depending on + # 'PYTHON_CONNECTION_CLASS' env var + if "PYTHON_CONNECTION_CLASS" in os.environ: + kw["node_class"] = os.environ["PYTHON_CONNECTION_CLASS"] - # Add certificate verification if we're using HTTPS. - if elasticsearch_url.startswith("https"): - kw["ca_certs"] = CA_CERTS + if transport: + kw["transport_class"] = transport - # We do this little dance with the URL to force - # Requests to respect 'headers: None' within rest API spec tests. - client = elasticsearch.Elasticsearch(elasticsearch_url, **kw) + # We do this little dance with the URL to force + # Requests to respect 'headers: None' within rest API spec tests. + return elasticsearch.Elasticsearch(elasticsearch_url, **kw) - # Wipe the cluster before we start testing just in case it wasn't wiped - # cleanly from the previous run of pytest? - wipe_cluster(client) +@pytest.fixture(scope="session") +def sync_client_factory(elasticsearch_url): + client = None + try: + client = _create(elasticsearch_url) yield client finally: if client: diff --git a/test_elasticsearch/test_server/test_vectorstore/_test_utils.py b/test_elasticsearch/test_server/test_vectorstore/_test_utils.py index 68fa9dc91..6b38ed901 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_test_utils.py +++ b/test_elasticsearch/test_server/test_vectorstore/_test_utils.py @@ -15,9 +15,8 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List +from typing import List -from elastic_transport import Transport from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService @@ -28,9 +27,6 @@ class FakeEmbeddings(EmbeddingService): def __init__(self, dimensionality: int = 10) -> None: self.dimensionality = dimensionality - def num_dimensions(self) -> int: - return self.dimensionality - def embed_documents(self, texts: List[str]) -> List[List[float]]: """Return simple embeddings. Embeddings encode each text as its index.""" return [ @@ -55,9 +51,6 @@ def __init__(self, dimensionality: int = 10) -> None: self.known_texts: List[str] = [] self.dimensionality = dimensionality - def num_dimensions(self) -> int: - return self.dimensionality - def embed_documents(self, texts: List[str]) -> List[List[float]]: """Return consistent embeddings for each text seen so far.""" out_vectors = [] @@ -75,13 +68,3 @@ def embed_query(self, text: str) -> List[float]: one if the text is unknown.""" result = self.embed_documents([text]) return result[0] - - -class RequestSavingTransport(Transport): - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.requests: List[Dict] = [] - - def perform_request(self, *args, **kwargs): # type: ignore - self.requests.append(kwargs) - return super().perform_request(*args, **kwargs) diff --git a/test_elasticsearch/test_server/test_vectorstore/conftest.py b/test_elasticsearch/test_server/test_vectorstore/conftest.py index be11547f5..246c728b6 100644 --- a/test_elasticsearch/test_server/test_vectorstore/conftest.py +++ b/test_elasticsearch/test_server/test_vectorstore/conftest.py @@ -15,86 +15,45 @@ # specific language governing permissions and limitations # under the License. -import os import uuid -from typing import Dict, Iterator import pytest +from elastic_transport import Transport -from elasticsearch import Elasticsearch +from ...utils import wipe_cluster +from ..conftest import _create -from ._test_utils import RequestSavingTransport +class RequestSavingTransport(Transport): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.requests: list = [] -@pytest.fixture -def es_client(elasticsearch_url: str) -> Iterator[Elasticsearch]: - client = _create_es_client(elasticsearch_url) + def perform_request(self, *args, **kwargs): + self.requests.append(kwargs) + return super().perform_request(*args, **kwargs) - yield client - - # clear indices - _clear_test_indices(client) - - # clear all test pipelines - try: - response = client.ingest.get_pipeline(id="test_*,*_sparse_embedding") - - for pipeline_id, _ in response.items(): - try: - client.ingest.delete_pipeline(id=pipeline_id) - print(f"Deleted pipeline: {pipeline_id}") # noqa: T201 - except Exception as e: - print(f"Pipeline error: {e}") # noqa: T201 - - except Exception: - pass - finally: - client.close() +@pytest.fixture(scope="function") +def index_name() -> str: + return f"test_{uuid.uuid4().hex}" -@pytest.fixture -def requests_saving_client( - elasticsearch_url: str, -) -> Iterator[Elasticsearch]: - client = _create_es_client( - elasticsearch_url, es_kwargs={"transport_class": RequestSavingTransport} - ) +@pytest.fixture(scope="function") +def sync_client_request_saving_factory(elasticsearch_url): + client = None try: + client = _create(elasticsearch_url, RequestSavingTransport) yield client finally: - client.close() + if client: + client.close() @pytest.fixture(scope="function") -def index_name() -> str: - return f"test_{uuid.uuid4().hex}" - - -def _clear_test_indices(client: Elasticsearch) -> None: - response = client.indices.get(index="_all") - index_names = response.keys() - for index_name in index_names: - if index_name.startswith("test_"): - client.indices.delete(index=index_name) - client.indices.refresh(index="_all") - - -def _create_es_client(elasticsearch_url: str, es_kwargs: Dict = {}) -> Elasticsearch: - if not elasticsearch_url: - elasticsearch_url = os.environ.get("ES_URL", "http://localhost:9200") - cloud_id = os.environ.get("ES_CLOUD_ID") - api_key = os.environ.get("ES_API_KEY") - - if cloud_id: - es_params = {"es_cloud_id": cloud_id, "es_api_key": api_key} - else: - es_params = {"es_url": elasticsearch_url} - - if "es_cloud_id" in es_params: - return Elasticsearch( - cloud_id=es_params["es_cloud_id"], - api_key=es_params["es_api_key"], - **es_kwargs, - ) - return Elasticsearch(hosts=[es_params["es_url"]], **es_kwargs) +def sync_client_request_saving(sync_client_request_saving_factory): + try: + yield sync_client_request_saving_factory + finally: + # Wipe the cluster clean after every test execution. + wipe_cluster(sync_client_request_saving_factory) diff --git a/test_elasticsearch/test_server/test_vectorstore/test_embedding_service.py b/test_elasticsearch/test_server/test_vectorstore/test_embedding_service.py index f9677fd04..85595684c 100644 --- a/test_elasticsearch/test_server/test_vectorstore/test_embedding_service.py +++ b/test_elasticsearch/test_server/test_vectorstore/test_embedding_service.py @@ -29,15 +29,15 @@ NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384")) -def test_elasticsearch_embedding_documents(es_client: Elasticsearch) -> None: +def test_elasticsearch_embedding_documents(sync_client: Elasticsearch) -> None: """Test Elasticsearch embedding documents.""" - if not model_is_deployed(es_client, MODEL_ID): + if not model_is_deployed(sync_client, MODEL_ID): pytest.skip(f"{MODEL_ID} model is not deployed in ML Node, skipping test") documents = ["foo bar", "bar foo", "foo"] embedding = ElasticsearchEmbeddings( - es_client=es_client, user_agent="test", model_id=MODEL_ID + es_client=sync_client, user_agent="test", model_id=MODEL_ID ) output = embedding.embed_documents(documents) assert len(output) == 3 @@ -46,15 +46,15 @@ def test_elasticsearch_embedding_documents(es_client: Elasticsearch) -> None: assert len(output[2]) == NUM_DIMENSIONS -def test_elasticsearch_embedding_query(es_client: Elasticsearch) -> None: +def test_elasticsearch_embedding_query(sync_client: Elasticsearch) -> None: """Test Elasticsearch embedding query.""" - if not model_is_deployed(es_client, MODEL_ID): + if not model_is_deployed(sync_client, MODEL_ID): pytest.skip(f"{MODEL_ID} model is not deployed in ML Node, skipping test") document = "foo bar" embedding = ElasticsearchEmbeddings( - es_client=es_client, user_agent="test", model_id=MODEL_ID + es_client=sync_client, user_agent="test", model_id=MODEL_ID ) output = embedding.embed_query(document) assert len(output) == NUM_DIMENSIONS diff --git a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py index f7eab24dc..b4799b01c 100644 --- a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py @@ -17,7 +17,7 @@ import logging from functools import partial -from typing import Any, List, Optional, Union, cast +from typing import Any, List, Optional, Union import pytest @@ -36,7 +36,6 @@ from ._test_utils import ( ConsistentFakeEmbeddings, FakeEmbeddings, - RequestSavingTransport, ) logging.basicConfig(level=logging.DEBUG) @@ -786,15 +785,16 @@ def test_indexing_exception_error( assert log_message in caplog.text def test_user_agent( - self, requests_saving_client: Elasticsearch, index_name: str + self, sync_client_request_saving: Elasticsearch, index_name: str ) -> None: """Test to make sure the user-agent is set correctly.""" user_agent = "this is THE user_agent!" + store = VectorStore( user_agent=user_agent, index_name=index_name, retrieval_strategy=BM25(), - es_client=requests_saving_client, + es_client=sync_client_request_saving, ) assert store.es_client._headers["User-Agent"] == user_agent @@ -802,18 +802,16 @@ def test_user_agent( texts = ["foo", "bob", "baz"] store.add_texts(texts) - transport = cast(RequestSavingTransport, store.es_client.transport) - - for request in transport.requests: + for request in store.es_client.transport.requests: # type: ignore assert request["headers"]["User-Agent"] == user_agent - def test_bulk_args(self, requests_saving_client: Any, index_name: str) -> None: + def test_bulk_args(self, sync_client_request_saving: Any, index_name: str) -> None: """Test to make sure the bulk arguments work as expected.""" store = VectorStore( user_agent="test", index_name=index_name, retrieval_strategy=BM25(), - es_client=requests_saving_client, + es_client=sync_client_request_saving, ) texts = ["foo", "bob", "baz"] From a99a4f476dd041a215c0613c53de7bee584cca54 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Wed, 24 Apr 2024 13:58:36 +0200 Subject: [PATCH 24/36] fix numpy typing --- elasticsearch/vectorstore/_utils.py | 7 +++++-- .../test_server/test_vectorstore/_test_utils.py | 11 +++++++++++ .../test_server/test_vectorstore/conftest.py | 12 +----------- .../test_server/test_vectorstore/test_vectorstore.py | 5 +---- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/elasticsearch/vectorstore/_utils.py b/elasticsearch/vectorstore/_utils.py index a9b29623a..02e133220 100644 --- a/elasticsearch/vectorstore/_utils.py +++ b/elasticsearch/vectorstore/_utils.py @@ -20,8 +20,11 @@ if TYPE_CHECKING: import numpy as np + import numpy.typing as npt -Matrix = Union[List[List[float]], List["np.ndarray"], "np.ndarray"] +Matrix = Union[ + List[List[float]], List["npt.NDArray[np.float64]"], "npt.NDArray[np.float64]" +] class DistanceMetric(str, Enum): @@ -75,7 +78,7 @@ def maximal_marginal_relevance( return idxs -def _cosine_similarity(X: Matrix, Y: Matrix) -> "np.ndarray": +def _cosine_similarity(X: Matrix, Y: Matrix) -> "npt.NDArray[np.float64]": """Row-wise cosine similarity between two equal-width matrices.""" try: diff --git a/test_elasticsearch/test_server/test_vectorstore/_test_utils.py b/test_elasticsearch/test_server/test_vectorstore/_test_utils.py index 6b38ed901..84937c782 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_test_utils.py +++ b/test_elasticsearch/test_server/test_vectorstore/_test_utils.py @@ -17,10 +17,21 @@ from typing import List +from elastic_transport import Transport from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService +class RequestSavingTransport(Transport): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.requests: list = [] + + def perform_request(self, *args, **kwargs): + self.requests.append(kwargs) + return super().perform_request(*args, **kwargs) + + class FakeEmbeddings(EmbeddingService): """Fake embeddings functionality for testing.""" diff --git a/test_elasticsearch/test_server/test_vectorstore/conftest.py b/test_elasticsearch/test_server/test_vectorstore/conftest.py index 246c728b6..31c8c4def 100644 --- a/test_elasticsearch/test_server/test_vectorstore/conftest.py +++ b/test_elasticsearch/test_server/test_vectorstore/conftest.py @@ -18,20 +18,10 @@ import uuid import pytest -from elastic_transport import Transport from ...utils import wipe_cluster from ..conftest import _create - - -class RequestSavingTransport(Transport): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.requests: list = [] - - def perform_request(self, *args, **kwargs): - self.requests.append(kwargs) - return super().perform_request(*args, **kwargs) +from ._test_utils import RequestSavingTransport @pytest.fixture(scope="function") diff --git a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py index b4799b01c..2ecb852fe 100644 --- a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py @@ -33,10 +33,7 @@ ) from elasticsearch.vectorstore._sync._utils import model_is_deployed -from ._test_utils import ( - ConsistentFakeEmbeddings, - FakeEmbeddings, -) +from ._test_utils import ConsistentFakeEmbeddings, FakeEmbeddings logging.basicConfig(level=logging.DEBUG) From d3c2e629ed7de836b0cb9e7be65613eb076ebd91 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Wed, 24 Apr 2024 14:02:17 +0200 Subject: [PATCH 25/36] add user agent default argument --- .../vectorstore/_async/vectorstore.py | 3 +- .../vectorstore/_sync/vectorstore.py | 3 +- .../test_vectorstore/test_vectorstore.py | 42 +++++++++---------- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/elasticsearch/vectorstore/_async/vectorstore.py b/elasticsearch/vectorstore/_async/vectorstore.py index 87e05b10a..f7126b39b 100644 --- a/elasticsearch/vectorstore/_async/vectorstore.py +++ b/elasticsearch/vectorstore/_async/vectorstore.py @@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, List, Optional from elasticsearch import AsyncElasticsearch +from elasticsearch._version import __versionstr__ as lib_version from elasticsearch.helpers import BulkIndexError, async_bulk from elasticsearch.vectorstore._async.embedding_service import AsyncEmbeddingService from elasticsearch.vectorstore._async.strategies import AsyncRetrievalStrategy @@ -41,7 +42,6 @@ class AsyncVectorStore: def __init__( self, es_client: AsyncElasticsearch, - user_agent: str, index_name: str, retrieval_strategy: AsyncRetrievalStrategy, embedding_service: Optional[AsyncEmbeddingService] = None, @@ -49,6 +49,7 @@ def __init__( text_field: str = "text_field", vector_field: str = "vector_field", metadata_mappings: Optional[Dict[str, Any]] = None, + user_agent: str = f"es-py-vs/{lib_version}", ) -> None: """ Args: diff --git a/elasticsearch/vectorstore/_sync/vectorstore.py b/elasticsearch/vectorstore/_sync/vectorstore.py index 862932eef..35c808ce1 100644 --- a/elasticsearch/vectorstore/_sync/vectorstore.py +++ b/elasticsearch/vectorstore/_sync/vectorstore.py @@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, List, Optional from elasticsearch import Elasticsearch +from elasticsearch._version import __versionstr__ as lib_version from elasticsearch.helpers import BulkIndexError, bulk from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService from elasticsearch.vectorstore._sync.strategies import RetrievalStrategy @@ -41,7 +42,6 @@ class VectorStore: def __init__( self, es_client: Elasticsearch, - user_agent: str, index_name: str, retrieval_strategy: RetrievalStrategy, embedding_service: Optional[EmbeddingService] = None, @@ -49,6 +49,7 @@ def __init__( text_field: str = "text_field", vector_field: str = "vector_field", metadata_mappings: Optional[Dict[str, Any]] = None, + user_agent: str = f"es-py-vs/{lib_version}", ) -> None: """ Args: diff --git a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py index 2ecb852fe..b572f5d84 100644 --- a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py @@ -77,7 +77,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: return query_body store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=DenseVector(), embedding_service=FakeEmbeddings(), @@ -95,7 +94,6 @@ def test_search_without_metadata_async( ) -> None: """Test end to end construction and search without metadata.""" store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=DenseVector(), embedding_service=FakeEmbeddings(), @@ -124,7 +122,6 @@ def test_add_vectors(self, sync_client: Elasticsearch, index_name: str) -> None: embedding_vectors = embeddings.embed_documents(texts) store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=DenseVector(), embedding_service=embeddings, @@ -141,7 +138,6 @@ def test_search_with_metadata( ) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=DenseVector(), embedding_service=ConsistentFakeEmbeddings(), @@ -165,7 +161,6 @@ def test_search_with_filter( ) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=DenseVector(), embedding_service=FakeEmbeddings(), @@ -202,7 +197,6 @@ def test_search_script_score( ) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=DenseVectorScriptScore(), embedding_service=FakeEmbeddings(), @@ -249,7 +243,6 @@ def test_search_script_score_with_filter( ) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=DenseVectorScriptScore(), embedding_service=FakeEmbeddings(), @@ -302,7 +295,6 @@ def test_search_script_score_distance_dot_product( ) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=DenseVectorScriptScore( distance=DistanceMetric.DOT_PRODUCT, @@ -352,7 +344,6 @@ def test_search_knn_with_hybrid_search( ) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=DenseVector(hybrid=True), embedding_service=FakeEmbeddings(), @@ -439,7 +430,6 @@ def assert_query( ] for rrf_test_case in rrf_test_cases: store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=DenseVector(hybrid=True, rrf=rrf_test_case), embedding_service=FakeEmbeddings(), @@ -480,7 +470,6 @@ def assert_query( # 3. check rrf default option is okay store = VectorStore( - user_agent="test", index_name=f"{index_name}_default", retrieval_strategy=DenseVector(hybrid=True), embedding_service=FakeEmbeddings(), @@ -502,7 +491,6 @@ def test_search_knn_with_custom_query_fn( """test that custom query function is called with the query string and query body""" store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=DenseVector(), embedding_service=FakeEmbeddings(), @@ -542,7 +530,6 @@ def test_search_with_knn_infer_instack( text_field = "text_field" store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=DenseVector( model_id="sentence-transformers__all-minilm-l6-v2" @@ -632,7 +619,6 @@ def test_search_with_sparse_infer_instack( pytest.skip(reason) store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=SparseVector(model_id=ELSER_MODEL_ID), es_client=sync_client, @@ -650,7 +636,6 @@ def test_deployed_model_check_fails_semantic( """test that exceptions are raised if a specified model is not deployed""" with pytest.raises(NotFoundError): store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=DenseVector(model_id="non-existing model ID"), es_client=sync_client, @@ -660,7 +645,6 @@ def test_deployed_model_check_fails_semantic( def test_search_bm25(self, sync_client: Elasticsearch, index_name: str) -> None: """Test end to end using the BM25 retrieval strategy.""" store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=BM25(), es_client=sync_client, @@ -688,7 +672,6 @@ def test_search_bm25_with_filter( ) -> None: """Test end to using the BM25 retrieval strategy with metadata.""" store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=BM25(), es_client=sync_client, @@ -721,7 +704,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: def test_delete(self, sync_client: Elasticsearch, index_name: str) -> None: """Test delete methods from vector store.""" store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=DenseVector(), embedding_service=FakeEmbeddings(), @@ -759,7 +741,6 @@ def test_indexing_exception_error( ) -> None: """Test bulk exception logging is giving better hints.""" store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=BM25(), es_client=sync_client, @@ -781,7 +762,25 @@ def test_indexing_exception_error( assert log_message in caplog.text - def test_user_agent( + def test_user_agent_default( + self, sync_client_request_saving: Elasticsearch, index_name: str + ) -> None: + """Test to make sure the user-agent is set correctly.""" + store = VectorStore( + index_name=index_name, + retrieval_strategy=BM25(), + es_client=sync_client_request_saving, + ) + + assert store.es_client._headers["User-Agent"].startswith("es-py-vs/") + + texts = ["foo", "bob", "baz"] + store.add_texts(texts) + + for request in store.es_client.transport.requests: # type: ignore + assert request["headers"]["User-Agent"].startswith("es-py-vs/") + + def test_user_agent_custom( self, sync_client_request_saving: Elasticsearch, index_name: str ) -> None: """Test to make sure the user-agent is set correctly.""" @@ -805,7 +804,6 @@ def test_user_agent( def test_bulk_args(self, sync_client_request_saving: Any, index_name: str) -> None: """Test to make sure the bulk arguments work as expected.""" store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=BM25(), es_client=sync_client_request_saving, @@ -826,7 +824,6 @@ def test_max_marginal_relevance_search( text_field = "text_field" embedding_service = ConsistentFakeEmbeddings() store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=DenseVectorScriptScore(), embedding_service=embedding_service, @@ -888,7 +885,6 @@ def test_metadata_mapping( "another_field": {"type": "text"}, } store = VectorStore( - user_agent="test", index_name=index_name, retrieval_strategy=DenseVector(), embedding_service=FakeEmbeddings(), From 11c882515471fd0011c077139354afdb3e42590c Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Wed, 24 Apr 2024 14:20:14 +0200 Subject: [PATCH 26/36] move to `elasticsearch.helpers.vectorstore` --- .../{ => helpers}/vectorstore/__init__.py | 14 +++++++------- .../{ => helpers}/vectorstore/_async/__init__.py | 0 .../{ => helpers}/vectorstore/_async/_utils.py | 0 .../vectorstore/_async/embedding_service.py | 0 .../{ => helpers}/vectorstore/_async/strategies.py | 4 ++-- .../vectorstore/_async/vectorstore.py | 8 +++++--- .../{ => helpers}/vectorstore/_sync/__init__.py | 0 .../{ => helpers}/vectorstore/_sync/_utils.py | 0 .../vectorstore/_sync/embedding_service.py | 0 .../{ => helpers}/vectorstore/_sync/strategies.py | 4 ++-- .../{ => helpers}/vectorstore/_sync/vectorstore.py | 6 +++--- elasticsearch/{ => helpers}/vectorstore/_utils.py | 0 .../__init__.py | 0 .../_test_utils.py | 2 +- .../conftest.py | 0 .../docker-compose.yml | 0 .../test_embedding_service.py | 6 ++++-- .../test_vectorstore.py | 4 ++-- test_elasticsearch/utils.py | 14 ++++++++++---- utils/run-unasync.py | 4 ++-- 20 files changed, 38 insertions(+), 28 deletions(-) rename elasticsearch/{ => helpers}/vectorstore/__init__.py (75%) rename elasticsearch/{ => helpers}/vectorstore/_async/__init__.py (100%) rename elasticsearch/{ => helpers}/vectorstore/_async/_utils.py (100%) rename elasticsearch/{ => helpers}/vectorstore/_async/embedding_service.py (100%) rename elasticsearch/{ => helpers}/vectorstore/_async/strategies.py (98%) rename elasticsearch/{ => helpers}/vectorstore/_async/vectorstore.py (98%) rename elasticsearch/{ => helpers}/vectorstore/_sync/__init__.py (100%) rename elasticsearch/{ => helpers}/vectorstore/_sync/_utils.py (100%) rename elasticsearch/{ => helpers}/vectorstore/_sync/embedding_service.py (100%) rename elasticsearch/{ => helpers}/vectorstore/_sync/strategies.py (98%) rename elasticsearch/{ => helpers}/vectorstore/_sync/vectorstore.py (98%) rename elasticsearch/{ => helpers}/vectorstore/_utils.py (100%) rename test_elasticsearch/test_server/{test_vectorstore => test_helpers_vectorstore}/__init__.py (100%) rename test_elasticsearch/test_server/{test_vectorstore => test_helpers_vectorstore}/_test_utils.py (97%) rename test_elasticsearch/test_server/{test_vectorstore => test_helpers_vectorstore}/conftest.py (100%) rename test_elasticsearch/test_server/{test_vectorstore => test_helpers_vectorstore}/docker-compose.yml (100%) rename test_elasticsearch/test_server/{test_vectorstore => test_helpers_vectorstore}/test_embedding_service.py (92%) rename test_elasticsearch/test_server/{test_vectorstore => test_helpers_vectorstore}/test_vectorstore.py (99%) diff --git a/elasticsearch/vectorstore/__init__.py b/elasticsearch/helpers/vectorstore/__init__.py similarity index 75% rename from elasticsearch/vectorstore/__init__.py rename to elasticsearch/helpers/vectorstore/__init__.py index a53bde47f..7312d5f4c 100644 --- a/elasticsearch/vectorstore/__init__.py +++ b/elasticsearch/helpers/vectorstore/__init__.py @@ -15,31 +15,31 @@ # specific language governing permissions and limitations # under the License. -from elasticsearch.vectorstore._async.embedding_service import ( +from elasticsearch.helpers.vectorstore._async.embedding_service import ( AsyncElasticsearchEmbeddings, AsyncEmbeddingService, ) -from elasticsearch.vectorstore._async.strategies import ( +from elasticsearch.helpers.vectorstore._async.strategies import ( AsyncBM25, AsyncDenseVector, AsyncDenseVectorScriptScore, AsyncRetrievalStrategy, AsyncSparseVector, ) -from elasticsearch.vectorstore._async.vectorstore import AsyncVectorStore -from elasticsearch.vectorstore._sync.embedding_service import ( +from elasticsearch.helpers.vectorstore._async.vectorstore import AsyncVectorStore +from elasticsearch.helpers.vectorstore._sync.embedding_service import ( ElasticsearchEmbeddings, EmbeddingService, ) -from elasticsearch.vectorstore._sync.strategies import ( +from elasticsearch.helpers.vectorstore._sync.strategies import ( BM25, DenseVector, DenseVectorScriptScore, RetrievalStrategy, SparseVector, ) -from elasticsearch.vectorstore._sync.vectorstore import VectorStore -from elasticsearch.vectorstore._utils import DistanceMetric +from elasticsearch.helpers.vectorstore._sync.vectorstore import VectorStore +from elasticsearch.helpers.vectorstore._utils import DistanceMetric __all__ = [ "BM25", diff --git a/elasticsearch/vectorstore/_async/__init__.py b/elasticsearch/helpers/vectorstore/_async/__init__.py similarity index 100% rename from elasticsearch/vectorstore/_async/__init__.py rename to elasticsearch/helpers/vectorstore/_async/__init__.py diff --git a/elasticsearch/vectorstore/_async/_utils.py b/elasticsearch/helpers/vectorstore/_async/_utils.py similarity index 100% rename from elasticsearch/vectorstore/_async/_utils.py rename to elasticsearch/helpers/vectorstore/_async/_utils.py diff --git a/elasticsearch/vectorstore/_async/embedding_service.py b/elasticsearch/helpers/vectorstore/_async/embedding_service.py similarity index 100% rename from elasticsearch/vectorstore/_async/embedding_service.py rename to elasticsearch/helpers/vectorstore/_async/embedding_service.py diff --git a/elasticsearch/vectorstore/_async/strategies.py b/elasticsearch/helpers/vectorstore/_async/strategies.py similarity index 98% rename from elasticsearch/vectorstore/_async/strategies.py rename to elasticsearch/helpers/vectorstore/_async/strategies.py index f10d664c0..79e1049fe 100644 --- a/elasticsearch/vectorstore/_async/strategies.py +++ b/elasticsearch/helpers/vectorstore/_async/strategies.py @@ -19,8 +19,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union, cast from elasticsearch import AsyncElasticsearch -from elasticsearch.vectorstore._async._utils import model_must_be_deployed -from elasticsearch.vectorstore._utils import DistanceMetric +from elasticsearch.helpers.vectorstore._async._utils import model_must_be_deployed +from elasticsearch.helpers.vectorstore._utils import DistanceMetric class AsyncRetrievalStrategy(ABC): diff --git a/elasticsearch/vectorstore/_async/vectorstore.py b/elasticsearch/helpers/vectorstore/_async/vectorstore.py similarity index 98% rename from elasticsearch/vectorstore/_async/vectorstore.py rename to elasticsearch/helpers/vectorstore/_async/vectorstore.py index f7126b39b..521c05ca4 100644 --- a/elasticsearch/vectorstore/_async/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_async/vectorstore.py @@ -22,9 +22,11 @@ from elasticsearch import AsyncElasticsearch from elasticsearch._version import __versionstr__ as lib_version from elasticsearch.helpers import BulkIndexError, async_bulk -from elasticsearch.vectorstore._async.embedding_service import AsyncEmbeddingService -from elasticsearch.vectorstore._async.strategies import AsyncRetrievalStrategy -from elasticsearch.vectorstore._utils import maximal_marginal_relevance +from elasticsearch.helpers.vectorstore._async.embedding_service import ( + AsyncEmbeddingService, +) +from elasticsearch.helpers.vectorstore._async.strategies import AsyncRetrievalStrategy +from elasticsearch.helpers.vectorstore._utils import maximal_marginal_relevance logger = logging.getLogger(__name__) diff --git a/elasticsearch/vectorstore/_sync/__init__.py b/elasticsearch/helpers/vectorstore/_sync/__init__.py similarity index 100% rename from elasticsearch/vectorstore/_sync/__init__.py rename to elasticsearch/helpers/vectorstore/_sync/__init__.py diff --git a/elasticsearch/vectorstore/_sync/_utils.py b/elasticsearch/helpers/vectorstore/_sync/_utils.py similarity index 100% rename from elasticsearch/vectorstore/_sync/_utils.py rename to elasticsearch/helpers/vectorstore/_sync/_utils.py diff --git a/elasticsearch/vectorstore/_sync/embedding_service.py b/elasticsearch/helpers/vectorstore/_sync/embedding_service.py similarity index 100% rename from elasticsearch/vectorstore/_sync/embedding_service.py rename to elasticsearch/helpers/vectorstore/_sync/embedding_service.py diff --git a/elasticsearch/vectorstore/_sync/strategies.py b/elasticsearch/helpers/vectorstore/_sync/strategies.py similarity index 98% rename from elasticsearch/vectorstore/_sync/strategies.py rename to elasticsearch/helpers/vectorstore/_sync/strategies.py index a9b0df01e..603de51e4 100644 --- a/elasticsearch/vectorstore/_sync/strategies.py +++ b/elasticsearch/helpers/vectorstore/_sync/strategies.py @@ -19,8 +19,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union, cast from elasticsearch import Elasticsearch -from elasticsearch.vectorstore._sync._utils import model_must_be_deployed -from elasticsearch.vectorstore._utils import DistanceMetric +from elasticsearch.helpers.vectorstore._sync._utils import model_must_be_deployed +from elasticsearch.helpers.vectorstore._utils import DistanceMetric class RetrievalStrategy(ABC): diff --git a/elasticsearch/vectorstore/_sync/vectorstore.py b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py similarity index 98% rename from elasticsearch/vectorstore/_sync/vectorstore.py rename to elasticsearch/helpers/vectorstore/_sync/vectorstore.py index 35c808ce1..5944da41e 100644 --- a/elasticsearch/vectorstore/_sync/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py @@ -22,9 +22,9 @@ from elasticsearch import Elasticsearch from elasticsearch._version import __versionstr__ as lib_version from elasticsearch.helpers import BulkIndexError, bulk -from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService -from elasticsearch.vectorstore._sync.strategies import RetrievalStrategy -from elasticsearch.vectorstore._utils import maximal_marginal_relevance +from elasticsearch.helpers.vectorstore._sync.embedding_service import EmbeddingService +from elasticsearch.helpers.vectorstore._sync.strategies import RetrievalStrategy +from elasticsearch.helpers.vectorstore._utils import maximal_marginal_relevance logger = logging.getLogger(__name__) diff --git a/elasticsearch/vectorstore/_utils.py b/elasticsearch/helpers/vectorstore/_utils.py similarity index 100% rename from elasticsearch/vectorstore/_utils.py rename to elasticsearch/helpers/vectorstore/_utils.py diff --git a/test_elasticsearch/test_server/test_vectorstore/__init__.py b/test_elasticsearch/test_server/test_helpers_vectorstore/__init__.py similarity index 100% rename from test_elasticsearch/test_server/test_vectorstore/__init__.py rename to test_elasticsearch/test_server/test_helpers_vectorstore/__init__.py diff --git a/test_elasticsearch/test_server/test_vectorstore/_test_utils.py b/test_elasticsearch/test_server/test_helpers_vectorstore/_test_utils.py similarity index 97% rename from test_elasticsearch/test_server/test_vectorstore/_test_utils.py rename to test_elasticsearch/test_server/test_helpers_vectorstore/_test_utils.py index 84937c782..f855c1d19 100644 --- a/test_elasticsearch/test_server/test_vectorstore/_test_utils.py +++ b/test_elasticsearch/test_server/test_helpers_vectorstore/_test_utils.py @@ -19,7 +19,7 @@ from elastic_transport import Transport -from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService +from elasticsearch.helpers.vectorstore._sync.embedding_service import EmbeddingService class RequestSavingTransport(Transport): diff --git a/test_elasticsearch/test_server/test_vectorstore/conftest.py b/test_elasticsearch/test_server/test_helpers_vectorstore/conftest.py similarity index 100% rename from test_elasticsearch/test_server/test_vectorstore/conftest.py rename to test_elasticsearch/test_server/test_helpers_vectorstore/conftest.py diff --git a/test_elasticsearch/test_server/test_vectorstore/docker-compose.yml b/test_elasticsearch/test_server/test_helpers_vectorstore/docker-compose.yml similarity index 100% rename from test_elasticsearch/test_server/test_vectorstore/docker-compose.yml rename to test_elasticsearch/test_server/test_helpers_vectorstore/docker-compose.yml diff --git a/test_elasticsearch/test_server/test_vectorstore/test_embedding_service.py b/test_elasticsearch/test_server/test_helpers_vectorstore/test_embedding_service.py similarity index 92% rename from test_elasticsearch/test_server/test_vectorstore/test_embedding_service.py rename to test_elasticsearch/test_server/test_helpers_vectorstore/test_embedding_service.py index 85595684c..a2e6dfc20 100644 --- a/test_elasticsearch/test_server/test_vectorstore/test_embedding_service.py +++ b/test_elasticsearch/test_server/test_helpers_vectorstore/test_embedding_service.py @@ -20,8 +20,10 @@ import pytest from elasticsearch import Elasticsearch -from elasticsearch.vectorstore._sync._utils import model_is_deployed -from elasticsearch.vectorstore._sync.embedding_service import ElasticsearchEmbeddings +from elasticsearch.helpers.vectorstore._sync._utils import model_is_deployed +from elasticsearch.helpers.vectorstore._sync.embedding_service import ( + ElasticsearchEmbeddings, +) # deployed with # https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html diff --git a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py b/test_elasticsearch/test_server/test_helpers_vectorstore/test_vectorstore.py similarity index 99% rename from test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py rename to test_elasticsearch/test_server/test_helpers_vectorstore/test_vectorstore.py index b572f5d84..a8ccb3f94 100644 --- a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_helpers_vectorstore/test_vectorstore.py @@ -23,7 +23,7 @@ from elasticsearch import Elasticsearch, NotFoundError from elasticsearch.helpers import BulkIndexError -from elasticsearch.vectorstore import ( +from elasticsearch.helpers.vectorstore import ( BM25, DenseVector, DenseVectorScriptScore, @@ -31,7 +31,7 @@ SparseVector, VectorStore, ) -from elasticsearch.vectorstore._sync._utils import model_is_deployed +from elasticsearch.helpers.vectorstore._sync._utils import model_is_deployed from ._test_utils import ConsistentFakeEmbeddings, FakeEmbeddings diff --git a/test_elasticsearch/utils.py b/test_elasticsearch/utils.py index f4f8cc885..0d27b2222 100644 --- a/test_elasticsearch/utils.py +++ b/test_elasticsearch/utils.py @@ -215,10 +215,16 @@ def wipe_data_streams(client): def wipe_indices(client): - client.options(ignore_status=404).indices.delete( - index="*,-.ds-ilm-history-*", - expand_wildcards="all", - ) + response = client.indices.get(index="_all") + index_names = response.keys() + for index_name in index_names: + if index_name.startswith("test_"): + client.indices.delete(index=index_name) + client.indices.refresh(index="_all") + # client.options(ignore_status=404).indices.delete( + # index="*,-.ds-ilm-history-*", + # expand_wildcards="all", + # ) def wipe_searchable_snapshot_indices(client): diff --git a/utils/run-unasync.py b/utils/run-unasync.py index 80d12b427..53aa39eed 100644 --- a/utils/run-unasync.py +++ b/utils/run-unasync.py @@ -83,8 +83,8 @@ def main(): run( rule=unasync.Rule( - fromdir="elasticsearch/vectorstore/_async/", - todir="elasticsearch/vectorstore/_sync/", + fromdir="elasticsearch/helpers/vectorstore/_async/", + todir="elasticsearch/helpers/vectorstore/_sync/", additional_replacements={ "AsyncBM25": "BM25", "AsyncDenseVector": "DenseVector", From 0d94881bd841616ed0dfd767ff5a04a17b177b5a Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Wed, 24 Apr 2024 14:25:07 +0200 Subject: [PATCH 27/36] use Protocol over ABC --- .../vectorstore/_async/embedding_service.py | 7 ++----- .../helpers/vectorstore/_async/strategies.py | 7 ++----- .../helpers/vectorstore/_sync/embedding_service.py | 7 ++----- .../helpers/vectorstore/_sync/strategies.py | 7 ++----- test_elasticsearch/utils.py | 14 ++++---------- utils/run-unasync.py | 14 +++++++------- 6 files changed, 19 insertions(+), 37 deletions(-) diff --git a/elasticsearch/helpers/vectorstore/_async/embedding_service.py b/elasticsearch/helpers/vectorstore/_async/embedding_service.py index 1027d4f12..84d39d406 100644 --- a/elasticsearch/helpers/vectorstore/_async/embedding_service.py +++ b/elasticsearch/helpers/vectorstore/_async/embedding_service.py @@ -15,14 +15,12 @@ # specific language governing permissions and limitations # under the License. -from abc import ABC, abstractmethod -from typing import List, Optional +from typing import List, Optional, Protocol from elasticsearch import AsyncElasticsearch -class AsyncEmbeddingService(ABC): - @abstractmethod +class AsyncEmbeddingService(Protocol): async def embed_documents(self, texts: List[str]) -> List[List[float]]: """Generate embeddings for a list of documents. @@ -33,7 +31,6 @@ async def embed_documents(self, texts: List[str]) -> List[List[float]]: A list of embeddings, one for each document in the input. """ - @abstractmethod async def embed_query(self, query: str) -> List[float]: """Generate an embedding for a single query text. diff --git a/elasticsearch/helpers/vectorstore/_async/strategies.py b/elasticsearch/helpers/vectorstore/_async/strategies.py index 79e1049fe..2fd14adc8 100644 --- a/elasticsearch/helpers/vectorstore/_async/strategies.py +++ b/elasticsearch/helpers/vectorstore/_async/strategies.py @@ -15,16 +15,14 @@ # specific language governing permissions and limitations # under the License. -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Protocol, Tuple, Union, cast from elasticsearch import AsyncElasticsearch from elasticsearch.helpers.vectorstore._async._utils import model_must_be_deployed from elasticsearch.helpers.vectorstore._utils import DistanceMetric -class AsyncRetrievalStrategy(ABC): - @abstractmethod +class AsyncRetrievalStrategy(Protocol): def es_query( self, query: Optional[str], @@ -50,7 +48,6 @@ def es_query( Dict: The Elasticsearch query body. """ - @abstractmethod def es_mappings_settings( self, text_field: str, diff --git a/elasticsearch/helpers/vectorstore/_sync/embedding_service.py b/elasticsearch/helpers/vectorstore/_sync/embedding_service.py index 272c34214..6373646db 100644 --- a/elasticsearch/helpers/vectorstore/_sync/embedding_service.py +++ b/elasticsearch/helpers/vectorstore/_sync/embedding_service.py @@ -15,14 +15,12 @@ # specific language governing permissions and limitations # under the License. -from abc import ABC, abstractmethod -from typing import List, Optional +from typing import List, Optional, Protocol from elasticsearch import Elasticsearch -class EmbeddingService(ABC): - @abstractmethod +class EmbeddingService(Protocol): def embed_documents(self, texts: List[str]) -> List[List[float]]: """Generate embeddings for a list of documents. @@ -33,7 +31,6 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: A list of embeddings, one for each document in the input. """ - @abstractmethod def embed_query(self, query: str) -> List[float]: """Generate an embedding for a single query text. diff --git a/elasticsearch/helpers/vectorstore/_sync/strategies.py b/elasticsearch/helpers/vectorstore/_sync/strategies.py index 603de51e4..0ed538894 100644 --- a/elasticsearch/helpers/vectorstore/_sync/strategies.py +++ b/elasticsearch/helpers/vectorstore/_sync/strategies.py @@ -15,16 +15,14 @@ # specific language governing permissions and limitations # under the License. -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Protocol, Tuple, Union, cast from elasticsearch import Elasticsearch from elasticsearch.helpers.vectorstore._sync._utils import model_must_be_deployed from elasticsearch.helpers.vectorstore._utils import DistanceMetric -class RetrievalStrategy(ABC): - @abstractmethod +class RetrievalStrategy(Protocol): def es_query( self, query: Optional[str], @@ -50,7 +48,6 @@ def es_query( Dict: The Elasticsearch query body. """ - @abstractmethod def es_mappings_settings( self, text_field: str, diff --git a/test_elasticsearch/utils.py b/test_elasticsearch/utils.py index 0d27b2222..f4f8cc885 100644 --- a/test_elasticsearch/utils.py +++ b/test_elasticsearch/utils.py @@ -215,16 +215,10 @@ def wipe_data_streams(client): def wipe_indices(client): - response = client.indices.get(index="_all") - index_names = response.keys() - for index_name in index_names: - if index_name.startswith("test_"): - client.indices.delete(index=index_name) - client.indices.refresh(index="_all") - # client.options(ignore_status=404).indices.delete( - # index="*,-.ds-ilm-history-*", - # expand_wildcards="all", - # ) + client.options(ignore_status=404).indices.delete( + index="*,-.ds-ilm-history-*", + expand_wildcards="all", + ) def wipe_searchable_snapshot_indices(client): diff --git a/utils/run-unasync.py b/utils/run-unasync.py index 53aa39eed..51c041974 100644 --- a/utils/run-unasync.py +++ b/utils/run-unasync.py @@ -24,12 +24,11 @@ def cleanup(source_dir: Path, output_dir: Path, patterns: list[str]): - if patterns: - for file in glob("*.py", root_dir=source_dir): - path = Path(output_dir) / file - for pattern in patterns: - subprocess.check_call(["sed", "-i.bak", pattern, str(path)]) - subprocess.check_call(["rm", f"{path}.bak"]) + for file in glob("*.py", root_dir=source_dir): + path = Path(output_dir) / file + for pattern in patterns: + subprocess.check_call(["sed", "-i.bak", pattern, str(path)]) + subprocess.check_call(["rm", f"{path}.bak"]) def format_dir(dir: Path): @@ -57,7 +56,8 @@ def run( unasync.unasync_files(filepaths, [rule]) - cleanup(source_dir, output_dir, cleanup_patterns) + if cleanup_patterns: + cleanup(source_dir, output_dir, cleanup_patterns) if format: format_dir(source_dir) From 6aa6d736656b4feaf4971b652c281ba03cfeeb73 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Wed, 24 Apr 2024 14:41:57 +0200 Subject: [PATCH 28/36] revert Protocol change because Python 3.7 --- .../helpers/vectorstore/_async/embedding_service.py | 7 +++++-- elasticsearch/helpers/vectorstore/_async/strategies.py | 7 +++++-- .../helpers/vectorstore/_sync/embedding_service.py | 7 +++++-- elasticsearch/helpers/vectorstore/_sync/strategies.py | 7 +++++-- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/elasticsearch/helpers/vectorstore/_async/embedding_service.py b/elasticsearch/helpers/vectorstore/_async/embedding_service.py index 84d39d406..1027d4f12 100644 --- a/elasticsearch/helpers/vectorstore/_async/embedding_service.py +++ b/elasticsearch/helpers/vectorstore/_async/embedding_service.py @@ -15,12 +15,14 @@ # specific language governing permissions and limitations # under the License. -from typing import List, Optional, Protocol +from abc import ABC, abstractmethod +from typing import List, Optional from elasticsearch import AsyncElasticsearch -class AsyncEmbeddingService(Protocol): +class AsyncEmbeddingService(ABC): + @abstractmethod async def embed_documents(self, texts: List[str]) -> List[List[float]]: """Generate embeddings for a list of documents. @@ -31,6 +33,7 @@ async def embed_documents(self, texts: List[str]) -> List[List[float]]: A list of embeddings, one for each document in the input. """ + @abstractmethod async def embed_query(self, query: str) -> List[float]: """Generate an embedding for a single query text. diff --git a/elasticsearch/helpers/vectorstore/_async/strategies.py b/elasticsearch/helpers/vectorstore/_async/strategies.py index 2fd14adc8..79e1049fe 100644 --- a/elasticsearch/helpers/vectorstore/_async/strategies.py +++ b/elasticsearch/helpers/vectorstore/_async/strategies.py @@ -15,14 +15,16 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Protocol, Tuple, Union, cast +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Tuple, Union, cast from elasticsearch import AsyncElasticsearch from elasticsearch.helpers.vectorstore._async._utils import model_must_be_deployed from elasticsearch.helpers.vectorstore._utils import DistanceMetric -class AsyncRetrievalStrategy(Protocol): +class AsyncRetrievalStrategy(ABC): + @abstractmethod def es_query( self, query: Optional[str], @@ -48,6 +50,7 @@ def es_query( Dict: The Elasticsearch query body. """ + @abstractmethod def es_mappings_settings( self, text_field: str, diff --git a/elasticsearch/helpers/vectorstore/_sync/embedding_service.py b/elasticsearch/helpers/vectorstore/_sync/embedding_service.py index 6373646db..272c34214 100644 --- a/elasticsearch/helpers/vectorstore/_sync/embedding_service.py +++ b/elasticsearch/helpers/vectorstore/_sync/embedding_service.py @@ -15,12 +15,14 @@ # specific language governing permissions and limitations # under the License. -from typing import List, Optional, Protocol +from abc import ABC, abstractmethod +from typing import List, Optional from elasticsearch import Elasticsearch -class EmbeddingService(Protocol): +class EmbeddingService(ABC): + @abstractmethod def embed_documents(self, texts: List[str]) -> List[List[float]]: """Generate embeddings for a list of documents. @@ -31,6 +33,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: A list of embeddings, one for each document in the input. """ + @abstractmethod def embed_query(self, query: str) -> List[float]: """Generate an embedding for a single query text. diff --git a/elasticsearch/helpers/vectorstore/_sync/strategies.py b/elasticsearch/helpers/vectorstore/_sync/strategies.py index 0ed538894..603de51e4 100644 --- a/elasticsearch/helpers/vectorstore/_sync/strategies.py +++ b/elasticsearch/helpers/vectorstore/_sync/strategies.py @@ -15,14 +15,16 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Protocol, Tuple, Union, cast +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Tuple, Union, cast from elasticsearch import Elasticsearch from elasticsearch.helpers.vectorstore._sync._utils import model_must_be_deployed from elasticsearch.helpers.vectorstore._utils import DistanceMetric -class RetrievalStrategy(Protocol): +class RetrievalStrategy(ABC): + @abstractmethod def es_query( self, query: Optional[str], @@ -48,6 +50,7 @@ def es_query( Dict: The Elasticsearch query body. """ + @abstractmethod def es_mappings_settings( self, text_field: str, From 71ca330449129ff0d78b1950a42e51726caa51e1 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Thu, 25 Apr 2024 12:01:35 +0200 Subject: [PATCH 29/36] address PR feedback: - Strategy suffix - Sphinx docstrings - add user agent to EmbeddingService - raise ConflictError - various cleanup --- elasticsearch/helpers/vectorstore/__init__.py | 32 ++--- .../helpers/vectorstore/_async/_utils.py | 28 ++--- .../vectorstore/_async/embedding_service.py | 42 +++---- .../helpers/vectorstore/_async/strategies.py | 50 ++++---- .../helpers/vectorstore/_async/vectorstore.py | 116 +++++++++--------- .../helpers/vectorstore/_sync/_utils.py | 23 ++-- .../vectorstore/_sync/embedding_service.py | 42 +++---- .../helpers/vectorstore/_sync/strategies.py | 50 ++++---- .../helpers/vectorstore/_sync/vectorstore.py | 116 +++++++++--------- test_elasticsearch/test_server/conftest.py | 5 + .../test_helpers_vectorstore/conftest.py | 11 ++ .../docker-compose.yml | 15 --- .../test_embedding_service.py | 33 ++++- .../test_vectorstore.py | 74 ++++++----- utils/run-unasync.py | 10 +- 15 files changed, 322 insertions(+), 325 deletions(-) diff --git a/elasticsearch/helpers/vectorstore/__init__.py b/elasticsearch/helpers/vectorstore/__init__.py index 7312d5f4c..815f3f2ee 100644 --- a/elasticsearch/helpers/vectorstore/__init__.py +++ b/elasticsearch/helpers/vectorstore/__init__.py @@ -20,11 +20,11 @@ AsyncEmbeddingService, ) from elasticsearch.helpers.vectorstore._async.strategies import ( - AsyncBM25, - AsyncDenseVector, - AsyncDenseVectorScriptScore, + AsyncBM25Strategy, + AsyncDenseVectorScriptScoreStrategy, + AsyncDenseVectorStrategy, AsyncRetrievalStrategy, - AsyncSparseVector, + AsyncSparseVectorStrategy, ) from elasticsearch.helpers.vectorstore._async.vectorstore import AsyncVectorStore from elasticsearch.helpers.vectorstore._sync.embedding_service import ( @@ -32,31 +32,31 @@ EmbeddingService, ) from elasticsearch.helpers.vectorstore._sync.strategies import ( - BM25, - DenseVector, - DenseVectorScriptScore, + BM25Strategy, + DenseVectorScriptScoreStrategy, + DenseVectorStrategy, RetrievalStrategy, - SparseVector, + SparseVectorStrategy, ) from elasticsearch.helpers.vectorstore._sync.vectorstore import VectorStore from elasticsearch.helpers.vectorstore._utils import DistanceMetric __all__ = [ - "BM25", - "DenseVector", - "DenseVectorScriptScore", + "BM25Strategy", + "DenseVectorStrategy", + "DenseVectorScriptScoreStrategy", "ElasticsearchEmbeddings", "EmbeddingService", "RetrievalStrategy", - "SparseVector", + "SparseVectorStrategy", "VectorStore", - "AsyncBM25", - "AsyncDenseVector", - "AsyncDenseVectorScriptScore", + "AsyncBM25Strategy", + "AsyncDenseVectorStrategy", + "AsyncDenseVectorScriptScoreStrategy", "AsyncElasticsearchEmbeddings", "AsyncEmbeddingService", "AsyncRetrievalStrategy", - "AsyncSparseVector", + "AsyncSparseVectorStrategy", "AsyncVectorStore", "DistanceMetric", ] diff --git a/elasticsearch/helpers/vectorstore/_async/_utils.py b/elasticsearch/helpers/vectorstore/_async/_utils.py index ad8794def..2305b1448 100644 --- a/elasticsearch/helpers/vectorstore/_async/_utils.py +++ b/elasticsearch/helpers/vectorstore/_async/_utils.py @@ -15,33 +15,21 @@ # specific language governing permissions and limitations # under the License. -from elasticsearch import ( - AsyncElasticsearch, - BadRequestError, - ConflictError, - NotFoundError, -) +from elasticsearch import AsyncElasticsearch, BadRequestError, NotFoundError async def model_must_be_deployed(client: AsyncElasticsearch, model_id: str) -> None: + """ + :raises [NotFoundError]: if the model is neither downloaded nor deployed. + :raises [ConflictError]: if the model is downloaded but not yet deployed. + """ + doc = {"text_field": f"test if the model '{model_id}' is deployed"} try: - dummy = {"x": "y"} - await client.ml.infer_trained_model(model_id=model_id, docs=[dummy]) - except NotFoundError as err: - raise err - except ConflictError as err: - raise NotFoundError( - f"model '{model_id}' not found, please deploy it first", - meta=err.meta, - body=err.body, - ) from err + await client.ml.infer_trained_model(model_id=model_id, docs=[doc]) except BadRequestError: - # This error is expected because we do not know the expected document - # shape and just use a dummy doc above. + # The model is deployed but expects a different input field name. pass - return None - async def model_is_deployed(es_client: AsyncElasticsearch, model_id: str) -> bool: try: diff --git a/elasticsearch/helpers/vectorstore/_async/embedding_service.py b/elasticsearch/helpers/vectorstore/_async/embedding_service.py index 1027d4f12..e86304ed4 100644 --- a/elasticsearch/helpers/vectorstore/_async/embedding_service.py +++ b/elasticsearch/helpers/vectorstore/_async/embedding_service.py @@ -16,9 +16,10 @@ # under the License. from abc import ABC, abstractmethod -from typing import List, Optional +from typing import List from elasticsearch import AsyncElasticsearch +from elasticsearch._version import __versionstr__ as lib_version class AsyncEmbeddingService(ABC): @@ -26,22 +27,18 @@ class AsyncEmbeddingService(ABC): async def embed_documents(self, texts: List[str]) -> List[List[float]]: """Generate embeddings for a list of documents. - Args: - texts: A list of document strings to generate embeddings for. + :param texts: A list of document strings to generate embeddings for. - Returns: - A list of embeddings, one for each document in the input. + :return: A list of embeddings, one for each document in the input. """ @abstractmethod async def embed_query(self, query: str) -> List[float]: """Generate an embedding for a single query text. - Args: - text: The query text to generate an embedding for. + :param text: The query text to generate an embedding for. - Returns: - The embedding for the input query text. + :return: The embedding for the input query text. """ @@ -56,31 +53,26 @@ class AsyncElasticsearchEmbeddings(AsyncEmbeddingService): def __init__( self, es_client: AsyncElasticsearch, - user_agent: str, model_id: str, input_field: str = "text_field", - num_dimensions: Optional[int] = None, + user_agent: str = f"elasticsearch-py-es/{lib_version}", ): """ - Args: - agent_header: user agent header specific to the 3rd party integration. - Used for usage tracking in Elastic Cloud. - model_id: The model_id of the model deployed in the Elasticsearch cluster. - input_field: The name of the key for the input text field in the - document. Defaults to 'text_field'. - num_dimensions: The number of embedding dimensions. If None, then dimensions - will be infer from an example inference call. - es_client: Elasticsearch client connection. Alternatively specify the - Elasticsearch connection with the other es_* parameters. + :param agent_header: user agent header specific to the 3rd party integration. + Used for usage tracking in Elastic Cloud. + :param model_id: The model_id of the model deployed in the Elasticsearch cluster. + :param input_field: The name of the key for the input text field in the + document. Defaults to 'text_field'. + :param es_client: Elasticsearch client connection. Alternatively specify the + Elasticsearch connection with the other es_* parameters. """ # Add integration-specific usage header for tracking usage in Elastic Cloud. - # client.options preserces existing (non-user-agent) headers. + # client.options preserves existing (non-user-agent) headers. es_client = es_client.options(headers={"User-Agent": user_agent}) - self.client = es_client.ml + self.es_client = es_client self.model_id = model_id self.input_field = input_field - self._num_dimensions = num_dimensions async def embed_documents(self, texts: List[str]) -> List[List[float]]: result = await self._embedding_func(texts) @@ -91,7 +83,7 @@ async def embed_query(self, text: str) -> List[float]: return result[0] async def _embedding_func(self, texts: List[str]) -> List[List[float]]: - response = await self.client.infer_trained_model( + response = await self.es_client.ml.infer_trained_model( model_id=self.model_id, docs=[{self.input_field: text} for text in texts] ) return [doc["predicted_value"] for doc in response["inference_results"]] diff --git a/elasticsearch/helpers/vectorstore/_async/strategies.py b/elasticsearch/helpers/vectorstore/_async/strategies.py index 79e1049fe..29c1f58d7 100644 --- a/elasticsearch/helpers/vectorstore/_async/strategies.py +++ b/elasticsearch/helpers/vectorstore/_async/strategies.py @@ -39,15 +39,13 @@ def es_query( Returns the Elasticsearch query body for the given parameters. The store will execute the query. - Args: - query: The text query. Can be None if query_vector is given. - k: The total number of results to retrieve. - num_candidates: The number of results to fetch initially in knn search. - filter: List of filter clauses to apply to the query. - query_vector: The query vector. Can be None if a query string is given. - - Returns: - Dict: The Elasticsearch query body. + :param query: The text query. Can be None if query_vector is given. + :param k: The total number of results to retrieve. + :param num_candidates: The number of results to fetch initially in knn search. + :param filter: List of filter clauses to apply to the query. + :param query_vector: The query vector. Can be None if a query string is given. + + :return: The Elasticsearch query body. """ @abstractmethod @@ -61,11 +59,10 @@ def es_mappings_settings( Create the required index and do necessary preliminary work, like creating inference pipelines or checking if a required model was deployed. - Args: - client: Elasticsearch client connection. - index_name: The name of the Elasticsearch index to create. - metadata_mapping: Flat dictionary with field and field type pairs that - describe the schema of the metadata. + :param client: Elasticsearch client connection. + :param index_name: The name of the Elasticsearch index to create. + :param metadata_mapping: Flat dictionary with field and field type pairs that + describe the schema of the metadata. """ async def before_index_creation( @@ -74,22 +71,27 @@ async def before_index_creation( """ Executes before the index is created. Used for setting up any required Elasticsearch resources like a pipeline. + Defaults to a no-op. - Args: - client: The Elasticsearch client. - text_field: The field containing the text data in the index. - vector_field: The field containing the vector representations in the index. + :param client: The Elasticsearch client. + :param text_field: The field containing the text data in the index. + :param vector_field: The field containing the vector representations in the index. """ pass def needs_inference(self) -> bool: """ - TODO + Some retrieval strategies index embedding vectors and allow search by embedding + vector, for example the `DenseVectorStrategy` strategy. Mapping a user input query + string to an embedding vector is called inference. Inference can be applied + in Elasticsearch (using a `model_id`) or outside of Elasticsearch (using an + `EmbeddingService` defined on the `VectorStore`). In the latter case, + this method has to return True. """ return False -class AsyncSparseVector(AsyncRetrievalStrategy): +class AsyncSparseVectorStrategy(AsyncRetrievalStrategy): """Sparse retrieval strategy using the `text_expansion` processor.""" def __init__(self, model_id: str = ".elser_model_2"): @@ -176,7 +178,7 @@ async def before_index_creation( ) -class AsyncDenseVector(AsyncRetrievalStrategy): +class AsyncDenseVectorStrategy(AsyncRetrievalStrategy): """K-nearest-neighbors retrieval.""" def __init__( @@ -189,7 +191,7 @@ def __init__( ): if hybrid and not text_field: raise ValueError( - "to enable hybrid you have to specify a text_field (for BM25 matching)" + "to enable hybrid you have to specify a text_field (for BM25Strategy matching)" ) self.distance = distance @@ -304,7 +306,7 @@ def needs_inference(self) -> bool: return not self.model_id -class AsyncDenseVectorScriptScore(AsyncRetrievalStrategy): +class AsyncDenseVectorScriptScoreStrategy(AsyncRetrievalStrategy): """Exact nearest neighbors retrieval using the `script_score` query.""" def __init__(self, distance: DistanceMetric = DistanceMetric.COSINE) -> None: @@ -383,7 +385,7 @@ def needs_inference(self) -> bool: return True -class AsyncBM25(AsyncRetrievalStrategy): +class AsyncBM25Strategy(AsyncRetrievalStrategy): def __init__( self, k1: Optional[float] = None, diff --git a/elasticsearch/helpers/vectorstore/_async/vectorstore.py b/elasticsearch/helpers/vectorstore/_async/vectorstore.py index 521c05ca4..829610400 100644 --- a/elasticsearch/helpers/vectorstore/_async/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_async/vectorstore.py @@ -37,7 +37,7 @@ class AsyncVectorStore: Documents are flat text documents. Depending on the strategy, vector embeddings are - created by the user beforehand - - created by this class in Python + - created by this AsyncVectorStore class in Python - created in-stack by inference pipelines. """ @@ -51,23 +51,22 @@ def __init__( text_field: str = "text_field", vector_field: str = "vector_field", metadata_mappings: Optional[Dict[str, Any]] = None, - user_agent: str = f"es-py-vs/{lib_version}", + user_agent: str = f"elasticsearch-py-vs/{lib_version}", ) -> None: """ - Args: - user_header: user agent header specific to the 3rd party integration. - Used for usage tracking in Elastic Cloud. - index_name: The name of the index to query. - retrieval_strategy: how to index and search the data. See the strategies - module for availble strategies. - text_field: Name of the field with the textual data. - vector_field: For strategies that perform embedding inference in Python, - the embedding vector goes in this field. - es_client: Elasticsearch client connection. Alternatively specify the - Elasticsearch connection with the other es_* parameters. + :param user_header: user agent header specific to the 3rd party integration. + Used for usage tracking in Elastic Cloud. + :param index_name: The name of the index to query. + :param retrieval_strategy: how to index and search the data. See the strategies + module for availble strategies. + :param text_field: Name of the field with the textual data. + :param vector_field: For strategies that perform embedding inference in Python, + the embedding vector goes in this field. + :param es_client: Elasticsearch client connection. Alternatively specify the + Elasticsearch connection with the other es_* parameters. """ # Add integration-specific usage header for tracking usage in Elastic Cloud. - # client.options preserces existing (non-user-agent) headers. + # client.options preserves existing (non-user-agent) headers. es_client = es_client.options(headers={"User-Agent": user_agent}) if hasattr(retrieval_strategy, "text_field"): @@ -99,23 +98,21 @@ async def add_texts( ) -> List[str]: """Add documents to the Elasticsearch index. - Args: - texts: List of text documents. - metadata: Optional list of document metadata. Must be of same length as - texts. - vectors: Optional list of embedding vectors. Must be of same length as - texts. - ids: Optional list of ID strings. Must be of same length as texts. - refresh_indices: Whether to refresh the index after deleting documents. - Defaults to True. - create_index_if_not_exists: Whether to create the index if it does not - exist. Defaults to True. - bulk_kwargs: Arguments to pass to the bulk function when indexing - (for example chunk_size). - - Returns: - List of IDs of the created documents, either echoing the provided one - or returning newly created ones. + :param texts: List of text documents. + :param metadata: Optional list of document metadata. Must be of same length as + texts. + :param vectors: Optional list of embedding vectors. Must be of same length as + texts. + :param ids: Optional list of ID strings. Must be of same length as texts. + :param refresh_indices: Whether to refresh the index after deleting documents. + Defaults to True. + :param create_index_if_not_exists: Whether to create the index if it does not + exist. Defaults to True. + :param bulk_kwargs: Arguments to pass to the bulk function when indexing + (for example chunk_size). + + :return: List of IDs of the created documents, either echoing the provided one + or returning newly created ones. """ bulk_kwargs = bulk_kwargs or {} ids = ids or [str(uuid.uuid4()) for _ in texts] @@ -173,10 +170,11 @@ async def delete( # type: ignore[no-untyped-def] ) -> bool: """Delete documents from the Elasticsearch index. - Args: - ids: List of IDs of documents to delete. - refresh_indices: Whether to refresh the index after deleting documents. - Defaults to True. + :param ids: List of IDs of documents to delete. + :param refresh_indices: Whether to refresh the index after deleting documents. + Defaults to True. + + :return: True if deletion was successful. """ if ids is not None and query is not None: raise ValueError("one of ids or query must be specified") @@ -227,19 +225,17 @@ async def search( ] = None, ) -> List[Dict[str, Any]]: """ - Args: - query: Input query string. - query_vector: Input embedding vector. If given, input query string is - ignored. - k: Number of returned results. - num_candidates: Number of candidates to fetch from data nodes in knn. - fields: List of field names to return. - filter: Elasticsearch filters to apply. - custom_query: Function to modify the Elasticsearch query body before it is - sent to Elasticsearch. - - Returns: - List of document hits. Includes _index, _id, _score and _source. + :param query: Input query string. + :param query_vector: Input embedding vector. If given, input query string is + ignored. + :param k: Number of returned results. + :param num_candidates: Number of candidates to fetch from data nodes in knn. + :param fields: List of field names to return. + :param filter: Elasticsearch filters to apply. + :param custom_query: Function to modify the Elasticsearch query body before it is + sent to Elasticsearch. + + :return: List of document hits. Includes _index, _id, _score and _source. """ if fields is None: fields = [] @@ -334,19 +330,17 @@ async def max_marginal_relevance_search( Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents. - Args: - query (str): Text to look up documents similar to. - k (int): Number of Documents to return. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - lambda_mult (float): Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - fields: Other fields to get from elasticsearch source. These fields - will be added to the document metadata. - - Returns: - List[Document]: A list of Documents selected by maximal marginal relevance. + :param query (str): Text to look up documents similar to. + :param k (int): Number of Documents to return. Defaults to 4. + :param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + :param lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + :param fields: Other fields to get from elasticsearch source. These fields + will be added to the document metadata. + + :return: A list of Documents selected by maximal marginal relevance. """ remove_vector_query_field_from_metadata = True if fields is None: diff --git a/elasticsearch/helpers/vectorstore/_sync/_utils.py b/elasticsearch/helpers/vectorstore/_sync/_utils.py index ad77be5aa..dba9bdcd4 100644 --- a/elasticsearch/helpers/vectorstore/_sync/_utils.py +++ b/elasticsearch/helpers/vectorstore/_sync/_utils.py @@ -15,28 +15,21 @@ # specific language governing permissions and limitations # under the License. -from elasticsearch import BadRequestError, ConflictError, Elasticsearch, NotFoundError +from elasticsearch import BadRequestError, Elasticsearch, NotFoundError def model_must_be_deployed(client: Elasticsearch, model_id: str) -> None: + """ + :raises [NotFoundError]: if the model is neither downloaded nor deployed. + :raises [ConflictError]: if the model is downloaded but not yet deployed. + """ + doc = {"text_field": f"test if the model '{model_id}' is deployed"} try: - dummy = {"x": "y"} - client.ml.infer_trained_model(model_id=model_id, docs=[dummy]) - except NotFoundError as err: - raise err - except ConflictError as err: - raise NotFoundError( - f"model '{model_id}' not found, please deploy it first", - meta=err.meta, - body=err.body, - ) from err + client.ml.infer_trained_model(model_id=model_id, docs=[doc]) except BadRequestError: - # This error is expected because we do not know the expected document - # shape and just use a dummy doc above. + # The model is deployed but expects a different input field name. pass - return None - def model_is_deployed(es_client: Elasticsearch, model_id: str) -> bool: try: diff --git a/elasticsearch/helpers/vectorstore/_sync/embedding_service.py b/elasticsearch/helpers/vectorstore/_sync/embedding_service.py index 272c34214..a4d95cd84 100644 --- a/elasticsearch/helpers/vectorstore/_sync/embedding_service.py +++ b/elasticsearch/helpers/vectorstore/_sync/embedding_service.py @@ -16,9 +16,10 @@ # under the License. from abc import ABC, abstractmethod -from typing import List, Optional +from typing import List from elasticsearch import Elasticsearch +from elasticsearch._version import __versionstr__ as lib_version class EmbeddingService(ABC): @@ -26,22 +27,18 @@ class EmbeddingService(ABC): def embed_documents(self, texts: List[str]) -> List[List[float]]: """Generate embeddings for a list of documents. - Args: - texts: A list of document strings to generate embeddings for. + :param texts: A list of document strings to generate embeddings for. - Returns: - A list of embeddings, one for each document in the input. + :return: A list of embeddings, one for each document in the input. """ @abstractmethod def embed_query(self, query: str) -> List[float]: """Generate an embedding for a single query text. - Args: - text: The query text to generate an embedding for. + :param text: The query text to generate an embedding for. - Returns: - The embedding for the input query text. + :return: The embedding for the input query text. """ @@ -56,31 +53,26 @@ class ElasticsearchEmbeddings(EmbeddingService): def __init__( self, es_client: Elasticsearch, - user_agent: str, model_id: str, input_field: str = "text_field", - num_dimensions: Optional[int] = None, + user_agent: str = f"elasticsearch-py-es/{lib_version}", ): """ - Args: - agent_header: user agent header specific to the 3rd party integration. - Used for usage tracking in Elastic Cloud. - model_id: The model_id of the model deployed in the Elasticsearch cluster. - input_field: The name of the key for the input text field in the - document. Defaults to 'text_field'. - num_dimensions: The number of embedding dimensions. If None, then dimensions - will be infer from an example inference call. - es_client: Elasticsearch client connection. Alternatively specify the - Elasticsearch connection with the other es_* parameters. + :param agent_header: user agent header specific to the 3rd party integration. + Used for usage tracking in Elastic Cloud. + :param model_id: The model_id of the model deployed in the Elasticsearch cluster. + :param input_field: The name of the key for the input text field in the + document. Defaults to 'text_field'. + :param es_client: Elasticsearch client connection. Alternatively specify the + Elasticsearch connection with the other es_* parameters. """ # Add integration-specific usage header for tracking usage in Elastic Cloud. - # client.options preserces existing (non-user-agent) headers. + # client.options preserves existing (non-user-agent) headers. es_client = es_client.options(headers={"User-Agent": user_agent}) - self.client = es_client.ml + self.es_client = es_client self.model_id = model_id self.input_field = input_field - self._num_dimensions = num_dimensions def embed_documents(self, texts: List[str]) -> List[List[float]]: result = self._embedding_func(texts) @@ -91,7 +83,7 @@ def embed_query(self, text: str) -> List[float]: return result[0] def _embedding_func(self, texts: List[str]) -> List[List[float]]: - response = self.client.infer_trained_model( + response = self.es_client.ml.infer_trained_model( model_id=self.model_id, docs=[{self.input_field: text} for text in texts] ) return [doc["predicted_value"] for doc in response["inference_results"]] diff --git a/elasticsearch/helpers/vectorstore/_sync/strategies.py b/elasticsearch/helpers/vectorstore/_sync/strategies.py index 603de51e4..32b6adaf6 100644 --- a/elasticsearch/helpers/vectorstore/_sync/strategies.py +++ b/elasticsearch/helpers/vectorstore/_sync/strategies.py @@ -39,15 +39,13 @@ def es_query( Returns the Elasticsearch query body for the given parameters. The store will execute the query. - Args: - query: The text query. Can be None if query_vector is given. - k: The total number of results to retrieve. - num_candidates: The number of results to fetch initially in knn search. - filter: List of filter clauses to apply to the query. - query_vector: The query vector. Can be None if a query string is given. - - Returns: - Dict: The Elasticsearch query body. + :param query: The text query. Can be None if query_vector is given. + :param k: The total number of results to retrieve. + :param num_candidates: The number of results to fetch initially in knn search. + :param filter: List of filter clauses to apply to the query. + :param query_vector: The query vector. Can be None if a query string is given. + + :return: The Elasticsearch query body. """ @abstractmethod @@ -61,11 +59,10 @@ def es_mappings_settings( Create the required index and do necessary preliminary work, like creating inference pipelines or checking if a required model was deployed. - Args: - client: Elasticsearch client connection. - index_name: The name of the Elasticsearch index to create. - metadata_mapping: Flat dictionary with field and field type pairs that - describe the schema of the metadata. + :param client: Elasticsearch client connection. + :param index_name: The name of the Elasticsearch index to create. + :param metadata_mapping: Flat dictionary with field and field type pairs that + describe the schema of the metadata. """ def before_index_creation( @@ -74,22 +71,27 @@ def before_index_creation( """ Executes before the index is created. Used for setting up any required Elasticsearch resources like a pipeline. + Defaults to a no-op. - Args: - client: The Elasticsearch client. - text_field: The field containing the text data in the index. - vector_field: The field containing the vector representations in the index. + :param client: The Elasticsearch client. + :param text_field: The field containing the text data in the index. + :param vector_field: The field containing the vector representations in the index. """ pass def needs_inference(self) -> bool: """ - TODO + Some retrieval strategies index embedding vectors and allow search by embedding + vector, for example the `DenseVectorStrategy` strategy. Mapping a user input query + string to an embedding vector is called inference. Inference can be applied + in Elasticsearch (using a `model_id`) or outside of Elasticsearch (using an + `EmbeddingService` defined on the `VectorStore`). In the latter case, + this method has to return True. """ return False -class SparseVector(RetrievalStrategy): +class SparseVectorStrategy(RetrievalStrategy): """Sparse retrieval strategy using the `text_expansion` processor.""" def __init__(self, model_id: str = ".elser_model_2"): @@ -176,7 +178,7 @@ def before_index_creation( ) -class DenseVector(RetrievalStrategy): +class DenseVectorStrategy(RetrievalStrategy): """K-nearest-neighbors retrieval.""" def __init__( @@ -189,7 +191,7 @@ def __init__( ): if hybrid and not text_field: raise ValueError( - "to enable hybrid you have to specify a text_field (for BM25 matching)" + "to enable hybrid you have to specify a text_field (for BM25Strategy matching)" ) self.distance = distance @@ -304,7 +306,7 @@ def needs_inference(self) -> bool: return not self.model_id -class DenseVectorScriptScore(RetrievalStrategy): +class DenseVectorScriptScoreStrategy(RetrievalStrategy): """Exact nearest neighbors retrieval using the `script_score` query.""" def __init__(self, distance: DistanceMetric = DistanceMetric.COSINE) -> None: @@ -383,7 +385,7 @@ def needs_inference(self) -> bool: return True -class BM25(RetrievalStrategy): +class BM25Strategy(RetrievalStrategy): def __init__( self, k1: Optional[float] = None, diff --git a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py index 5944da41e..a715fe282 100644 --- a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py @@ -35,7 +35,7 @@ class VectorStore: Documents are flat text documents. Depending on the strategy, vector embeddings are - created by the user beforehand - - created by this class in Python + - created by this AsyncVectorStore class in Python - created in-stack by inference pipelines. """ @@ -49,23 +49,22 @@ def __init__( text_field: str = "text_field", vector_field: str = "vector_field", metadata_mappings: Optional[Dict[str, Any]] = None, - user_agent: str = f"es-py-vs/{lib_version}", + user_agent: str = f"elasticsearch-py-vs/{lib_version}", ) -> None: """ - Args: - user_header: user agent header specific to the 3rd party integration. - Used for usage tracking in Elastic Cloud. - index_name: The name of the index to query. - retrieval_strategy: how to index and search the data. See the strategies - module for availble strategies. - text_field: Name of the field with the textual data. - vector_field: For strategies that perform embedding inference in Python, - the embedding vector goes in this field. - es_client: Elasticsearch client connection. Alternatively specify the - Elasticsearch connection with the other es_* parameters. + :param user_header: user agent header specific to the 3rd party integration. + Used for usage tracking in Elastic Cloud. + :param index_name: The name of the index to query. + :param retrieval_strategy: how to index and search the data. See the strategies + module for availble strategies. + :param text_field: Name of the field with the textual data. + :param vector_field: For strategies that perform embedding inference in Python, + the embedding vector goes in this field. + :param es_client: Elasticsearch client connection. Alternatively specify the + Elasticsearch connection with the other es_* parameters. """ # Add integration-specific usage header for tracking usage in Elastic Cloud. - # client.options preserces existing (non-user-agent) headers. + # client.options preserves existing (non-user-agent) headers. es_client = es_client.options(headers={"User-Agent": user_agent}) if hasattr(retrieval_strategy, "text_field"): @@ -97,23 +96,21 @@ def add_texts( ) -> List[str]: """Add documents to the Elasticsearch index. - Args: - texts: List of text documents. - metadata: Optional list of document metadata. Must be of same length as - texts. - vectors: Optional list of embedding vectors. Must be of same length as - texts. - ids: Optional list of ID strings. Must be of same length as texts. - refresh_indices: Whether to refresh the index after deleting documents. - Defaults to True. - create_index_if_not_exists: Whether to create the index if it does not - exist. Defaults to True. - bulk_kwargs: Arguments to pass to the bulk function when indexing - (for example chunk_size). - - Returns: - List of IDs of the created documents, either echoing the provided one - or returning newly created ones. + :param texts: List of text documents. + :param metadata: Optional list of document metadata. Must be of same length as + texts. + :param vectors: Optional list of embedding vectors. Must be of same length as + texts. + :param ids: Optional list of ID strings. Must be of same length as texts. + :param refresh_indices: Whether to refresh the index after deleting documents. + Defaults to True. + :param create_index_if_not_exists: Whether to create the index if it does not + exist. Defaults to True. + :param bulk_kwargs: Arguments to pass to the bulk function when indexing + (for example chunk_size). + + :return: List of IDs of the created documents, either echoing the provided one + or returning newly created ones. """ bulk_kwargs = bulk_kwargs or {} ids = ids or [str(uuid.uuid4()) for _ in texts] @@ -171,10 +168,11 @@ def delete( # type: ignore[no-untyped-def] ) -> bool: """Delete documents from the Elasticsearch index. - Args: - ids: List of IDs of documents to delete. - refresh_indices: Whether to refresh the index after deleting documents. - Defaults to True. + :param ids: List of IDs of documents to delete. + :param refresh_indices: Whether to refresh the index after deleting documents. + Defaults to True. + + :return: True if deletion was successful. """ if ids is not None and query is not None: raise ValueError("one of ids or query must be specified") @@ -225,19 +223,17 @@ def search( ] = None, ) -> List[Dict[str, Any]]: """ - Args: - query: Input query string. - query_vector: Input embedding vector. If given, input query string is - ignored. - k: Number of returned results. - num_candidates: Number of candidates to fetch from data nodes in knn. - fields: List of field names to return. - filter: Elasticsearch filters to apply. - custom_query: Function to modify the Elasticsearch query body before it is - sent to Elasticsearch. - - Returns: - List of document hits. Includes _index, _id, _score and _source. + :param query: Input query string. + :param query_vector: Input embedding vector. If given, input query string is + ignored. + :param k: Number of returned results. + :param num_candidates: Number of candidates to fetch from data nodes in knn. + :param fields: List of field names to return. + :param filter: Elasticsearch filters to apply. + :param custom_query: Function to modify the Elasticsearch query body before it is + sent to Elasticsearch. + + :return: List of document hits. Includes _index, _id, _score and _source. """ if fields is None: fields = [] @@ -332,19 +328,17 @@ def max_marginal_relevance_search( Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents. - Args: - query (str): Text to look up documents similar to. - k (int): Number of Documents to return. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - lambda_mult (float): Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - fields: Other fields to get from elasticsearch source. These fields - will be added to the document metadata. - - Returns: - List[Document]: A list of Documents selected by maximal marginal relevance. + :param query (str): Text to look up documents similar to. + :param k (int): Number of Documents to return. Defaults to 4. + :param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + :param lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + :param fields: Other fields to get from elasticsearch source. These fields + will be added to the document metadata. + + :return: A list of Documents selected by maximal marginal relevance. """ remove_vector_query_field_from_metadata = True if fields is None: diff --git a/test_elasticsearch/test_server/conftest.py b/test_elasticsearch/test_server/conftest.py index 1fa1f2c74..9bbadece7 100644 --- a/test_elasticsearch/test_server/conftest.py +++ b/test_elasticsearch/test_server/conftest.py @@ -54,6 +54,11 @@ def sync_client_factory(elasticsearch_url): client = None try: client = _create(elasticsearch_url) + + # Wipe the cluster before we start testing just in case it wasn't wiped + # cleanly from the previous run of pytest? + wipe_cluster(client) + yield client finally: if client: diff --git a/test_elasticsearch/test_server/test_helpers_vectorstore/conftest.py b/test_elasticsearch/test_server/test_helpers_vectorstore/conftest.py index 31c8c4def..97c264532 100644 --- a/test_elasticsearch/test_server/test_helpers_vectorstore/conftest.py +++ b/test_elasticsearch/test_server/test_helpers_vectorstore/conftest.py @@ -32,8 +32,19 @@ def index_name() -> str: @pytest.fixture(scope="function") def sync_client_request_saving_factory(elasticsearch_url): client = None + try: + client = _create(elasticsearch_url) + # Wipe the cluster before we start testing just in case it wasn't wiped + # cleanly from the previous run of pytest? + wipe_cluster(client) + finally: + client.close() + + try: + # Recreate client with a transport that saves requests. client = _create(elasticsearch_url, RequestSavingTransport) + yield client finally: if client: diff --git a/test_elasticsearch/test_server/test_helpers_vectorstore/docker-compose.yml b/test_elasticsearch/test_server/test_helpers_vectorstore/docker-compose.yml index d598fe235..b3aa43f18 100644 --- a/test_elasticsearch/test_server/test_helpers_vectorstore/docker-compose.yml +++ b/test_elasticsearch/test_server/test_helpers_vectorstore/docker-compose.yml @@ -25,18 +25,3 @@ services: - elasticsearch restart: no command: sh -c "sleep 10 && eland_import_hub_model --url http://elasticsearch:9200 --hub-model-id sentence-transformers/msmarco-minilm-l-12-v3 --start" - - kibana: - image: kibana:8.13.0 - environment: - - ELASTICSEARCH_URL=http://elasticsearch:9200 - ports: - - "5601:5601" - healthcheck: - test: - [ - "CMD-SHELL", - "curl --silent --fail http://localhost:5601/login || exit 1" - ] - interval: 10s - retries: 60 diff --git a/test_elasticsearch/test_server/test_helpers_vectorstore/test_embedding_service.py b/test_elasticsearch/test_server/test_helpers_vectorstore/test_embedding_service.py index a2e6dfc20..396480e3e 100644 --- a/test_elasticsearch/test_server/test_helpers_vectorstore/test_embedding_service.py +++ b/test_elasticsearch/test_server/test_helpers_vectorstore/test_embedding_service.py @@ -16,6 +16,7 @@ # under the License. import os +import re import pytest @@ -28,7 +29,7 @@ # deployed with # https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html MODEL_ID = os.getenv("MODEL_ID", "sentence-transformers__msmarco-minilm-l-12-v3") -NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384")) +NUM_DIMENSIONS = int(os.getenv("NUM_DIMENSIONS", "384")) def test_elasticsearch_embedding_documents(sync_client: Elasticsearch) -> None: @@ -60,3 +61,33 @@ def test_elasticsearch_embedding_query(sync_client: Elasticsearch) -> None: ) output = embedding.embed_query(document) assert len(output) == NUM_DIMENSIONS + + +def test_user_agent_default( + sync_client: Elasticsearch, sync_client_request_saving: Elasticsearch +) -> None: + """Test to make sure the user-agent is set correctly.""" + + if not model_is_deployed(sync_client, MODEL_ID): + pytest.skip(f"{MODEL_ID} model is not deployed in ML Node, skipping test") + + embeddings = ElasticsearchEmbeddings( + es_client=sync_client_request_saving, model_id=MODEL_ID + ) + + expected_pattern = r"^elasticsearch-py-es/\d+\.\d+\.\d+$" + + got_agent = embeddings.es_client._headers["User-Agent"] + assert ( + re.match(expected_pattern, got_agent) is not None + ), f"The user agent '{got_agent}' does not match the expected pattern." + + embeddings.embed_query("foo bar") + + requests = embeddings.es_client.transport.requests # type: ignore + assert len(requests) == 1 + + got_request_agent = requests[0]["headers"]["User-Agent"] + assert ( + re.match(expected_pattern, got_request_agent) is not None + ), f"The user agent '{got_request_agent}' does not match the expected pattern." diff --git a/test_elasticsearch/test_server/test_helpers_vectorstore/test_vectorstore.py b/test_elasticsearch/test_server/test_helpers_vectorstore/test_vectorstore.py index a8ccb3f94..54e6cfded 100644 --- a/test_elasticsearch/test_server/test_helpers_vectorstore/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_helpers_vectorstore/test_vectorstore.py @@ -16,6 +16,7 @@ # under the License. import logging +import re from functools import partial from typing import Any, List, Optional, Union @@ -24,11 +25,11 @@ from elasticsearch import Elasticsearch, NotFoundError from elasticsearch.helpers import BulkIndexError from elasticsearch.helpers.vectorstore import ( - BM25, - DenseVector, - DenseVectorScriptScore, + BM25Strategy, + DenseVectorScriptScoreStrategy, + DenseVectorStrategy, DistanceMetric, - SparseVector, + SparseVectorStrategy, VectorStore, ) from elasticsearch.helpers.vectorstore._sync._utils import model_is_deployed @@ -78,7 +79,7 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: store = VectorStore( index_name=index_name, - retrieval_strategy=DenseVector(), + retrieval_strategy=DenseVectorStrategy(), embedding_service=FakeEmbeddings(), es_client=sync_client, ) @@ -95,7 +96,7 @@ def test_search_without_metadata_async( """Test end to end construction and search without metadata.""" store = VectorStore( index_name=index_name, - retrieval_strategy=DenseVector(), + retrieval_strategy=DenseVectorStrategy(), embedding_service=FakeEmbeddings(), es_client=sync_client, ) @@ -123,7 +124,7 @@ def test_add_vectors(self, sync_client: Elasticsearch, index_name: str) -> None: store = VectorStore( index_name=index_name, - retrieval_strategy=DenseVector(), + retrieval_strategy=DenseVectorStrategy(), embedding_service=embeddings, es_client=sync_client, ) @@ -139,7 +140,7 @@ def test_search_with_metadata( """Test end to end construction and search with metadata.""" store = VectorStore( index_name=index_name, - retrieval_strategy=DenseVector(), + retrieval_strategy=DenseVectorStrategy(), embedding_service=ConsistentFakeEmbeddings(), es_client=sync_client, ) @@ -162,7 +163,7 @@ def test_search_with_filter( """Test end to end construction and search with metadata.""" store = VectorStore( index_name=index_name, - retrieval_strategy=DenseVector(), + retrieval_strategy=DenseVectorStrategy(), embedding_service=FakeEmbeddings(), es_client=sync_client, ) @@ -198,7 +199,7 @@ def test_search_script_score( """Test end to end construction and search with metadata.""" store = VectorStore( index_name=index_name, - retrieval_strategy=DenseVectorScriptScore(), + retrieval_strategy=DenseVectorScriptScoreStrategy(), embedding_service=FakeEmbeddings(), es_client=sync_client, ) @@ -244,7 +245,7 @@ def test_search_script_score_with_filter( """Test end to end construction and search with metadata.""" store = VectorStore( index_name=index_name, - retrieval_strategy=DenseVectorScriptScore(), + retrieval_strategy=DenseVectorScriptScoreStrategy(), embedding_service=FakeEmbeddings(), es_client=sync_client, ) @@ -296,7 +297,7 @@ def test_search_script_score_distance_dot_product( """Test end to end construction and search with metadata.""" store = VectorStore( index_name=index_name, - retrieval_strategy=DenseVectorScriptScore( + retrieval_strategy=DenseVectorScriptScoreStrategy( distance=DistanceMetric.DOT_PRODUCT, ), embedding_service=FakeEmbeddings(), @@ -345,7 +346,7 @@ def test_search_knn_with_hybrid_search( """Test end to end construction and search with metadata.""" store = VectorStore( index_name=index_name, - retrieval_strategy=DenseVector(hybrid=True), + retrieval_strategy=DenseVectorStrategy(hybrid=True), embedding_service=FakeEmbeddings(), es_client=sync_client, ) @@ -431,7 +432,7 @@ def assert_query( for rrf_test_case in rrf_test_cases: store = VectorStore( index_name=index_name, - retrieval_strategy=DenseVector(hybrid=True, rrf=rrf_test_case), + retrieval_strategy=DenseVectorStrategy(hybrid=True, rrf=rrf_test_case), embedding_service=FakeEmbeddings(), es_client=sync_client, ) @@ -471,7 +472,7 @@ def assert_query( # 3. check rrf default option is okay store = VectorStore( index_name=f"{index_name}_default", - retrieval_strategy=DenseVector(hybrid=True), + retrieval_strategy=DenseVectorStrategy(hybrid=True), embedding_service=FakeEmbeddings(), es_client=sync_client, ) @@ -492,7 +493,7 @@ def test_search_knn_with_custom_query_fn( with the query string and query body""" store = VectorStore( index_name=index_name, - retrieval_strategy=DenseVector(), + retrieval_strategy=DenseVectorStrategy(), embedding_service=FakeEmbeddings(), es_client=sync_client, ) @@ -531,7 +532,7 @@ def test_search_with_knn_infer_instack( store = VectorStore( index_name=index_name, - retrieval_strategy=DenseVector( + retrieval_strategy=DenseVectorStrategy( model_id="sentence-transformers__all-minilm-l6-v2" ), es_client=sync_client, @@ -620,7 +621,7 @@ def test_search_with_sparse_infer_instack( store = VectorStore( index_name=index_name, - retrieval_strategy=SparseVector(model_id=ELSER_MODEL_ID), + retrieval_strategy=SparseVectorStrategy(model_id=ELSER_MODEL_ID), es_client=sync_client, ) @@ -637,16 +638,18 @@ def test_deployed_model_check_fails_semantic( with pytest.raises(NotFoundError): store = VectorStore( index_name=index_name, - retrieval_strategy=DenseVector(model_id="non-existing model ID"), + retrieval_strategy=DenseVectorStrategy( + model_id="non-existing model ID" + ), es_client=sync_client, ) store.add_texts(["foo", "bar", "baz"]) def test_search_bm25(self, sync_client: Elasticsearch, index_name: str) -> None: - """Test end to end using the BM25 retrieval strategy.""" + """Test end to end using the BM25Strategy retrieval strategy.""" store = VectorStore( index_name=index_name, - retrieval_strategy=BM25(), + retrieval_strategy=BM25Strategy(), es_client=sync_client, ) @@ -670,10 +673,10 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: def test_search_bm25_with_filter( self, sync_client: Elasticsearch, index_name: str ) -> None: - """Test end to using the BM25 retrieval strategy with metadata.""" + """Test end to using the BM25Strategy retrieval strategy with metadata.""" store = VectorStore( index_name=index_name, - retrieval_strategy=BM25(), + retrieval_strategy=BM25Strategy(), es_client=sync_client, ) @@ -705,7 +708,7 @@ def test_delete(self, sync_client: Elasticsearch, index_name: str) -> None: """Test delete methods from vector store.""" store = VectorStore( index_name=index_name, - retrieval_strategy=DenseVector(), + retrieval_strategy=DenseVectorStrategy(), embedding_service=FakeEmbeddings(), es_client=sync_client, ) @@ -742,7 +745,7 @@ def test_indexing_exception_error( """Test bulk exception logging is giving better hints.""" store = VectorStore( index_name=index_name, - retrieval_strategy=BM25(), + retrieval_strategy=BM25Strategy(), es_client=sync_client, ) @@ -768,17 +771,24 @@ def test_user_agent_default( """Test to make sure the user-agent is set correctly.""" store = VectorStore( index_name=index_name, - retrieval_strategy=BM25(), + retrieval_strategy=BM25Strategy(), es_client=sync_client_request_saving, ) + expected_pattern = r"^elasticsearch-py-vs/\d+\.\d+\.\d+$" - assert store.es_client._headers["User-Agent"].startswith("es-py-vs/") + got_agent = store.es_client._headers["User-Agent"] + assert ( + re.match(expected_pattern, got_agent) is not None + ), f"The user agent '{got_agent}' does not match the expected pattern." texts = ["foo", "bob", "baz"] store.add_texts(texts) for request in store.es_client.transport.requests: # type: ignore - assert request["headers"]["User-Agent"].startswith("es-py-vs/") + agent = request["headers"]["User-Agent"] + assert ( + re.match(expected_pattern, agent) is not None + ), f"The user agent '{agent}' does not match the expected pattern." def test_user_agent_custom( self, sync_client_request_saving: Elasticsearch, index_name: str @@ -789,7 +799,7 @@ def test_user_agent_custom( store = VectorStore( user_agent=user_agent, index_name=index_name, - retrieval_strategy=BM25(), + retrieval_strategy=BM25Strategy(), es_client=sync_client_request_saving, ) @@ -805,7 +815,7 @@ def test_bulk_args(self, sync_client_request_saving: Any, index_name: str) -> No """Test to make sure the bulk arguments work as expected.""" store = VectorStore( index_name=index_name, - retrieval_strategy=BM25(), + retrieval_strategy=BM25Strategy(), es_client=sync_client_request_saving, ) @@ -825,7 +835,7 @@ def test_max_marginal_relevance_search( embedding_service = ConsistentFakeEmbeddings() store = VectorStore( index_name=index_name, - retrieval_strategy=DenseVectorScriptScore(), + retrieval_strategy=DenseVectorScriptScoreStrategy(), embedding_service=embedding_service, vector_field=vector_field, text_field=text_field, @@ -886,7 +896,7 @@ def test_metadata_mapping( } store = VectorStore( index_name=index_name, - retrieval_strategy=DenseVector(), + retrieval_strategy=DenseVectorStrategy(), embedding_service=FakeEmbeddings(), es_client=sync_client, metadata_mappings=test_mappings, diff --git a/utils/run-unasync.py b/utils/run-unasync.py index 51c041974..ec7623869 100644 --- a/utils/run-unasync.py +++ b/utils/run-unasync.py @@ -86,14 +86,14 @@ def main(): fromdir="elasticsearch/helpers/vectorstore/_async/", todir="elasticsearch/helpers/vectorstore/_sync/", additional_replacements={ - "AsyncBM25": "BM25", - "AsyncDenseVector": "DenseVector", - "AsyncDenseVectorScriptScore": "DenseVectorScriptScore", + "AsyncBM25Strategy": "BM25Strategy", + "AsyncDenseVectorStrategy": "DenseVectorStrategy", + "AsyncDenseVectorScriptScoreStrategy": "DenseVectorScriptScoreStrategy", "AsyncElasticsearch": "Elasticsearch", "AsyncElasticsearchEmbeddings": "ElasticsearchEmbeddings", "AsyncEmbeddingService": "EmbeddingService", "AsyncRetrievalStrategy": "RetrievalStrategy", - "AsyncSparseVector": "SparseVector", + "AsyncSparseVectorStrategy": "SparseVectorStrategy", "AsyncTransport": "Transport", "AsyncVectorStore": "VectorStore", "async_bulk": "bulk", @@ -102,8 +102,6 @@ def main(): ), cleanup_patterns=[ "/^import asyncio$/d", - "/^import pytest_asyncio*/d", - "/ *@pytest.mark.asyncio$/d", ], format=True, ) From a5dea84c7b540deffe09f8f334b328e61766af77 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Thu, 25 Apr 2024 13:22:48 +0200 Subject: [PATCH 30/36] improve docstring --- .../helpers/vectorstore/_async/strategies.py | 2 +- .../helpers/vectorstore/_async/vectorstore.py | 17 ++++++++++++----- .../helpers/vectorstore/_sync/strategies.py | 2 +- .../helpers/vectorstore/_sync/vectorstore.py | 17 ++++++++++++----- 4 files changed, 26 insertions(+), 12 deletions(-) diff --git a/elasticsearch/helpers/vectorstore/_async/strategies.py b/elasticsearch/helpers/vectorstore/_async/strategies.py index 29c1f58d7..88027ebc6 100644 --- a/elasticsearch/helpers/vectorstore/_async/strategies.py +++ b/elasticsearch/helpers/vectorstore/_async/strategies.py @@ -217,7 +217,7 @@ def es_query( "num_candidates": num_candidates, } - if query_vector: + if query_vector is not None: knn["query_vector"] = query_vector else: # Inference in Elasticsearch. When initializing we make sure to always have diff --git a/elasticsearch/helpers/vectorstore/_async/vectorstore.py b/elasticsearch/helpers/vectorstore/_async/vectorstore.py index 829610400..bd72284e4 100644 --- a/elasticsearch/helpers/vectorstore/_async/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_async/vectorstore.py @@ -32,13 +32,20 @@ class AsyncVectorStore: - """VectorStore is a higher-level abstraction of indexing and search. + """ + VectorStore is a higher-level abstraction of indexing and search. Users can pick from available retrieval strategies. - Documents are flat text documents. Depending on the strategy, vector embeddings are - - created by the user beforehand - - created by this AsyncVectorStore class in Python - - created in-stack by inference pipelines. + Documents have up to 3 fields: + - text_field: the text to be indexed and searched. + - metadata: additional information about the document, either schema-free + or defined by the supplied metadata_mappings. + - vector_field (usually not filled by the user): the embedding vector of the text. + + Depending on the strategy, vector embeddings are + - created by the user beforehand + - created by this AsyncVectorStore class in Python + - created in-stack by inference pipelines. """ def __init__( diff --git a/elasticsearch/helpers/vectorstore/_sync/strategies.py b/elasticsearch/helpers/vectorstore/_sync/strategies.py index 32b6adaf6..83b75392b 100644 --- a/elasticsearch/helpers/vectorstore/_sync/strategies.py +++ b/elasticsearch/helpers/vectorstore/_sync/strategies.py @@ -217,7 +217,7 @@ def es_query( "num_candidates": num_candidates, } - if query_vector: + if query_vector is not None: knn["query_vector"] = query_vector else: # Inference in Elasticsearch. When initializing we make sure to always have diff --git a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py index a715fe282..34e3cfc1a 100644 --- a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py @@ -30,13 +30,20 @@ class VectorStore: - """VectorStore is a higher-level abstraction of indexing and search. + """ + VectorStore is a higher-level abstraction of indexing and search. Users can pick from available retrieval strategies. - Documents are flat text documents. Depending on the strategy, vector embeddings are - - created by the user beforehand - - created by this AsyncVectorStore class in Python - - created in-stack by inference pipelines. + Documents have up to 3 fields: + - text_field: the text to be indexed and searched. + - metadata: additional information about the document, either schema-free + or defined by the supplied metadata_mappings. + - vector_field (usually not filled by the user): the embedding vector of the text. + + Depending on the strategy, vector embeddings are + - created by the user beforehand + - created by this AsyncVectorStore class in Python + - created in-stack by inference pipelines. """ def __init__( From 6f81af91f774aa77d251bdf93e3f2dff51d298c5 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Fri, 26 Apr 2024 11:08:02 +0200 Subject: [PATCH 31/36] fix metadata mappings issue --- .../helpers/vectorstore/_async/vectorstore.py | 3 +-- .../helpers/vectorstore/_sync/vectorstore.py | 3 +-- .../test_helpers_vectorstore/test_vectorstore.py | 12 ++++++++++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/elasticsearch/helpers/vectorstore/_async/vectorstore.py b/elasticsearch/helpers/vectorstore/_async/vectorstore.py index bd72284e4..01cb17dfc 100644 --- a/elasticsearch/helpers/vectorstore/_async/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_async/vectorstore.py @@ -302,7 +302,6 @@ async def _create_index_if_not_exists(self) -> None: vector_field=self.vector_field, num_dimensions=self.num_dimensions, ) - if self.metadata_mappings: metadata = mappings["properties"].get("metadata", {"properties": {}}) for key in self.metadata_mappings.keys(): @@ -310,7 +309,7 @@ async def _create_index_if_not_exists(self) -> None: raise ValueError(f"metadata key {key} already exists in mappings") metadata = dict(**metadata["properties"], **self.metadata_mappings) - mappings["properties"] = {"metadata": {"properties": metadata}} + mappings["properties"]["metadata"] = {"properties": metadata} await self.retrieval_strategy.before_index_creation( self.es_client, self.text_field, self.vector_field diff --git a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py index 34e3cfc1a..bca9d1b3e 100644 --- a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py @@ -300,7 +300,6 @@ def _create_index_if_not_exists(self) -> None: vector_field=self.vector_field, num_dimensions=self.num_dimensions, ) - if self.metadata_mappings: metadata = mappings["properties"].get("metadata", {"properties": {}}) for key in self.metadata_mappings.keys(): @@ -308,7 +307,7 @@ def _create_index_if_not_exists(self) -> None: raise ValueError(f"metadata key {key} already exists in mappings") metadata = dict(**metadata["properties"], **self.metadata_mappings) - mappings["properties"] = {"metadata": {"properties": metadata}} + mappings["properties"]["metadata"] = {"properties": metadata} self.retrieval_strategy.before_index_creation( self.es_client, self.text_field, self.vector_field diff --git a/test_elasticsearch/test_server/test_helpers_vectorstore/test_vectorstore.py b/test_elasticsearch/test_server/test_helpers_vectorstore/test_vectorstore.py index 54e6cfded..62e345119 100644 --- a/test_elasticsearch/test_server/test_helpers_vectorstore/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_helpers_vectorstore/test_vectorstore.py @@ -896,18 +896,26 @@ def test_metadata_mapping( } store = VectorStore( index_name=index_name, - retrieval_strategy=DenseVectorStrategy(), + retrieval_strategy=DenseVectorStrategy(distance=DistanceMetric.COSINE), embedding_service=FakeEmbeddings(), + num_dimensions=10, es_client=sync_client, metadata_mappings=test_mappings, ) texts = ["foo", "foo", "foo"] - metadatas = [{"page": i} for i in range(len(texts))] + metadatas = [{"my_field": str(i)} for i in range(len(texts))] store.add_texts(texts=texts, metadatas=metadatas) mapping_response = sync_client.indices.get_mapping(index=index_name) mapping_properties = mapping_response[index_name]["mappings"]["properties"] + assert mapping_properties["vector_field"] == { + "type": "dense_vector", + "dims": 10, + "index": True, + "similarity": "cosine", + } + assert "metadata" in mapping_properties for key, val in test_mappings.items(): assert mapping_properties["metadata"]["properties"][key] == val From 881d56c6d499b2e4019b5cad08161b5ce07acb97 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Mon, 29 Apr 2024 12:34:36 +0200 Subject: [PATCH 32/36] address PR feedback --- elasticsearch/helpers/vectorstore/__init__.py | 18 ++--- .../vectorstore/_async/embedding_service.py | 3 +- .../helpers/vectorstore/_async/strategies.py | 26 +++--- .../helpers/vectorstore/_async/vectorstore.py | 4 +- .../vectorstore/_sync/embedding_service.py | 3 +- .../helpers/vectorstore/_sync/strategies.py | 26 +++--- .../helpers/vectorstore/_sync/vectorstore.py | 3 +- elasticsearch/helpers/vectorstore/_utils.py | 2 +- noxfile.py | 6 +- setup.py | 3 +- .../test_helpers_vectorstore/__init__.py | 65 +++++++++++++++ .../test_helpers_vectorstore/_test_utils.py | 81 ------------------- .../test_helpers_vectorstore/conftest.py | 2 +- .../test_embedding_service.py | 4 +- .../test_vectorstore.py | 3 +- utils/run-unasync.py | 23 ++---- 16 files changed, 121 insertions(+), 151 deletions(-) delete mode 100644 test_elasticsearch/test_server/test_helpers_vectorstore/_test_utils.py diff --git a/elasticsearch/helpers/vectorstore/__init__.py b/elasticsearch/helpers/vectorstore/__init__.py index 815f3f2ee..30a4c3d6e 100644 --- a/elasticsearch/helpers/vectorstore/__init__.py +++ b/elasticsearch/helpers/vectorstore/__init__.py @@ -42,21 +42,21 @@ from elasticsearch.helpers.vectorstore._utils import DistanceMetric __all__ = [ - "BM25Strategy", - "DenseVectorStrategy", - "DenseVectorScriptScoreStrategy", - "ElasticsearchEmbeddings", - "EmbeddingService", - "RetrievalStrategy", - "SparseVectorStrategy", - "VectorStore", "AsyncBM25Strategy", - "AsyncDenseVectorStrategy", "AsyncDenseVectorScriptScoreStrategy", + "AsyncDenseVectorStrategy", "AsyncElasticsearchEmbeddings", "AsyncEmbeddingService", "AsyncRetrievalStrategy", "AsyncSparseVectorStrategy", "AsyncVectorStore", + "BM25Strategy", + "DenseVectorScriptScoreStrategy", + "DenseVectorStrategy", "DistanceMetric", + "ElasticsearchEmbeddings", + "EmbeddingService", + "RetrievalStrategy", + "SparseVectorStrategy", + "VectorStore", ] diff --git a/elasticsearch/helpers/vectorstore/_async/embedding_service.py b/elasticsearch/helpers/vectorstore/_async/embedding_service.py index e86304ed4..2612372b4 100644 --- a/elasticsearch/helpers/vectorstore/_async/embedding_service.py +++ b/elasticsearch/helpers/vectorstore/_async/embedding_service.py @@ -75,8 +75,7 @@ def __init__( self.input_field = input_field async def embed_documents(self, texts: List[str]) -> List[List[float]]: - result = await self._embedding_func(texts) - return result + return await self._embedding_func(texts) async def embed_query(self, text: str) -> List[float]: result = await self._embedding_func([text]) diff --git a/elasticsearch/helpers/vectorstore/_async/strategies.py b/elasticsearch/helpers/vectorstore/_async/strategies.py index 88027ebc6..8d19698e6 100644 --- a/elasticsearch/helpers/vectorstore/_async/strategies.py +++ b/elasticsearch/helpers/vectorstore/_async/strategies.py @@ -241,13 +241,13 @@ def es_mappings_settings( num_dimensions: Optional[int], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: if self.distance is DistanceMetric.COSINE: - similarityAlgo = "cosine" + similarity = "cosine" elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE: - similarityAlgo = "l2_norm" + similarity = "l2_norm" elif self.distance is DistanceMetric.DOT_PRODUCT: - similarityAlgo = "dot_product" + similarity = "dot_product" elif self.distance is DistanceMetric.MAX_INNER_PRODUCT: - similarityAlgo = "max_inner_product" + similarity = "max_inner_product" else: raise ValueError(f"Similarity {self.distance} not supported.") @@ -257,7 +257,7 @@ def es_mappings_settings( "type": "dense_vector", "dims": num_dimensions, "index": True, - "similarity": similarityAlgo, + "similarity": similarity, }, } } @@ -326,18 +326,18 @@ def es_query( raise ValueError("specify a query_vector") if self.distance is DistanceMetric.COSINE: - similarityAlgo = ( + similarity_algo = ( f"cosineSimilarity(params.query_vector, '{vector_field}') + 1.0" ) elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE: - similarityAlgo = f"1 / (1 + l2norm(params.query_vector, '{vector_field}'))" + similarity_algo = f"1 / (1 + l2norm(params.query_vector, '{vector_field}'))" elif self.distance is DistanceMetric.DOT_PRODUCT: - similarityAlgo = f""" + similarity_algo = f""" double value = dotProduct(params.query_vector, '{vector_field}'); return sigmoid(1, Math.E, -value); """ elif self.distance is DistanceMetric.MAX_INNER_PRODUCT: - similarityAlgo = f""" + similarity_algo = f""" double value = dotProduct(params.query_vector, '{vector_field}'); if (dotProduct < 0) {{ return 1 / (1 + -1 * dotProduct); @@ -347,16 +347,16 @@ def es_query( else: raise ValueError(f"Similarity {self.distance} not supported.") - queryBool: Dict[str, Any] = {"match_all": {}} + query_bool: Dict[str, Any] = {"match_all": {}} if filter: - queryBool = {"bool": {"filter": filter}} + query_bool = {"bool": {"filter": filter}} return { "query": { "script_score": { - "query": queryBool, + "query": query_bool, "script": { - "source": similarityAlgo, + "source": similarity_algo, "params": {"query_vector": query_vector}, }, }, diff --git a/elasticsearch/helpers/vectorstore/_async/vectorstore.py b/elasticsearch/helpers/vectorstore/_async/vectorstore.py index 01cb17dfc..146b05394 100644 --- a/elasticsearch/helpers/vectorstore/_async/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_async/vectorstore.py @@ -22,10 +22,10 @@ from elasticsearch import AsyncElasticsearch from elasticsearch._version import __versionstr__ as lib_version from elasticsearch.helpers import BulkIndexError, async_bulk -from elasticsearch.helpers.vectorstore._async.embedding_service import ( +from elasticsearch.helpers.vectorstore import ( AsyncEmbeddingService, + AsyncRetrievalStrategy, ) -from elasticsearch.helpers.vectorstore._async.strategies import AsyncRetrievalStrategy from elasticsearch.helpers.vectorstore._utils import maximal_marginal_relevance logger = logging.getLogger(__name__) diff --git a/elasticsearch/helpers/vectorstore/_sync/embedding_service.py b/elasticsearch/helpers/vectorstore/_sync/embedding_service.py index a4d95cd84..51b607237 100644 --- a/elasticsearch/helpers/vectorstore/_sync/embedding_service.py +++ b/elasticsearch/helpers/vectorstore/_sync/embedding_service.py @@ -75,8 +75,7 @@ def __init__( self.input_field = input_field def embed_documents(self, texts: List[str]) -> List[List[float]]: - result = self._embedding_func(texts) - return result + return self._embedding_func(texts) def embed_query(self, text: str) -> List[float]: result = self._embedding_func([text]) diff --git a/elasticsearch/helpers/vectorstore/_sync/strategies.py b/elasticsearch/helpers/vectorstore/_sync/strategies.py index 83b75392b..d172a7c1f 100644 --- a/elasticsearch/helpers/vectorstore/_sync/strategies.py +++ b/elasticsearch/helpers/vectorstore/_sync/strategies.py @@ -241,13 +241,13 @@ def es_mappings_settings( num_dimensions: Optional[int], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: if self.distance is DistanceMetric.COSINE: - similarityAlgo = "cosine" + similarity = "cosine" elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE: - similarityAlgo = "l2_norm" + similarity = "l2_norm" elif self.distance is DistanceMetric.DOT_PRODUCT: - similarityAlgo = "dot_product" + similarity = "dot_product" elif self.distance is DistanceMetric.MAX_INNER_PRODUCT: - similarityAlgo = "max_inner_product" + similarity = "max_inner_product" else: raise ValueError(f"Similarity {self.distance} not supported.") @@ -257,7 +257,7 @@ def es_mappings_settings( "type": "dense_vector", "dims": num_dimensions, "index": True, - "similarity": similarityAlgo, + "similarity": similarity, }, } } @@ -326,18 +326,18 @@ def es_query( raise ValueError("specify a query_vector") if self.distance is DistanceMetric.COSINE: - similarityAlgo = ( + similarity_algo = ( f"cosineSimilarity(params.query_vector, '{vector_field}') + 1.0" ) elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE: - similarityAlgo = f"1 / (1 + l2norm(params.query_vector, '{vector_field}'))" + similarity_algo = f"1 / (1 + l2norm(params.query_vector, '{vector_field}'))" elif self.distance is DistanceMetric.DOT_PRODUCT: - similarityAlgo = f""" + similarity_algo = f""" double value = dotProduct(params.query_vector, '{vector_field}'); return sigmoid(1, Math.E, -value); """ elif self.distance is DistanceMetric.MAX_INNER_PRODUCT: - similarityAlgo = f""" + similarity_algo = f""" double value = dotProduct(params.query_vector, '{vector_field}'); if (dotProduct < 0) {{ return 1 / (1 + -1 * dotProduct); @@ -347,16 +347,16 @@ def es_query( else: raise ValueError(f"Similarity {self.distance} not supported.") - queryBool: Dict[str, Any] = {"match_all": {}} + query_bool: Dict[str, Any] = {"match_all": {}} if filter: - queryBool = {"bool": {"filter": filter}} + query_bool = {"bool": {"filter": filter}} return { "query": { "script_score": { - "query": queryBool, + "query": query_bool, "script": { - "source": similarityAlgo, + "source": similarity_algo, "params": {"query_vector": query_vector}, }, }, diff --git a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py index bca9d1b3e..03d9932a2 100644 --- a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py @@ -22,8 +22,7 @@ from elasticsearch import Elasticsearch from elasticsearch._version import __versionstr__ as lib_version from elasticsearch.helpers import BulkIndexError, bulk -from elasticsearch.helpers.vectorstore._sync.embedding_service import EmbeddingService -from elasticsearch.helpers.vectorstore._sync.strategies import RetrievalStrategy +from elasticsearch.helpers.vectorstore import EmbeddingService, RetrievalStrategy from elasticsearch.helpers.vectorstore._utils import maximal_marginal_relevance logger = logging.getLogger(__name__) diff --git a/elasticsearch/helpers/vectorstore/_utils.py b/elasticsearch/helpers/vectorstore/_utils.py index 02e133220..df91b5cc9 100644 --- a/elasticsearch/helpers/vectorstore/_utils.py +++ b/elasticsearch/helpers/vectorstore/_utils.py @@ -112,5 +112,5 @@ def _raise_missing_mmr_deps_error(parent_error: ModuleNotFoundError) -> None: raise ModuleNotFoundError( f"Failed to compute maximal marginal relevance because the required " f"module '{parent_error.name}' is missing. You can install it by running: " - f"'{sys.executable} -m pip install elasticsearch[mmr]'" + f"'{sys.executable} -m pip install elasticsearch[vectorstore_mmr]'" ) from parent_error diff --git a/noxfile.py b/noxfile.py index 275deceed..a3e1fc172 100644 --- a/noxfile.py +++ b/noxfile.py @@ -48,7 +48,9 @@ def pytest_argv(): @nox.session(python=["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]) def test(session): - session.install(".[async,requests,orjson,mmr]", env=INSTALL_ENV, silent=False) + session.install( + ".[async,requests,orjson,vectorstore_mmr]", env=INSTALL_ENV, silent=False + ) session.install("-r", "dev-requirements.txt", silent=False) session.run(*pytest_argv()) @@ -95,7 +97,7 @@ def lint(session): session.run("flake8", *SOURCE_FILES) session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) - session.install(".[async,requests,orjson,mmr]", env=INSTALL_ENV) + session.install(".[async,requests,orjson,vectorstore_mmr]", env=INSTALL_ENV) # Run mypy on the package and then the type examples separately for # the two different mypy use-cases, ourselves and our users. diff --git a/setup.py b/setup.py index 775a7b319..e9ee3a377 100644 --- a/setup.py +++ b/setup.py @@ -92,6 +92,7 @@ "requests": ["requests>=2.4.0, <3.0.0"], "async": ["aiohttp>=3,<4"], "orjson": ["orjson>=3"], - "mmr": ["numpy>=1", "simsimd>=3"], + # Maximal Marginal Relevance (MMR) for search results + "vectorstore_mmr": ["numpy>=1", "simsimd>=3"], }, ) diff --git a/test_elasticsearch/test_server/test_helpers_vectorstore/__init__.py b/test_elasticsearch/test_server/test_helpers_vectorstore/__init__.py index 2a87d183f..87710976a 100644 --- a/test_elasticsearch/test_server/test_helpers_vectorstore/__init__.py +++ b/test_elasticsearch/test_server/test_helpers_vectorstore/__init__.py @@ -14,3 +14,68 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from typing import List + +from elastic_transport import Transport + +from elasticsearch.helpers.vectorstore import EmbeddingService + + +class RequestSavingTransport(Transport): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.requests: list = [] + + def perform_request(self, *args, **kwargs): + self.requests.append(kwargs) + return super().perform_request(*args, **kwargs) + + +class FakeEmbeddings(EmbeddingService): + """Fake embeddings functionality for testing.""" + + def __init__(self, dimensionality: int = 10) -> None: + self.dimensionality = dimensionality + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return simple embeddings. Embeddings encode each text as its index.""" + return [ + [float(1.0)] * (self.dimensionality - 1) + [float(i)] + for i in range(len(texts)) + ] + + def embed_query(self, text: str) -> List[float]: + """Return constant query embeddings. + Embeddings are identical to embed_documents(texts)[0]. + Distance to each text will be that text's index, + as it was passed to embed_documents. + """ + return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)] + + +class ConsistentFakeEmbeddings(FakeEmbeddings): + """Fake embeddings which remember all the texts seen so far to return consistent + vectors for the same texts.""" + + def __init__(self, dimensionality: int = 10) -> None: + self.known_texts: List[str] = [] + self.dimensionality = dimensionality + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return consistent embeddings for each text seen so far.""" + out_vectors = [] + for text in texts: + if text not in self.known_texts: + self.known_texts.append(text) + vector = [float(1.0)] * (self.dimensionality - 1) + [ + float(self.known_texts.index(text)) + ] + out_vectors.append(vector) + return out_vectors + + def embed_query(self, text: str) -> List[float]: + """Return consistent embeddings for the text, if seen before, or a constant + one if the text is unknown.""" + result = self.embed_documents([text]) + return result[0] diff --git a/test_elasticsearch/test_server/test_helpers_vectorstore/_test_utils.py b/test_elasticsearch/test_server/test_helpers_vectorstore/_test_utils.py deleted file mode 100644 index f855c1d19..000000000 --- a/test_elasticsearch/test_server/test_helpers_vectorstore/_test_utils.py +++ /dev/null @@ -1,81 +0,0 @@ -# Licensed to Elasticsearch B.V. under one or more contributor -# license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright -# ownership. Elasticsearch B.V. licenses this file to you under -# the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import List - -from elastic_transport import Transport - -from elasticsearch.helpers.vectorstore._sync.embedding_service import EmbeddingService - - -class RequestSavingTransport(Transport): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.requests: list = [] - - def perform_request(self, *args, **kwargs): - self.requests.append(kwargs) - return super().perform_request(*args, **kwargs) - - -class FakeEmbeddings(EmbeddingService): - """Fake embeddings functionality for testing.""" - - def __init__(self, dimensionality: int = 10) -> None: - self.dimensionality = dimensionality - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Return simple embeddings. Embeddings encode each text as its index.""" - return [ - [float(1.0)] * (self.dimensionality - 1) + [float(i)] - for i in range(len(texts)) - ] - - def embed_query(self, text: str) -> List[float]: - """Return constant query embeddings. - Embeddings are identical to embed_documents(texts)[0]. - Distance to each text will be that text's index, - as it was passed to embed_documents. - """ - return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)] - - -class ConsistentFakeEmbeddings(FakeEmbeddings): - """Fake embeddings which remember all the texts seen so far to return consistent - vectors for the same texts.""" - - def __init__(self, dimensionality: int = 10) -> None: - self.known_texts: List[str] = [] - self.dimensionality = dimensionality - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Return consistent embeddings for each text seen so far.""" - out_vectors = [] - for text in texts: - if text not in self.known_texts: - self.known_texts.append(text) - vector = [float(1.0)] * (self.dimensionality - 1) + [ - float(self.known_texts.index(text)) - ] - out_vectors.append(vector) - return out_vectors - - def embed_query(self, text: str) -> List[float]: - """Return consistent embeddings for the text, if seen before, or a constant - one if the text is unknown.""" - result = self.embed_documents([text]) - return result[0] diff --git a/test_elasticsearch/test_server/test_helpers_vectorstore/conftest.py b/test_elasticsearch/test_server/test_helpers_vectorstore/conftest.py index 97c264532..c2028920e 100644 --- a/test_elasticsearch/test_server/test_helpers_vectorstore/conftest.py +++ b/test_elasticsearch/test_server/test_helpers_vectorstore/conftest.py @@ -21,7 +21,7 @@ from ...utils import wipe_cluster from ..conftest import _create -from ._test_utils import RequestSavingTransport +from . import RequestSavingTransport @pytest.fixture(scope="function") diff --git a/test_elasticsearch/test_server/test_helpers_vectorstore/test_embedding_service.py b/test_elasticsearch/test_server/test_helpers_vectorstore/test_embedding_service.py index 396480e3e..ddfc5b8ec 100644 --- a/test_elasticsearch/test_server/test_helpers_vectorstore/test_embedding_service.py +++ b/test_elasticsearch/test_server/test_helpers_vectorstore/test_embedding_service.py @@ -21,10 +21,8 @@ import pytest from elasticsearch import Elasticsearch +from elasticsearch.helpers.vectorstore import ElasticsearchEmbeddings from elasticsearch.helpers.vectorstore._sync._utils import model_is_deployed -from elasticsearch.helpers.vectorstore._sync.embedding_service import ( - ElasticsearchEmbeddings, -) # deployed with # https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html diff --git a/test_elasticsearch/test_server/test_helpers_vectorstore/test_vectorstore.py b/test_elasticsearch/test_server/test_helpers_vectorstore/test_vectorstore.py index 62e345119..792ae36a3 100644 --- a/test_elasticsearch/test_server/test_helpers_vectorstore/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_helpers_vectorstore/test_vectorstore.py @@ -34,7 +34,7 @@ ) from elasticsearch.helpers.vectorstore._sync._utils import model_is_deployed -from ._test_utils import ConsistentFakeEmbeddings, FakeEmbeddings +from . import ConsistentFakeEmbeddings, FakeEmbeddings logging.basicConfig(level=logging.DEBUG) @@ -119,7 +119,6 @@ def test_add_vectors(self, sync_client: Elasticsearch, index_name: str) -> None: texts = ["foo1", "foo2", "foo3"] metadatas = [{"page": i} for i in range(len(texts))] - """In real use case, embedding_input can be questions for each text""" embedding_vectors = embeddings.embed_documents(texts) store = VectorStore( diff --git a/utils/run-unasync.py b/utils/run-unasync.py index ec7623869..4a943c10f 100644 --- a/utils/run-unasync.py +++ b/utils/run-unasync.py @@ -25,33 +25,27 @@ def cleanup(source_dir: Path, output_dir: Path, patterns: list[str]): for file in glob("*.py", root_dir=source_dir): - path = Path(output_dir) / file + path = output_dir / file for pattern in patterns: subprocess.check_call(["sed", "-i.bak", pattern, str(path)]) subprocess.check_call(["rm", f"{path}.bak"]) -def format_dir(dir: Path): - subprocess.check_call(["isort", "--profile=black", dir]) - subprocess.check_call(["black", dir]) - - def run( rule: unasync.Rule, cleanup_patterns: list[str] = [], - format: bool = False, ): - root = Path(__file__).absolute().parent.parent - source_dir = root / rule.fromdir.lstrip("/") - output_dir = root / rule.todir.lstrip("/") + root_dir = Path(__file__).absolute().parent.parent + source_dir = root_dir / rule.fromdir.lstrip("/") + output_dir = root_dir / rule.todir.lstrip("/") filepaths = [] for root, _, filenames in os.walk(source_dir): for filename in filenames: - if filename.rpartition(".")[-1] in ( + if filename.rpartition(".")[-1] in { "py", "pyi", - ) and not filename.startswith("utils.py"): + } and not filename.startswith("utils.py"): filepaths.append(os.path.join(root, filename)) unasync.unasync_files(filepaths, [rule]) @@ -59,10 +53,6 @@ def run( if cleanup_patterns: cleanup(source_dir, output_dir, cleanup_patterns) - if format: - format_dir(source_dir) - format_dir(output_dir) - def main(): run( @@ -103,7 +93,6 @@ def main(): cleanup_patterns=[ "/^import asyncio$/d", ], - format=True, ) From f32ceb2c6dda1ecdd21701847095ecce5fc05820 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Mon, 29 Apr 2024 14:10:22 +0200 Subject: [PATCH 33/36] add error tests for strategies --- .../__init__.py | 0 .../conftest.py | 0 .../docker-compose.yml | 0 .../test_embedding_service.py | 0 .../test_vectorstore.py | 0 test_elasticsearch/test_strategies.py | 62 +++++++++++++++++++ 6 files changed, 62 insertions(+) rename test_elasticsearch/test_server/{test_helpers_vectorstore => test_vectorstore}/__init__.py (100%) rename test_elasticsearch/test_server/{test_helpers_vectorstore => test_vectorstore}/conftest.py (100%) rename test_elasticsearch/test_server/{test_helpers_vectorstore => test_vectorstore}/docker-compose.yml (100%) rename test_elasticsearch/test_server/{test_helpers_vectorstore => test_vectorstore}/test_embedding_service.py (100%) rename test_elasticsearch/test_server/{test_helpers_vectorstore => test_vectorstore}/test_vectorstore.py (100%) create mode 100644 test_elasticsearch/test_strategies.py diff --git a/test_elasticsearch/test_server/test_helpers_vectorstore/__init__.py b/test_elasticsearch/test_server/test_vectorstore/__init__.py similarity index 100% rename from test_elasticsearch/test_server/test_helpers_vectorstore/__init__.py rename to test_elasticsearch/test_server/test_vectorstore/__init__.py diff --git a/test_elasticsearch/test_server/test_helpers_vectorstore/conftest.py b/test_elasticsearch/test_server/test_vectorstore/conftest.py similarity index 100% rename from test_elasticsearch/test_server/test_helpers_vectorstore/conftest.py rename to test_elasticsearch/test_server/test_vectorstore/conftest.py diff --git a/test_elasticsearch/test_server/test_helpers_vectorstore/docker-compose.yml b/test_elasticsearch/test_server/test_vectorstore/docker-compose.yml similarity index 100% rename from test_elasticsearch/test_server/test_helpers_vectorstore/docker-compose.yml rename to test_elasticsearch/test_server/test_vectorstore/docker-compose.yml diff --git a/test_elasticsearch/test_server/test_helpers_vectorstore/test_embedding_service.py b/test_elasticsearch/test_server/test_vectorstore/test_embedding_service.py similarity index 100% rename from test_elasticsearch/test_server/test_helpers_vectorstore/test_embedding_service.py rename to test_elasticsearch/test_server/test_vectorstore/test_embedding_service.py diff --git a/test_elasticsearch/test_server/test_helpers_vectorstore/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py similarity index 100% rename from test_elasticsearch/test_server/test_helpers_vectorstore/test_vectorstore.py rename to test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py diff --git a/test_elasticsearch/test_strategies.py b/test_elasticsearch/test_strategies.py new file mode 100644 index 000000000..11eb21579 --- /dev/null +++ b/test_elasticsearch/test_strategies.py @@ -0,0 +1,62 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from elasticsearch.helpers.vectorstore import ( + DenseVectorScriptScoreStrategy, + DenseVectorStrategy, + SparseVectorStrategy, +) + + +def test_sparse_vector_strategy_raises_errors(): + strategy = SparseVectorStrategy("my_model_id") + + with pytest.raises(ValueError): + # missing query + strategy.es_query(None, None, "text_field", "vector_field", 10, 20, []) + + with pytest.raises(ValueError): + # query vector not allowed + strategy.es_query("hi", [1, 2, 3], "text_field", "vector_field", 10, 20, []) + + +def test_dense_vector_strategy_raises_error(): + with pytest.raises(ValueError): + # unknown distance + DenseVectorStrategy(hybrid=True, text_field=None) + + with pytest.raises(ValueError): + # unknown distance + DenseVectorStrategy(distance="unknown distance").es_mappings_settings( + "text_field", "vector_field", 10 + ) + + +def test_dense_vector_script_score_strategy_raises_error(): + with pytest.raises(ValueError): + # missing query vector + DenseVectorScriptScoreStrategy().es_query( + None, None, "text_field", "vector_field", 10, 20, [] + ) + + with pytest.raises(ValueError): + # unknown distance + DenseVectorScriptScoreStrategy(distance="unknown distance").es_query( + None, [1, 2, 3], "text_field", "vector_field", 10, 20, [] + ) From 9b1778ef8cdf90583a4190af4c5f4c97d56302dd Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Tue, 30 Apr 2024 14:20:13 +0200 Subject: [PATCH 34/36] canonical names, keyword args only --- .../helpers/vectorstore/_async/_utils.py | 4 +- .../vectorstore/_async/embedding_service.py | 11 +- .../helpers/vectorstore/_async/strategies.py | 25 +- .../helpers/vectorstore/_async/vectorstore.py | 49 ++-- .../helpers/vectorstore/_sync/_utils.py | 4 +- .../vectorstore/_sync/embedding_service.py | 11 +- .../helpers/vectorstore/_sync/strategies.py | 25 +- .../helpers/vectorstore/_sync/vectorstore.py | 49 ++-- test_elasticsearch/test_server/conftest.py | 5 +- .../test_server/test_mapbox_vector_tile.py | 9 +- .../test_server/test_vectorstore/conftest.py | 2 +- .../test_embedding_service.py | 10 +- .../test_vectorstore/test_vectorstore.py | 255 +++++++++--------- test_elasticsearch/test_strategies.py | 38 ++- 14 files changed, 282 insertions(+), 215 deletions(-) diff --git a/elasticsearch/helpers/vectorstore/_async/_utils.py b/elasticsearch/helpers/vectorstore/_async/_utils.py index 2305b1448..67b6b6a27 100644 --- a/elasticsearch/helpers/vectorstore/_async/_utils.py +++ b/elasticsearch/helpers/vectorstore/_async/_utils.py @@ -31,9 +31,9 @@ async def model_must_be_deployed(client: AsyncElasticsearch, model_id: str) -> N pass -async def model_is_deployed(es_client: AsyncElasticsearch, model_id: str) -> bool: +async def model_is_deployed(client: AsyncElasticsearch, model_id: str) -> bool: try: - await model_must_be_deployed(es_client, model_id) + await model_must_be_deployed(client, model_id) return True except NotFoundError: return False diff --git a/elasticsearch/helpers/vectorstore/_async/embedding_service.py b/elasticsearch/helpers/vectorstore/_async/embedding_service.py index 2612372b4..20005b665 100644 --- a/elasticsearch/helpers/vectorstore/_async/embedding_service.py +++ b/elasticsearch/helpers/vectorstore/_async/embedding_service.py @@ -52,7 +52,8 @@ class AsyncElasticsearchEmbeddings(AsyncEmbeddingService): def __init__( self, - es_client: AsyncElasticsearch, + *, + client: AsyncElasticsearch, model_id: str, input_field: str = "text_field", user_agent: str = f"elasticsearch-py-es/{lib_version}", @@ -63,14 +64,14 @@ def __init__( :param model_id: The model_id of the model deployed in the Elasticsearch cluster. :param input_field: The name of the key for the input text field in the document. Defaults to 'text_field'. - :param es_client: Elasticsearch client connection. Alternatively specify the + :param client: Elasticsearch client connection. Alternatively specify the Elasticsearch connection with the other es_* parameters. """ # Add integration-specific usage header for tracking usage in Elastic Cloud. # client.options preserves existing (non-user-agent) headers. - es_client = es_client.options(headers={"User-Agent": user_agent}) + client = client.options(headers={"User-Agent": user_agent}) - self.es_client = es_client + self.client = client self.model_id = model_id self.input_field = input_field @@ -82,7 +83,7 @@ async def embed_query(self, text: str) -> List[float]: return result[0] async def _embedding_func(self, texts: List[str]) -> List[List[float]]: - response = await self.es_client.ml.infer_trained_model( + response = await self.client.ml.infer_trained_model( model_id=self.model_id, docs=[{self.input_field: text} for text in texts] ) return [doc["predicted_value"] for doc in response["inference_results"]] diff --git a/elasticsearch/helpers/vectorstore/_async/strategies.py b/elasticsearch/helpers/vectorstore/_async/strategies.py index 8d19698e6..09690e24c 100644 --- a/elasticsearch/helpers/vectorstore/_async/strategies.py +++ b/elasticsearch/helpers/vectorstore/_async/strategies.py @@ -27,6 +27,7 @@ class AsyncRetrievalStrategy(ABC): @abstractmethod def es_query( self, + *, query: Optional[str], query_vector: Optional[List[float]], text_field: str, @@ -51,6 +52,7 @@ def es_query( @abstractmethod def es_mappings_settings( self, + *, text_field: str, vector_field: str, num_dimensions: Optional[int], @@ -60,13 +62,15 @@ def es_mappings_settings( creating inference pipelines or checking if a required model was deployed. :param client: Elasticsearch client connection. - :param index_name: The name of the Elasticsearch index to create. - :param metadata_mapping: Flat dictionary with field and field type pairs that - describe the schema of the metadata. + :param text_field: The field containing the text data in the index. + :param vector_field: The field containing the vector representations in the index. + :param num_dimensions: If vectors are indexed, how many dimensions do they have. + + :return: Dictionary with field and field type pairs that describe the schema. """ async def before_index_creation( - self, client: AsyncElasticsearch, text_field: str, vector_field: str + self, *, client: AsyncElasticsearch, text_field: str, vector_field: str ) -> None: """ Executes before the index is created. Used for setting up @@ -101,6 +105,7 @@ def __init__(self, model_id: str = ".elser_model_2"): def es_query( self, + *, query: Optional[str], query_vector: Optional[List[float]], text_field: str, @@ -138,6 +143,7 @@ def es_query( def es_mappings_settings( self, + *, text_field: str, vector_field: str, num_dimensions: Optional[int], @@ -154,7 +160,7 @@ def es_mappings_settings( return mappings, settings async def before_index_creation( - self, client: AsyncElasticsearch, text_field: str, vector_field: str + self, *, client: AsyncElasticsearch, text_field: str, vector_field: str ) -> None: if self.model_id: await model_must_be_deployed(client, self.model_id) @@ -183,6 +189,7 @@ class AsyncDenseVectorStrategy(AsyncRetrievalStrategy): def __init__( self, + *, distance: DistanceMetric = DistanceMetric.COSINE, model_id: Optional[str] = None, hybrid: bool = False, @@ -202,6 +209,7 @@ def __init__( def es_query( self, + *, query: Optional[str], query_vector: Optional[List[float]], text_field: str, @@ -236,6 +244,7 @@ def es_query( def es_mappings_settings( self, + *, text_field: str, vector_field: str, num_dimensions: Optional[int], @@ -265,7 +274,7 @@ def es_mappings_settings( return mappings, {} async def before_index_creation( - self, client: AsyncElasticsearch, text_field: str, vector_field: str + self, *, client: AsyncElasticsearch, text_field: str, vector_field: str ) -> None: if self.model_id: await model_must_be_deployed(client, self.model_id) @@ -314,6 +323,7 @@ def __init__(self, distance: DistanceMetric = DistanceMetric.COSINE) -> None: def es_query( self, + *, query: Optional[str], query_vector: Optional[List[float]], text_field: str, @@ -365,6 +375,7 @@ def es_query( def es_mappings_settings( self, + *, text_field: str, vector_field: str, num_dimensions: Optional[int], @@ -396,6 +407,7 @@ def __init__( def es_query( self, + *, query: Optional[str], query_vector: Optional[List[float]], text_field: str, @@ -423,6 +435,7 @@ def es_query( def es_mappings_settings( self, + *, text_field: str, vector_field: str, num_dimensions: Optional[int], diff --git a/elasticsearch/helpers/vectorstore/_async/vectorstore.py b/elasticsearch/helpers/vectorstore/_async/vectorstore.py index 146b05394..b79e2dcaf 100644 --- a/elasticsearch/helpers/vectorstore/_async/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_async/vectorstore.py @@ -50,8 +50,9 @@ class AsyncVectorStore: def __init__( self, - es_client: AsyncElasticsearch, - index_name: str, + client: AsyncElasticsearch, + *, + index: str, retrieval_strategy: AsyncRetrievalStrategy, embedding_service: Optional[AsyncEmbeddingService] = None, num_dimensions: Optional[int] = None, @@ -63,26 +64,26 @@ def __init__( """ :param user_header: user agent header specific to the 3rd party integration. Used for usage tracking in Elastic Cloud. - :param index_name: The name of the index to query. + :param index: The name of the index to query. :param retrieval_strategy: how to index and search the data. See the strategies module for availble strategies. :param text_field: Name of the field with the textual data. :param vector_field: For strategies that perform embedding inference in Python, the embedding vector goes in this field. - :param es_client: Elasticsearch client connection. Alternatively specify the + :param client: Elasticsearch client connection. Alternatively specify the Elasticsearch connection with the other es_* parameters. """ # Add integration-specific usage header for tracking usage in Elastic Cloud. # client.options preserves existing (non-user-agent) headers. - es_client = es_client.options(headers={"User-Agent": user_agent}) + client = client.options(headers={"User-Agent": user_agent}) if hasattr(retrieval_strategy, "text_field"): retrieval_strategy.text_field = text_field if hasattr(retrieval_strategy, "vector_field"): retrieval_strategy.vector_field = vector_field - self.es_client = es_client - self.index_name = index_name + self.client = client + self.index = index self.retrieval_strategy = retrieval_strategy self.embedding_service = embedding_service self.num_dimensions = num_dimensions @@ -91,11 +92,12 @@ def __init__( self.metadata_mappings = metadata_mappings async def close(self) -> None: - return await self.es_client.close() + return await self.client.close() async def add_texts( self, texts: List[str], + *, metadatas: Optional[List[Dict[str, Any]]] = None, vectors: Optional[List[List[float]]] = None, ids: Optional[List[str]] = None, @@ -136,7 +138,7 @@ async def add_texts( request: Dict[str, Any] = { "_op_type": "index", - "_index": self.index_name, + "_index": self.index, self.text_field: text, "metadata": metadata, "_id": ids[i], @@ -150,7 +152,7 @@ async def add_texts( if len(requests) > 0: try: success, failed = await async_bulk( - self.es_client, + self.client, requests, stats_only=True, refresh=refresh_indices, @@ -170,6 +172,7 @@ async def add_texts( async def delete( # type: ignore[no-untyped-def] self, + *, ids: Optional[List[str]] = None, query: Optional[Dict[str, Any]] = None, refresh_indices: bool = True, @@ -191,11 +194,11 @@ async def delete( # type: ignore[no-untyped-def] try: if ids: body = [ - {"_op_type": "delete", "_index": self.index_name, "_id": _id} + {"_op_type": "delete", "_index": self.index, "_id": _id} for _id in ids ] await async_bulk( - self.es_client, + self.client, body, refresh=refresh_indices, ignore_status=404, @@ -204,8 +207,8 @@ async def delete( # type: ignore[no-untyped-def] logger.debug(f"Deleted {len(body)} texts from index") else: - await self.es_client.delete_by_query( - index=self.index_name, + await self.client.delete_by_query( + index=self.index, query=query, refresh=refresh_indices, **delete_kwargs, @@ -221,6 +224,7 @@ async def delete( # type: ignore[no-untyped-def] async def search( self, + *, query: Optional[str], query_vector: Optional[List[float]] = None, k: int = 4, @@ -270,8 +274,8 @@ async def search( query_body = custom_query(query_body, query) logger.debug(f"Calling custom_query, Query body now: {query_body}") - response = await self.es_client.search( - index=self.index_name, + response = await self.client.search( + index=self.index, **query_body, size=k, source=True, @@ -282,9 +286,9 @@ async def search( return hits async def _create_index_if_not_exists(self) -> None: - exists = await self.es_client.indices.exists(index=self.index_name) + exists = await self.client.indices.exists(index=self.index) if exists.meta.status == 200: - logger.debug(f"Index {self.index_name} already exists. Skipping creation.") + logger.debug(f"Index {self.index} already exists. Skipping creation.") return if self.retrieval_strategy.needs_inference(): @@ -312,14 +316,17 @@ async def _create_index_if_not_exists(self) -> None: mappings["properties"]["metadata"] = {"properties": metadata} await self.retrieval_strategy.before_index_creation( - self.es_client, self.text_field, self.vector_field + client=self.client, + text_field=self.text_field, + vector_field=self.vector_field, ) - await self.es_client.indices.create( - index=self.index_name, mappings=mappings, settings=settings + await self.client.indices.create( + index=self.index, mappings=mappings, settings=settings ) async def max_marginal_relevance_search( self, + *, embedding_service: AsyncEmbeddingService, query: str, vector_field: str, diff --git a/elasticsearch/helpers/vectorstore/_sync/_utils.py b/elasticsearch/helpers/vectorstore/_sync/_utils.py index dba9bdcd4..496aec970 100644 --- a/elasticsearch/helpers/vectorstore/_sync/_utils.py +++ b/elasticsearch/helpers/vectorstore/_sync/_utils.py @@ -31,9 +31,9 @@ def model_must_be_deployed(client: Elasticsearch, model_id: str) -> None: pass -def model_is_deployed(es_client: Elasticsearch, model_id: str) -> bool: +def model_is_deployed(client: Elasticsearch, model_id: str) -> bool: try: - model_must_be_deployed(es_client, model_id) + model_must_be_deployed(client, model_id) return True except NotFoundError: return False diff --git a/elasticsearch/helpers/vectorstore/_sync/embedding_service.py b/elasticsearch/helpers/vectorstore/_sync/embedding_service.py index 51b607237..5b0163d98 100644 --- a/elasticsearch/helpers/vectorstore/_sync/embedding_service.py +++ b/elasticsearch/helpers/vectorstore/_sync/embedding_service.py @@ -52,7 +52,8 @@ class ElasticsearchEmbeddings(EmbeddingService): def __init__( self, - es_client: Elasticsearch, + *, + client: Elasticsearch, model_id: str, input_field: str = "text_field", user_agent: str = f"elasticsearch-py-es/{lib_version}", @@ -63,14 +64,14 @@ def __init__( :param model_id: The model_id of the model deployed in the Elasticsearch cluster. :param input_field: The name of the key for the input text field in the document. Defaults to 'text_field'. - :param es_client: Elasticsearch client connection. Alternatively specify the + :param client: Elasticsearch client connection. Alternatively specify the Elasticsearch connection with the other es_* parameters. """ # Add integration-specific usage header for tracking usage in Elastic Cloud. # client.options preserves existing (non-user-agent) headers. - es_client = es_client.options(headers={"User-Agent": user_agent}) + client = client.options(headers={"User-Agent": user_agent}) - self.es_client = es_client + self.client = client self.model_id = model_id self.input_field = input_field @@ -82,7 +83,7 @@ def embed_query(self, text: str) -> List[float]: return result[0] def _embedding_func(self, texts: List[str]) -> List[List[float]]: - response = self.es_client.ml.infer_trained_model( + response = self.client.ml.infer_trained_model( model_id=self.model_id, docs=[{self.input_field: text} for text in texts] ) return [doc["predicted_value"] for doc in response["inference_results"]] diff --git a/elasticsearch/helpers/vectorstore/_sync/strategies.py b/elasticsearch/helpers/vectorstore/_sync/strategies.py index d172a7c1f..17b42beb1 100644 --- a/elasticsearch/helpers/vectorstore/_sync/strategies.py +++ b/elasticsearch/helpers/vectorstore/_sync/strategies.py @@ -27,6 +27,7 @@ class RetrievalStrategy(ABC): @abstractmethod def es_query( self, + *, query: Optional[str], query_vector: Optional[List[float]], text_field: str, @@ -51,6 +52,7 @@ def es_query( @abstractmethod def es_mappings_settings( self, + *, text_field: str, vector_field: str, num_dimensions: Optional[int], @@ -60,13 +62,15 @@ def es_mappings_settings( creating inference pipelines or checking if a required model was deployed. :param client: Elasticsearch client connection. - :param index_name: The name of the Elasticsearch index to create. - :param metadata_mapping: Flat dictionary with field and field type pairs that - describe the schema of the metadata. + :param text_field: The field containing the text data in the index. + :param vector_field: The field containing the vector representations in the index. + :param num_dimensions: If vectors are indexed, how many dimensions do they have. + + :return: Dictionary with field and field type pairs that describe the schema. """ def before_index_creation( - self, client: Elasticsearch, text_field: str, vector_field: str + self, *, client: Elasticsearch, text_field: str, vector_field: str ) -> None: """ Executes before the index is created. Used for setting up @@ -101,6 +105,7 @@ def __init__(self, model_id: str = ".elser_model_2"): def es_query( self, + *, query: Optional[str], query_vector: Optional[List[float]], text_field: str, @@ -138,6 +143,7 @@ def es_query( def es_mappings_settings( self, + *, text_field: str, vector_field: str, num_dimensions: Optional[int], @@ -154,7 +160,7 @@ def es_mappings_settings( return mappings, settings def before_index_creation( - self, client: Elasticsearch, text_field: str, vector_field: str + self, *, client: Elasticsearch, text_field: str, vector_field: str ) -> None: if self.model_id: model_must_be_deployed(client, self.model_id) @@ -183,6 +189,7 @@ class DenseVectorStrategy(RetrievalStrategy): def __init__( self, + *, distance: DistanceMetric = DistanceMetric.COSINE, model_id: Optional[str] = None, hybrid: bool = False, @@ -202,6 +209,7 @@ def __init__( def es_query( self, + *, query: Optional[str], query_vector: Optional[List[float]], text_field: str, @@ -236,6 +244,7 @@ def es_query( def es_mappings_settings( self, + *, text_field: str, vector_field: str, num_dimensions: Optional[int], @@ -265,7 +274,7 @@ def es_mappings_settings( return mappings, {} def before_index_creation( - self, client: Elasticsearch, text_field: str, vector_field: str + self, *, client: Elasticsearch, text_field: str, vector_field: str ) -> None: if self.model_id: model_must_be_deployed(client, self.model_id) @@ -314,6 +323,7 @@ def __init__(self, distance: DistanceMetric = DistanceMetric.COSINE) -> None: def es_query( self, + *, query: Optional[str], query_vector: Optional[List[float]], text_field: str, @@ -365,6 +375,7 @@ def es_query( def es_mappings_settings( self, + *, text_field: str, vector_field: str, num_dimensions: Optional[int], @@ -396,6 +407,7 @@ def __init__( def es_query( self, + *, query: Optional[str], query_vector: Optional[List[float]], text_field: str, @@ -423,6 +435,7 @@ def es_query( def es_mappings_settings( self, + *, text_field: str, vector_field: str, num_dimensions: Optional[int], diff --git a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py index 03d9932a2..2feb96ec4 100644 --- a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py @@ -47,8 +47,9 @@ class VectorStore: def __init__( self, - es_client: Elasticsearch, - index_name: str, + client: Elasticsearch, + *, + index: str, retrieval_strategy: RetrievalStrategy, embedding_service: Optional[EmbeddingService] = None, num_dimensions: Optional[int] = None, @@ -60,26 +61,26 @@ def __init__( """ :param user_header: user agent header specific to the 3rd party integration. Used for usage tracking in Elastic Cloud. - :param index_name: The name of the index to query. + :param index: The name of the index to query. :param retrieval_strategy: how to index and search the data. See the strategies module for availble strategies. :param text_field: Name of the field with the textual data. :param vector_field: For strategies that perform embedding inference in Python, the embedding vector goes in this field. - :param es_client: Elasticsearch client connection. Alternatively specify the + :param client: Elasticsearch client connection. Alternatively specify the Elasticsearch connection with the other es_* parameters. """ # Add integration-specific usage header for tracking usage in Elastic Cloud. # client.options preserves existing (non-user-agent) headers. - es_client = es_client.options(headers={"User-Agent": user_agent}) + client = client.options(headers={"User-Agent": user_agent}) if hasattr(retrieval_strategy, "text_field"): retrieval_strategy.text_field = text_field if hasattr(retrieval_strategy, "vector_field"): retrieval_strategy.vector_field = vector_field - self.es_client = es_client - self.index_name = index_name + self.client = client + self.index = index self.retrieval_strategy = retrieval_strategy self.embedding_service = embedding_service self.num_dimensions = num_dimensions @@ -88,11 +89,12 @@ def __init__( self.metadata_mappings = metadata_mappings def close(self) -> None: - return self.es_client.close() + return self.client.close() def add_texts( self, texts: List[str], + *, metadatas: Optional[List[Dict[str, Any]]] = None, vectors: Optional[List[List[float]]] = None, ids: Optional[List[str]] = None, @@ -133,7 +135,7 @@ def add_texts( request: Dict[str, Any] = { "_op_type": "index", - "_index": self.index_name, + "_index": self.index, self.text_field: text, "metadata": metadata, "_id": ids[i], @@ -147,7 +149,7 @@ def add_texts( if len(requests) > 0: try: success, failed = bulk( - self.es_client, + self.client, requests, stats_only=True, refresh=refresh_indices, @@ -167,6 +169,7 @@ def add_texts( def delete( # type: ignore[no-untyped-def] self, + *, ids: Optional[List[str]] = None, query: Optional[Dict[str, Any]] = None, refresh_indices: bool = True, @@ -188,11 +191,11 @@ def delete( # type: ignore[no-untyped-def] try: if ids: body = [ - {"_op_type": "delete", "_index": self.index_name, "_id": _id} + {"_op_type": "delete", "_index": self.index, "_id": _id} for _id in ids ] bulk( - self.es_client, + self.client, body, refresh=refresh_indices, ignore_status=404, @@ -201,8 +204,8 @@ def delete( # type: ignore[no-untyped-def] logger.debug(f"Deleted {len(body)} texts from index") else: - self.es_client.delete_by_query( - index=self.index_name, + self.client.delete_by_query( + index=self.index, query=query, refresh=refresh_indices, **delete_kwargs, @@ -218,6 +221,7 @@ def delete( # type: ignore[no-untyped-def] def search( self, + *, query: Optional[str], query_vector: Optional[List[float]] = None, k: int = 4, @@ -267,8 +271,8 @@ def search( query_body = custom_query(query_body, query) logger.debug(f"Calling custom_query, Query body now: {query_body}") - response = self.es_client.search( - index=self.index_name, + response = self.client.search( + index=self.index, **query_body, size=k, source=True, @@ -279,9 +283,9 @@ def search( return hits def _create_index_if_not_exists(self) -> None: - exists = self.es_client.indices.exists(index=self.index_name) + exists = self.client.indices.exists(index=self.index) if exists.meta.status == 200: - logger.debug(f"Index {self.index_name} already exists. Skipping creation.") + logger.debug(f"Index {self.index} already exists. Skipping creation.") return if self.retrieval_strategy.needs_inference(): @@ -309,14 +313,17 @@ def _create_index_if_not_exists(self) -> None: mappings["properties"]["metadata"] = {"properties": metadata} self.retrieval_strategy.before_index_creation( - self.es_client, self.text_field, self.vector_field + client=self.client, + text_field=self.text_field, + vector_field=self.vector_field, ) - self.es_client.indices.create( - index=self.index_name, mappings=mappings, settings=settings + self.client.indices.create( + index=self.index, mappings=mappings, settings=settings ) def max_marginal_relevance_search( self, + *, embedding_service: EmbeddingService, query: str, vector_field: str, diff --git a/test_elasticsearch/test_server/conftest.py b/test_elasticsearch/test_server/conftest.py index 9bbadece7..7b87fd1d3 100644 --- a/test_elasticsearch/test_server/conftest.py +++ b/test_elasticsearch/test_server/conftest.py @@ -30,7 +30,7 @@ ELASTICSEARCH_REST_API_TESTS = [] -def _create(elasticsearch_url, transport=None): +def _create(elasticsearch_url, transport=None, node_class=None): # Configure the client with certificates kw = {} if elasticsearch_url.startswith("https://"): @@ -41,6 +41,9 @@ def _create(elasticsearch_url, transport=None): if "PYTHON_CONNECTION_CLASS" in os.environ: kw["node_class"] = os.environ["PYTHON_CONNECTION_CLASS"] + if node_class is not None and "node_class" not in kw: + kw["node_class"] = node_class + if transport: kw["transport_class"] = transport diff --git a/test_elasticsearch/test_server/test_mapbox_vector_tile.py b/test_elasticsearch/test_server/test_mapbox_vector_tile.py index 988210984..332e8d144 100644 --- a/test_elasticsearch/test_server/test_mapbox_vector_tile.py +++ b/test_elasticsearch/test_server/test_mapbox_vector_tile.py @@ -17,7 +17,9 @@ import pytest -from elasticsearch import Elasticsearch, RequestError +from elasticsearch import RequestError + +from .conftest import _create @pytest.fixture(scope="function") @@ -73,7 +75,8 @@ def mvt_setup(sync_client): @pytest.mark.parametrize("node_class", ["urllib3", "requests"]) def test_mapbox_vector_tile_error(elasticsearch_url, mvt_setup, node_class, ca_certs): - client = Elasticsearch(elasticsearch_url, node_class=node_class, ca_certs=ca_certs) + client = _create(elasticsearch_url, node_class=node_class) + client.search_mvt( index="museums", zoom=13, @@ -121,7 +124,7 @@ def test_mapbox_vector_tile_response( except ImportError: return pytest.skip("Requires the 'mapbox-vector-tile' package") - client = Elasticsearch(elasticsearch_url, node_class=node_class, ca_certs=ca_certs) + client = _create(elasticsearch_url, node_class=node_class) resp = client.search_mvt( index="museums", diff --git a/test_elasticsearch/test_server/test_vectorstore/conftest.py b/test_elasticsearch/test_server/test_vectorstore/conftest.py index c2028920e..a0886a9c4 100644 --- a/test_elasticsearch/test_server/test_vectorstore/conftest.py +++ b/test_elasticsearch/test_server/test_vectorstore/conftest.py @@ -25,7 +25,7 @@ @pytest.fixture(scope="function") -def index_name() -> str: +def index() -> str: return f"test_{uuid.uuid4().hex}" diff --git a/test_elasticsearch/test_server/test_vectorstore/test_embedding_service.py b/test_elasticsearch/test_server/test_vectorstore/test_embedding_service.py index ddfc5b8ec..c667a8a38 100644 --- a/test_elasticsearch/test_server/test_vectorstore/test_embedding_service.py +++ b/test_elasticsearch/test_server/test_vectorstore/test_embedding_service.py @@ -38,7 +38,7 @@ def test_elasticsearch_embedding_documents(sync_client: Elasticsearch) -> None: documents = ["foo bar", "bar foo", "foo"] embedding = ElasticsearchEmbeddings( - es_client=sync_client, user_agent="test", model_id=MODEL_ID + client=sync_client, user_agent="test", model_id=MODEL_ID ) output = embedding.embed_documents(documents) assert len(output) == 3 @@ -55,7 +55,7 @@ def test_elasticsearch_embedding_query(sync_client: Elasticsearch) -> None: document = "foo bar" embedding = ElasticsearchEmbeddings( - es_client=sync_client, user_agent="test", model_id=MODEL_ID + client=sync_client, user_agent="test", model_id=MODEL_ID ) output = embedding.embed_query(document) assert len(output) == NUM_DIMENSIONS @@ -70,19 +70,19 @@ def test_user_agent_default( pytest.skip(f"{MODEL_ID} model is not deployed in ML Node, skipping test") embeddings = ElasticsearchEmbeddings( - es_client=sync_client_request_saving, model_id=MODEL_ID + client=sync_client_request_saving, model_id=MODEL_ID ) expected_pattern = r"^elasticsearch-py-es/\d+\.\d+\.\d+$" - got_agent = embeddings.es_client._headers["User-Agent"] + got_agent = embeddings.client._headers["User-Agent"] assert ( re.match(expected_pattern, got_agent) is not None ), f"The user agent '{got_agent}' does not match the expected pattern." embeddings.embed_query("foo bar") - requests = embeddings.es_client.transport.requests # type: ignore + requests = embeddings.client.transport.requests # type: ignore assert len(requests) == 1 got_request_agent = requests[0]["headers"]["User-Agent"] diff --git a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py index 792ae36a3..2dedc61c3 100644 --- a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py @@ -61,7 +61,7 @@ class TestVectorStore: def test_search_without_metadata( - self, sync_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index: str ) -> None: """Test end to end construction and search without metadata.""" @@ -78,36 +78,36 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: return query_body store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=DenseVectorStrategy(), embedding_service=FakeEmbeddings(), - es_client=sync_client, + client=sync_client, ) texts = ["foo", "bar", "baz"] store.add_texts(texts) - output = store.search("foo", k=1, custom_query=assert_query) + output = store.search(query="foo", k=1, custom_query=assert_query) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] def test_search_without_metadata_async( - self, sync_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index: str ) -> None: """Test end to end construction and search without metadata.""" store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=DenseVectorStrategy(), embedding_service=FakeEmbeddings(), - es_client=sync_client, + client=sync_client, ) texts = ["foo", "bar", "baz"] store.add_texts(texts) - output = store.search("foo", k=1) + output = store.search(query="foo", k=1) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - def test_add_vectors(self, sync_client: Elasticsearch, index_name: str) -> None: + def test_add_vectors(self, sync_client: Elasticsearch, index: str) -> None: """ Test adding pre-built embeddings instead of using inference for the texts. This allows you to separate the embeddings text and the page_content @@ -122,49 +122,45 @@ def test_add_vectors(self, sync_client: Elasticsearch, index_name: str) -> None: embedding_vectors = embeddings.embed_documents(texts) store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=DenseVectorStrategy(), embedding_service=embeddings, - es_client=sync_client, + client=sync_client, ) store.add_texts(texts=texts, vectors=embedding_vectors, metadatas=metadatas) - output = store.search("foo1", k=1) + output = store.search(query="foo1", k=1) assert [doc["_source"]["text_field"] for doc in output] == ["foo1"] assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] - def test_search_with_metadata( - self, sync_client: Elasticsearch, index_name: str - ) -> None: + def test_search_with_metadata(self, sync_client: Elasticsearch, index: str) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=DenseVectorStrategy(), embedding_service=ConsistentFakeEmbeddings(), - es_client=sync_client, + client=sync_client, ) texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] store.add_texts(texts=texts, metadatas=metadatas) - output = store.search("foo", k=1) + output = store.search(query="foo", k=1) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] - output = store.search("bar", k=1) + output = store.search(query="bar", k=1) assert [doc["_source"]["text_field"] for doc in output] == ["bar"] assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - def test_search_with_filter( - self, sync_client: Elasticsearch, index_name: str - ) -> None: + def test_search_with_filter(self, sync_client: Elasticsearch, index: str) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=DenseVectorStrategy(), embedding_service=FakeEmbeddings(), - es_client=sync_client, + client=sync_client, ) texts = ["foo", "foo", "foo"] @@ -192,15 +188,13 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["text_field"] for doc in output] == ["foo"] assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - def test_search_script_score( - self, sync_client: Elasticsearch, index_name: str - ) -> None: + def test_search_script_score(self, sync_client: Elasticsearch, index: str) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=DenseVectorScriptScoreStrategy(), embedding_service=FakeEmbeddings(), - es_client=sync_client, + client=sync_client, ) texts = ["foo", "bar", "baz"] @@ -235,18 +229,18 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert query_body == expected_query return query_body - output = store.search("foo", k=1, custom_query=assert_query) + output = store.search(query="foo", k=1, custom_query=assert_query) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] def test_search_script_score_with_filter( - self, sync_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index: str ) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=DenseVectorScriptScoreStrategy(), embedding_service=FakeEmbeddings(), - es_client=sync_client, + client=sync_client, ) texts = ["foo", "bar", "baz"] @@ -282,7 +276,7 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: return query_body output = store.search( - "foo", + query="foo", k=1, custom_query=assert_query, filter=[{"term": {"metadata.page": 0}}], @@ -291,16 +285,16 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["metadata"]["page"] for doc in output] == [0] def test_search_script_score_distance_dot_product( - self, sync_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index: str ) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=DenseVectorScriptScoreStrategy( distance=DistanceMetric.DOT_PRODUCT, ), embedding_service=FakeEmbeddings(), - es_client=sync_client, + client=sync_client, ) texts = ["foo", "bar", "baz"] @@ -336,18 +330,18 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: } return query_body - output = store.search("foo", k=1, custom_query=assert_query) + output = store.search(query="foo", k=1, custom_query=assert_query) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] def test_search_knn_with_hybrid_search( - self, sync_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index: str ) -> None: """Test end to end construction and search with metadata.""" store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=DenseVectorStrategy(hybrid=True), embedding_service=FakeEmbeddings(), - es_client=sync_client, + client=sync_client, ) texts = ["foo", "bar", "baz"] @@ -372,11 +366,11 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: } return query_body - output = store.search("foo", k=1, custom_query=assert_query) + output = store.search(query="foo", k=1, custom_query=assert_query) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] def test_search_knn_with_hybrid_search_rrf( - self, sync_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index: str ) -> None: """Test end to end construction and rrf hybrid search with metadata.""" texts = ["foo", "bar", "baz"] @@ -430,23 +424,23 @@ def assert_query( ] for rrf_test_case in rrf_test_cases: store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=DenseVectorStrategy(hybrid=True, rrf=rrf_test_case), embedding_service=FakeEmbeddings(), - es_client=sync_client, + client=sync_client, ) store.add_texts(texts) ## without fetch_k parameter output = store.search( - "foo", + query="foo", k=3, custom_query=partial(assert_query, expected_rrf=rrf_test_case), ) # 2. check query result is okay - es_output = store.es_client.search( - index=index_name, + es_output = store.client.search( + index=index, query={ "bool": { "filter": [], @@ -470,31 +464,31 @@ def assert_query( # 3. check rrf default option is okay store = VectorStore( - index_name=f"{index_name}_default", + index=f"{index}_default", retrieval_strategy=DenseVectorStrategy(hybrid=True), embedding_service=FakeEmbeddings(), - es_client=sync_client, + client=sync_client, ) store.add_texts(texts) ## with fetch_k parameter output = store.search( - "foo", + query="foo", k=3, num_candidates=50, custom_query=partial(assert_query, expected_rrf={}), ) def test_search_knn_with_custom_query_fn( - self, sync_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index: str ) -> None: """test that custom query function is called with the query string and query body""" store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=DenseVectorStrategy(), embedding_service=FakeEmbeddings(), - es_client=sync_client, + client=sync_client, ) def my_custom_query(query_body: dict, query: Optional[str]) -> dict: @@ -514,11 +508,11 @@ def my_custom_query(query_body: dict, query: Optional[str]) -> dict: texts = ["foo", "bar", "baz"] store.add_texts(texts) - output = store.search("foo", k=1, custom_query=my_custom_query) + output = store.search(query="foo", k=1, custom_query=my_custom_query) assert [doc["_source"]["text_field"] for doc in output] == ["bar"] def test_search_with_knn_infer_instack( - self, sync_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index: str ) -> None: """test end to end with knn retrieval strategy and inference in-stack""" @@ -530,15 +524,15 @@ def test_search_with_knn_infer_instack( text_field = "text_field" store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=DenseVectorStrategy( model_id="sentence-transformers__all-minilm-l6-v2" ), - es_client=sync_client, + client=sync_client, ) # setting up the pipeline for inference - store.es_client.ingest.put_pipeline( + store.client.ingest.put_pipeline( id="test_pipeline", processors=[ { @@ -553,8 +547,8 @@ def test_search_with_knn_infer_instack( # creating a new index with the pipeline, # not relying on langchain to create the index - store.es_client.indices.create( - index=index_name, + store.client.indices.create( + index=index, mappings={ "properties": { text_field: {"type": "text_field"}, @@ -577,13 +571,13 @@ def test_search_with_knn_infer_instack( texts = ["foo", "bar", "baz"] for i, text in enumerate(texts): - store.es_client.create( - index=index_name, + store.client.create( + index=index, id=str(i), document={text_field: text, "metadata": {}}, ) - store.es_client.indices.refresh(index=index_name) + store.client.indices.refresh(index=index) def assert_query(query_body: dict, query: Optional[str]) -> dict: assert query_body == { @@ -602,54 +596,53 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: } return query_body - output = store.search("foo", k=1, custom_query=assert_query) + output = store.search(query="foo", k=1, custom_query=assert_query) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] - output = store.search("bar", k=1) + output = store.search(query="bar", k=1) assert [doc["_source"]["text_field"] for doc in output] == ["bar"] def test_search_with_sparse_infer_instack( - self, sync_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index: str ) -> None: """test end to end with sparse retrieval strategy and inference in-stack""" if not model_is_deployed(sync_client, ELSER_MODEL_ID): reason = f"{ELSER_MODEL_ID} model not deployed in ML Node, skipping test" - pytest.skip(reason) store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=SparseVectorStrategy(model_id=ELSER_MODEL_ID), - es_client=sync_client, + client=sync_client, ) texts = ["foo", "bar", "baz"] store.add_texts(texts) - output = store.search("foo", k=1) + output = store.search(query="foo", k=1) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] def test_deployed_model_check_fails_semantic( - self, sync_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index: str ) -> None: """test that exceptions are raised if a specified model is not deployed""" with pytest.raises(NotFoundError): store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=DenseVectorStrategy( model_id="non-existing model ID" ), - es_client=sync_client, + client=sync_client, ) store.add_texts(["foo", "bar", "baz"]) - def test_search_bm25(self, sync_client: Elasticsearch, index_name: str) -> None: + def test_search_bm25(self, sync_client: Elasticsearch, index: str) -> None: """Test end to end using the BM25Strategy retrieval strategy.""" store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=BM25Strategy(), - es_client=sync_client, + client=sync_client, ) texts = ["foo", "bar", "baz"] @@ -666,17 +659,17 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: } return query_body - output = store.search("foo", k=1, custom_query=assert_query) + output = store.search(query="foo", k=1, custom_query=assert_query) assert [doc["_source"]["text_field"] for doc in output] == ["foo"] def test_search_bm25_with_filter( - self, sync_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index: str ) -> None: """Test end to using the BM25Strategy retrieval strategy with metadata.""" store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=BM25Strategy(), - es_client=sync_client, + client=sync_client, ) texts = ["foo", "foo", "foo"] @@ -695,7 +688,7 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: return query_body output = store.search( - "foo", + query="foo", k=3, custom_query=assert_query, filter=[{"term": {"metadata.page": 1}}], @@ -703,53 +696,53 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict: assert [doc["_source"]["text_field"] for doc in output] == ["foo"] assert [doc["_source"]["metadata"]["page"] for doc in output] == [1] - def test_delete(self, sync_client: Elasticsearch, index_name: str) -> None: + def test_delete(self, sync_client: Elasticsearch, index: str) -> None: """Test delete methods from vector store.""" store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=DenseVectorStrategy(), embedding_service=FakeEmbeddings(), - es_client=sync_client, + client=sync_client, ) texts = ["foo", "bar", "baz", "gni"] metadatas = [{"page": i} for i in range(len(texts))] ids = store.add_texts(texts=texts, metadatas=metadatas) - output = store.search("foo", k=10) + output = store.search(query="foo", k=10) assert len(output) == 4 - store.delete(ids[1:3]) - output = store.search("foo", k=10) + store.delete(ids=ids[1:3]) + output = store.search(query="foo", k=10) assert len(output) == 2 - store.delete(["not-existing"]) - output = store.search("foo", k=10) + store.delete(ids=["not-existing"]) + output = store.search(query="foo", k=10) assert len(output) == 2 - store.delete([ids[0]]) - output = store.search("foo", k=10) + store.delete(ids=[ids[0]]) + output = store.search(query="foo", k=10) assert len(output) == 1 - store.delete([ids[3]]) - output = store.search("gni", k=10) + store.delete(ids=[ids[3]]) + output = store.search(query="gni", k=10) assert len(output) == 0 def test_indexing_exception_error( self, sync_client: Elasticsearch, - index_name: str, + index: str, caplog: pytest.LogCaptureFixture, ) -> None: """Test bulk exception logging is giving better hints.""" store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=BM25Strategy(), - es_client=sync_client, + client=sync_client, ) - store.es_client.indices.create( - index=index_name, + store.client.indices.create( + index=index, mappings={"properties": {}}, settings={"index": {"default_pipeline": "not-existing-pipeline"}}, ) @@ -765,17 +758,17 @@ def test_indexing_exception_error( assert log_message in caplog.text def test_user_agent_default( - self, sync_client_request_saving: Elasticsearch, index_name: str + self, sync_client_request_saving: Elasticsearch, index: str ) -> None: """Test to make sure the user-agent is set correctly.""" store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=BM25Strategy(), - es_client=sync_client_request_saving, + client=sync_client_request_saving, ) expected_pattern = r"^elasticsearch-py-vs/\d+\.\d+\.\d+$" - got_agent = store.es_client._headers["User-Agent"] + got_agent = store.client._headers["User-Agent"] assert ( re.match(expected_pattern, got_agent) is not None ), f"The user agent '{got_agent}' does not match the expected pattern." @@ -783,49 +776,49 @@ def test_user_agent_default( texts = ["foo", "bob", "baz"] store.add_texts(texts) - for request in store.es_client.transport.requests: # type: ignore + for request in store.client.transport.requests: # type: ignore agent = request["headers"]["User-Agent"] assert ( re.match(expected_pattern, agent) is not None ), f"The user agent '{agent}' does not match the expected pattern." def test_user_agent_custom( - self, sync_client_request_saving: Elasticsearch, index_name: str + self, sync_client_request_saving: Elasticsearch, index: str ) -> None: """Test to make sure the user-agent is set correctly.""" user_agent = "this is THE user_agent!" store = VectorStore( user_agent=user_agent, - index_name=index_name, + index=index, retrieval_strategy=BM25Strategy(), - es_client=sync_client_request_saving, + client=sync_client_request_saving, ) - assert store.es_client._headers["User-Agent"] == user_agent + assert store.client._headers["User-Agent"] == user_agent texts = ["foo", "bob", "baz"] store.add_texts(texts) - for request in store.es_client.transport.requests: # type: ignore + for request in store.client.transport.requests: # type: ignore assert request["headers"]["User-Agent"] == user_agent - def test_bulk_args(self, sync_client_request_saving: Any, index_name: str) -> None: + def test_bulk_args(self, sync_client_request_saving: Any, index: str) -> None: """Test to make sure the bulk arguments work as expected.""" store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=BM25Strategy(), - es_client=sync_client_request_saving, + client=sync_client_request_saving, ) texts = ["foo", "bob", "baz"] store.add_texts(texts, bulk_kwargs={"chunk_size": 1}) # 1 for index exist, 1 for index create, 3 to index docs - assert len(store.es_client.transport.requests) == 5 # type: ignore + assert len(store.client.transport.requests) == 5 # type: ignore def test_max_marginal_relevance_search( - self, sync_client: Elasticsearch, index_name: str + self, sync_client: Elasticsearch, index: str ) -> None: """Test max marginal relevance search.""" texts = ["foo", "bar", "baz"] @@ -833,28 +826,28 @@ def test_max_marginal_relevance_search( text_field = "text_field" embedding_service = ConsistentFakeEmbeddings() store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=DenseVectorScriptScoreStrategy(), embedding_service=embedding_service, vector_field=vector_field, text_field=text_field, - es_client=sync_client, + client=sync_client, ) store.add_texts(texts) mmr_output = store.max_marginal_relevance_search( - embedding_service, - texts[0], + embedding_service=embedding_service, + query=texts[0], vector_field=vector_field, k=3, num_candidates=3, ) - sim_output = store.search(texts[0], k=3) + sim_output = store.search(query=texts[0], k=3) assert mmr_output == sim_output mmr_output = store.max_marginal_relevance_search( - embedding_service, - texts[0], + embedding_service=embedding_service, + query=texts[0], vector_field=vector_field, k=2, num_candidates=3, @@ -864,8 +857,8 @@ def test_max_marginal_relevance_search( assert mmr_output[1]["_source"][text_field] == texts[1] mmr_output = store.max_marginal_relevance_search( - embedding_service, - texts[0], + embedding_service=embedding_service, + query=texts[0], vector_field=vector_field, k=2, num_candidates=3, @@ -877,28 +870,26 @@ def test_max_marginal_relevance_search( # if fetch_k < k, then the output will be less than k mmr_output = store.max_marginal_relevance_search( - embedding_service, - texts[0], + embedding_service=embedding_service, + query=texts[0], vector_field=vector_field, k=3, num_candidates=2, ) assert len(mmr_output) == 2 - def test_metadata_mapping( - self, sync_client: Elasticsearch, index_name: str - ) -> None: + def test_metadata_mapping(self, sync_client: Elasticsearch, index: str) -> None: """Test that the metadata mapping is applied.""" test_mappings = { "my_field": {"type": "keyword"}, "another_field": {"type": "text"}, } store = VectorStore( - index_name=index_name, + index=index, retrieval_strategy=DenseVectorStrategy(distance=DistanceMetric.COSINE), embedding_service=FakeEmbeddings(), num_dimensions=10, - es_client=sync_client, + client=sync_client, metadata_mappings=test_mappings, ) @@ -906,8 +897,8 @@ def test_metadata_mapping( metadatas = [{"my_field": str(i)} for i in range(len(texts))] store.add_texts(texts=texts, metadatas=metadatas) - mapping_response = sync_client.indices.get_mapping(index=index_name) - mapping_properties = mapping_response[index_name]["mappings"]["properties"] + mapping_response = sync_client.indices.get_mapping(index=index) + mapping_properties = mapping_response[index]["mappings"]["properties"] assert mapping_properties["vector_field"] == { "type": "dense_vector", "dims": 10, diff --git a/test_elasticsearch/test_strategies.py b/test_elasticsearch/test_strategies.py index 11eb21579..36ce63e9f 100644 --- a/test_elasticsearch/test_strategies.py +++ b/test_elasticsearch/test_strategies.py @@ -29,11 +29,27 @@ def test_sparse_vector_strategy_raises_errors(): with pytest.raises(ValueError): # missing query - strategy.es_query(None, None, "text_field", "vector_field", 10, 20, []) + strategy.es_query( + query=None, + query_vector=None, + text_field="text_field", + vector_field="vector_field", + k=10, + num_candidates=20, + filter=[], + ) with pytest.raises(ValueError): # query vector not allowed - strategy.es_query("hi", [1, 2, 3], "text_field", "vector_field", 10, 20, []) + strategy.es_query( + query="hi", + query_vector=[1, 2, 3], + text_field="text_field", + vector_field="vector_field", + k=10, + num_candidates=20, + filter=[], + ) def test_dense_vector_strategy_raises_error(): @@ -44,7 +60,7 @@ def test_dense_vector_strategy_raises_error(): with pytest.raises(ValueError): # unknown distance DenseVectorStrategy(distance="unknown distance").es_mappings_settings( - "text_field", "vector_field", 10 + text_field="text_field", vector_field="vector_field", num_dimensions=10 ) @@ -52,11 +68,23 @@ def test_dense_vector_script_score_strategy_raises_error(): with pytest.raises(ValueError): # missing query vector DenseVectorScriptScoreStrategy().es_query( - None, None, "text_field", "vector_field", 10, 20, [] + query=None, + query_vector=None, + text_field="text_field", + vector_field="vector_field", + k=10, + num_candidates=20, + filter=[], ) with pytest.raises(ValueError): # unknown distance DenseVectorScriptScoreStrategy(distance="unknown distance").es_query( - None, [1, 2, 3], "text_field", "vector_field", 10, 20, [] + query=None, + query_vector=[1, 2, 3], + text_field="text_field", + vector_field="vector_field", + k=10, + num_candidates=20, + filter=[], ) From a8d80f2a3f818f9b916fdbe63b1e50349897c9f5 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Tue, 30 Apr 2024 14:41:14 +0200 Subject: [PATCH 35/36] fix sparse vector strategy bug (duplicate `size`) --- elasticsearch/helpers/vectorstore/_async/strategies.py | 3 +-- elasticsearch/helpers/vectorstore/_sync/strategies.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/elasticsearch/helpers/vectorstore/_async/strategies.py b/elasticsearch/helpers/vectorstore/_async/strategies.py index 09690e24c..a7f813f43 100644 --- a/elasticsearch/helpers/vectorstore/_async/strategies.py +++ b/elasticsearch/helpers/vectorstore/_async/strategies.py @@ -137,8 +137,7 @@ def es_query( ], "filter": filter, } - }, - "size": k, + } } def es_mappings_settings( diff --git a/elasticsearch/helpers/vectorstore/_sync/strategies.py b/elasticsearch/helpers/vectorstore/_sync/strategies.py index 17b42beb1..928d34143 100644 --- a/elasticsearch/helpers/vectorstore/_sync/strategies.py +++ b/elasticsearch/helpers/vectorstore/_sync/strategies.py @@ -137,8 +137,7 @@ def es_query( ], "filter": filter, } - }, - "size": k, + } } def es_mappings_settings( From d27f9f831fe783d27cbcf7f6086e85387da3bce5 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Tue, 30 Apr 2024 14:45:48 +0200 Subject: [PATCH 36/36] all wildcard deletes in compose ES --- .../test_server/test_vectorstore/docker-compose.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/test_elasticsearch/test_server/test_vectorstore/docker-compose.yml b/test_elasticsearch/test_server/test_vectorstore/docker-compose.yml index b3aa43f18..b13520e06 100644 --- a/test_elasticsearch/test_server/test_vectorstore/docker-compose.yml +++ b/test_elasticsearch/test_server/test_vectorstore/docker-compose.yml @@ -4,6 +4,7 @@ services: elasticsearch: image: elasticsearch:8.13.0 environment: + - action.destructive_requires_name=false # allow wildcard index deletions - discovery.type=single-node - xpack.license.self_generated.type=trial - xpack.security.enabled=false # disable password and TLS; never do this in production!