From 9c3bccd4c4efdb5a71400cb7b8ce14bf8fc72303 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Wed, 27 Nov 2024 17:05:36 +0000 Subject: [PATCH] async support (#53) * async support * format * fixed unit test failures * sort imports * more code reformat * typing fixes * typing fixes * one more reformat * typing fixes * docs * use sed to change get_messages() to messages --- libs/elasticsearch/Makefile | 6 + .../langchain_elasticsearch/__init__.py | 34 +- .../langchain_elasticsearch/_async/cache.py | 427 ++++++ .../_async/chat_history.py | 166 ++ .../_async/embeddings.py | 254 ++++ .../_async/retrievers.py | 118 ++ .../_async/vectorstores.py | 860 +++++++++++ .../langchain_elasticsearch/_sync/cache.py | 425 ++++++ .../_sync/chat_history.py | 166 ++ .../_sync/embeddings.py | 254 ++++ .../_sync/retrievers.py | 118 ++ .../_sync/vectorstores.py | 858 +++++++++++ .../langchain_elasticsearch/_utilities.py | 488 +++++- .../langchain_elasticsearch/cache.py | 422 +----- .../langchain_elasticsearch/chat_history.py | 156 +- .../langchain_elasticsearch/client.py | 36 +- .../langchain_elasticsearch/embeddings.py | 270 +--- .../langchain_elasticsearch/retrievers.py | 134 +- .../langchain_elasticsearch/vectorstores.py | 1347 +---------------- libs/elasticsearch/pyproject.toml | 3 +- libs/elasticsearch/scripts/run_unasync.py | 149 ++ .../tests/_async/fake_embeddings.py | 49 + .../tests/_sync/fake_embeddings.py | 49 + libs/elasticsearch/tests/conftest.py | 68 +- libs/elasticsearch/tests/fake_embeddings.py | 55 +- .../integration_tests/_async/__init__.py | 0 .../_async/_test_utilities.py | 77 + .../integration_tests/_async/test_cache.py | 299 ++++ .../_async/test_chat_history.py | 72 + .../_async/test_embeddings.py | 54 + .../_async/test_retrievers.py | 239 +++ .../_async/test_vectorstores.py | 936 ++++++++++++ .../tests/integration_tests/_sync/__init__.py | 0 .../{ => _sync}/_test_utilities.py | 0 .../{ => _sync}/test_cache.py | 69 +- .../{ => _sync}/test_chat_history.py | 11 +- .../{ => _sync}/test_embeddings.py | 31 +- .../{ => _sync}/test_retrievers.py | 19 +- .../{ => _sync}/test_vectorstores.py | 42 +- .../tests/unit_tests/_async/__init__.py | 0 .../tests/unit_tests/_async/test_cache.py | 405 +++++ .../unit_tests/_async/test_vectorstores.py | 417 +++++ .../tests/unit_tests/_sync/__init__.py | 0 .../unit_tests/{ => _sync}/test_cache.py | 50 +- .../{ => _sync}/test_vectorstores.py | 22 +- .../tests/unit_tests/test_imports.py | 53 +- 46 files changed, 7334 insertions(+), 2374 deletions(-) create mode 100644 libs/elasticsearch/langchain_elasticsearch/_async/cache.py create mode 100644 libs/elasticsearch/langchain_elasticsearch/_async/chat_history.py create mode 100644 libs/elasticsearch/langchain_elasticsearch/_async/embeddings.py create mode 100644 libs/elasticsearch/langchain_elasticsearch/_async/retrievers.py create mode 100644 libs/elasticsearch/langchain_elasticsearch/_async/vectorstores.py create mode 100644 libs/elasticsearch/langchain_elasticsearch/_sync/cache.py create mode 100644 libs/elasticsearch/langchain_elasticsearch/_sync/chat_history.py create mode 100644 libs/elasticsearch/langchain_elasticsearch/_sync/embeddings.py create mode 100644 libs/elasticsearch/langchain_elasticsearch/_sync/retrievers.py create mode 100644 libs/elasticsearch/langchain_elasticsearch/_sync/vectorstores.py create mode 100644 libs/elasticsearch/scripts/run_unasync.py create mode 100644 libs/elasticsearch/tests/_async/fake_embeddings.py create mode 100644 libs/elasticsearch/tests/_sync/fake_embeddings.py create mode 100644 libs/elasticsearch/tests/integration_tests/_async/__init__.py create mode 100644 libs/elasticsearch/tests/integration_tests/_async/_test_utilities.py create mode 100644 libs/elasticsearch/tests/integration_tests/_async/test_cache.py create mode 100644 libs/elasticsearch/tests/integration_tests/_async/test_chat_history.py create mode 100644 libs/elasticsearch/tests/integration_tests/_async/test_embeddings.py create mode 100644 libs/elasticsearch/tests/integration_tests/_async/test_retrievers.py create mode 100644 libs/elasticsearch/tests/integration_tests/_async/test_vectorstores.py create mode 100644 libs/elasticsearch/tests/integration_tests/_sync/__init__.py rename libs/elasticsearch/tests/integration_tests/{ => _sync}/_test_utilities.py (100%) rename libs/elasticsearch/tests/integration_tests/{ => _sync}/test_cache.py (81%) rename libs/elasticsearch/tests/integration_tests/{ => _sync}/test_chat_history.py (87%) rename libs/elasticsearch/tests/integration_tests/{ => _sync}/test_embeddings.py (62%) rename libs/elasticsearch/tests/integration_tests/{ => _sync}/test_retrievers.py (94%) rename libs/elasticsearch/tests/integration_tests/{ => _sync}/test_vectorstores.py (97%) create mode 100644 libs/elasticsearch/tests/unit_tests/_async/__init__.py create mode 100644 libs/elasticsearch/tests/unit_tests/_async/test_cache.py create mode 100644 libs/elasticsearch/tests/unit_tests/_async/test_vectorstores.py create mode 100644 libs/elasticsearch/tests/unit_tests/_sync/__init__.py rename libs/elasticsearch/tests/unit_tests/{ => _sync}/test_cache.py (90%) rename libs/elasticsearch/tests/unit_tests/{ => _sync}/test_vectorstores.py (96%) diff --git a/libs/elasticsearch/Makefile b/libs/elasticsearch/Makefile index 9ada9f6..218bb83 100644 --- a/libs/elasticsearch/Makefile +++ b/libs/elasticsearch/Makefile @@ -46,6 +46,12 @@ spell_fix: check_imports: $(shell find langchain_elasticsearch -name '*.py') poetry run python ./scripts/check_imports.py $^ +run_unasync: + poetry run python ./scripts/run_unasync.py + +run_unasync_check: + poetry run python ./scripts/run_unasync.py --check + ###################### # HELP ###################### diff --git a/libs/elasticsearch/langchain_elasticsearch/__init__.py b/libs/elasticsearch/langchain_elasticsearch/__init__.py index 17611c4..45cfe70 100644 --- a/libs/elasticsearch/langchain_elasticsearch/__init__.py +++ b/libs/elasticsearch/langchain_elasticsearch/__init__.py @@ -1,4 +1,9 @@ from elasticsearch.helpers.vectorstore import ( + AsyncBM25Strategy, + AsyncDenseVectorScriptScoreStrategy, + AsyncDenseVectorStrategy, + AsyncRetrievalStrategy, + AsyncSparseVectorStrategy, BM25Strategy, DenseVectorScriptScoreStrategy, DenseVectorStrategy, @@ -8,14 +13,26 @@ ) from langchain_elasticsearch.cache import ( + AsyncElasticsearchCache, + AsyncElasticsearchEmbeddingsCache, ElasticsearchCache, ElasticsearchEmbeddingsCache, ) -from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory -from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings -from langchain_elasticsearch.retrievers import ElasticsearchRetriever +from langchain_elasticsearch.chat_history import ( + AsyncElasticsearchChatMessageHistory, + ElasticsearchChatMessageHistory, +) +from langchain_elasticsearch.embeddings import ( + AsyncElasticsearchEmbeddings, + ElasticsearchEmbeddings, +) +from langchain_elasticsearch.retrievers import ( + AsyncElasticsearchRetriever, + ElasticsearchRetriever, +) from langchain_elasticsearch.vectorstores import ( ApproxRetrievalStrategy, + AsyncElasticsearchStore, BM25RetrievalStrategy, ElasticsearchStore, ExactRetrievalStrategy, @@ -23,6 +40,12 @@ ) __all__ = [ + "AsyncElasticsearchCache", + "AsyncElasticsearchChatMessageHistory", + "AsyncElasticsearchEmbeddings", + "AsyncElasticsearchEmbeddingsCache", + "AsyncElasticsearchRetriever", + "AsyncElasticsearchStore", "ElasticsearchCache", "ElasticsearchChatMessageHistory", "ElasticsearchEmbeddings", @@ -30,6 +53,11 @@ "ElasticsearchRetriever", "ElasticsearchStore", # retrieval strategies + "AsyncBM25Strategy", + "AsyncDenseVectorScriptScoreStrategy", + "AsyncDenseVectorStrategy", + "AsyncRetrievalStrategy", + "AsyncSparseVectorStrategy", "BM25Strategy", "DenseVectorScriptScoreStrategy", "DenseVectorStrategy", diff --git a/libs/elasticsearch/langchain_elasticsearch/_async/cache.py b/libs/elasticsearch/langchain_elasticsearch/_async/cache.py new file mode 100644 index 0000000..56625ac --- /dev/null +++ b/libs/elasticsearch/langchain_elasticsearch/_async/cache.py @@ -0,0 +1,427 @@ +import base64 +import hashlib +import logging +from datetime import datetime +from functools import cached_property +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, +) + +from elasticsearch import AsyncElasticsearch, exceptions, helpers +from elasticsearch.helpers import BulkIndexError +from langchain_core.caches import RETURN_VAL_TYPE, BaseCache +from langchain_core.load import dumps, loads +from langchain_core.stores import ByteStore + +from langchain_elasticsearch.client import create_async_elasticsearch_client + +if TYPE_CHECKING: + from elasticsearch import AsyncElasticsearch + +logger = logging.getLogger(__name__) + + +async def _manage_cache_index( + es_client: AsyncElasticsearch, index_name: str, mapping: Dict[str, Any] +) -> bool: + """Write or update an index or alias according to the default mapping""" + if await es_client.indices.exists_alias(name=index_name): + await es_client.indices.put_mapping(index=index_name, body=mapping["mappings"]) + return True + + elif not await es_client.indices.exists(index=index_name): + logger.debug(f"Creating new Elasticsearch index: {index_name}") + await es_client.indices.create(index=index_name, body=mapping) + return False + + return False + + +class AsyncElasticsearchCache(BaseCache): + """An Elasticsearch cache integration for LLMs. + + For synchronous applications, use the ``ElasticsearchCache`` class. + For asyhchronous applications, use the ``AsyncElasticsearchCache`` class. + """ + + def __init__( + self, + index_name: str, + store_input: bool = True, + store_input_params: bool = True, + metadata: Optional[Dict[str, Any]] = None, + *, + es_url: Optional[str] = None, + es_cloud_id: Optional[str] = None, + es_user: Optional[str] = None, + es_api_key: Optional[str] = None, + es_password: Optional[str] = None, + es_params: Optional[Dict[str, Any]] = None, + ): + """ + Initialize the Elasticsearch cache store by specifying the index/alias + to use and determining which additional information (like input, input + parameters, and any other metadata) should be stored in the cache. + + Args: + index_name (str): The name of the index or the alias to use for the cache. + If they do not exist an index is created, + according to the default mapping defined by the `mapping` property. + store_input (bool): Whether to store the LLM input in the cache, i.e., + the input prompt. Default to True. + store_input_params (bool): Whether to store the input parameters in the + cache, i.e., the LLM parameters used to generate the LLM response. + Default to True. + metadata (Optional[dict]): Additional metadata to store in the cache, + for filtering purposes. This must be JSON serializable in an + Elasticsearch document. Default to None. + es_url: URL of the Elasticsearch instance to connect to. + es_cloud_id: Cloud ID of the Elasticsearch instance to connect to. + es_user: Username to use when connecting to Elasticsearch. + es_password: Password to use when connecting to Elasticsearch. + es_api_key: API key to use when connecting to Elasticsearch. + es_params: Other parameters for the Elasticsearch client. + """ + + self._index_name = index_name + self._store_input = store_input + self._store_input_params = store_input_params + self._metadata = metadata + self._es_client = create_async_elasticsearch_client( + url=es_url, + cloud_id=es_cloud_id, + api_key=es_api_key, + username=es_user, + password=es_password, + params=es_params, + ) + self._is_alias: Optional[bool] = None + + async def is_alias(self) -> bool: + if self._is_alias is None: + self._is_alias = await _manage_cache_index( + self._es_client, + self._index_name, + self.mapping, + ) + return self._is_alias # type: ignore[return-value] + + @cached_property + def mapping(self) -> Dict[str, Any]: + """Get the default mapping for the index.""" + return { + "mappings": { + "properties": { + "llm_output": {"type": "text", "index": False}, + "llm_params": {"type": "text", "index": False}, + "llm_input": {"type": "text", "index": False}, + "metadata": {"type": "object"}, + "timestamp": {"type": "date"}, + } + } + } + + @staticmethod + def _key(prompt: str, llm_string: str) -> str: + """Generate a key for the cache store.""" + return hashlib.md5((prompt + llm_string).encode()).hexdigest() + + async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + cache_key = self._key(prompt, llm_string) + if await self.is_alias(): + # get the latest record according to its writing date, in order to + # address cases where multiple indices have a doc with the same id + result = await self._es_client.search( + index=self._index_name, + body={ + "query": {"term": {"_id": cache_key}}, + "sort": {"timestamp": {"order": "asc"}}, + }, + source_includes=["llm_output"], + ) + if result["hits"]["total"]["value"] > 0: + record = result["hits"]["hits"][0] + else: + return None + else: + try: + record = await self._es_client.get( + index=self._index_name, id=cache_key, source=["llm_output"] + ) + except exceptions.NotFoundError: + return None + return [loads(item) for item in record["_source"]["llm_output"]] + + def build_document( + self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE + ) -> Dict[str, Any]: + """Build the Elasticsearch document for storing a single LLM interaction""" + body: Dict[str, Any] = { + "llm_output": [dumps(item) for item in return_val], + "timestamp": datetime.now().isoformat(), + } + if self._store_input_params: + body["llm_params"] = llm_string + if self._metadata is not None: + body["metadata"] = self._metadata + if self._store_input: + body["llm_input"] = prompt + return body + + async def aupdate( + self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE + ) -> None: + """Update based on prompt and llm_string.""" + body = self.build_document(prompt, llm_string, return_val) + await self._es_client.index( + index=self._index_name, + id=self._key(prompt, llm_string), + body=body, + require_alias=await self.is_alias(), + refresh=True, + ) + + async def aclear(self, **kwargs: Any) -> None: + """Clear cache.""" + await self._es_client.delete_by_query( + index=self._index_name, + body={"query": {"match_all": {}}}, + refresh=True, + wait_for_completion=True, + ) + + +class AsyncElasticsearchEmbeddingsCache(ByteStore): + """An Elasticsearch store for caching embeddings. + + For synchronous applications, use the `ElasticsearchEmbeddingsCache` class. + For asyhchronous applications, use the `AsyncElasticsearchEmbeddingsCache` class. + """ + + def __init__( + self, + index_name: str, + store_input: bool = True, + metadata: Optional[Dict[str, Any]] = None, + namespace: Optional[str] = None, + maximum_duplicates_allowed: int = 1, + *, + es_url: Optional[str] = None, + es_cloud_id: Optional[str] = None, + es_user: Optional[str] = None, + es_api_key: Optional[str] = None, + es_password: Optional[str] = None, + es_params: Optional[Dict[str, Any]] = None, + ): + """ + Initialize the Elasticsearch cache store by specifying the index/alias + to use and determining which additional information (like input, input + parameters, and any other metadata) should be stored in the cache. + Provide a namespace to organize the cache. + + + Args: + index_name (str): The name of the index or the alias to use for the cache. + If they do not exist an index is created, + according to the default mapping defined by the `mapping` property. + store_input (bool): Whether to store the input in the cache. + Default to True. + metadata (Optional[dict]): Additional metadata to store in the cache, + for filtering purposes. This must be JSON serializable in an + Elasticsearch document. Default to None. + namespace (Optional[str]): A namespace to use for the cache. + maximum_duplicates_allowed (int): Defines the maximum number of duplicate + keys permitted. Must be used in scenarios where the same key appears + across multiple indices that share the same alias. Default to 1. + es_url: URL of the Elasticsearch instance to connect to. + es_cloud_id: Cloud ID of the Elasticsearch instance to connect to. + es_user: Username to use when connecting to Elasticsearch. + es_password: Password to use when connecting to Elasticsearch. + es_api_key: API key to use when connecting to Elasticsearch. + es_params: Other parameters for the Elasticsearch client. + """ + self._namespace = namespace + self._maximum_duplicates_allowed = maximum_duplicates_allowed + self._index_name = index_name + self._store_input = store_input + self._metadata = metadata + self._es_client = create_async_elasticsearch_client( + url=es_url, + cloud_id=es_cloud_id, + api_key=es_api_key, + username=es_user, + password=es_password, + params=es_params, + ) + self._is_alias: Optional[bool] = None + + async def is_alias(self) -> bool: + if self._is_alias is None: + self._is_alias = await _manage_cache_index( + self._es_client, + self._index_name, + self.mapping, + ) + return self._is_alias # type: ignore[return-value] + + @staticmethod + def encode_vector(data: bytes) -> str: + """Encode the vector data as bytes to as a base64 string.""" + return base64.b64encode(data).decode("utf-8") + + @staticmethod + def decode_vector(data: str) -> bytes: + """Decode the base64 string to vector data as bytes.""" + return base64.b64decode(data) + + @cached_property + def mapping(self) -> Dict[str, Any]: + """Get the default mapping for the index.""" + return { + "mappings": { + "properties": { + "text_input": {"type": "text", "index": False}, + "vector_dump": { + "type": "binary", + "doc_values": False, + }, + "metadata": {"type": "object"}, + "timestamp": {"type": "date"}, + "namespace": {"type": "keyword"}, + } + } + } + + def _key(self, input_text: str) -> str: + """Generate a key for the store.""" + return hashlib.md5(((self._namespace or "") + input_text).encode()).hexdigest() + + @classmethod + def _deduplicate_hits(cls, hits: List[dict]) -> Dict[str, bytes]: + """ + Collapse the results from a search query with multiple indices + returning only the latest version of the documents + """ + map_ids = {} + for hit in sorted( + hits, + key=lambda x: datetime.fromisoformat(x["_source"]["timestamp"]), + reverse=True, + ): + vector_id: str = hit["_id"] + if vector_id not in map_ids: + map_ids[vector_id] = cls.decode_vector(hit["_source"]["vector_dump"]) + + return map_ids + + async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]: + """Get the values associated with the given keys.""" + if not any(keys): + return [] + + cache_keys = [self._key(k) for k in keys] + if await self.is_alias(): + try: + results = await self._es_client.search( + index=self._index_name, + body={ + "query": {"ids": {"values": cache_keys}}, + "size": len(cache_keys) * self._maximum_duplicates_allowed, + }, + source_includes=["vector_dump", "timestamp"], + ) + + except exceptions.BadRequestError as e: + if "window too large" in ( + e.body.get("error", {}).get("root_cause", [{}])[0].get("reason", "") + ): + logger.warning( + "Exceeded the maximum window size, " + "Reduce the duplicates manually or lower " + "`maximum_duplicate_allowed.`" + ) + raise e + + total_hits = results["hits"]["total"]["value"] + if self._maximum_duplicates_allowed > 1 and total_hits > len(cache_keys): + logger.warning( + f"Deduplicating, found {total_hits} hits for {len(cache_keys)} keys" + ) + map_ids = self._deduplicate_hits(results["hits"]["hits"]) + else: + map_ids = { + r["_id"]: self.decode_vector(r["_source"]["vector_dump"]) + for r in results["hits"]["hits"] + } + + return [map_ids.get(k) for k in cache_keys] + + else: + records = await self._es_client.mget( + index=self._index_name, ids=cache_keys, source_includes=["vector_dump"] + ) + return [ + self.decode_vector(r["_source"]["vector_dump"]) if r["found"] else None + for r in records["docs"] + ] + + def build_document(self, text_input: str, vector: bytes) -> Dict[str, Any]: + """Build the Elasticsearch document for storing a single embedding""" + body: Dict[str, Any] = { + "vector_dump": self.encode_vector(vector), + "timestamp": datetime.now().isoformat(), + } + if self._metadata is not None: + body["metadata"] = self._metadata + if self._store_input: + body["text_input"] = text_input + if self._namespace: + body["namespace"] = self._namespace + return body + + async def _bulk(self, actions: Iterable[Dict[str, Any]]) -> None: + try: + await helpers.async_bulk( + client=self._es_client, + actions=actions, + index=self._index_name, + require_alias=await self.is_alias(), + refresh=True, + ) + except BulkIndexError as e: + first_error = e.errors[0].get("index", {}).get("error", {}) + logger.error(f"First bulk error reason: {first_error.get('reason')}") + raise e + + async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: + """Set the values for the given keys.""" + actions = ( + { + "_op_type": "index", + "_id": self._key(key), + "_source": self.build_document(key, vector), + } + for key, vector in key_value_pairs + ) + await self._bulk(actions) + + async def amdelete(self, keys: Sequence[str]) -> None: + """Delete the given keys and their associated values.""" + actions = ({"_op_type": "delete", "_id": self._key(key)} for key in keys) + await self._bulk(actions) + + async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]: # type: ignore[override] + """Get an iterator over keys that match the given prefix.""" + # TODO This method is not currently used by CacheBackedEmbeddings, + # we can leave it blank. It could be implemented with ES "index_prefixes", + # but they are limited and expensive. + raise NotImplementedError() diff --git a/libs/elasticsearch/langchain_elasticsearch/_async/chat_history.py b/libs/elasticsearch/langchain_elasticsearch/_async/chat_history.py new file mode 100644 index 0000000..fb98458 --- /dev/null +++ b/libs/elasticsearch/langchain_elasticsearch/_async/chat_history.py @@ -0,0 +1,166 @@ +import json +import logging +from time import time +from typing import TYPE_CHECKING, List, Optional, Sequence + +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict + +from langchain_elasticsearch._utilities import async_with_user_agent_header +from langchain_elasticsearch.client import create_async_elasticsearch_client + +if TYPE_CHECKING: + from elasticsearch import AsyncElasticsearch + +logger = logging.getLogger(__name__) + + +class AsyncElasticsearchChatMessageHistory(BaseChatMessageHistory): + """Chat message history that stores history in Elasticsearch. + + Args: + es_url: URL of the Elasticsearch instance to connect to. + es_cloud_id: Cloud ID of the Elasticsearch instance to connect to. + es_user: Username to use when connecting to Elasticsearch. + es_password: Password to use when connecting to Elasticsearch. + es_api_key: API key to use when connecting to Elasticsearch. + es_connection: Optional pre-existing Elasticsearch connection. + esnsure_ascii: Used to escape ASCII symbols in json.dumps. Defaults to True. + index: Name of the index to use. + session_id: Arbitrary key that is used to store the messages + of a single chat session. + + For synchronous applications, use the `ElasticsearchChatMessageHistory` class. + For asyhchronous applications, use the `AsyncElasticsearchChatMessageHistory` class. + """ + + def __init__( + self, + index: str, + session_id: str, + *, + es_connection: Optional["AsyncElasticsearch"] = None, + es_url: Optional[str] = None, + es_cloud_id: Optional[str] = None, + es_user: Optional[str] = None, + es_api_key: Optional[str] = None, + es_password: Optional[str] = None, + esnsure_ascii: Optional[bool] = True, + ): + self.index: str = index + self.session_id: str = session_id + self.ensure_ascii = esnsure_ascii + + # Initialize Elasticsearch client from passed client arg or connection info + if es_connection is not None: + self.client = es_connection + elif es_url is not None or es_cloud_id is not None: + try: + self.client = create_async_elasticsearch_client( + url=es_url, + username=es_user, + password=es_password, + cloud_id=es_cloud_id, + api_key=es_api_key, + ) + except Exception as err: + logger.error(f"Error connecting to Elasticsearch: {err}") + raise err + else: + raise ValueError( + """Either provide a pre-existing Elasticsearch connection, \ + or valid credentials for creating a new connection.""" + ) + + self.client = async_with_user_agent_header(self.client, "langchain-py-ms") + self.created = False + + async def create_if_missing(self) -> None: + if not self.created: + if await self.client.indices.exists(index=self.index): + logger.debug( + ( + f"Chat history index {self.index} already exists, " + "skipping creation." + ) + ) + else: + logger.debug(f"Creating index {self.index} for storing chat history.") + + await self.client.indices.create( + index=self.index, + mappings={ + "properties": { + "session_id": {"type": "keyword"}, + "created_at": {"type": "date"}, + "history": {"type": "text"}, + } + }, + ) + self.created = True + + async def aget_messages(self) -> List[BaseMessage]: # type: ignore[override] + """Retrieve the messages from Elasticsearch""" + try: + from elasticsearch import ApiError + + await self.create_if_missing() + result = await self.client.search( + index=self.index, + query={"term": {"session_id": self.session_id}}, + sort="created_at:asc", + ) + except ApiError as err: + logger.error(f"Could not retrieve messages from Elasticsearch: {err}") + raise err + + if result and len(result["hits"]["hits"]) > 0: + items = [ + json.loads(document["_source"]["history"]) + for document in result["hits"]["hits"] + ] + else: + items = [] + + return messages_from_dict(items) + + async def aadd_message(self, message: BaseMessage) -> None: + """Add messages to the chat session in Elasticsearch""" + try: + from elasticsearch import ApiError + + await self.create_if_missing() + await self.client.index( + index=self.index, + document={ + "session_id": self.session_id, + "created_at": round(time() * 1000), + "history": json.dumps( + message_to_dict(message), + ensure_ascii=bool(self.ensure_ascii), + ), + }, + refresh=True, + ) + except ApiError as err: + logger.error(f"Could not add message to Elasticsearch: {err}") + raise err + + async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: + for message in messages: + await self.aadd_message(message) + + async def aclear(self) -> None: + """Clear session memory in Elasticsearch""" + try: + from elasticsearch import ApiError + + await self.create_if_missing() + await self.client.delete_by_query( + index=self.index, + query={"term": {"session_id": self.session_id}}, + refresh=True, + ) + except ApiError as err: + logger.error(f"Could not clear session memory in Elasticsearch: {err}") + raise err diff --git a/libs/elasticsearch/langchain_elasticsearch/_async/embeddings.py b/libs/elasticsearch/langchain_elasticsearch/_async/embeddings.py new file mode 100644 index 0000000..c6d1b7e --- /dev/null +++ b/libs/elasticsearch/langchain_elasticsearch/_async/embeddings.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional + +from elasticsearch import AsyncElasticsearch +from elasticsearch.helpers.vectorstore import AsyncEmbeddingService +from langchain_core.embeddings import Embeddings +from langchain_core.utils import get_from_env + +if TYPE_CHECKING: + from elasticsearch._async.client.ml import MlClient + + +class AsyncElasticsearchEmbeddings(Embeddings): + """Elasticsearch embedding models. + + This class provides an interface to generate embeddings using a model deployed + in an Elasticsearch cluster. It requires an Elasticsearch connection object + and the model_id of the model deployed in the cluster. + + In Elasticsearch you need to have an embedding model loaded and deployed. + - https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html + - https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html + + For synchronous applications, use the `ElasticsearchEmbeddings` class. + For asyhchronous applications, use the `AsyncElasticsearchEmbeddings` class. + """ # noqa: E501 + + def __init__( + self, + client: MlClient, + model_id: str, + *, + input_field: str = "text_field", + ): + """ + Initialize the ElasticsearchEmbeddings instance. + + Args: + client (MlClient): An Elasticsearch ML client object. + model_id (str): The model_id of the model deployed in the Elasticsearch + cluster. + input_field (str): The name of the key for the input text field in the + document. Defaults to 'text_field'. + """ + self.client = client + self.model_id = model_id + self.input_field = input_field + + @classmethod + def from_credentials( + cls, + model_id: str, + *, + es_cloud_id: Optional[str] = None, + es_api_key: Optional[str] = None, + input_field: str = "text_field", + ) -> AsyncElasticsearchEmbeddings: + """Instantiate embeddings from Elasticsearch credentials. + + Args: + model_id (str): The model_id of the model deployed in the Elasticsearch + cluster. + input_field (str): The name of the key for the input text field in the + document. Defaults to 'text_field'. + es_cloud_id: (str, optional): The Elasticsearch cloud ID to connect to. + es_user: (str, optional): Elasticsearch username. + es_password: (str, optional): Elasticsearch password. + + Example: + .. code-block:: python + + from langchain_elasticserach.embeddings import ElasticsearchEmbeddings + + # Define the model ID and input field name (if different from default) + model_id = "your_model_id" + # Optional, only if different from 'text_field' + input_field = "your_input_field" + + # Credentials can be passed in two ways. Either set the env vars + # ES_CLOUD_ID, ES_USER, ES_PASSWORD and they will be automatically + # pulled in, or pass them in directly as kwargs. + embeddings = ElasticsearchEmbeddings.from_credentials( + model_id, + input_field=input_field, + # es_cloud_id="foo", + # es_user="bar", + # es_password="baz", + ) + + documents = [ + "This is an example document.", + "Another example document to generate embeddings for.", + ] + embeddings_generator.embed_documents(documents) + """ + from elasticsearch._async.client.ml import MlClient + + es_cloud_id = es_cloud_id or get_from_env("es_cloud_id", "ES_CLOUD_ID") + es_api_key = es_api_key or get_from_env("es_api_key", "ES_API_KEY") + + # Connect to Elasticsearch + es_connection = AsyncElasticsearch(cloud_id=es_cloud_id, api_key=es_api_key) + client = MlClient(es_connection) + return cls(client, model_id, input_field=input_field) + + @classmethod + def from_es_connection( + cls, + model_id: str, + es_connection: AsyncElasticsearch, + input_field: str = "text_field", + ) -> AsyncElasticsearchEmbeddings: + """ + Instantiate embeddings from an existing Elasticsearch connection. + + This method provides a way to create an instance of the ElasticsearchEmbeddings + class using an existing Elasticsearch connection. The connection object is used + to create an MlClient, which is then used to initialize the + ElasticsearchEmbeddings instance. + + Args: + model_id (str): The model_id of the model deployed in the Elasticsearch cluster. + es_connection (elasticsearch.Elasticsearch): An existing Elasticsearch + connection object. input_field (str, optional): The name of the key for the + input text field in the document. Defaults to 'text_field'. + + Returns: + ElasticsearchEmbeddings: An instance of the ElasticsearchEmbeddings class. + + Example: + .. code-block:: python + + from elasticsearch import Elasticsearch + + from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings + + # Define the model ID and input field name (if different from default) + model_id = "your_model_id" + # Optional, only if different from 'text_field' + input_field = "your_input_field" + + # Create Elasticsearch connection + es_connection = Elasticsearch( + hosts=["localhost:9200"], http_auth=("user", "password") + ) + + # Instantiate ElasticsearchEmbeddings using the existing connection + embeddings = ElasticsearchEmbeddings.from_es_connection( + model_id, + es_connection, + input_field=input_field, + ) + + documents = [ + "This is an example document.", + "Another example document to generate embeddings for.", + ] + embeddings_generator.embed_documents(documents) + """ + from elasticsearch._async.client.ml import MlClient + + # Create an MlClient from the given Elasticsearch connection + client = MlClient(es_connection) + + # Return a new instance of the ElasticsearchEmbeddings class with + # the MlClient, model_id, and input_field + return cls(client, model_id, input_field=input_field) + + async def _embedding_func(self, texts: List[str]) -> List[List[float]]: + """ + Generate embeddings for the given texts using the Elasticsearch model. + + Args: + texts (List[str]): A list of text strings to generate embeddings for. + + Returns: + List[List[float]]: A list of embeddings, one for each text in the input + list. + """ + response = await self.client.infer_trained_model( + model_id=self.model_id, docs=[{self.input_field: text} for text in texts] + ) + + embeddings = [doc["predicted_value"] for doc in response["inference_results"]] + return embeddings + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """ + Generate embeddings for a list of documents. + + Args: + texts (List[str]): A list of document text strings to generate embeddings + for. + + Returns: + List[List[float]]: A list of embeddings, one for each document in the input + list. + """ + return await self._embedding_func(texts) + + async def aembed_query(self, text: str) -> List[float]: + """ + Generate an embedding for a single query text. + + Args: + text (str): The query text to generate an embedding for. + + Returns: + List[float]: The embedding for the input query text. + """ + return (await self._embedding_func([text]))[0] + + +class AsyncEmbeddingServiceAdapter(AsyncEmbeddingService): + """ + Adapter for LangChain Embeddings to support the EmbeddingService interface from + elasticsearch.helpers.vectorstore. + """ + + def __init__(self, langchain_embeddings: Embeddings): + self._langchain_embeddings = langchain_embeddings + + def __eq__(self, other): # type: ignore[no-untyped-def] + if isinstance(other, self.__class__): + return self.__dict__ == other.__dict__ + else: + return False + + async def embed_documents(self, texts: List[str]) -> List[List[float]]: + """ + Generate embeddings for a list of documents. + + Args: + texts (List[str]): A list of document text strings to generate embeddings + for. + + Returns: + List[List[float]]: A list of embeddings, one for each document in the input + list. + """ + return await self._langchain_embeddings.aembed_documents(texts) + + async def embed_query(self, text: str) -> List[float]: + """ + Generate an embedding for a single query text. + + Args: + text (str): The query text to generate an embedding for. + + Returns: + List[float]: The embedding for the input query text. + """ + return await self._langchain_embeddings.aembed_query(text) diff --git a/libs/elasticsearch/langchain_elasticsearch/_async/retrievers.py b/libs/elasticsearch/langchain_elasticsearch/_async/retrievers.py new file mode 100644 index 0000000..487ee10 --- /dev/null +++ b/libs/elasticsearch/langchain_elasticsearch/_async/retrievers.py @@ -0,0 +1,118 @@ +import logging +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast + +from elasticsearch import AsyncElasticsearch +from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun +from langchain_core.documents import Document +from langchain_core.retrievers import BaseRetriever + +from langchain_elasticsearch._utilities import async_with_user_agent_header +from langchain_elasticsearch.client import create_async_elasticsearch_client + +logger = logging.getLogger(__name__) + + +class AsyncElasticsearchRetriever(BaseRetriever): + """ + Elasticsearch retriever + + Args: + es_client: Elasticsearch client connection. Alternatively you can use the + `from_es_params` method with parameters to initialize the client. + index_name: The name of the index to query. Can also be a list of names. + body_func: Function to create an Elasticsearch DSL query body from a search + string. The returned query body must fit what you would normally send in a + POST request the the _search endpoint. If applicable, it also includes + parameters the `size` parameter etc. + content_field: The document field name that contains the page content. If + multiple indices are queried, specify a dict {index_name: field_name} here. + document_mapper: Function to map Elasticsearch hits to LangChain Documents. + + For synchronous applications, use the ``ElasticsearchRetriever`` class. + For asyhchronous applications, use the ``AsyncElasticsearchRetriever`` class. + """ + + es_client: AsyncElasticsearch + index_name: Union[str, Sequence[str]] + body_func: Callable[[str], Dict] + content_field: Optional[Union[str, Mapping[str, str]]] = None + document_mapper: Optional[Callable[[Mapping], Document]] = None + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + if self.content_field is None and self.document_mapper is None: + raise ValueError("One of content_field or document_mapper must be defined.") + if self.content_field is not None and self.document_mapper is not None: + raise ValueError( + "Both content_field and document_mapper are defined. " + "Please provide only one." + ) + + if not self.document_mapper: + if isinstance(self.content_field, str): + self.document_mapper = self._single_field_mapper + elif isinstance(self.content_field, Mapping): + self.document_mapper = self._multi_field_mapper + else: + raise ValueError( + "unknown type for content_field, expected string or dict." + ) + + self.es_client = async_with_user_agent_header(self.es_client, "langchain-py-r") + + @classmethod + def from_es_params( + cls, + index_name: Union[str, Sequence[str]], + body_func: Callable[[str], Dict], + content_field: Optional[Union[str, Mapping[str, str]]] = None, + document_mapper: Optional[Callable[[Mapping], Document]] = None, + url: Optional[str] = None, + cloud_id: Optional[str] = None, + api_key: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + params: Optional[Dict[str, Any]] = None, + ) -> "AsyncElasticsearchRetriever": + client = None + try: + client = create_async_elasticsearch_client( + url=url, + cloud_id=cloud_id, + api_key=api_key, + username=username, + password=password, + params=params, + ) + except Exception as err: + logger.error(f"Error connecting to Elasticsearch: {err}") + raise err + + return cls( + es_client=client, + index_name=index_name, + body_func=body_func, + content_field=content_field, + document_mapper=document_mapper, + ) + + async def _aget_relevant_documents( + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + ) -> List[Document]: + if not self.es_client or not self.document_mapper: + raise ValueError("faulty configuration") # should not happen + + body = self.body_func(query) + results = await self.es_client.search(index=self.index_name, body=body) + return [self.document_mapper(hit) for hit in results["hits"]["hits"]] + + def _single_field_mapper(self, hit: Mapping[str, Any]) -> Document: + content = hit["_source"].pop(self.content_field) + return Document(page_content=content, metadata=hit) + + def _multi_field_mapper(self, hit: Mapping[str, Any]) -> Document: + self.content_field = cast(Mapping, self.content_field) + field = self.content_field[hit["_index"]] + content = hit["_source"].pop(field) + return Document(page_content=content, metadata=hit) diff --git a/libs/elasticsearch/langchain_elasticsearch/_async/vectorstores.py b/libs/elasticsearch/langchain_elasticsearch/_async/vectorstores.py new file mode 100644 index 0000000..b07a34a --- /dev/null +++ b/libs/elasticsearch/langchain_elasticsearch/_async/vectorstores.py @@ -0,0 +1,860 @@ +import logging +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union + +from elasticsearch import AsyncElasticsearch +from elasticsearch.helpers.vectorstore import ( + AsyncBM25Strategy, + AsyncDenseVectorScriptScoreStrategy, + AsyncDenseVectorStrategy, + AsyncRetrievalStrategy, + AsyncSparseVectorStrategy, + DistanceMetric, +) +from elasticsearch.helpers.vectorstore import AsyncVectorStore as EVectorStore +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.vectorstores import VectorStore + +from langchain_elasticsearch._utilities import ( + ApproxRetrievalStrategy, + BaseRetrievalStrategy, + BM25RetrievalStrategy, + DistanceStrategy, + ExactRetrievalStrategy, + SparseRetrievalStrategy, + _hits_to_docs_scores, + user_agent, +) +from langchain_elasticsearch.client import create_async_elasticsearch_client +from langchain_elasticsearch.embeddings import AsyncEmbeddingServiceAdapter + +logger = logging.getLogger(__name__) + + +def _convert_retrieval_strategy( + langchain_strategy: BaseRetrievalStrategy, + distance: Optional[DistanceStrategy] = None, +) -> AsyncRetrievalStrategy: + if isinstance(langchain_strategy, ApproxRetrievalStrategy): + if distance is None: + raise ValueError( + "ApproxRetrievalStrategy requires a distance strategy to be provided." + ) + return AsyncDenseVectorStrategy( + distance=DistanceMetric[distance], + model_id=langchain_strategy.query_model_id, + hybrid=( + False + if langchain_strategy.hybrid is None + else langchain_strategy.hybrid + ), + rrf=False if langchain_strategy.rrf is None else langchain_strategy.rrf, + ) + elif isinstance(langchain_strategy, ExactRetrievalStrategy): + if distance is None: + raise ValueError( + "ExactRetrievalStrategy requires a distance strategy to be provided." + ) + return AsyncDenseVectorScriptScoreStrategy(distance=DistanceMetric[distance]) + elif isinstance(langchain_strategy, SparseRetrievalStrategy): + return AsyncSparseVectorStrategy(langchain_strategy.model_id) + elif isinstance(langchain_strategy, BM25RetrievalStrategy): + return AsyncBM25Strategy(k1=langchain_strategy.k1, b=langchain_strategy.b) + else: + raise TypeError( + f"Strategy {langchain_strategy} not supported. To provide a " + f"custom strategy, please subclass {AsyncRetrievalStrategy}." + ) + + +class AsyncElasticsearchStore(VectorStore): + """`Elasticsearch` vector store. + + Setup: + Install ``langchain_elasticsearch`` and running the Elasticsearch docker container. + + .. code-block:: bash + + pip install -qU langchain_elasticsearch + docker run -p 9200:9200 \ + -e "discovery.type=single-node" \ + -e "xpack.security.enabled=false" \ + -e "xpack.security.http.ssl.enabled=false" \ + docker.elastic.co/elasticsearch/elasticsearch:8.12.1 + + Key init args — indexing params: + index_name: str + Name of the index to create. + embedding: Embeddings + Embedding function to use. + + Key init args — client params: + es_connection: Optional[Elasticsearch] + Pre-existing Elasticsearch connection. + es_url: Optional[str] + URL of the Elasticsearch instance to connect to. + es_cloud_id: Optional[str] + Cloud ID of the Elasticsearch instance to connect to. + es_user: Optional[str] + Username to use when connecting to Elasticsearch. + es_password: Optional[str] + Password to use when connecting to Elasticsearch. + es_api_key: Optional[str] + API key to use when connecting to Elasticsearch. + + Instantiate: + .. code-block:: python + + from langchain_elasticsearch import ElasticsearchStore + from langchain_openai import OpenAIEmbeddings + + vector_store = ElasticsearchStore( + index_name="langchain-demo", + embedding=OpenAIEmbeddings(), + es_url="http://localhost:9200", + ) + + If you want to use a cloud hosted Elasticsearch instance, you can pass in the + cloud_id argument instead of the es_url argument. + + Instantiate from cloud: + .. code-block:: python + + from langchain_elasticsearch.vectorstores import ElasticsearchStore + from langchain_openai import OpenAIEmbeddings + + store = ElasticsearchStore( + embedding=OpenAIEmbeddings(), + index_name="langchain-demo", + es_cloud_id="" + es_user="elastic", + es_password="" + ) + + You can also connect to an existing Elasticsearch instance by passing in a + pre-existing Elasticsearch connection via the es_connection argument. + + Instantiate from existing connection: + .. code-block:: python + + from langchain_elasticsearch.vectorstores import ElasticsearchStore + from langchain_openai import OpenAIEmbeddings + + from elasticsearch import Elasticsearch + + es_connection = Elasticsearch("http://localhost:9200") + + store = ElasticsearchStore( + embedding=OpenAIEmbeddings(), + index_name="langchain-demo", + es_connection=es_connection + ) + + Add Documents: + .. code-block:: python + + from langchain_core.documents import Document + + document_1 = Document(page_content="foo", metadata={"baz": "bar"}) + document_2 = Document(page_content="thud", metadata={"bar": "baz"}) + document_3 = Document(page_content="i will be deleted :(") + + documents = [document_1, document_2, document_3] + ids = ["1", "2", "3"] + vector_store.add_documents(documents=documents, ids=ids) + + Delete Documents: + .. code-block:: python + + vector_store.delete(ids=["3"]) + + Search: + .. code-block:: python + + results = vector_store.similarity_search(query="thud",k=1) + for doc in results: + print(f"* {doc.page_content} [{doc.metadata}]") + + .. code-block:: python + + * thud [{'bar': 'baz'}] + + Search with filter: + .. code-block:: python + + results = vector_store.similarity_search(query="thud",k=1,filter=[{"term": {"metadata.bar.keyword": "baz"}}]) + for doc in results: + print(f"* {doc.page_content} [{doc.metadata}]") + + .. code-block:: python + + * thud [{'bar': 'baz'}] + + Search with score: + .. code-block:: python + + results = vector_store.similarity_search_with_score(query="qux",k=1) + for doc, score in results: + print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]") + + .. code-block:: python + + * [SIM=0.916092] foo [{'baz': 'bar'}] + + Async: + .. code-block:: python + + from langchain_elasticsearch import AsyncElasticsearchStore + + vector_store = AsyncElasticsearchStore(...) + + # add documents + await vector_store.aadd_documents(documents=documents, ids=ids) + + # delete documents + await vector_store.adelete(ids=["3"]) + + # search + results = vector_store.asimilarity_search(query="thud",k=1) + + # search with score + results = await vector_store.asimilarity_search_with_score(query="qux",k=1) + for doc,score in results: + print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]") + + .. code-block:: python + + * [SIM=0.916092] foo [{'baz': 'bar'}] + + Use as Retriever: + + .. code-block:: bash + + pip install "elasticsearch[vectorstore_mmr]" + + .. code-block:: python + + retriever = vector_store.as_retriever( + search_type="mmr", + search_kwargs={"k": 1, "fetch_k": 2, "lambda_mult": 0.5}, + ) + retriever.invoke("thud") + + .. code-block:: python + + [Document(metadata={'bar': 'baz'}, page_content='thud')] + + **Advanced Uses:** + + ElasticsearchStore by default uses the ApproxRetrievalStrategy, which uses the + HNSW algorithm to perform approximate nearest neighbor search. This is the + fastest and most memory efficient algorithm. + + If you want to use the Brute force / Exact strategy for searching vectors, you + can pass in the ExactRetrievalStrategy to the ElasticsearchStore constructor. + + Use ExactRetrievalStrategy: + .. code-block:: python + + from langchain_elasticsearch.vectorstores import ElasticsearchStore + from langchain_openai import OpenAIEmbeddings + + store = ElasticsearchStore( + embedding=OpenAIEmbeddings(), + index_name="langchain-demo", + es_url="http://localhost:9200", + strategy=ElasticsearchStore.ExactRetrievalStrategy() + ) + + Both strategies require that you know the similarity metric you want to use + when creating the index. The default is cosine similarity, but you can also + use dot product or euclidean distance. + + Use dot product similarity: + .. code-block:: python + + from langchain_elasticsearch.vectorstores import ElasticsearchStore + from langchain_openai import OpenAIEmbeddings + from langchain_community.vectorstores.utils import DistanceStrategy + + store = ElasticsearchStore( + "langchain-demo", + embedding=OpenAIEmbeddings(), + es_url="http://localhost:9200", + distance_strategy="DOT_PRODUCT" + ) + + """ # noqa: E501 + + def __init__( + self, + index_name: str, + *, + embedding: Optional[Embeddings] = None, + es_connection: Optional[AsyncElasticsearch] = None, + es_url: Optional[str] = None, + es_cloud_id: Optional[str] = None, + es_user: Optional[str] = None, + es_api_key: Optional[str] = None, + es_password: Optional[str] = None, + vector_query_field: str = "vector", + query_field: str = "text", + distance_strategy: Optional[ + Literal[ + DistanceStrategy.COSINE, + DistanceStrategy.DOT_PRODUCT, + DistanceStrategy.EUCLIDEAN_DISTANCE, + DistanceStrategy.MAX_INNER_PRODUCT, + ] + ] = None, + strategy: Union[ + BaseRetrievalStrategy, AsyncRetrievalStrategy + ] = ApproxRetrievalStrategy(), + es_params: Optional[Dict[str, Any]] = None, + ): + if isinstance(strategy, BaseRetrievalStrategy): + strategy = _convert_retrieval_strategy( + strategy, distance=distance_strategy or DistanceStrategy.COSINE + ) + + embedding_service = None + if embedding: + embedding_service = AsyncEmbeddingServiceAdapter(embedding) + + if not es_connection: + es_connection = create_async_elasticsearch_client( + url=es_url, + cloud_id=es_cloud_id, + api_key=es_api_key, + username=es_user, + password=es_password, + params=es_params, + ) + + self._store = EVectorStore( + client=es_connection, + index=index_name, + retrieval_strategy=strategy, + embedding_service=embedding_service, + text_field=query_field, + vector_field=vector_query_field, + user_agent=user_agent("langchain-py-vs"), + ) + + self.embedding = embedding + self.client = self._store.client + self._embedding_service = embedding_service + self.query_field = query_field + self.vector_query_field = vector_query_field + + async def aclose(self) -> None: + await self._store.close() + + @property + def embeddings(self) -> Optional[Embeddings]: + return self.embedding + + @staticmethod + def connect_to_elasticsearch( + *, + es_url: Optional[str] = None, + cloud_id: Optional[str] = None, + api_key: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + es_params: Optional[Dict[str, Any]] = None, + ) -> AsyncElasticsearch: + return create_async_elasticsearch_client( + url=es_url, + cloud_id=cloud_id, + api_key=api_key, + username=username, + password=password, + params=es_params, + ) + + async def asimilarity_search( + self, + query: str, + k: int = 4, + fetch_k: int = 50, + filter: Optional[List[dict]] = None, + *, + custom_query: Optional[ + Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] + ] = None, + doc_builder: Optional[Callable[[Dict], Document]] = None, + **kwargs: Any, + ) -> List[Document]: + """Return Elasticsearch documents most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to knn num_candidates. + filter: Array of Elasticsearch filter clauses to apply to the query. + + Returns: + List of Documents most similar to the query, + in descending order of similarity. + """ + hits = await self._store.search( + query=query, + k=k, + num_candidates=fetch_k, + filter=filter, + custom_query=custom_query, + ) + docs = _hits_to_docs_scores( + hits=hits, + content_field=self.query_field, + doc_builder=doc_builder, + ) + return [doc for doc, _score in docs] + + async def amax_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + fields: Optional[List[str]] = None, + *, + custom_query: Optional[ + Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] + ] = None, + doc_builder: Optional[Callable[[Dict], Document]] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + fields: Other fields to get from elasticsearch source. These fields + will be added to the document metadata. + + Returns: + List[Document]: A list of Documents selected by maximal marginal relevance. + """ + if self._embedding_service is None: + raise ValueError( + "maximal marginal relevance search requires an embedding service." + ) + + hits = await self._store.max_marginal_relevance_search( + embedding_service=self._embedding_service, + query=query, + vector_field=self.vector_query_field, + k=k, + num_candidates=fetch_k, + lambda_mult=lambda_mult, + fields=fields, + custom_query=custom_query, + ) + + docs_scores = _hits_to_docs_scores( + hits=hits, + content_field=self.query_field, + fields=fields, + doc_builder=doc_builder, + ) + + return [doc for doc, _score in docs_scores] + + @staticmethod + def _identity_fn(score: float) -> float: + return score + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + """ + The 'correct' relevance function + may differ depending on a few things, including: + - the distance / similarity metric used by the VectorStore + - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) + - embedding dimensionality + - etc. + + Vectorstores should define their own selection based method of relevance. + """ + # All scores from Elasticsearch are already normalized similarities: + # https://www.elastic.co/guide/en/elasticsearch/reference/current/dense-vector.html#dense-vector-params + return self._identity_fn + + async def asimilarity_search_with_score( + self, + query: str, + k: int = 4, + filter: Optional[List[dict]] = None, + *, + custom_query: Optional[ + Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] + ] = None, + doc_builder: Optional[Callable[[Dict], Document]] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return Elasticsearch documents most similar to query, along with scores. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Array of Elasticsearch filter clauses to apply to the query. + + Returns: + List of Documents most similar to the query and score for each + """ + if ( + isinstance(self._store.retrieval_strategy, AsyncDenseVectorStrategy) + and self._store.retrieval_strategy.hybrid + ): + raise ValueError("scores are currently not supported in hybrid mode") + + hits = await self._store.search( + query=query, k=k, filter=filter, custom_query=custom_query + ) + return _hits_to_docs_scores( + hits=hits, + content_field=self.query_field, + doc_builder=doc_builder, + ) + + async def asimilarity_search_by_vector_with_relevance_scores( + self, + embedding: List[float], + k: int = 4, + filter: Optional[List[Dict]] = None, + *, + custom_query: Optional[ + Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] + ] = None, + doc_builder: Optional[Callable[[Dict], Document]] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return Elasticsearch documents most similar to query, along with scores. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Array of Elasticsearch filter clauses to apply to the query. + + Returns: + List of Documents most similar to the embedding and score for each + """ + if ( + isinstance(self._store.retrieval_strategy, AsyncDenseVectorStrategy) + and self._store.retrieval_strategy.hybrid + ): + raise ValueError("scores are currently not supported in hybrid mode") + + hits = await self._store.search( + query=None, + query_vector=embedding, + k=k, + filter=filter, + custom_query=custom_query, + ) + return _hits_to_docs_scores( + hits=hits, + content_field=self.query_field, + doc_builder=doc_builder, + ) + + async def adelete( + self, + ids: Optional[List[str]] = None, + refresh_indices: Optional[bool] = True, + **kwargs: Any, + ) -> Optional[bool]: + """Delete documents from the Elasticsearch index. + + Args: + ids: List of ids of documents to delete. + refresh_indices: Whether to refresh the index + after deleting documents. Defaults to True. + """ + if ids is None: + raise ValueError("please specify some IDs") + + return await self._store.delete( + ids=ids, refresh_indices=refresh_indices or False + ) + + async def aadd_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[Dict[Any, Any]]] = None, + ids: Optional[List[str]] = None, + refresh_indices: bool = True, + create_index_if_not_exists: bool = True, + bulk_kwargs: Optional[Dict] = None, + **kwargs: Any, + ) -> List[str]: + """Run more texts through the embeddings and add to the store. + + Args: + texts: Iterable of strings to add to the store. + metadatas: Optional list of metadatas associated with the texts. + ids: Optional list of ids to associate with the texts. + refresh_indices: Whether to refresh the Elasticsearch indices + after adding the texts. + create_index_if_not_exists: Whether to create the Elasticsearch + index if it doesn't already exist. + *bulk_kwargs: Additional arguments to pass to Elasticsearch bulk. + - chunk_size: Optional. Number of texts to add to the + index at a time. Defaults to 500. + + Returns: + List of ids from adding the texts into the store. + """ + return await self._store.add_texts( + texts=list(texts), + metadatas=metadatas, + ids=ids, + refresh_indices=refresh_indices, + create_index_if_not_exists=create_index_if_not_exists, + bulk_kwargs=bulk_kwargs, + ) + + async def aadd_embeddings( + self, + text_embeddings: Iterable[Tuple[str, List[float]]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + refresh_indices: bool = True, + create_index_if_not_exists: bool = True, + bulk_kwargs: Optional[Dict] = None, + **kwargs: Any, + ) -> List[str]: + """Add the given texts and embeddings to the store. + + Args: + text_embeddings: Iterable pairs of string and embedding to + add to the store. + metadatas: Optional list of metadatas associated with the texts. + ids: Optional list of unique IDs. + refresh_indices: Whether to refresh the Elasticsearch indices + after adding the texts. + create_index_if_not_exists: Whether to create the Elasticsearch + index if it doesn't already exist. + *bulk_kwargs: Additional arguments to pass to Elasticsearch bulk. + - chunk_size: Optional. Number of texts to add to the + index at a time. Defaults to 500. + + Returns: + List of ids from adding the texts into the store. + """ + texts, embeddings = zip(*text_embeddings) + return await self._store.add_texts( + texts=list(texts), + metadatas=metadatas, + vectors=list(embeddings), + ids=ids, + refresh_indices=refresh_indices, + create_index_if_not_exists=create_index_if_not_exists, + bulk_kwargs=bulk_kwargs, + ) + + @classmethod + async def afrom_texts( + cls, + texts: List[str], + embedding: Optional[Embeddings] = None, + metadatas: Optional[List[Dict[str, Any]]] = None, + bulk_kwargs: Optional[Dict] = None, + **kwargs: Any, + ) -> "AsyncElasticsearchStore": + """Construct ElasticsearchStore wrapper from raw documents. + + Example: + .. code-block:: python + + from langchain_elasticsearch.vectorstores import ElasticsearchStore + from langchain_openai import OpenAIEmbeddings + + db = ElasticsearchStore.from_texts( + texts, + // embeddings optional if using + // a strategy that doesn't require inference + embeddings, + index_name="langchain-demo", + es_url="http://localhost:9200" + ) + + Args: + texts: List of texts to add to the Elasticsearch index. + embedding: Embedding function to use to embed the texts. + metadatas: Optional list of metadatas associated with the texts. + index_name: Name of the Elasticsearch index to create. + es_url: URL of the Elasticsearch instance to connect to. + cloud_id: Cloud ID of the Elasticsearch instance to connect to. + es_user: Username to use when connecting to Elasticsearch. + es_password: Password to use when connecting to Elasticsearch. + es_api_key: API key to use when connecting to Elasticsearch. + es_connection: Optional pre-existing Elasticsearch connection. + vector_query_field: Optional. Name of the field to + store the embedding vectors in. + query_field: Optional. Name of the field to store the texts in. + distance_strategy: Optional. Name of the distance + strategy to use. Defaults to "COSINE". + can be one of "COSINE", + "EUCLIDEAN_DISTANCE", "DOT_PRODUCT", + "MAX_INNER_PRODUCT". + bulk_kwargs: Optional. Additional arguments to pass to + Elasticsearch bulk. + """ + + index_name = kwargs.get("index_name") + if index_name is None: + raise ValueError("Please provide an index_name.") + + elasticsearchStore = cls(embedding=embedding, **kwargs) + + # Encode the provided texts and add them to the newly created index. + await elasticsearchStore.aadd_texts( + texts=texts, metadatas=metadatas, bulk_kwargs=bulk_kwargs + ) + + return elasticsearchStore + + @classmethod + async def afrom_documents( + cls, + documents: List[Document], + embedding: Optional[Embeddings] = None, + bulk_kwargs: Optional[Dict] = None, + **kwargs: Any, + ) -> "AsyncElasticsearchStore": + """Construct ElasticsearchStore wrapper from documents. + + Example: + .. code-block:: python + + from langchain_elasticsearch.vectorstores import ElasticsearchStore + from langchain_openai import OpenAIEmbeddings + + db = ElasticsearchStore.from_documents( + texts, + embeddings, + index_name="langchain-demo", + es_url="http://localhost:9200" + ) + + Args: + texts: List of texts to add to the Elasticsearch index. + embedding: Embedding function to use to embed the texts. + Do not provide if using a strategy + that doesn't require inference. + metadatas: Optional list of metadatas associated with the texts. + index_name: Name of the Elasticsearch index to create. + es_url: URL of the Elasticsearch instance to connect to. + cloud_id: Cloud ID of the Elasticsearch instance to connect to. + es_user: Username to use when connecting to Elasticsearch. + es_password: Password to use when connecting to Elasticsearch. + es_api_key: API key to use when connecting to Elasticsearch. + es_connection: Optional pre-existing Elasticsearch connection. + vector_query_field: Optional. Name of the field + to store the embedding vectors in. + query_field: Optional. Name of the field to store the texts in. + bulk_kwargs: Optional. Additional arguments to pass to + Elasticsearch bulk. + """ + + index_name = kwargs.get("index_name") + if index_name is None: + raise ValueError("Please provide an index_name.") + + elasticsearchStore = cls(embedding=embedding, **kwargs) + + # Encode the provided texts and add them to the newly created index. + await elasticsearchStore.aadd_documents(documents, bulk_kwargs=bulk_kwargs) + + return elasticsearchStore + + @staticmethod + def ExactRetrievalStrategy() -> "ExactRetrievalStrategy": + """Used to perform brute force / exact + nearest neighbor search via script_score.""" + return ExactRetrievalStrategy() + + @staticmethod + def ApproxRetrievalStrategy( + query_model_id: Optional[str] = None, + hybrid: Optional[bool] = False, + rrf: Optional[Union[dict, bool]] = True, + ) -> "ApproxRetrievalStrategy": + """Used to perform approximate nearest neighbor search + using the HNSW algorithm. + + At build index time, this strategy will create a + dense vector field in the index and store the + embedding vectors in the index. + + At query time, the text will either be embedded using the + provided embedding function or the query_model_id + will be used to embed the text using the model + deployed to Elasticsearch. + + if query_model_id is used, do not provide an embedding function. + + Args: + query_model_id: Optional. ID of the model to use to + embed the query text within the stack. Requires + embedding model to be deployed to Elasticsearch. + hybrid: Optional. If True, will perform a hybrid search + using both the knn query and a text query. + Defaults to False. + rrf: Optional. rrf is Reciprocal Rank Fusion. + When `hybrid` is True, + and `rrf` is True, then rrf: {}. + and `rrf` is False, then rrf is omitted. + and isinstance(rrf, dict) is True, then pass in the dict values. + rrf could be passed for adjusting 'rank_constant' and 'window_size'. + """ + return ApproxRetrievalStrategy( + query_model_id=query_model_id, hybrid=hybrid, rrf=rrf + ) + + @staticmethod + def SparseVectorRetrievalStrategy( + model_id: Optional[str] = None, + ) -> "SparseRetrievalStrategy": + """Used to perform sparse vector search via text_expansion. + Used for when you want to use ELSER model to perform document search. + + At build index time, this strategy will create a pipeline that + will embed the text using the ELSER model and store the + resulting tokens in the index. + + At query time, the text will be embedded using the ELSER + model and the resulting tokens will be used to + perform a text_expansion query. + + Args: + model_id: Optional. Default is ".elser_model_1". + ID of the model to use to embed the query text + within the stack. Requires embedding model to be + deployed to Elasticsearch. + """ + return SparseRetrievalStrategy(model_id=model_id) + + @staticmethod + def BM25RetrievalStrategy( + k1: Union[float, None] = None, b: Union[float, None] = None + ) -> "BM25RetrievalStrategy": + """Used to apply BM25 without vector search. + + Args: + k1: Optional. This corresponds to the BM25 parameter, k1. Default is None, + which uses the default setting of Elasticsearch. + b: Optional. This corresponds to the BM25 parameter, b. Default is None, + which uses the default setting of Elasticsearch. + """ + return BM25RetrievalStrategy(k1=k1, b=b) diff --git a/libs/elasticsearch/langchain_elasticsearch/_sync/cache.py b/libs/elasticsearch/langchain_elasticsearch/_sync/cache.py new file mode 100644 index 0000000..9620dba --- /dev/null +++ b/libs/elasticsearch/langchain_elasticsearch/_sync/cache.py @@ -0,0 +1,425 @@ +import base64 +import hashlib +import logging +from datetime import datetime +from functools import cached_property +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, +) + +from elasticsearch import Elasticsearch, exceptions, helpers +from elasticsearch.helpers import BulkIndexError +from langchain_core.caches import RETURN_VAL_TYPE, BaseCache +from langchain_core.load import dumps, loads +from langchain_core.stores import ByteStore + +from langchain_elasticsearch.client import create_elasticsearch_client + +if TYPE_CHECKING: + from elasticsearch import Elasticsearch + +logger = logging.getLogger(__name__) + + +def _manage_cache_index( + es_client: Elasticsearch, index_name: str, mapping: Dict[str, Any] +) -> bool: + """Write or update an index or alias according to the default mapping""" + if es_client.indices.exists_alias(name=index_name): + es_client.indices.put_mapping(index=index_name, body=mapping["mappings"]) + return True + + elif not es_client.indices.exists(index=index_name): + logger.debug(f"Creating new Elasticsearch index: {index_name}") + es_client.indices.create(index=index_name, body=mapping) + return False + + return False + + +class ElasticsearchCache(BaseCache): + """An Elasticsearch cache integration for LLMs. + + For synchronous applications, use the ``ElasticsearchCache`` class. + For asyhchronous applications, use the ``AsyncElasticsearchCache`` class. + """ + + def __init__( + self, + index_name: str, + store_input: bool = True, + store_input_params: bool = True, + metadata: Optional[Dict[str, Any]] = None, + *, + es_url: Optional[str] = None, + es_cloud_id: Optional[str] = None, + es_user: Optional[str] = None, + es_api_key: Optional[str] = None, + es_password: Optional[str] = None, + es_params: Optional[Dict[str, Any]] = None, + ): + """ + Initialize the Elasticsearch cache store by specifying the index/alias + to use and determining which additional information (like input, input + parameters, and any other metadata) should be stored in the cache. + + Args: + index_name (str): The name of the index or the alias to use for the cache. + If they do not exist an index is created, + according to the default mapping defined by the `mapping` property. + store_input (bool): Whether to store the LLM input in the cache, i.e., + the input prompt. Default to True. + store_input_params (bool): Whether to store the input parameters in the + cache, i.e., the LLM parameters used to generate the LLM response. + Default to True. + metadata (Optional[dict]): Additional metadata to store in the cache, + for filtering purposes. This must be JSON serializable in an + Elasticsearch document. Default to None. + es_url: URL of the Elasticsearch instance to connect to. + es_cloud_id: Cloud ID of the Elasticsearch instance to connect to. + es_user: Username to use when connecting to Elasticsearch. + es_password: Password to use when connecting to Elasticsearch. + es_api_key: API key to use when connecting to Elasticsearch. + es_params: Other parameters for the Elasticsearch client. + """ + + self._index_name = index_name + self._store_input = store_input + self._store_input_params = store_input_params + self._metadata = metadata + self._es_client = create_elasticsearch_client( + url=es_url, + cloud_id=es_cloud_id, + api_key=es_api_key, + username=es_user, + password=es_password, + params=es_params, + ) + self._is_alias: Optional[bool] = None + + def is_alias(self) -> bool: + if self._is_alias is None: + self._is_alias = _manage_cache_index( + self._es_client, + self._index_name, + self.mapping, + ) + return self._is_alias # type: ignore[return-value] + + @cached_property + def mapping(self) -> Dict[str, Any]: + """Get the default mapping for the index.""" + return { + "mappings": { + "properties": { + "llm_output": {"type": "text", "index": False}, + "llm_params": {"type": "text", "index": False}, + "llm_input": {"type": "text", "index": False}, + "metadata": {"type": "object"}, + "timestamp": {"type": "date"}, + } + } + } + + @staticmethod + def _key(prompt: str, llm_string: str) -> str: + """Generate a key for the cache store.""" + return hashlib.md5((prompt + llm_string).encode()).hexdigest() + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + cache_key = self._key(prompt, llm_string) + if self.is_alias(): + # get the latest record according to its writing date, in order to + # address cases where multiple indices have a doc with the same id + result = self._es_client.search( + index=self._index_name, + body={ + "query": {"term": {"_id": cache_key}}, + "sort": {"timestamp": {"order": "asc"}}, + }, + source_includes=["llm_output"], + ) + if result["hits"]["total"]["value"] > 0: + record = result["hits"]["hits"][0] + else: + return None + else: + try: + record = self._es_client.get( + index=self._index_name, id=cache_key, source=["llm_output"] + ) + except exceptions.NotFoundError: + return None + return [loads(item) for item in record["_source"]["llm_output"]] + + def build_document( + self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE + ) -> Dict[str, Any]: + """Build the Elasticsearch document for storing a single LLM interaction""" + body: Dict[str, Any] = { + "llm_output": [dumps(item) for item in return_val], + "timestamp": datetime.now().isoformat(), + } + if self._store_input_params: + body["llm_params"] = llm_string + if self._metadata is not None: + body["metadata"] = self._metadata + if self._store_input: + body["llm_input"] = prompt + return body + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update based on prompt and llm_string.""" + body = self.build_document(prompt, llm_string, return_val) + self._es_client.index( + index=self._index_name, + id=self._key(prompt, llm_string), + body=body, + require_alias=self.is_alias(), + refresh=True, + ) + + def clear(self, **kwargs: Any) -> None: + """Clear cache.""" + self._es_client.delete_by_query( + index=self._index_name, + body={"query": {"match_all": {}}}, + refresh=True, + wait_for_completion=True, + ) + + +class ElasticsearchEmbeddingsCache(ByteStore): + """An Elasticsearch store for caching embeddings. + + For synchronous applications, use the `ElasticsearchEmbeddingsCache` class. + For asyhchronous applications, use the `AsyncElasticsearchEmbeddingsCache` class. + """ + + def __init__( + self, + index_name: str, + store_input: bool = True, + metadata: Optional[Dict[str, Any]] = None, + namespace: Optional[str] = None, + maximum_duplicates_allowed: int = 1, + *, + es_url: Optional[str] = None, + es_cloud_id: Optional[str] = None, + es_user: Optional[str] = None, + es_api_key: Optional[str] = None, + es_password: Optional[str] = None, + es_params: Optional[Dict[str, Any]] = None, + ): + """ + Initialize the Elasticsearch cache store by specifying the index/alias + to use and determining which additional information (like input, input + parameters, and any other metadata) should be stored in the cache. + Provide a namespace to organize the cache. + + + Args: + index_name (str): The name of the index or the alias to use for the cache. + If they do not exist an index is created, + according to the default mapping defined by the `mapping` property. + store_input (bool): Whether to store the input in the cache. + Default to True. + metadata (Optional[dict]): Additional metadata to store in the cache, + for filtering purposes. This must be JSON serializable in an + Elasticsearch document. Default to None. + namespace (Optional[str]): A namespace to use for the cache. + maximum_duplicates_allowed (int): Defines the maximum number of duplicate + keys permitted. Must be used in scenarios where the same key appears + across multiple indices that share the same alias. Default to 1. + es_url: URL of the Elasticsearch instance to connect to. + es_cloud_id: Cloud ID of the Elasticsearch instance to connect to. + es_user: Username to use when connecting to Elasticsearch. + es_password: Password to use when connecting to Elasticsearch. + es_api_key: API key to use when connecting to Elasticsearch. + es_params: Other parameters for the Elasticsearch client. + """ + self._namespace = namespace + self._maximum_duplicates_allowed = maximum_duplicates_allowed + self._index_name = index_name + self._store_input = store_input + self._metadata = metadata + self._es_client = create_elasticsearch_client( + url=es_url, + cloud_id=es_cloud_id, + api_key=es_api_key, + username=es_user, + password=es_password, + params=es_params, + ) + self._is_alias: Optional[bool] = None + + def is_alias(self) -> bool: + if self._is_alias is None: + self._is_alias = _manage_cache_index( + self._es_client, + self._index_name, + self.mapping, + ) + return self._is_alias # type: ignore[return-value] + + @staticmethod + def encode_vector(data: bytes) -> str: + """Encode the vector data as bytes to as a base64 string.""" + return base64.b64encode(data).decode("utf-8") + + @staticmethod + def decode_vector(data: str) -> bytes: + """Decode the base64 string to vector data as bytes.""" + return base64.b64decode(data) + + @cached_property + def mapping(self) -> Dict[str, Any]: + """Get the default mapping for the index.""" + return { + "mappings": { + "properties": { + "text_input": {"type": "text", "index": False}, + "vector_dump": { + "type": "binary", + "doc_values": False, + }, + "metadata": {"type": "object"}, + "timestamp": {"type": "date"}, + "namespace": {"type": "keyword"}, + } + } + } + + def _key(self, input_text: str) -> str: + """Generate a key for the store.""" + return hashlib.md5(((self._namespace or "") + input_text).encode()).hexdigest() + + @classmethod + def _deduplicate_hits(cls, hits: List[dict]) -> Dict[str, bytes]: + """ + Collapse the results from a search query with multiple indices + returning only the latest version of the documents + """ + map_ids = {} + for hit in sorted( + hits, + key=lambda x: datetime.fromisoformat(x["_source"]["timestamp"]), + reverse=True, + ): + vector_id: str = hit["_id"] + if vector_id not in map_ids: + map_ids[vector_id] = cls.decode_vector(hit["_source"]["vector_dump"]) + + return map_ids + + def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: + """Get the values associated with the given keys.""" + if not any(keys): + return [] + + cache_keys = [self._key(k) for k in keys] + if self.is_alias(): + try: + results = self._es_client.search( + index=self._index_name, + body={ + "query": {"ids": {"values": cache_keys}}, + "size": len(cache_keys) * self._maximum_duplicates_allowed, + }, + source_includes=["vector_dump", "timestamp"], + ) + + except exceptions.BadRequestError as e: + if "window too large" in ( + e.body.get("error", {}).get("root_cause", [{}])[0].get("reason", "") + ): + logger.warning( + "Exceeded the maximum window size, " + "Reduce the duplicates manually or lower " + "`maximum_duplicate_allowed.`" + ) + raise e + + total_hits = results["hits"]["total"]["value"] + if self._maximum_duplicates_allowed > 1 and total_hits > len(cache_keys): + logger.warning( + f"Deduplicating, found {total_hits} hits for {len(cache_keys)} keys" + ) + map_ids = self._deduplicate_hits(results["hits"]["hits"]) + else: + map_ids = { + r["_id"]: self.decode_vector(r["_source"]["vector_dump"]) + for r in results["hits"]["hits"] + } + + return [map_ids.get(k) for k in cache_keys] + + else: + records = self._es_client.mget( + index=self._index_name, ids=cache_keys, source_includes=["vector_dump"] + ) + return [ + self.decode_vector(r["_source"]["vector_dump"]) if r["found"] else None + for r in records["docs"] + ] + + def build_document(self, text_input: str, vector: bytes) -> Dict[str, Any]: + """Build the Elasticsearch document for storing a single embedding""" + body: Dict[str, Any] = { + "vector_dump": self.encode_vector(vector), + "timestamp": datetime.now().isoformat(), + } + if self._metadata is not None: + body["metadata"] = self._metadata + if self._store_input: + body["text_input"] = text_input + if self._namespace: + body["namespace"] = self._namespace + return body + + def _bulk(self, actions: Iterable[Dict[str, Any]]) -> None: + try: + helpers.bulk( + client=self._es_client, + actions=actions, + index=self._index_name, + require_alias=self.is_alias(), + refresh=True, + ) + except BulkIndexError as e: + first_error = e.errors[0].get("index", {}).get("error", {}) + logger.error(f"First bulk error reason: {first_error.get('reason')}") + raise e + + def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: + """Set the values for the given keys.""" + actions = ( + { + "_op_type": "index", + "_id": self._key(key), + "_source": self.build_document(key, vector), + } + for key, vector in key_value_pairs + ) + self._bulk(actions) + + def mdelete(self, keys: Sequence[str]) -> None: + """Delete the given keys and their associated values.""" + actions = ({"_op_type": "delete", "_id": self._key(key)} for key in keys) + self._bulk(actions) + + def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: # type: ignore[override] + """Get an iterator over keys that match the given prefix.""" + # TODO This method is not currently used by CacheBackedEmbeddings, + # we can leave it blank. It could be implemented with ES "index_prefixes", + # but they are limited and expensive. + raise NotImplementedError() diff --git a/libs/elasticsearch/langchain_elasticsearch/_sync/chat_history.py b/libs/elasticsearch/langchain_elasticsearch/_sync/chat_history.py new file mode 100644 index 0000000..5c5614d --- /dev/null +++ b/libs/elasticsearch/langchain_elasticsearch/_sync/chat_history.py @@ -0,0 +1,166 @@ +import json +import logging +from time import time +from typing import TYPE_CHECKING, List, Optional, Sequence + +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict + +from langchain_elasticsearch._utilities import with_user_agent_header +from langchain_elasticsearch.client import create_elasticsearch_client + +if TYPE_CHECKING: + from elasticsearch import Elasticsearch + +logger = logging.getLogger(__name__) + + +class ElasticsearchChatMessageHistory(BaseChatMessageHistory): + """Chat message history that stores history in Elasticsearch. + + Args: + es_url: URL of the Elasticsearch instance to connect to. + es_cloud_id: Cloud ID of the Elasticsearch instance to connect to. + es_user: Username to use when connecting to Elasticsearch. + es_password: Password to use when connecting to Elasticsearch. + es_api_key: API key to use when connecting to Elasticsearch. + es_connection: Optional pre-existing Elasticsearch connection. + esnsure_ascii: Used to escape ASCII symbols in json.dumps. Defaults to True. + index: Name of the index to use. + session_id: Arbitrary key that is used to store the messages + of a single chat session. + + For synchronous applications, use the `ElasticsearchChatMessageHistory` class. + For asyhchronous applications, use the `AsyncElasticsearchChatMessageHistory` class. + """ + + def __init__( + self, + index: str, + session_id: str, + *, + es_connection: Optional["Elasticsearch"] = None, + es_url: Optional[str] = None, + es_cloud_id: Optional[str] = None, + es_user: Optional[str] = None, + es_api_key: Optional[str] = None, + es_password: Optional[str] = None, + esnsure_ascii: Optional[bool] = True, + ): + self.index: str = index + self.session_id: str = session_id + self.ensure_ascii = esnsure_ascii + + # Initialize Elasticsearch client from passed client arg or connection info + if es_connection is not None: + self.client = es_connection + elif es_url is not None or es_cloud_id is not None: + try: + self.client = create_elasticsearch_client( + url=es_url, + username=es_user, + password=es_password, + cloud_id=es_cloud_id, + api_key=es_api_key, + ) + except Exception as err: + logger.error(f"Error connecting to Elasticsearch: {err}") + raise err + else: + raise ValueError( + """Either provide a pre-existing Elasticsearch connection, \ + or valid credentials for creating a new connection.""" + ) + + self.client = with_user_agent_header(self.client, "langchain-py-ms") + self.created = False + + def create_if_missing(self) -> None: + if not self.created: + if self.client.indices.exists(index=self.index): + logger.debug( + ( + f"Chat history index {self.index} already exists, " + "skipping creation." + ) + ) + else: + logger.debug(f"Creating index {self.index} for storing chat history.") + + self.client.indices.create( + index=self.index, + mappings={ + "properties": { + "session_id": {"type": "keyword"}, + "created_at": {"type": "date"}, + "history": {"type": "text"}, + } + }, + ) + self.created = True + + def get_messages(self) -> List[BaseMessage]: # type: ignore[override] + """Retrieve the messages from Elasticsearch""" + try: + from elasticsearch import ApiError + + self.create_if_missing() + result = self.client.search( + index=self.index, + query={"term": {"session_id": self.session_id}}, + sort="created_at:asc", + ) + except ApiError as err: + logger.error(f"Could not retrieve messages from Elasticsearch: {err}") + raise err + + if result and len(result["hits"]["hits"]) > 0: + items = [ + json.loads(document["_source"]["history"]) + for document in result["hits"]["hits"] + ] + else: + items = [] + + return messages_from_dict(items) + + def add_message(self, message: BaseMessage) -> None: + """Add messages to the chat session in Elasticsearch""" + try: + from elasticsearch import ApiError + + self.create_if_missing() + self.client.index( + index=self.index, + document={ + "session_id": self.session_id, + "created_at": round(time() * 1000), + "history": json.dumps( + message_to_dict(message), + ensure_ascii=bool(self.ensure_ascii), + ), + }, + refresh=True, + ) + except ApiError as err: + logger.error(f"Could not add message to Elasticsearch: {err}") + raise err + + def add_messages(self, messages: Sequence[BaseMessage]) -> None: + for message in messages: + self.add_message(message) + + def clear(self) -> None: + """Clear session memory in Elasticsearch""" + try: + from elasticsearch import ApiError + + self.create_if_missing() + self.client.delete_by_query( + index=self.index, + query={"term": {"session_id": self.session_id}}, + refresh=True, + ) + except ApiError as err: + logger.error(f"Could not clear session memory in Elasticsearch: {err}") + raise err diff --git a/libs/elasticsearch/langchain_elasticsearch/_sync/embeddings.py b/libs/elasticsearch/langchain_elasticsearch/_sync/embeddings.py new file mode 100644 index 0000000..03a6a64 --- /dev/null +++ b/libs/elasticsearch/langchain_elasticsearch/_sync/embeddings.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional + +from elasticsearch import Elasticsearch +from elasticsearch.helpers.vectorstore import EmbeddingService +from langchain_core.embeddings import Embeddings +from langchain_core.utils import get_from_env + +if TYPE_CHECKING: + from elasticsearch._sync.client.ml import MlClient + + +class ElasticsearchEmbeddings(Embeddings): + """Elasticsearch embedding models. + + This class provides an interface to generate embeddings using a model deployed + in an Elasticsearch cluster. It requires an Elasticsearch connection object + and the model_id of the model deployed in the cluster. + + In Elasticsearch you need to have an embedding model loaded and deployed. + - https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html + - https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html + + For synchronous applications, use the `ElasticsearchEmbeddings` class. + For asyhchronous applications, use the `AsyncElasticsearchEmbeddings` class. + """ # noqa: E501 + + def __init__( + self, + client: MlClient, + model_id: str, + *, + input_field: str = "text_field", + ): + """ + Initialize the ElasticsearchEmbeddings instance. + + Args: + client (MlClient): An Elasticsearch ML client object. + model_id (str): The model_id of the model deployed in the Elasticsearch + cluster. + input_field (str): The name of the key for the input text field in the + document. Defaults to 'text_field'. + """ + self.client = client + self.model_id = model_id + self.input_field = input_field + + @classmethod + def from_credentials( + cls, + model_id: str, + *, + es_cloud_id: Optional[str] = None, + es_api_key: Optional[str] = None, + input_field: str = "text_field", + ) -> ElasticsearchEmbeddings: + """Instantiate embeddings from Elasticsearch credentials. + + Args: + model_id (str): The model_id of the model deployed in the Elasticsearch + cluster. + input_field (str): The name of the key for the input text field in the + document. Defaults to 'text_field'. + es_cloud_id: (str, optional): The Elasticsearch cloud ID to connect to. + es_user: (str, optional): Elasticsearch username. + es_password: (str, optional): Elasticsearch password. + + Example: + .. code-block:: python + + from langchain_elasticserach.embeddings import ElasticsearchEmbeddings + + # Define the model ID and input field name (if different from default) + model_id = "your_model_id" + # Optional, only if different from 'text_field' + input_field = "your_input_field" + + # Credentials can be passed in two ways. Either set the env vars + # ES_CLOUD_ID, ES_USER, ES_PASSWORD and they will be automatically + # pulled in, or pass them in directly as kwargs. + embeddings = ElasticsearchEmbeddings.from_credentials( + model_id, + input_field=input_field, + # es_cloud_id="foo", + # es_user="bar", + # es_password="baz", + ) + + documents = [ + "This is an example document.", + "Another example document to generate embeddings for.", + ] + embeddings_generator.embed_documents(documents) + """ + from elasticsearch._sync.client.ml import MlClient + + es_cloud_id = es_cloud_id or get_from_env("es_cloud_id", "ES_CLOUD_ID") + es_api_key = es_api_key or get_from_env("es_api_key", "ES_API_KEY") + + # Connect to Elasticsearch + es_connection = Elasticsearch(cloud_id=es_cloud_id, api_key=es_api_key) + client = MlClient(es_connection) + return cls(client, model_id, input_field=input_field) + + @classmethod + def from_es_connection( + cls, + model_id: str, + es_connection: Elasticsearch, + input_field: str = "text_field", + ) -> ElasticsearchEmbeddings: + """ + Instantiate embeddings from an existing Elasticsearch connection. + + This method provides a way to create an instance of the ElasticsearchEmbeddings + class using an existing Elasticsearch connection. The connection object is used + to create an MlClient, which is then used to initialize the + ElasticsearchEmbeddings instance. + + Args: + model_id (str): The model_id of the model deployed in the Elasticsearch cluster. + es_connection (elasticsearch.Elasticsearch): An existing Elasticsearch + connection object. input_field (str, optional): The name of the key for the + input text field in the document. Defaults to 'text_field'. + + Returns: + ElasticsearchEmbeddings: An instance of the ElasticsearchEmbeddings class. + + Example: + .. code-block:: python + + from elasticsearch import Elasticsearch + + from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings + + # Define the model ID and input field name (if different from default) + model_id = "your_model_id" + # Optional, only if different from 'text_field' + input_field = "your_input_field" + + # Create Elasticsearch connection + es_connection = Elasticsearch( + hosts=["localhost:9200"], http_auth=("user", "password") + ) + + # Instantiate ElasticsearchEmbeddings using the existing connection + embeddings = ElasticsearchEmbeddings.from_es_connection( + model_id, + es_connection, + input_field=input_field, + ) + + documents = [ + "This is an example document.", + "Another example document to generate embeddings for.", + ] + embeddings_generator.embed_documents(documents) + """ + from elasticsearch._sync.client.ml import MlClient + + # Create an MlClient from the given Elasticsearch connection + client = MlClient(es_connection) + + # Return a new instance of the ElasticsearchEmbeddings class with + # the MlClient, model_id, and input_field + return cls(client, model_id, input_field=input_field) + + def _embedding_func(self, texts: List[str]) -> List[List[float]]: + """ + Generate embeddings for the given texts using the Elasticsearch model. + + Args: + texts (List[str]): A list of text strings to generate embeddings for. + + Returns: + List[List[float]]: A list of embeddings, one for each text in the input + list. + """ + response = self.client.infer_trained_model( + model_id=self.model_id, docs=[{self.input_field: text} for text in texts] + ) + + embeddings = [doc["predicted_value"] for doc in response["inference_results"]] + return embeddings + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """ + Generate embeddings for a list of documents. + + Args: + texts (List[str]): A list of document text strings to generate embeddings + for. + + Returns: + List[List[float]]: A list of embeddings, one for each document in the input + list. + """ + return self._embedding_func(texts) + + def embed_query(self, text: str) -> List[float]: + """ + Generate an embedding for a single query text. + + Args: + text (str): The query text to generate an embedding for. + + Returns: + List[float]: The embedding for the input query text. + """ + return (self._embedding_func([text]))[0] + + +class EmbeddingServiceAdapter(EmbeddingService): + """ + Adapter for LangChain Embeddings to support the EmbeddingService interface from + elasticsearch.helpers.vectorstore. + """ + + def __init__(self, langchain_embeddings: Embeddings): + self._langchain_embeddings = langchain_embeddings + + def __eq__(self, other): # type: ignore[no-untyped-def] + if isinstance(other, self.__class__): + return self.__dict__ == other.__dict__ + else: + return False + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """ + Generate embeddings for a list of documents. + + Args: + texts (List[str]): A list of document text strings to generate embeddings + for. + + Returns: + List[List[float]]: A list of embeddings, one for each document in the input + list. + """ + return self._langchain_embeddings.embed_documents(texts) + + def embed_query(self, text: str) -> List[float]: + """ + Generate an embedding for a single query text. + + Args: + text (str): The query text to generate an embedding for. + + Returns: + List[float]: The embedding for the input query text. + """ + return self._langchain_embeddings.embed_query(text) diff --git a/libs/elasticsearch/langchain_elasticsearch/_sync/retrievers.py b/libs/elasticsearch/langchain_elasticsearch/_sync/retrievers.py new file mode 100644 index 0000000..c35974d --- /dev/null +++ b/libs/elasticsearch/langchain_elasticsearch/_sync/retrievers.py @@ -0,0 +1,118 @@ +import logging +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast + +from elasticsearch import Elasticsearch +from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.documents import Document +from langchain_core.retrievers import BaseRetriever + +from langchain_elasticsearch._utilities import with_user_agent_header +from langchain_elasticsearch.client import create_elasticsearch_client + +logger = logging.getLogger(__name__) + + +class ElasticsearchRetriever(BaseRetriever): + """ + Elasticsearch retriever + + Args: + es_client: Elasticsearch client connection. Alternatively you can use the + `from_es_params` method with parameters to initialize the client. + index_name: The name of the index to query. Can also be a list of names. + body_func: Function to create an Elasticsearch DSL query body from a search + string. The returned query body must fit what you would normally send in a + POST request the the _search endpoint. If applicable, it also includes + parameters the `size` parameter etc. + content_field: The document field name that contains the page content. If + multiple indices are queried, specify a dict {index_name: field_name} here. + document_mapper: Function to map Elasticsearch hits to LangChain Documents. + + For synchronous applications, use the ``ElasticsearchRetriever`` class. + For asyhchronous applications, use the ``AsyncElasticsearchRetriever`` class. + """ + + es_client: Elasticsearch + index_name: Union[str, Sequence[str]] + body_func: Callable[[str], Dict] + content_field: Optional[Union[str, Mapping[str, str]]] = None + document_mapper: Optional[Callable[[Mapping], Document]] = None + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + if self.content_field is None and self.document_mapper is None: + raise ValueError("One of content_field or document_mapper must be defined.") + if self.content_field is not None and self.document_mapper is not None: + raise ValueError( + "Both content_field and document_mapper are defined. " + "Please provide only one." + ) + + if not self.document_mapper: + if isinstance(self.content_field, str): + self.document_mapper = self._single_field_mapper + elif isinstance(self.content_field, Mapping): + self.document_mapper = self._multi_field_mapper + else: + raise ValueError( + "unknown type for content_field, expected string or dict." + ) + + self.es_client = with_user_agent_header(self.es_client, "langchain-py-r") + + @classmethod + def from_es_params( + cls, + index_name: Union[str, Sequence[str]], + body_func: Callable[[str], Dict], + content_field: Optional[Union[str, Mapping[str, str]]] = None, + document_mapper: Optional[Callable[[Mapping], Document]] = None, + url: Optional[str] = None, + cloud_id: Optional[str] = None, + api_key: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + params: Optional[Dict[str, Any]] = None, + ) -> "ElasticsearchRetriever": + client = None + try: + client = create_elasticsearch_client( + url=url, + cloud_id=cloud_id, + api_key=api_key, + username=username, + password=password, + params=params, + ) + except Exception as err: + logger.error(f"Error connecting to Elasticsearch: {err}") + raise err + + return cls( + es_client=client, + index_name=index_name, + body_func=body_func, + content_field=content_field, + document_mapper=document_mapper, + ) + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + if not self.es_client or not self.document_mapper: + raise ValueError("faulty configuration") # should not happen + + body = self.body_func(query) + results = self.es_client.search(index=self.index_name, body=body) + return [self.document_mapper(hit) for hit in results["hits"]["hits"]] + + def _single_field_mapper(self, hit: Mapping[str, Any]) -> Document: + content = hit["_source"].pop(self.content_field) + return Document(page_content=content, metadata=hit) + + def _multi_field_mapper(self, hit: Mapping[str, Any]) -> Document: + self.content_field = cast(Mapping, self.content_field) + field = self.content_field[hit["_index"]] + content = hit["_source"].pop(field) + return Document(page_content=content, metadata=hit) diff --git a/libs/elasticsearch/langchain_elasticsearch/_sync/vectorstores.py b/libs/elasticsearch/langchain_elasticsearch/_sync/vectorstores.py new file mode 100644 index 0000000..758b225 --- /dev/null +++ b/libs/elasticsearch/langchain_elasticsearch/_sync/vectorstores.py @@ -0,0 +1,858 @@ +import logging +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union + +from elasticsearch import Elasticsearch +from elasticsearch.helpers.vectorstore import ( + BM25Strategy, + DenseVectorScriptScoreStrategy, + DenseVectorStrategy, + DistanceMetric, + RetrievalStrategy, + SparseVectorStrategy, +) +from elasticsearch.helpers.vectorstore import VectorStore as EVectorStore +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.vectorstores import VectorStore + +from langchain_elasticsearch._utilities import ( + ApproxRetrievalStrategy, + BaseRetrievalStrategy, + BM25RetrievalStrategy, + DistanceStrategy, + ExactRetrievalStrategy, + SparseRetrievalStrategy, + _hits_to_docs_scores, + user_agent, +) +from langchain_elasticsearch.client import create_elasticsearch_client +from langchain_elasticsearch.embeddings import EmbeddingServiceAdapter + +logger = logging.getLogger(__name__) + + +def _convert_retrieval_strategy( + langchain_strategy: BaseRetrievalStrategy, + distance: Optional[DistanceStrategy] = None, +) -> RetrievalStrategy: + if isinstance(langchain_strategy, ApproxRetrievalStrategy): + if distance is None: + raise ValueError( + "ApproxRetrievalStrategy requires a distance strategy to be provided." + ) + return DenseVectorStrategy( + distance=DistanceMetric[distance], + model_id=langchain_strategy.query_model_id, + hybrid=( + False + if langchain_strategy.hybrid is None + else langchain_strategy.hybrid + ), + rrf=False if langchain_strategy.rrf is None else langchain_strategy.rrf, + ) + elif isinstance(langchain_strategy, ExactRetrievalStrategy): + if distance is None: + raise ValueError( + "ExactRetrievalStrategy requires a distance strategy to be provided." + ) + return DenseVectorScriptScoreStrategy(distance=DistanceMetric[distance]) + elif isinstance(langchain_strategy, SparseRetrievalStrategy): + return SparseVectorStrategy(langchain_strategy.model_id) + elif isinstance(langchain_strategy, BM25RetrievalStrategy): + return BM25Strategy(k1=langchain_strategy.k1, b=langchain_strategy.b) + else: + raise TypeError( + f"Strategy {langchain_strategy} not supported. To provide a " + f"custom strategy, please subclass {RetrievalStrategy}." + ) + + +class ElasticsearchStore(VectorStore): + """`Elasticsearch` vector store. + + Setup: + Install ``langchain_elasticsearch`` and running the Elasticsearch docker container. + + .. code-block:: bash + + pip install -qU langchain_elasticsearch + docker run -p 9200:9200 \ + -e "discovery.type=single-node" \ + -e "xpack.security.enabled=false" \ + -e "xpack.security.http.ssl.enabled=false" \ + docker.elastic.co/elasticsearch/elasticsearch:8.12.1 + + Key init args — indexing params: + index_name: str + Name of the index to create. + embedding: Embeddings + Embedding function to use. + + Key init args — client params: + es_connection: Optional[Elasticsearch] + Pre-existing Elasticsearch connection. + es_url: Optional[str] + URL of the Elasticsearch instance to connect to. + es_cloud_id: Optional[str] + Cloud ID of the Elasticsearch instance to connect to. + es_user: Optional[str] + Username to use when connecting to Elasticsearch. + es_password: Optional[str] + Password to use when connecting to Elasticsearch. + es_api_key: Optional[str] + API key to use when connecting to Elasticsearch. + + Instantiate: + .. code-block:: python + + from langchain_elasticsearch import ElasticsearchStore + from langchain_openai import OpenAIEmbeddings + + vector_store = ElasticsearchStore( + index_name="langchain-demo", + embedding=OpenAIEmbeddings(), + es_url="http://localhost:9200", + ) + + If you want to use a cloud hosted Elasticsearch instance, you can pass in the + cloud_id argument instead of the es_url argument. + + Instantiate from cloud: + .. code-block:: python + + from langchain_elasticsearch.vectorstores import ElasticsearchStore + from langchain_openai import OpenAIEmbeddings + + store = ElasticsearchStore( + embedding=OpenAIEmbeddings(), + index_name="langchain-demo", + es_cloud_id="" + es_user="elastic", + es_password="" + ) + + You can also connect to an existing Elasticsearch instance by passing in a + pre-existing Elasticsearch connection via the es_connection argument. + + Instantiate from existing connection: + .. code-block:: python + + from langchain_elasticsearch.vectorstores import ElasticsearchStore + from langchain_openai import OpenAIEmbeddings + + from elasticsearch import Elasticsearch + + es_connection = Elasticsearch("http://localhost:9200") + + store = ElasticsearchStore( + embedding=OpenAIEmbeddings(), + index_name="langchain-demo", + es_connection=es_connection + ) + + Add Documents: + .. code-block:: python + + from langchain_core.documents import Document + + document_1 = Document(page_content="foo", metadata={"baz": "bar"}) + document_2 = Document(page_content="thud", metadata={"bar": "baz"}) + document_3 = Document(page_content="i will be deleted :(") + + documents = [document_1, document_2, document_3] + ids = ["1", "2", "3"] + vector_store.add_documents(documents=documents, ids=ids) + + Delete Documents: + .. code-block:: python + + vector_store.delete(ids=["3"]) + + Search: + .. code-block:: python + + results = vector_store.similarity_search(query="thud",k=1) + for doc in results: + print(f"* {doc.page_content} [{doc.metadata}]") + + .. code-block:: python + + * thud [{'bar': 'baz'}] + + Search with filter: + .. code-block:: python + + results = vector_store.similarity_search(query="thud",k=1,filter=[{"term": {"metadata.bar.keyword": "baz"}}]) + for doc in results: + print(f"* {doc.page_content} [{doc.metadata}]") + + .. code-block:: python + + * thud [{'bar': 'baz'}] + + Search with score: + .. code-block:: python + + results = vector_store.similarity_search_with_score(query="qux",k=1) + for doc, score in results: + print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]") + + .. code-block:: python + + * [SIM=0.916092] foo [{'baz': 'bar'}] + + Async: + .. code-block:: python + + from langchain_elasticsearch import AsyncElasticsearchStore + + vector_store = AsyncElasticsearchStore(...) + + # add documents + await vector_store.aadd_documents(documents=documents, ids=ids) + + # delete documents + await vector_store.adelete(ids=["3"]) + + # search + results = vector_store.asimilarity_search(query="thud",k=1) + + # search with score + results = await vector_store.asimilarity_search_with_score(query="qux",k=1) + for doc,score in results: + print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]") + + .. code-block:: python + + * [SIM=0.916092] foo [{'baz': 'bar'}] + + Use as Retriever: + + .. code-block:: bash + + pip install "elasticsearch[vectorstore_mmr]" + + .. code-block:: python + + retriever = vector_store.as_retriever( + search_type="mmr", + search_kwargs={"k": 1, "fetch_k": 2, "lambda_mult": 0.5}, + ) + retriever.invoke("thud") + + .. code-block:: python + + [Document(metadata={'bar': 'baz'}, page_content='thud')] + + **Advanced Uses:** + + ElasticsearchStore by default uses the ApproxRetrievalStrategy, which uses the + HNSW algorithm to perform approximate nearest neighbor search. This is the + fastest and most memory efficient algorithm. + + If you want to use the Brute force / Exact strategy for searching vectors, you + can pass in the ExactRetrievalStrategy to the ElasticsearchStore constructor. + + Use ExactRetrievalStrategy: + .. code-block:: python + + from langchain_elasticsearch.vectorstores import ElasticsearchStore + from langchain_openai import OpenAIEmbeddings + + store = ElasticsearchStore( + embedding=OpenAIEmbeddings(), + index_name="langchain-demo", + es_url="http://localhost:9200", + strategy=ElasticsearchStore.ExactRetrievalStrategy() + ) + + Both strategies require that you know the similarity metric you want to use + when creating the index. The default is cosine similarity, but you can also + use dot product or euclidean distance. + + Use dot product similarity: + .. code-block:: python + + from langchain_elasticsearch.vectorstores import ElasticsearchStore + from langchain_openai import OpenAIEmbeddings + from langchain_community.vectorstores.utils import DistanceStrategy + + store = ElasticsearchStore( + "langchain-demo", + embedding=OpenAIEmbeddings(), + es_url="http://localhost:9200", + distance_strategy="DOT_PRODUCT" + ) + + """ # noqa: E501 + + def __init__( + self, + index_name: str, + *, + embedding: Optional[Embeddings] = None, + es_connection: Optional[Elasticsearch] = None, + es_url: Optional[str] = None, + es_cloud_id: Optional[str] = None, + es_user: Optional[str] = None, + es_api_key: Optional[str] = None, + es_password: Optional[str] = None, + vector_query_field: str = "vector", + query_field: str = "text", + distance_strategy: Optional[ + Literal[ + DistanceStrategy.COSINE, + DistanceStrategy.DOT_PRODUCT, + DistanceStrategy.EUCLIDEAN_DISTANCE, + DistanceStrategy.MAX_INNER_PRODUCT, + ] + ] = None, + strategy: Union[ + BaseRetrievalStrategy, RetrievalStrategy + ] = ApproxRetrievalStrategy(), + es_params: Optional[Dict[str, Any]] = None, + ): + if isinstance(strategy, BaseRetrievalStrategy): + strategy = _convert_retrieval_strategy( + strategy, distance=distance_strategy or DistanceStrategy.COSINE + ) + + embedding_service = None + if embedding: + embedding_service = EmbeddingServiceAdapter(embedding) + + if not es_connection: + es_connection = create_elasticsearch_client( + url=es_url, + cloud_id=es_cloud_id, + api_key=es_api_key, + username=es_user, + password=es_password, + params=es_params, + ) + + self._store = EVectorStore( + client=es_connection, + index=index_name, + retrieval_strategy=strategy, + embedding_service=embedding_service, + text_field=query_field, + vector_field=vector_query_field, + user_agent=user_agent("langchain-py-vs"), + ) + + self.embedding = embedding + self.client = self._store.client + self._embedding_service = embedding_service + self.query_field = query_field + self.vector_query_field = vector_query_field + + def close(self) -> None: + self._store.close() + + @property + def embeddings(self) -> Optional[Embeddings]: + return self.embedding + + @staticmethod + def connect_to_elasticsearch( + *, + es_url: Optional[str] = None, + cloud_id: Optional[str] = None, + api_key: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + es_params: Optional[Dict[str, Any]] = None, + ) -> Elasticsearch: + return create_elasticsearch_client( + url=es_url, + cloud_id=cloud_id, + api_key=api_key, + username=username, + password=password, + params=es_params, + ) + + def similarity_search( + self, + query: str, + k: int = 4, + fetch_k: int = 50, + filter: Optional[List[dict]] = None, + *, + custom_query: Optional[ + Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] + ] = None, + doc_builder: Optional[Callable[[Dict], Document]] = None, + **kwargs: Any, + ) -> List[Document]: + """Return Elasticsearch documents most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to knn num_candidates. + filter: Array of Elasticsearch filter clauses to apply to the query. + + Returns: + List of Documents most similar to the query, + in descending order of similarity. + """ + hits = self._store.search( + query=query, + k=k, + num_candidates=fetch_k, + filter=filter, + custom_query=custom_query, + ) + docs = _hits_to_docs_scores( + hits=hits, + content_field=self.query_field, + doc_builder=doc_builder, + ) + return [doc for doc, _score in docs] + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + fields: Optional[List[str]] = None, + *, + custom_query: Optional[ + Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] + ] = None, + doc_builder: Optional[Callable[[Dict], Document]] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + fields: Other fields to get from elasticsearch source. These fields + will be added to the document metadata. + + Returns: + List[Document]: A list of Documents selected by maximal marginal relevance. + """ + if self._embedding_service is None: + raise ValueError( + "maximal marginal relevance search requires an embedding service." + ) + + hits = self._store.max_marginal_relevance_search( + embedding_service=self._embedding_service, + query=query, + vector_field=self.vector_query_field, + k=k, + num_candidates=fetch_k, + lambda_mult=lambda_mult, + fields=fields, + custom_query=custom_query, + ) + + docs_scores = _hits_to_docs_scores( + hits=hits, + content_field=self.query_field, + fields=fields, + doc_builder=doc_builder, + ) + + return [doc for doc, _score in docs_scores] + + @staticmethod + def _identity_fn(score: float) -> float: + return score + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + """ + The 'correct' relevance function + may differ depending on a few things, including: + - the distance / similarity metric used by the VectorStore + - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) + - embedding dimensionality + - etc. + + Vectorstores should define their own selection based method of relevance. + """ + # All scores from Elasticsearch are already normalized similarities: + # https://www.elastic.co/guide/en/elasticsearch/reference/current/dense-vector.html#dense-vector-params + return self._identity_fn + + def similarity_search_with_score( + self, + query: str, + k: int = 4, + filter: Optional[List[dict]] = None, + *, + custom_query: Optional[ + Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] + ] = None, + doc_builder: Optional[Callable[[Dict], Document]] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return Elasticsearch documents most similar to query, along with scores. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Array of Elasticsearch filter clauses to apply to the query. + + Returns: + List of Documents most similar to the query and score for each + """ + if ( + isinstance(self._store.retrieval_strategy, DenseVectorStrategy) + and self._store.retrieval_strategy.hybrid + ): + raise ValueError("scores are currently not supported in hybrid mode") + + hits = self._store.search( + query=query, k=k, filter=filter, custom_query=custom_query + ) + return _hits_to_docs_scores( + hits=hits, + content_field=self.query_field, + doc_builder=doc_builder, + ) + + def similarity_search_by_vector_with_relevance_scores( + self, + embedding: List[float], + k: int = 4, + filter: Optional[List[Dict]] = None, + *, + custom_query: Optional[ + Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] + ] = None, + doc_builder: Optional[Callable[[Dict], Document]] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return Elasticsearch documents most similar to query, along with scores. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Array of Elasticsearch filter clauses to apply to the query. + + Returns: + List of Documents most similar to the embedding and score for each + """ + if ( + isinstance(self._store.retrieval_strategy, DenseVectorStrategy) + and self._store.retrieval_strategy.hybrid + ): + raise ValueError("scores are currently not supported in hybrid mode") + + hits = self._store.search( + query=None, + query_vector=embedding, + k=k, + filter=filter, + custom_query=custom_query, + ) + return _hits_to_docs_scores( + hits=hits, + content_field=self.query_field, + doc_builder=doc_builder, + ) + + def delete( + self, + ids: Optional[List[str]] = None, + refresh_indices: Optional[bool] = True, + **kwargs: Any, + ) -> Optional[bool]: + """Delete documents from the Elasticsearch index. + + Args: + ids: List of ids of documents to delete. + refresh_indices: Whether to refresh the index + after deleting documents. Defaults to True. + """ + if ids is None: + raise ValueError("please specify some IDs") + + return self._store.delete(ids=ids, refresh_indices=refresh_indices or False) + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[Dict[Any, Any]]] = None, + ids: Optional[List[str]] = None, + refresh_indices: bool = True, + create_index_if_not_exists: bool = True, + bulk_kwargs: Optional[Dict] = None, + **kwargs: Any, + ) -> List[str]: + """Run more texts through the embeddings and add to the store. + + Args: + texts: Iterable of strings to add to the store. + metadatas: Optional list of metadatas associated with the texts. + ids: Optional list of ids to associate with the texts. + refresh_indices: Whether to refresh the Elasticsearch indices + after adding the texts. + create_index_if_not_exists: Whether to create the Elasticsearch + index if it doesn't already exist. + *bulk_kwargs: Additional arguments to pass to Elasticsearch bulk. + - chunk_size: Optional. Number of texts to add to the + index at a time. Defaults to 500. + + Returns: + List of ids from adding the texts into the store. + """ + return self._store.add_texts( + texts=list(texts), + metadatas=metadatas, + ids=ids, + refresh_indices=refresh_indices, + create_index_if_not_exists=create_index_if_not_exists, + bulk_kwargs=bulk_kwargs, + ) + + def add_embeddings( + self, + text_embeddings: Iterable[Tuple[str, List[float]]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + refresh_indices: bool = True, + create_index_if_not_exists: bool = True, + bulk_kwargs: Optional[Dict] = None, + **kwargs: Any, + ) -> List[str]: + """Add the given texts and embeddings to the store. + + Args: + text_embeddings: Iterable pairs of string and embedding to + add to the store. + metadatas: Optional list of metadatas associated with the texts. + ids: Optional list of unique IDs. + refresh_indices: Whether to refresh the Elasticsearch indices + after adding the texts. + create_index_if_not_exists: Whether to create the Elasticsearch + index if it doesn't already exist. + *bulk_kwargs: Additional arguments to pass to Elasticsearch bulk. + - chunk_size: Optional. Number of texts to add to the + index at a time. Defaults to 500. + + Returns: + List of ids from adding the texts into the store. + """ + texts, embeddings = zip(*text_embeddings) + return self._store.add_texts( + texts=list(texts), + metadatas=metadatas, + vectors=list(embeddings), + ids=ids, + refresh_indices=refresh_indices, + create_index_if_not_exists=create_index_if_not_exists, + bulk_kwargs=bulk_kwargs, + ) + + @classmethod + def from_texts( + cls, + texts: List[str], + embedding: Optional[Embeddings] = None, + metadatas: Optional[List[Dict[str, Any]]] = None, + bulk_kwargs: Optional[Dict] = None, + **kwargs: Any, + ) -> "ElasticsearchStore": + """Construct ElasticsearchStore wrapper from raw documents. + + Example: + .. code-block:: python + + from langchain_elasticsearch.vectorstores import ElasticsearchStore + from langchain_openai import OpenAIEmbeddings + + db = ElasticsearchStore.from_texts( + texts, + // embeddings optional if using + // a strategy that doesn't require inference + embeddings, + index_name="langchain-demo", + es_url="http://localhost:9200" + ) + + Args: + texts: List of texts to add to the Elasticsearch index. + embedding: Embedding function to use to embed the texts. + metadatas: Optional list of metadatas associated with the texts. + index_name: Name of the Elasticsearch index to create. + es_url: URL of the Elasticsearch instance to connect to. + cloud_id: Cloud ID of the Elasticsearch instance to connect to. + es_user: Username to use when connecting to Elasticsearch. + es_password: Password to use when connecting to Elasticsearch. + es_api_key: API key to use when connecting to Elasticsearch. + es_connection: Optional pre-existing Elasticsearch connection. + vector_query_field: Optional. Name of the field to + store the embedding vectors in. + query_field: Optional. Name of the field to store the texts in. + distance_strategy: Optional. Name of the distance + strategy to use. Defaults to "COSINE". + can be one of "COSINE", + "EUCLIDEAN_DISTANCE", "DOT_PRODUCT", + "MAX_INNER_PRODUCT". + bulk_kwargs: Optional. Additional arguments to pass to + Elasticsearch bulk. + """ + + index_name = kwargs.get("index_name") + if index_name is None: + raise ValueError("Please provide an index_name.") + + elasticsearchStore = cls(embedding=embedding, **kwargs) + + # Encode the provided texts and add them to the newly created index. + elasticsearchStore.add_texts( + texts=texts, metadatas=metadatas, bulk_kwargs=bulk_kwargs + ) + + return elasticsearchStore + + @classmethod + def from_documents( + cls, + documents: List[Document], + embedding: Optional[Embeddings] = None, + bulk_kwargs: Optional[Dict] = None, + **kwargs: Any, + ) -> "ElasticsearchStore": + """Construct ElasticsearchStore wrapper from documents. + + Example: + .. code-block:: python + + from langchain_elasticsearch.vectorstores import ElasticsearchStore + from langchain_openai import OpenAIEmbeddings + + db = ElasticsearchStore.from_documents( + texts, + embeddings, + index_name="langchain-demo", + es_url="http://localhost:9200" + ) + + Args: + texts: List of texts to add to the Elasticsearch index. + embedding: Embedding function to use to embed the texts. + Do not provide if using a strategy + that doesn't require inference. + metadatas: Optional list of metadatas associated with the texts. + index_name: Name of the Elasticsearch index to create. + es_url: URL of the Elasticsearch instance to connect to. + cloud_id: Cloud ID of the Elasticsearch instance to connect to. + es_user: Username to use when connecting to Elasticsearch. + es_password: Password to use when connecting to Elasticsearch. + es_api_key: API key to use when connecting to Elasticsearch. + es_connection: Optional pre-existing Elasticsearch connection. + vector_query_field: Optional. Name of the field + to store the embedding vectors in. + query_field: Optional. Name of the field to store the texts in. + bulk_kwargs: Optional. Additional arguments to pass to + Elasticsearch bulk. + """ + + index_name = kwargs.get("index_name") + if index_name is None: + raise ValueError("Please provide an index_name.") + + elasticsearchStore = cls(embedding=embedding, **kwargs) + + # Encode the provided texts and add them to the newly created index. + elasticsearchStore.add_documents(documents, bulk_kwargs=bulk_kwargs) + + return elasticsearchStore + + @staticmethod + def ExactRetrievalStrategy() -> "ExactRetrievalStrategy": + """Used to perform brute force / exact + nearest neighbor search via script_score.""" + return ExactRetrievalStrategy() + + @staticmethod + def ApproxRetrievalStrategy( + query_model_id: Optional[str] = None, + hybrid: Optional[bool] = False, + rrf: Optional[Union[dict, bool]] = True, + ) -> "ApproxRetrievalStrategy": + """Used to perform approximate nearest neighbor search + using the HNSW algorithm. + + At build index time, this strategy will create a + dense vector field in the index and store the + embedding vectors in the index. + + At query time, the text will either be embedded using the + provided embedding function or the query_model_id + will be used to embed the text using the model + deployed to Elasticsearch. + + if query_model_id is used, do not provide an embedding function. + + Args: + query_model_id: Optional. ID of the model to use to + embed the query text within the stack. Requires + embedding model to be deployed to Elasticsearch. + hybrid: Optional. If True, will perform a hybrid search + using both the knn query and a text query. + Defaults to False. + rrf: Optional. rrf is Reciprocal Rank Fusion. + When `hybrid` is True, + and `rrf` is True, then rrf: {}. + and `rrf` is False, then rrf is omitted. + and isinstance(rrf, dict) is True, then pass in the dict values. + rrf could be passed for adjusting 'rank_constant' and 'window_size'. + """ + return ApproxRetrievalStrategy( + query_model_id=query_model_id, hybrid=hybrid, rrf=rrf + ) + + @staticmethod + def SparseVectorRetrievalStrategy( + model_id: Optional[str] = None, + ) -> "SparseRetrievalStrategy": + """Used to perform sparse vector search via text_expansion. + Used for when you want to use ELSER model to perform document search. + + At build index time, this strategy will create a pipeline that + will embed the text using the ELSER model and store the + resulting tokens in the index. + + At query time, the text will be embedded using the ELSER + model and the resulting tokens will be used to + perform a text_expansion query. + + Args: + model_id: Optional. Default is ".elser_model_1". + ID of the model to use to embed the query text + within the stack. Requires embedding model to be + deployed to Elasticsearch. + """ + return SparseRetrievalStrategy(model_id=model_id) + + @staticmethod + def BM25RetrievalStrategy( + k1: Union[float, None] = None, b: Union[float, None] = None + ) -> "BM25RetrievalStrategy": + """Used to apply BM25 without vector search. + + Args: + k1: Optional. This corresponds to the BM25 parameter, k1. Default is None, + which uses the default setting of Elasticsearch. + b: Optional. This corresponds to the BM25 parameter, b. Default is None, + which uses the default setting of Elasticsearch. + """ + return BM25RetrievalStrategy(k1=k1, b=b) diff --git a/libs/elasticsearch/langchain_elasticsearch/_utilities.py b/libs/elasticsearch/langchain_elasticsearch/_utilities.py index c6dc21d..7fb37a6 100644 --- a/libs/elasticsearch/langchain_elasticsearch/_utilities.py +++ b/libs/elasticsearch/langchain_elasticsearch/_utilities.py @@ -1,8 +1,12 @@ import logging +from abc import ABC, abstractmethod from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from elasticsearch import Elasticsearch, exceptions +from elasticsearch import AsyncElasticsearch, Elasticsearch, exceptions from langchain_core import __version__ as langchain_version +from langchain_core._api.deprecation import deprecated +from langchain_core.documents import Document logger = logging.getLogger(__name__) @@ -28,6 +32,14 @@ def with_user_agent_header(client: Elasticsearch, header_prefix: str) -> Elastic return client.options(headers=headers) +def async_with_user_agent_header( + client: AsyncElasticsearch, header_prefix: str +) -> AsyncElasticsearch: + headers = dict(client._headers) + headers.update({"user-agent": f"{user_agent(header_prefix)}"}) + return client.options(headers=headers) + + def model_must_be_deployed(client: Elasticsearch, model_id: str) -> None: try: dummy = {"x": "y"} @@ -44,3 +56,477 @@ def model_must_be_deployed(client: Elasticsearch, model_id: str) -> None: # This error is expected because we do not know the expected document # shape and just use a dummy doc above. pass + + +def _hits_to_docs_scores( + hits: List[Dict[str, Any]], + content_field: str, + fields: Optional[List[str]] = None, + doc_builder: Optional[Callable[[Dict], Document]] = None, +) -> List[Tuple[Document, float]]: + if fields is None: + fields = [] + + documents = [] + + def default_doc_builder(hit: Dict) -> Document: + return Document( + page_content=hit["_source"].get(content_field, ""), + metadata=hit["_source"].get("metadata", {}), + ) + + doc_builder = doc_builder or default_doc_builder + + for hit in hits: + for field in fields: + if "metadata" not in hit["_source"]: + hit["_source"]["metadata"] = {} + if field in hit["_source"] and field not in [ + "metadata", + content_field, + ]: + hit["_source"]["metadata"][field] = hit["_source"][field] + + doc = doc_builder(hit) + documents.append((doc, hit["_score"])) + + return documents + + +@deprecated("0.2.0", alternative="RetrievalStrategy", pending=True) +class BaseRetrievalStrategy(ABC): + """Base class for `Elasticsearch` retrieval strategies.""" + + @abstractmethod + def query( + self, + query_vector: Union[List[float], None], + query: Union[str, None], + *, + k: int, + fetch_k: int, + vector_query_field: str, + text_field: str, + filter: List[dict], + similarity: Union[DistanceStrategy, None], + ) -> Dict: + """ + Executes when a search is performed on the store. + + Args: + query_vector: The query vector, + or None if not using vector-based query. + query: The text query, or None if not using text-based query. + k: The total number of results to retrieve. + fetch_k: The number of results to fetch initially. + vector_query_field: The field containing the vector + representations in the index. + text_field: The field containing the text data in the index. + filter: List of filter clauses to apply to the query. + similarity: The similarity strategy to use, or None if not using one. + + Returns: + Dict: The Elasticsearch query body. + """ + + @abstractmethod + def index( + self, + dims_length: Union[int, None], + vector_query_field: str, + text_field: str, + similarity: Union[DistanceStrategy, None], + ) -> Dict: + """ + Executes when the index is created. + + Args: + dims_length: Numeric length of the embedding vectors, + or None if not using vector-based query. + vector_query_field: The field containing the vector + representations in the index. + text_field: The field containing the text data in the index. + similarity: The similarity strategy to use, + or None if not using one. + + Returns: + Dict: The Elasticsearch settings and mappings for the strategy. + """ + + def before_index_setup( + self, client: "Elasticsearch", text_field: str, vector_query_field: str + ) -> None: + """ + Executes before the index is created. Used for setting up + any required Elasticsearch resources like a pipeline. + + Args: + client: The Elasticsearch client. + text_field: The field containing the text data in the index. + vector_query_field: The field containing the vector + representations in the index. + """ + + def require_inference(self) -> bool: + """ + Returns whether or not the strategy requires inference + to be performed on the text before it is added to the index. + + Returns: + bool: Whether or not the strategy requires inference + to be performed on the text before it is added to the index. + """ + return True + + +@deprecated("0.2.0", alternative="DenseVectorStrategy", pending=True) +class ApproxRetrievalStrategy(BaseRetrievalStrategy): + """Approximate retrieval strategy using the `HNSW` algorithm.""" + + def __init__( + self, + query_model_id: Optional[str] = None, + hybrid: Optional[bool] = False, + rrf: Optional[Union[dict, bool]] = True, + ): + self.query_model_id = query_model_id + self.hybrid = hybrid + + # RRF has two optional parameters + # 'rank_constant', 'window_size' + # https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html + self.rrf = rrf + + def query( + self, + query_vector: Union[List[float], None], + query: Union[str, None], + k: int, + fetch_k: int, + vector_query_field: str, + text_field: str, + filter: List[dict], + similarity: Union[DistanceStrategy, None], + ) -> Dict: + knn = { + "filter": filter, + "field": vector_query_field, + "k": k, + "num_candidates": fetch_k, + } + + # Embedding provided via the embedding function + if query_vector is not None and not self.query_model_id: + knn["query_vector"] = list(query_vector) + + # Case 2: Used when model has been deployed to + # Elasticsearch and can infer the query vector from the query text + elif query and self.query_model_id: + knn["query_vector_builder"] = { + "text_embedding": { + "model_id": self.query_model_id, # use 'model_id' argument + "model_text": query, # use 'query' argument + } + } + + else: + raise ValueError( + "You must provide an embedding function or a" + " query_model_id to perform a similarity search." + ) + + # If hybrid, add a query to the knn query + # RRF is used to even the score from the knn query and text query + # RRF has two optional parameters: {'rank_constant':int, 'window_size':int} + # https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html + if self.hybrid: + query_body = { + "knn": knn, + "query": { + "bool": { + "must": [ + { + "match": { + text_field: { + "query": query, + } + } + } + ], + "filter": filter, + } + }, + } + + if isinstance(self.rrf, dict): + query_body["rank"] = {"rrf": self.rrf} + elif isinstance(self.rrf, bool) and self.rrf is True: + query_body["rank"] = {"rrf": {}} + + return query_body + else: + return {"knn": knn} + + def before_index_setup( + self, client: "Elasticsearch", text_field: str, vector_query_field: str + ) -> None: + if self.query_model_id: + model_must_be_deployed(client, self.query_model_id) + + def index( + self, + dims_length: Union[int, None], + vector_query_field: str, + text_field: str, + similarity: Union[DistanceStrategy, None], + ) -> Dict: + """Create the mapping for the Elasticsearch index.""" + + if similarity is DistanceStrategy.COSINE: + similarityAlgo = "cosine" + elif similarity is DistanceStrategy.EUCLIDEAN_DISTANCE: + similarityAlgo = "l2_norm" + elif similarity is DistanceStrategy.DOT_PRODUCT: + similarityAlgo = "dot_product" + elif similarity is DistanceStrategy.MAX_INNER_PRODUCT: + similarityAlgo = "max_inner_product" + else: + raise ValueError(f"Similarity {similarity} not supported.") + + return { + "mappings": { + "properties": { + vector_query_field: { + "type": "dense_vector", + "dims": dims_length, + "index": True, + "similarity": similarityAlgo, + }, + } + } + } + + +@deprecated("0.2.0", alternative="DenseVectorScriptScoreStrategy", pending=True) +class ExactRetrievalStrategy(BaseRetrievalStrategy): + """Exact retrieval strategy using the `script_score` query.""" + + def query( + self, + query_vector: Union[List[float], None], + query: Union[str, None], + k: int, + fetch_k: int, + vector_query_field: str, + text_field: str, + filter: Union[List[dict], None], + similarity: Union[DistanceStrategy, None], + ) -> Dict: + if similarity is DistanceStrategy.COSINE: + similarityAlgo = ( + f"cosineSimilarity(params.query_vector, '{vector_query_field}') + 1.0" + ) + elif similarity is DistanceStrategy.EUCLIDEAN_DISTANCE: + similarityAlgo = ( + f"1 / (1 + l2norm(params.query_vector, '{vector_query_field}'))" + ) + elif similarity is DistanceStrategy.DOT_PRODUCT: + similarityAlgo = f""" + double value = dotProduct(params.query_vector, '{vector_query_field}'); + return sigmoid(1, Math.E, -value); + """ + else: + raise ValueError(f"Similarity {similarity} not supported.") + + queryBool: Dict = {"match_all": {}} + if filter: + queryBool = {"bool": {"filter": filter}} + + return { + "query": { + "script_score": { + "query": queryBool, + "script": { + "source": similarityAlgo, + "params": {"query_vector": query_vector}, + }, + }, + } + } + + def index( + self, + dims_length: Union[int, None], + vector_query_field: str, + text_field: str, + similarity: Union[DistanceStrategy, None], + ) -> Dict: + """Create the mapping for the Elasticsearch index.""" + + return { + "mappings": { + "properties": { + vector_query_field: { + "type": "dense_vector", + "dims": dims_length, + "index": False, + }, + } + } + } + + +@deprecated("0.2.0", alternative="SparseVectorStrategy", pending=True) +class SparseRetrievalStrategy(BaseRetrievalStrategy): + """Sparse retrieval strategy using the `text_expansion` processor.""" + + def __init__(self, model_id: Optional[str] = None): + self.model_id = model_id or ".elser_model_1" + + def query( + self, + query_vector: Union[List[float], None], + query: Union[str, None], + k: int, + fetch_k: int, + vector_query_field: str, + text_field: str, + filter: List[dict], + similarity: Union[DistanceStrategy, None], + ) -> Dict: + return { + "query": { + "bool": { + "must": [ + { + "text_expansion": { + f"{vector_query_field}.tokens": { + "model_id": self.model_id, + "model_text": query, + } + } + } + ], + "filter": filter, + } + } + } + + def _get_pipeline_name(self) -> str: + return f"{self.model_id}_sparse_embedding" + + def before_index_setup( + self, client: "Elasticsearch", text_field: str, vector_query_field: str + ) -> None: + if self.model_id: + model_must_be_deployed(client, self.model_id) + + # Create a pipeline for the model + client.ingest.put_pipeline( + id=self._get_pipeline_name(), + description="Embedding pipeline for langchain vectorstore", + processors=[ + { + "inference": { + "model_id": self.model_id, + "target_field": vector_query_field, + "field_map": {text_field: "text_field"}, + "inference_config": { + "text_expansion": {"results_field": "tokens"} + }, + } + } + ], + ) + + def index( + self, + dims_length: Union[int, None], + vector_query_field: str, + text_field: str, + similarity: Union[DistanceStrategy, None], + ) -> Dict: + return { + "mappings": { + "properties": { + vector_query_field: { + "properties": {"tokens": {"type": "rank_features"}} + } + } + }, + "settings": {"default_pipeline": self._get_pipeline_name()}, + } + + def require_inference(self) -> bool: + return False + + +@deprecated("0.2.0", alternative="BM25Strategy", pending=True) +class BM25RetrievalStrategy(BaseRetrievalStrategy): + """Retrieval strategy using the native BM25 algorithm of Elasticsearch.""" + + def __init__(self, k1: Union[float, None] = None, b: Union[float, None] = None): + self.k1 = k1 + self.b = b + + def query( + self, + query_vector: Union[List[float], None], + query: Union[str, None], + k: int, + fetch_k: int, + vector_query_field: str, + text_field: str, + filter: List[dict], + similarity: Union[DistanceStrategy, None], + ) -> Dict: + return { + "query": { + "bool": { + "must": [ + { + "match": { + text_field: { + "query": query, + } + }, + }, + ], + "filter": filter, + }, + }, + } + + def index( + self, + dims_length: Union[int, None], + vector_query_field: str, + text_field: str, + similarity: Union[DistanceStrategy, None], + ) -> Dict: + mappings: Dict = { + "properties": { + text_field: { + "type": "text", + "similarity": "custom_bm25", + }, + }, + } + settings: Dict = { + "similarity": { + "custom_bm25": { + "type": "BM25", + }, + }, + } + + if self.k1 is not None: + settings["similarity"]["custom_bm25"]["k1"] = self.k1 + + if self.b is not None: + settings["similarity"]["custom_bm25"]["b"] = self.b + + return {"mappings": mappings, "settings": settings} + + def require_inference(self) -> bool: + return False diff --git a/libs/elasticsearch/langchain_elasticsearch/cache.py b/libs/elasticsearch/langchain_elasticsearch/cache.py index 478b5a1..fefacf1 100644 --- a/libs/elasticsearch/langchain_elasticsearch/cache.py +++ b/libs/elasticsearch/langchain_elasticsearch/cache.py @@ -1,407 +1,51 @@ -import base64 -import hashlib -import logging -from datetime import datetime -from functools import cached_property -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Iterable, - Iterator, - List, - Optional, - Sequence, - Tuple, -) - -from elasticsearch import Elasticsearch, exceptions, helpers -from elasticsearch.helpers import BulkIndexError -from langchain_core.caches import RETURN_VAL_TYPE, BaseCache -from langchain_core.load import dumps, loads -from langchain_core.stores import ByteStore - -from langchain_elasticsearch.client import create_elasticsearch_client - -if TYPE_CHECKING: - from elasticsearch import Elasticsearch - -logger = logging.getLogger(__name__) - - -def _manage_cache_index( - es_client: Elasticsearch, index_name: str, mapping: Dict[str, Any] -) -> bool: - """Write or update an index or alias according to the default mapping""" - if es_client.indices.exists_alias(name=index_name): - es_client.indices.put_mapping(index=index_name, body=mapping["mappings"]) - return True - - elif not es_client.indices.exists(index=index_name): - logger.debug(f"Creating new Elasticsearch index: {index_name}") - es_client.indices.create(index=index_name, body=mapping) - return False - - return False +from typing import Any, Iterator, List, Optional, Sequence, Tuple +from elasticsearch import Elasticsearch, exceptions, helpers # noqa: F401 +from elasticsearch.helpers import BulkIndexError # noqa: F401 +from langchain_core.caches import RETURN_VAL_TYPE, BaseCache # noqa: F401 +from langchain_core.load import dumps, loads # noqa: F401 +from langchain_core.stores import ByteStore # noqa: F401 -class ElasticsearchCache(BaseCache): - """An Elasticsearch cache integration for LLMs.""" - - def __init__( - self, - index_name: str, - store_input: bool = True, - store_input_params: bool = True, - metadata: Optional[Dict[str, Any]] = None, - *, - es_url: Optional[str] = None, - es_cloud_id: Optional[str] = None, - es_user: Optional[str] = None, - es_api_key: Optional[str] = None, - es_password: Optional[str] = None, - es_params: Optional[Dict[str, Any]] = None, - ): - """ - Initialize the Elasticsearch cache store by specifying the index/alias - to use and determining which additional information (like input, input - parameters, and any other metadata) should be stored in the cache. - - Args: - index_name (str): The name of the index or the alias to use for the cache. - If they do not exist an index is created, - according to the default mapping defined by the `mapping` property. - store_input (bool): Whether to store the LLM input in the cache, i.e., - the input prompt. Default to True. - store_input_params (bool): Whether to store the input parameters in the - cache, i.e., the LLM parameters used to generate the LLM response. - Default to True. - metadata (Optional[dict]): Additional metadata to store in the cache, - for filtering purposes. This must be JSON serializable in an - Elasticsearch document. Default to None. - es_url: URL of the Elasticsearch instance to connect to. - es_cloud_id: Cloud ID of the Elasticsearch instance to connect to. - es_user: Username to use when connecting to Elasticsearch. - es_password: Password to use when connecting to Elasticsearch. - es_api_key: API key to use when connecting to Elasticsearch. - es_params: Other parameters for the Elasticsearch client. - """ - - self._index_name = index_name - self._store_input = store_input - self._store_input_params = store_input_params - self._metadata = metadata - self._es_client = create_elasticsearch_client( - url=es_url, - cloud_id=es_cloud_id, - api_key=es_api_key, - username=es_user, - password=es_password, - params=es_params, - ) - self._is_alias = _manage_cache_index( - self._es_client, - self._index_name, - self.mapping, - ) - - @cached_property - def mapping(self) -> Dict[str, Any]: - """Get the default mapping for the index.""" - return { - "mappings": { - "properties": { - "llm_output": {"type": "text", "index": False}, - "llm_params": {"type": "text", "index": False}, - "llm_input": {"type": "text", "index": False}, - "metadata": {"type": "object"}, - "timestamp": {"type": "date"}, - } - } - } +from langchain_elasticsearch._async.cache import ( + AsyncElasticsearchCache as _AsyncElasticsearchCache, +) +from langchain_elasticsearch._async.cache import ( + AsyncElasticsearchEmbeddingsCache as _AsyncElasticsearchEmbeddingsCache, +) +from langchain_elasticsearch._sync.cache import ( # noqa: F401 + ElasticsearchCache, + ElasticsearchEmbeddingsCache, +) +from langchain_elasticsearch.client import ( # noqa: F401 + create_async_elasticsearch_client, + create_elasticsearch_client, +) - @staticmethod - def _key(prompt: str, llm_string: str) -> str: - """Generate a key for the cache store.""" - return hashlib.md5((prompt + llm_string).encode()).hexdigest() +# langchain defines some sync methods as abstract in its base class +# so we have to add dummy methods for them, even though we only use the async versions +class AsyncElasticsearchCache(_AsyncElasticsearchCache): def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: - """Look up based on prompt and llm_string.""" - cache_key = self._key(prompt, llm_string) - if self._is_alias: - # get the latest record according to its writing date, in order to - # address cases where multiple indices have a doc with the same id - result = self._es_client.search( - index=self._index_name, - body={ - "query": {"term": {"_id": cache_key}}, - "sort": {"timestamp": {"order": "asc"}}, - }, - source_includes=["llm_output"], - ) - if result["hits"]["total"]["value"] > 0: - record = result["hits"]["hits"][0] - else: - return None - else: - try: - record = self._es_client.get( - index=self._index_name, id=cache_key, source=["llm_output"] - ) - except exceptions.NotFoundError: - return None - return [loads(item) for item in record["_source"]["llm_output"]] - - def build_document( - self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE - ) -> Dict[str, Any]: - """Build the Elasticsearch document for storing a single LLM interaction""" - body: Dict[str, Any] = { - "llm_output": [dumps(item) for item in return_val], - "timestamp": datetime.now().isoformat(), - } - if self._store_input_params: - body["llm_params"] = llm_string - if self._metadata is not None: - body["metadata"] = self._metadata - if self._store_input: - body["llm_input"] = prompt - return body + raise NotImplementedError("This class is asynchronous, use alookup()") def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: - """Update based on prompt and llm_string.""" - body = self.build_document(prompt, llm_string, return_val) - self._es_client.index( - index=self._index_name, - id=self._key(prompt, llm_string), - body=body, - require_alias=self._is_alias, - refresh=True, - ) + raise NotImplementedError("This class is asynchronous, use aupdate()") def clear(self, **kwargs: Any) -> None: - """Clear cache.""" - self._es_client.delete_by_query( - index=self._index_name, - body={"query": {"match_all": {}}}, - refresh=True, - wait_for_completion=True, - ) + raise NotImplementedError("This class is asynchronous, use aclear()") -class ElasticsearchEmbeddingsCache(ByteStore): - """An Elasticsearch store for caching embeddings.""" - - def __init__( - self, - index_name: str, - store_input: bool = True, - metadata: Optional[Dict[str, Any]] = None, - namespace: Optional[str] = None, - maximum_duplicates_allowed: int = 1, - *, - es_url: Optional[str] = None, - es_cloud_id: Optional[str] = None, - es_user: Optional[str] = None, - es_api_key: Optional[str] = None, - es_password: Optional[str] = None, - es_params: Optional[Dict[str, Any]] = None, - ): - """ - Initialize the Elasticsearch cache store by specifying the index/alias - to use and determining which additional information (like input, input - parameters, and any other metadata) should be stored in the cache. - Provide a namespace to organize the cache. - - - Args: - index_name (str): The name of the index or the alias to use for the cache. - If they do not exist an index is created, - according to the default mapping defined by the `mapping` property. - store_input (bool): Whether to store the input in the cache. - Default to True. - metadata (Optional[dict]): Additional metadata to store in the cache, - for filtering purposes. This must be JSON serializable in an - Elasticsearch document. Default to None. - namespace (Optional[str]): A namespace to use for the cache. - maximum_duplicates_allowed (int): Defines the maximum number of duplicate - keys permitted. Must be used in scenarios where the same key appears - across multiple indices that share the same alias. Default to 1. - es_url: URL of the Elasticsearch instance to connect to. - es_cloud_id: Cloud ID of the Elasticsearch instance to connect to. - es_user: Username to use when connecting to Elasticsearch. - es_password: Password to use when connecting to Elasticsearch. - es_api_key: API key to use when connecting to Elasticsearch. - es_params: Other parameters for the Elasticsearch client. - """ - self._namespace = namespace - self._maximum_duplicates_allowed = maximum_duplicates_allowed - self._index_name = index_name - self._store_input = store_input - self._metadata = metadata - self._es_client = create_elasticsearch_client( - url=es_url, - cloud_id=es_cloud_id, - api_key=es_api_key, - username=es_user, - password=es_password, - params=es_params, - ) - self._is_alias = _manage_cache_index( - self._es_client, - self._index_name, - self.mapping, - ) - - @staticmethod - def encode_vector(data: bytes) -> str: - """Encode the vector data as bytes to as a base64 string.""" - return base64.b64encode(data).decode("utf-8") - - @staticmethod - def decode_vector(data: str) -> bytes: - """Decode the base64 string to vector data as bytes.""" - return base64.b64decode(data) - - @cached_property - def mapping(self) -> Dict[str, Any]: - """Get the default mapping for the index.""" - return { - "mappings": { - "properties": { - "text_input": {"type": "text", "index": False}, - "vector_dump": { - "type": "binary", - "doc_values": False, - }, - "metadata": {"type": "object"}, - "timestamp": {"type": "date"}, - "namespace": {"type": "keyword"}, - } - } - } - - def _key(self, input_text: str) -> str: - """Generate a key for the store.""" - return hashlib.md5(((self._namespace or "") + input_text).encode()).hexdigest() - - @classmethod - def _deduplicate_hits(cls, hits: List[dict]) -> Dict[str, bytes]: - """ - Collapse the results from a search query with multiple indices - returning only the latest version of the documents - """ - map_ids = {} - for hit in sorted( - hits, - key=lambda x: datetime.fromisoformat(x["_source"]["timestamp"]), - reverse=True, - ): - vector_id: str = hit["_id"] - if vector_id not in map_ids: - map_ids[vector_id] = cls.decode_vector(hit["_source"]["vector_dump"]) - - return map_ids - +# langchain defines some sync methods as abstract in its base class +# so we have to add dummy methods for them, even though we only use the async versions +class AsyncElasticsearchEmbeddingsCache(_AsyncElasticsearchEmbeddingsCache): def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: - """Get the values associated with the given keys.""" - if not any(keys): - return [] - - cache_keys = [self._key(k) for k in keys] - if self._is_alias: - try: - results = self._es_client.search( - index=self._index_name, - body={ - "query": {"ids": {"values": cache_keys}}, - "size": len(cache_keys) * self._maximum_duplicates_allowed, - }, - source_includes=["vector_dump", "timestamp"], - ) - - except exceptions.BadRequestError as e: - if "window too large" in ( - e.body.get("error", {}).get("root_cause", [{}])[0].get("reason", "") - ): - logger.warning( - "Exceeded the maximum window size, " - "Reduce the duplicates manually or lower " - "`maximum_duplicate_allowed.`" - ) - raise e - - total_hits = results["hits"]["total"]["value"] - if self._maximum_duplicates_allowed > 1 and total_hits > len(cache_keys): - logger.warning( - f"Deduplicating, found {total_hits} hits for {len(cache_keys)} keys" - ) - map_ids = self._deduplicate_hits(results["hits"]["hits"]) - else: - map_ids = { - r["_id"]: self.decode_vector(r["_source"]["vector_dump"]) - for r in results["hits"]["hits"] - } - - return [map_ids.get(k) for k in cache_keys] - - else: - records = self._es_client.mget( - index=self._index_name, ids=cache_keys, source_includes=["vector_dump"] - ) - return [ - self.decode_vector(r["_source"]["vector_dump"]) if r["found"] else None - for r in records["docs"] - ] - - def build_document(self, text_input: str, vector: bytes) -> Dict[str, Any]: - """Build the Elasticsearch document for storing a single embedding""" - body: Dict[str, Any] = { - "vector_dump": self.encode_vector(vector), - "timestamp": datetime.now().isoformat(), - } - if self._metadata is not None: - body["metadata"] = self._metadata - if self._store_input: - body["text_input"] = text_input - if self._namespace: - body["namespace"] = self._namespace - return body - - def _bulk(self, actions: Iterable[Dict[str, Any]]) -> None: - try: - helpers.bulk( - client=self._es_client, - actions=actions, - index=self._index_name, - require_alias=self._is_alias, - refresh=True, - ) - except BulkIndexError as e: - first_error = e.errors[0].get("index", {}).get("error", {}) - logger.error(f"First bulk error reason: {first_error.get('reason')}") - raise e + raise NotImplementedError("This class is asynchronous, use amget()") def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: - """Set the values for the given keys.""" - actions = ( - { - "_op_type": "index", - "_id": self._key(key), - "_source": self.build_document(key, vector), - } - for key, vector in key_value_pairs - ) - self._bulk(actions) + raise NotImplementedError("This class is asynchronous, use amset()") def mdelete(self, keys: Sequence[str]) -> None: - """Delete the given keys and their associated values.""" - actions = ({"_op_type": "delete", "_id": self._key(key)} for key in keys) - self._bulk(actions) + raise NotImplementedError("This class is asynchronous, use amdelete()") def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: - """Get an iterator over keys that match the given prefix.""" - # TODO This method is not currently used by CacheBackedEmbeddings, - # we can leave it blank. It could be implemented with ES "index_prefixes", - # but they are limited and expensive. - raise NotImplementedError() + raise NotImplementedError("This class is asynchronous, use ayield_keys()") diff --git a/libs/elasticsearch/langchain_elasticsearch/chat_history.py b/libs/elasticsearch/langchain_elasticsearch/chat_history.py index d2f0774..a7b335b 100644 --- a/libs/elasticsearch/langchain_elasticsearch/chat_history.py +++ b/libs/elasticsearch/langchain_elasticsearch/chat_history.py @@ -1,150 +1,24 @@ -import json -import logging -from time import time -from typing import TYPE_CHECKING, List, Optional +from typing import List -from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict +from langchain_core.messages import BaseMessage -from langchain_elasticsearch._utilities import with_user_agent_header -from langchain_elasticsearch.client import create_elasticsearch_client +from langchain_elasticsearch._async.chat_history import ( + AsyncElasticsearchChatMessageHistory as _AsyncElasticsearchChatMessageHistory, +) +from langchain_elasticsearch._sync.chat_history import ( + ElasticsearchChatMessageHistory as _ElasticsearchChatMessageHistory, +) -if TYPE_CHECKING: - from elasticsearch import Elasticsearch - -logger = logging.getLogger(__name__) - - -class ElasticsearchChatMessageHistory(BaseChatMessageHistory): - """Chat message history that stores history in Elasticsearch. - - Args: - es_url: URL of the Elasticsearch instance to connect to. - es_cloud_id: Cloud ID of the Elasticsearch instance to connect to. - es_user: Username to use when connecting to Elasticsearch. - es_password: Password to use when connecting to Elasticsearch. - es_api_key: API key to use when connecting to Elasticsearch. - es_connection: Optional pre-existing Elasticsearch connection. - esnsure_ascii: Used to escape ASCII symbols in json.dumps. Defaults to True. - index: Name of the index to use. - session_id: Arbitrary key that is used to store the messages - of a single chat session. - """ - - def __init__( - self, - index: str, - session_id: str, - *, - es_connection: Optional["Elasticsearch"] = None, - es_url: Optional[str] = None, - es_cloud_id: Optional[str] = None, - es_user: Optional[str] = None, - es_api_key: Optional[str] = None, - es_password: Optional[str] = None, - esnsure_ascii: Optional[bool] = True, - ): - self.index: str = index - self.session_id: str = session_id - self.ensure_ascii = esnsure_ascii - - # Initialize Elasticsearch client from passed client arg or connection info - if es_connection is not None: - self.client = es_connection - elif es_url is not None or es_cloud_id is not None: - try: - self.client = create_elasticsearch_client( - url=es_url, - username=es_user, - password=es_password, - cloud_id=es_cloud_id, - api_key=es_api_key, - ) - except Exception as err: - logger.error(f"Error connecting to Elasticsearch: {err}") - raise err - else: - raise ValueError( - """Either provide a pre-existing Elasticsearch connection, \ - or valid credentials for creating a new connection.""" - ) - - self.client = with_user_agent_header(self.client, "langchain-py-ms") - - if self.client.indices.exists(index=index): - logger.debug( - f"Chat history index {index} already exists, skipping creation." - ) - else: - logger.debug(f"Creating index {index} for storing chat history.") - - self.client.indices.create( - index=index, - mappings={ - "properties": { - "session_id": {"type": "keyword"}, - "created_at": {"type": "date"}, - "history": {"type": "text"}, - } - }, - ) +# add the messages property which is only in the sync version +class ElasticsearchChatMessageHistory(_ElasticsearchChatMessageHistory): @property def messages(self) -> List[BaseMessage]: # type: ignore[override] - """Retrieve the messages from Elasticsearch""" - try: - from elasticsearch import ApiError - - result = self.client.search( - index=self.index, - query={"term": {"session_id": self.session_id}}, - sort="created_at:asc", - ) - except ApiError as err: - logger.error(f"Could not retrieve messages from Elasticsearch: {err}") - raise err - - if result and len(result["hits"]["hits"]) > 0: - items = [ - json.loads(document["_source"]["history"]) - for document in result["hits"]["hits"] - ] - else: - items = [] + return self.get_messages() - return messages_from_dict(items) - - def add_message(self, message: BaseMessage) -> None: - """Add a message to the chat session in Elasticsearch""" - try: - from elasticsearch import ApiError - - self.client.index( - index=self.index, - document={ - "session_id": self.session_id, - "created_at": round(time() * 1000), - "history": json.dumps( - message_to_dict(message), - ensure_ascii=bool(self.ensure_ascii), - ), - }, - refresh=True, - ) - except ApiError as err: - logger.error(f"Could not add message to Elasticsearch: {err}") - raise err +# langchain defines some sync methods as abstract in its base class +# so we have to add dummy methods for them, even though we only use the async versions +class AsyncElasticsearchChatMessageHistory(_AsyncElasticsearchChatMessageHistory): def clear(self) -> None: - """Clear session memory in Elasticsearch""" - try: - from elasticsearch import ApiError - - self.client.delete_by_query( - index=self.index, - query={"term": {"session_id": self.session_id}}, - refresh=True, - ) - except ApiError as err: - logger.error(f"Could not clear session memory in Elasticsearch: {err}") - raise err + raise NotImplementedError("This class is asynchronous, use aclear()") diff --git a/libs/elasticsearch/langchain_elasticsearch/client.py b/libs/elasticsearch/langchain_elasticsearch/client.py index 3e4b546..7a5c8be 100644 --- a/libs/elasticsearch/langchain_elasticsearch/client.py +++ b/libs/elasticsearch/langchain_elasticsearch/client.py @@ -1,6 +1,6 @@ from typing import Any, Dict, Optional -from elasticsearch import Elasticsearch +from elasticsearch import AsyncElasticsearch, Elasticsearch def create_elasticsearch_client( @@ -38,3 +38,37 @@ def create_elasticsearch_client( es_client.info() # test connection return es_client + + +def create_async_elasticsearch_client( + url: Optional[str] = None, + cloud_id: Optional[str] = None, + api_key: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + params: Optional[Dict[str, Any]] = None, +) -> AsyncElasticsearch: + if url and cloud_id: + raise ValueError( + "Both es_url and cloud_id are defined. Please provide only one." + ) + + connection_params: Dict[str, Any] = {} + + if url: + connection_params["hosts"] = [url] + elif cloud_id: + connection_params["cloud_id"] = cloud_id + else: + raise ValueError("Please provide either elasticsearch_url or cloud_id.") + + if api_key: + connection_params["api_key"] = api_key + elif username and password: + connection_params["basic_auth"] = (username, password) + + if params is not None: + connection_params.update(params) + + es_client = AsyncElasticsearch(**connection_params) + return es_client diff --git a/libs/elasticsearch/langchain_elasticsearch/embeddings.py b/libs/elasticsearch/langchain_elasticsearch/embeddings.py index a4768bd..278be14 100644 --- a/libs/elasticsearch/langchain_elasticsearch/embeddings.py +++ b/libs/elasticsearch/langchain_elasticsearch/embeddings.py @@ -1,251 +1,25 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, List, Optional - -from elasticsearch import Elasticsearch -from elasticsearch.helpers.vectorstore import EmbeddingService -from langchain_core.embeddings import Embeddings -from langchain_core.utils import get_from_env - -if TYPE_CHECKING: - from elasticsearch.client import MlClient - - -class ElasticsearchEmbeddings(Embeddings): - """Elasticsearch embedding models. - - This class provides an interface to generate embeddings using a model deployed - in an Elasticsearch cluster. It requires an Elasticsearch connection object - and the model_id of the model deployed in the cluster. - - In Elasticsearch you need to have an embedding model loaded and deployed. - - https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html - - https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html - """ # noqa: E501 - - def __init__( - self, - client: MlClient, - model_id: str, - *, - input_field: str = "text_field", - ): - """ - Initialize the ElasticsearchEmbeddings instance. - - Args: - client (MlClient): An Elasticsearch ML client object. - model_id (str): The model_id of the model deployed in the Elasticsearch - cluster. - input_field (str): The name of the key for the input text field in the - document. Defaults to 'text_field'. - """ - self.client = client - self.model_id = model_id - self.input_field = input_field - - @classmethod - def from_credentials( - cls, - model_id: str, - *, - es_cloud_id: Optional[str] = None, - es_api_key: Optional[str] = None, - input_field: str = "text_field", - ) -> ElasticsearchEmbeddings: - """Instantiate embeddings from Elasticsearch credentials. - - Args: - model_id (str): The model_id of the model deployed in the Elasticsearch - cluster. - input_field (str): The name of the key for the input text field in the - document. Defaults to 'text_field'. - es_cloud_id: (str, optional): The Elasticsearch cloud ID to connect to. - es_user: (str, optional): Elasticsearch username. - es_password: (str, optional): Elasticsearch password. - - Example: - .. code-block:: python - - from langchain_elasticserach.embeddings import ElasticsearchEmbeddings - - # Define the model ID and input field name (if different from default) - model_id = "your_model_id" - # Optional, only if different from 'text_field' - input_field = "your_input_field" - - # Credentials can be passed in two ways. Either set the env vars - # ES_CLOUD_ID, ES_USER, ES_PASSWORD and they will be automatically - # pulled in, or pass them in directly as kwargs. - embeddings = ElasticsearchEmbeddings.from_credentials( - model_id, - input_field=input_field, - # es_cloud_id="foo", - # es_user="bar", - # es_password="baz", - ) - - documents = [ - "This is an example document.", - "Another example document to generate embeddings for.", - ] - embeddings_generator.embed_documents(documents) - """ - from elasticsearch.client import MlClient - - es_cloud_id = es_cloud_id or get_from_env("es_cloud_id", "ES_CLOUD_ID") - es_api_key = es_api_key or get_from_env("es_api_key", "ES_API_KEY") - - # Connect to Elasticsearch - es_connection = Elasticsearch(cloud_id=es_cloud_id, api_key=es_api_key) - client = MlClient(es_connection) - return cls(client, model_id, input_field=input_field) - - @classmethod - def from_es_connection( - cls, - model_id: str, - es_connection: Elasticsearch, - input_field: str = "text_field", - ) -> ElasticsearchEmbeddings: - """ - Instantiate embeddings from an existing Elasticsearch connection. - - This method provides a way to create an instance of the ElasticsearchEmbeddings - class using an existing Elasticsearch connection. The connection object is used - to create an MlClient, which is then used to initialize the - ElasticsearchEmbeddings instance. - - Args: - model_id (str): The model_id of the model deployed in the Elasticsearch cluster. - es_connection (elasticsearch.Elasticsearch): An existing Elasticsearch - connection object. input_field (str, optional): The name of the key for the - input text field in the document. Defaults to 'text_field'. - - Returns: - ElasticsearchEmbeddings: An instance of the ElasticsearchEmbeddings class. - - Example: - .. code-block:: python - - from elasticsearch import Elasticsearch - - from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings - - # Define the model ID and input field name (if different from default) - model_id = "your_model_id" - # Optional, only if different from 'text_field' - input_field = "your_input_field" - - # Create Elasticsearch connection - es_connection = Elasticsearch( - hosts=["localhost:9200"], http_auth=("user", "password") - ) - - # Instantiate ElasticsearchEmbeddings using the existing connection - embeddings = ElasticsearchEmbeddings.from_es_connection( - model_id, - es_connection, - input_field=input_field, - ) - - documents = [ - "This is an example document.", - "Another example document to generate embeddings for.", - ] - embeddings_generator.embed_documents(documents) - """ - from elasticsearch.client import MlClient - - # Create an MlClient from the given Elasticsearch connection - client = MlClient(es_connection) - - # Return a new instance of the ElasticsearchEmbeddings class with - # the MlClient, model_id, and input_field - return cls(client, model_id, input_field=input_field) - - def _embedding_func(self, texts: List[str]) -> List[List[float]]: - """ - Generate embeddings for the given texts using the Elasticsearch model. - - Args: - texts (List[str]): A list of text strings to generate embeddings for. - - Returns: - List[List[float]]: A list of embeddings, one for each text in the input - list. - """ - response = self.client.infer_trained_model( - model_id=self.model_id, docs=[{self.input_field: text} for text in texts] - ) - - embeddings = [doc["predicted_value"] for doc in response["inference_results"]] - return embeddings - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - """ - Generate embeddings for a list of documents. - - Args: - texts (List[str]): A list of document text strings to generate embeddings - for. - - Returns: - List[List[float]]: A list of embeddings, one for each document in the input - list. - """ - return self._embedding_func(texts) - - def embed_query(self, text: str) -> List[float]: - """ - Generate an embedding for a single query text. - - Args: - text (str): The query text to generate an embedding for. - - Returns: - List[float]: The embedding for the input query text. - """ - return self._embedding_func([text])[0] - - -class EmbeddingServiceAdapter(EmbeddingService): - """ - Adapter for LangChain Embeddings to support the EmbeddingService interface from - elasticsearch.helpers.vectorstore. - """ - - def __init__(self, langchain_embeddings: Embeddings): - self._langchain_embeddings = langchain_embeddings - - def __eq__(self, other): # type: ignore[no-untyped-def] - if isinstance(other, self.__class__): - return self.__dict__ == other.__dict__ - else: - return False - +from typing import List + +from langchain_elasticsearch._async.embeddings import ( + AsyncElasticsearchEmbeddings as _AsyncElasticsearchEmbeddings, +) +from langchain_elasticsearch._async.embeddings import ( # noqa: F401 + AsyncEmbeddingService, + AsyncEmbeddingServiceAdapter, + Embeddings, +) +from langchain_elasticsearch._sync.embeddings import ( # noqa: F401 + ElasticsearchEmbeddings, + EmbeddingService, + EmbeddingServiceAdapter, +) + + +# langchain defines some sync methods as abstract in its base class +# so we have to add dummy methods for them, even though we only use the async versions +class AsyncElasticsearchEmbeddings(_AsyncElasticsearchEmbeddings): def embed_documents(self, texts: List[str]) -> List[List[float]]: - """ - Generate embeddings for a list of documents. - - Args: - texts (List[str]): A list of document text strings to generate embeddings - for. - - Returns: - List[List[float]]: A list of embeddings, one for each document in the input - list. - """ - return self._langchain_embeddings.embed_documents(texts) + raise NotImplementedError("This class is asynchronous, use aembed_documents()") def embed_query(self, text: str) -> List[float]: - """ - Generate an embedding for a single query text. - - Args: - text (str): The query text to generate an embedding for. - - Returns: - List[float]: The embedding for the input query text. - """ - return self._langchain_embeddings.embed_query(text) + raise NotImplementedError("This class is asynchronous, use aembed_query()") diff --git a/libs/elasticsearch/langchain_elasticsearch/retrievers.py b/libs/elasticsearch/langchain_elasticsearch/retrievers.py index 85618e1..018acf3 100644 --- a/libs/elasticsearch/langchain_elasticsearch/retrievers.py +++ b/libs/elasticsearch/langchain_elasticsearch/retrievers.py @@ -1,128 +1,22 @@ -import logging -from typing import ( - Any, - Callable, - Dict, - List, - Mapping, - Optional, - Sequence, - Union, - cast, -) +from typing import Any, List -from elasticsearch import Elasticsearch -from langchain_core.callbacks import ( - CallbackManagerForRetrieverRun, -) +from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document -from langchain_core.retrievers import BaseRetriever - -from langchain_elasticsearch._utilities import with_user_agent_header -from langchain_elasticsearch.client import create_elasticsearch_client - -logger = logging.getLogger(__name__) - - -class ElasticsearchRetriever(BaseRetriever): - """ - Elasticsearch retriever - - Args: - es_client: Elasticsearch client connection. Alternatively you can use the - `from_es_params` method with parameters to initialize the client. - index_name: The name of the index to query. Can also be a list of names. - body_func: Function to create an Elasticsearch DSL query body from a search - string. The returned query body must fit what you would normally send in a - POST request the the _search endpoint. If applicable, it also includes - parameters the `size` parameter etc. - content_field: The document field name that contains the page content. If - multiple indices are queried, specify a dict {index_name: field_name} here. - document_mapper: Function to map Elasticsearch hits to LangChain Documents. - """ - - _expects_other_args = True - - es_client: Elasticsearch - index_name: Union[str, Sequence[str]] - body_func: Callable[[str], Dict] - content_field: Optional[Union[str, Mapping[str, str]]] = None - document_mapper: Optional[Callable[[Mapping], Document]] = None - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - if self.content_field is None and self.document_mapper is None: - raise ValueError("One of content_field or document_mapper must be defined.") - if self.content_field is not None and self.document_mapper is not None: - raise ValueError( - "Both content_field and document_mapper are defined. " - "Please provide only one." - ) - - if not self.document_mapper: - if isinstance(self.content_field, str): - self.document_mapper = self._single_field_mapper - elif isinstance(self.content_field, Mapping): - self.document_mapper = self._multi_field_mapper - else: - raise ValueError( - "unknown type for content_field, expected string or dict." - ) - - self.es_client = with_user_agent_header(self.es_client, "langchain-py-r") - - @staticmethod - def from_es_params( - index_name: Union[str, Sequence[str]], - body_func: Callable[[str], Dict], - content_field: Optional[Union[str, Mapping[str, str]]] = None, - document_mapper: Optional[Callable[[Mapping], Document]] = None, - url: Optional[str] = None, - cloud_id: Optional[str] = None, - api_key: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - params: Optional[Dict[str, Any]] = None, - ) -> "ElasticsearchRetriever": - client = None - try: - client = create_elasticsearch_client( - url=url, - cloud_id=cloud_id, - api_key=api_key, - username=username, - password=password, - params=params, - ) - except Exception as err: - logger.error(f"Error connecting to Elasticsearch: {err}") - raise err +from langchain_elasticsearch._async.retrievers import ( + AsyncElasticsearchRetriever as _AsyncElasticsearchRetriever, +) +from langchain_elasticsearch._sync.retrievers import ( + ElasticsearchRetriever, # noqa: F401 +) - return ElasticsearchRetriever( - es_client=client, - index_name=index_name, - body_func=body_func, - content_field=content_field, - document_mapper=document_mapper, - ) +# langchain defines some sync methods as abstract in its base class +# so we have to add dummy methods for them, even though we only use the async versions +class AsyncElasticsearchRetriever(_AsyncElasticsearchRetriever): def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any ) -> List[Document]: - if not self.es_client or not self.document_mapper: - raise ValueError("faulty configuration") # should not happen - - body = self.body_func(query, **kwargs) - results = self.es_client.search(index=self.index_name, body=body) - return [self.document_mapper(hit) for hit in results["hits"]["hits"]] - - def _single_field_mapper(self, hit: Mapping[str, Any]) -> Document: - content = hit["_source"].pop(self.content_field) - return Document(page_content=content, metadata=hit) - - def _multi_field_mapper(self, hit: Mapping[str, Any]) -> Document: - self.content_field = cast(Mapping, self.content_field) - field = self.content_field[hit["_index"]] - content = hit["_source"].pop(field) - return Document(page_content=content, metadata=hit) + raise NotImplementedError( + "This class is asynchronous, use _aget_relevant_documents()" + ) diff --git a/libs/elasticsearch/langchain_elasticsearch/vectorstores.py b/libs/elasticsearch/langchain_elasticsearch/vectorstores.py index 902e024..1934c89 100644 --- a/libs/elasticsearch/langchain_elasticsearch/vectorstores.py +++ b/libs/elasticsearch/langchain_elasticsearch/vectorstores.py @@ -1,1325 +1,56 @@ -import logging -from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union - -from elasticsearch import Elasticsearch -from elasticsearch.helpers.vectorstore import ( +from typing import Any, Optional + +from langchain_elasticsearch._async.vectorstores import ( # noqa: F401 + AsyncBM25Strategy, + AsyncDenseVectorScriptScoreStrategy, + AsyncDenseVectorStrategy, + AsyncRetrievalStrategy, + AsyncSparseVectorStrategy, + DistanceMetric, + Document, + Embeddings, +) +from langchain_elasticsearch._async.vectorstores import ( + AsyncElasticsearchStore as _AsyncElasticsearchStore, +) +from langchain_elasticsearch._sync.vectorstores import ( # noqa: F401 BM25Strategy, DenseVectorScriptScoreStrategy, DenseVectorStrategy, - DistanceMetric, + ElasticsearchStore, RetrievalStrategy, SparseVectorStrategy, ) -from elasticsearch.helpers.vectorstore import VectorStore as EVectorStore -from langchain_core._api.deprecation import deprecated -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from langchain_core.vectorstores import VectorStore -from langchain_elasticsearch._utilities import ( +# deprecated strategy classes +from langchain_elasticsearch._utilities import ( # noqa: F401 + ApproxRetrievalStrategy, + BaseRetrievalStrategy, + BM25RetrievalStrategy, DistanceStrategy, - model_must_be_deployed, - user_agent, + ExactRetrievalStrategy, + SparseRetrievalStrategy, ) -from langchain_elasticsearch.client import create_elasticsearch_client -from langchain_elasticsearch.embeddings import EmbeddingServiceAdapter - -logger = logging.getLogger(__name__) - - -@deprecated("0.2.0", alternative="RetrievalStrategy", pending=True) -class BaseRetrievalStrategy(ABC): - """Base class for `Elasticsearch` retrieval strategies.""" - - @abstractmethod - def query( - self, - query_vector: Union[List[float], None], - query: Union[str, None], - *, - k: int, - fetch_k: int, - vector_query_field: str, - text_field: str, - filter: List[dict], - similarity: Union[DistanceStrategy, None], - ) -> Dict: - """ - Executes when a search is performed on the store. - - Args: - query_vector: The query vector, - or None if not using vector-based query. - query: The text query, or None if not using text-based query. - k: The total number of results to retrieve. - fetch_k: The number of results to fetch initially. - vector_query_field: The field containing the vector - representations in the index. - text_field: The field containing the text data in the index. - filter: List of filter clauses to apply to the query. - similarity: The similarity strategy to use, or None if not using one. - - Returns: - Dict: The Elasticsearch query body. - """ - - @abstractmethod - def index( - self, - dims_length: Union[int, None], - vector_query_field: str, - text_field: str, - similarity: Union[DistanceStrategy, None], - ) -> Dict: - """ - Executes when the index is created. - - Args: - dims_length: Numeric length of the embedding vectors, - or None if not using vector-based query. - vector_query_field: The field containing the vector - representations in the index. - text_field: The field containing the text data in the index. - similarity: The similarity strategy to use, - or None if not using one. - - Returns: - Dict: The Elasticsearch settings and mappings for the strategy. - """ - - def before_index_setup( - self, client: "Elasticsearch", text_field: str, vector_query_field: str - ) -> None: - """ - Executes before the index is created. Used for setting up - any required Elasticsearch resources like a pipeline. - - Args: - client: The Elasticsearch client. - text_field: The field containing the text data in the index. - vector_query_field: The field containing the vector - representations in the index. - """ - - def require_inference(self) -> bool: - """ - Returns whether or not the strategy requires inference - to be performed on the text before it is added to the index. - - Returns: - bool: Whether or not the strategy requires inference - to be performed on the text before it is added to the index. - """ - return True - - -@deprecated("0.2.0", alternative="DenseVectorStrategy", pending=True) -class ApproxRetrievalStrategy(BaseRetrievalStrategy): - """Approximate retrieval strategy using the `HNSW` algorithm.""" - - def __init__( - self, - query_model_id: Optional[str] = None, - hybrid: Optional[bool] = False, - rrf: Optional[Union[dict, bool]] = True, - ): - self.query_model_id = query_model_id - self.hybrid = hybrid - - # RRF has two optional parameters - # 'rank_constant', 'window_size' - # https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html - self.rrf = rrf - - def query( - self, - query_vector: Union[List[float], None], - query: Union[str, None], - k: int, - fetch_k: int, - vector_query_field: str, - text_field: str, - filter: List[dict], - similarity: Union[DistanceStrategy, None], - ) -> Dict: - knn = { - "filter": filter, - "field": vector_query_field, - "k": k, - "num_candidates": fetch_k, - } - - # Embedding provided via the embedding function - if query_vector is not None and not self.query_model_id: - knn["query_vector"] = list(query_vector) - - # Case 2: Used when model has been deployed to - # Elasticsearch and can infer the query vector from the query text - elif query and self.query_model_id: - knn["query_vector_builder"] = { - "text_embedding": { - "model_id": self.query_model_id, # use 'model_id' argument - "model_text": query, # use 'query' argument - } - } - - else: - raise ValueError( - "You must provide an embedding function or a" - " query_model_id to perform a similarity search." - ) - - # If hybrid, add a query to the knn query - # RRF is used to even the score from the knn query and text query - # RRF has two optional parameters: {'rank_constant':int, 'window_size':int} - # https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html - if self.hybrid: - query_body = { - "knn": knn, - "query": { - "bool": { - "must": [ - { - "match": { - text_field: { - "query": query, - } - } - } - ], - "filter": filter, - } - }, - } - - if isinstance(self.rrf, dict): - query_body["rank"] = {"rrf": self.rrf} - elif isinstance(self.rrf, bool) and self.rrf is True: - query_body["rank"] = {"rrf": {}} - - return query_body - else: - return {"knn": knn} - - def before_index_setup( - self, client: "Elasticsearch", text_field: str, vector_query_field: str - ) -> None: - if self.query_model_id: - model_must_be_deployed(client, self.query_model_id) - - def index( - self, - dims_length: Union[int, None], - vector_query_field: str, - text_field: str, - similarity: Union[DistanceStrategy, None], - ) -> Dict: - """Create the mapping for the Elasticsearch index.""" - - if similarity is DistanceStrategy.COSINE: - similarityAlgo = "cosine" - elif similarity is DistanceStrategy.EUCLIDEAN_DISTANCE: - similarityAlgo = "l2_norm" - elif similarity is DistanceStrategy.DOT_PRODUCT: - similarityAlgo = "dot_product" - elif similarity is DistanceStrategy.MAX_INNER_PRODUCT: - similarityAlgo = "max_inner_product" - else: - raise ValueError(f"Similarity {similarity} not supported.") - - return { - "mappings": { - "properties": { - vector_query_field: { - "type": "dense_vector", - "dims": dims_length, - "index": True, - "similarity": similarityAlgo, - }, - } - } - } - - -@deprecated("0.2.0", alternative="DenseVectorScriptScoreStrategy", pending=True) -class ExactRetrievalStrategy(BaseRetrievalStrategy): - """Exact retrieval strategy using the `script_score` query.""" - - def query( - self, - query_vector: Union[List[float], None], - query: Union[str, None], - k: int, - fetch_k: int, - vector_query_field: str, - text_field: str, - filter: Union[List[dict], None], - similarity: Union[DistanceStrategy, None], - ) -> Dict: - if similarity is DistanceStrategy.COSINE: - similarityAlgo = ( - f"cosineSimilarity(params.query_vector, '{vector_query_field}') + 1.0" - ) - elif similarity is DistanceStrategy.EUCLIDEAN_DISTANCE: - similarityAlgo = ( - f"1 / (1 + l2norm(params.query_vector, '{vector_query_field}'))" - ) - elif similarity is DistanceStrategy.DOT_PRODUCT: - similarityAlgo = f""" - double value = dotProduct(params.query_vector, '{vector_query_field}'); - return sigmoid(1, Math.E, -value); - """ - else: - raise ValueError(f"Similarity {similarity} not supported.") - - queryBool: Dict = {"match_all": {}} - if filter: - queryBool = {"bool": {"filter": filter}} - - return { - "query": { - "script_score": { - "query": queryBool, - "script": { - "source": similarityAlgo, - "params": {"query_vector": query_vector}, - }, - }, - } - } - - def index( - self, - dims_length: Union[int, None], - vector_query_field: str, - text_field: str, - similarity: Union[DistanceStrategy, None], - ) -> Dict: - """Create the mapping for the Elasticsearch index.""" - - return { - "mappings": { - "properties": { - vector_query_field: { - "type": "dense_vector", - "dims": dims_length, - "index": False, - }, - } - } - } - - -@deprecated("0.2.0", alternative="SparseVectorStrategy", pending=True) -class SparseRetrievalStrategy(BaseRetrievalStrategy): - """Sparse retrieval strategy using the `text_expansion` processor.""" - - def __init__(self, model_id: Optional[str] = None): - self.model_id = model_id or ".elser_model_1" - - def query( - self, - query_vector: Union[List[float], None], - query: Union[str, None], - k: int, - fetch_k: int, - vector_query_field: str, - text_field: str, - filter: List[dict], - similarity: Union[DistanceStrategy, None], - ) -> Dict: - return { - "query": { - "bool": { - "must": [ - { - "text_expansion": { - f"{vector_query_field}.tokens": { - "model_id": self.model_id, - "model_text": query, - } - } - } - ], - "filter": filter, - } - } - } - - def _get_pipeline_name(self) -> str: - return f"{self.model_id}_sparse_embedding" - - def before_index_setup( - self, client: "Elasticsearch", text_field: str, vector_query_field: str - ) -> None: - if self.model_id: - model_must_be_deployed(client, self.model_id) - - # Create a pipeline for the model - client.ingest.put_pipeline( - id=self._get_pipeline_name(), - description="Embedding pipeline for langchain vectorstore", - processors=[ - { - "inference": { - "model_id": self.model_id, - "target_field": vector_query_field, - "field_map": {text_field: "text_field"}, - "inference_config": { - "text_expansion": {"results_field": "tokens"} - }, - } - } - ], - ) - - def index( - self, - dims_length: Union[int, None], - vector_query_field: str, - text_field: str, - similarity: Union[DistanceStrategy, None], - ) -> Dict: - return { - "mappings": { - "properties": { - vector_query_field: { - "properties": {"tokens": {"type": "rank_features"}} - } - } - }, - "settings": {"default_pipeline": self._get_pipeline_name()}, - } - - def require_inference(self) -> bool: - return False - - -@deprecated("0.2.0", alternative="BM25Strategy", pending=True) -class BM25RetrievalStrategy(BaseRetrievalStrategy): - """Retrieval strategy using the native BM25 algorithm of Elasticsearch.""" - - def __init__(self, k1: Union[float, None] = None, b: Union[float, None] = None): - self.k1 = k1 - self.b = b - - def query( - self, - query_vector: Union[List[float], None], - query: Union[str, None], - k: int, - fetch_k: int, - vector_query_field: str, - text_field: str, - filter: List[dict], - similarity: Union[DistanceStrategy, None], - ) -> Dict: - return { - "query": { - "bool": { - "must": [ - { - "match": { - text_field: { - "query": query, - } - }, - }, - ], - "filter": filter, - }, - }, - } - - def index( - self, - dims_length: Union[int, None], - vector_query_field: str, - text_field: str, - similarity: Union[DistanceStrategy, None], - ) -> Dict: - mappings: Dict = { - "properties": { - text_field: { - "type": "text", - "similarity": "custom_bm25", - }, - }, - } - settings: Dict = { - "similarity": { - "custom_bm25": { - "type": "BM25", - }, - }, - } - - if self.k1 is not None: - settings["similarity"]["custom_bm25"]["k1"] = self.k1 - - if self.b is not None: - settings["similarity"]["custom_bm25"]["b"] = self.b - - return {"mappings": mappings, "settings": settings} - - def require_inference(self) -> bool: - return False - - -def _convert_retrieval_strategy( - langchain_strategy: BaseRetrievalStrategy, - distance: Optional[DistanceStrategy] = None, -) -> RetrievalStrategy: - if isinstance(langchain_strategy, ApproxRetrievalStrategy): - if distance is None: - raise ValueError( - "ApproxRetrievalStrategy requires a distance strategy to be provided." - ) - return DenseVectorStrategy( - distance=DistanceMetric[distance], - model_id=langchain_strategy.query_model_id, - hybrid=( - False - if langchain_strategy.hybrid is None - else langchain_strategy.hybrid - ), - rrf=False if langchain_strategy.rrf is None else langchain_strategy.rrf, - ) - elif isinstance(langchain_strategy, ExactRetrievalStrategy): - if distance is None: - raise ValueError( - "ExactRetrievalStrategy requires a distance strategy to be provided." - ) - return DenseVectorScriptScoreStrategy(distance=DistanceMetric[distance]) - elif isinstance(langchain_strategy, SparseRetrievalStrategy): - return SparseVectorStrategy(langchain_strategy.model_id) - elif isinstance(langchain_strategy, BM25RetrievalStrategy): - return BM25Strategy(k1=langchain_strategy.k1, b=langchain_strategy.b) - else: - raise TypeError( - f"Strategy {langchain_strategy} not supported. To provide a " - f"custom strategy, please subclass {RetrievalStrategy}." - ) - - -def _hits_to_docs_scores( - hits: List[Dict[str, Any]], - content_field: str, - fields: Optional[List[str]] = None, - doc_builder: Optional[Callable[[Dict], Document]] = None, -) -> List[Tuple[Document, float]]: - if fields is None: - fields = [] - - documents = [] - - def default_doc_builder(hit: Dict) -> Document: - return Document( - page_content=hit["_source"].get(content_field, ""), - metadata=hit["_source"].get("metadata", {}), - ) - - doc_builder = doc_builder or default_doc_builder - - for hit in hits: - for field in fields: - if "metadata" not in hit["_source"]: - hit["_source"]["metadata"] = {} - if field in hit["_source"] and field not in [ - "metadata", - content_field, - ]: - hit["_source"]["metadata"][field] = hit["_source"][field] - - doc = doc_builder(hit) - documents.append((doc, hit["_score"])) - - return documents - - -class ElasticsearchStore(VectorStore): - """`Elasticsearch` vector store. - - Setup: - Install ``langchain_elasticsearch`` and running the Elasticsearch docker container. - - .. code-block:: bash - - pip install -qU langchain_elasticsearch - docker run -p 9200:9200 \ - -e "discovery.type=single-node" \ - -e "xpack.security.enabled=false" \ - -e "xpack.security.http.ssl.enabled=false" \ - docker.elastic.co/elasticsearch/elasticsearch:8.12.1 - - Key init args — indexing params: - index_name: str - Name of the index to create. - embedding: Embeddings - Embedding function to use. - - Key init args — client params: - es_connection: Optional[Elasticsearch] - Pre-existing Elasticsearch connection. - es_url: Optional[str] - URL of the Elasticsearch instance to connect to. - es_cloud_id: Optional[str] - Cloud ID of the Elasticsearch instance to connect to. - es_user: Optional[str] - Username to use when connecting to Elasticsearch. - es_password: Optional[str] - Password to use when connecting to Elasticsearch. - es_api_key: Optional[str] - API key to use when connecting to Elasticsearch. - - Instantiate: - .. code-block:: python - - from langchain_elasticsearch import ElasticsearchStore - from langchain_openai import OpenAIEmbeddings - - vector_store = ElasticsearchStore( - index_name="langchain-demo", - embedding=OpenAIEmbeddings(), - es_url="http://localhost:9200", - ) - - If you want to use a cloud hosted Elasticsearch instance, you can pass in the - cloud_id argument instead of the es_url argument. - - Instantiate from cloud: - .. code-block:: python - - from langchain_elasticsearch.vectorstores import ElasticsearchStore - from langchain_openai import OpenAIEmbeddings - - store = ElasticsearchStore( - embedding=OpenAIEmbeddings(), - index_name="langchain-demo", - es_cloud_id="" - es_user="elastic", - es_password="" - ) - - You can also connect to an existing Elasticsearch instance by passing in a - pre-existing Elasticsearch connection via the es_connection argument. - - Instantiate from existing connection: - .. code-block:: python - - from langchain_elasticsearch.vectorstores import ElasticsearchStore - from langchain_openai import OpenAIEmbeddings - - from elasticsearch import Elasticsearch - - es_connection = Elasticsearch("http://localhost:9200") - - store = ElasticsearchStore( - embedding=OpenAIEmbeddings(), - index_name="langchain-demo", - es_connection=es_connection - ) - - Add Documents: - .. code-block:: python - - from langchain_core.documents import Document - - document_1 = Document(page_content="foo", metadata={"baz": "bar"}) - document_2 = Document(page_content="thud", metadata={"bar": "baz"}) - document_3 = Document(page_content="i will be deleted :(") - - documents = [document_1, document_2, document_3] - ids = ["1", "2", "3"] - vector_store.add_documents(documents=documents, ids=ids) - - Delete Documents: - .. code-block:: python - - vector_store.delete(ids=["3"]) - - Search: - .. code-block:: python - - results = vector_store.similarity_search(query="thud",k=1) - for doc in results: - print(f"* {doc.page_content} [{doc.metadata}]") - - .. code-block:: python - - * thud [{'bar': 'baz'}] - - Search with filter: - .. code-block:: python - - results = vector_store.similarity_search(query="thud",k=1,filter=[{"term": {"metadata.bar.keyword": "baz"}}]) - for doc in results: - print(f"* {doc.page_content} [{doc.metadata}]") - - .. code-block:: python - - * thud [{'bar': 'baz'}] - - Search with score: - .. code-block:: python - - results = vector_store.similarity_search_with_score(query="qux",k=1) - for doc, score in results: - print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]") - - .. code-block:: python - - * [SIM=0.916092] foo [{'baz': 'bar'}] - - Async: - .. code-block:: python - - # add documents - # await vector_store.aadd_documents(documents=documents, ids=ids) - - # delete documents - # await vector_store.adelete(ids=["3"]) - - # search - # results = vector_store.asimilarity_search(query="thud",k=1) - - # search with score - results = await vector_store.asimilarity_search_with_score(query="qux",k=1) - for doc,score in results: - print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]") - - .. code-block:: python - * [SIM=0.916092] foo [{'baz': 'bar'}] - - Use as Retriever: - - .. code-block:: bash - - pip install "elasticsearch[vectorstore_mmr]" - - .. code-block:: python - - retriever = vector_store.as_retriever( - search_type="mmr", - search_kwargs={"k": 1, "fetch_k": 2, "lambda_mult": 0.5}, - ) - retriever.invoke("thud") - - .. code-block:: python - - [Document(metadata={'bar': 'baz'}, page_content='thud')] - - **Advanced Uses:** - - ElasticsearchStore by default uses the ApproxRetrievalStrategy, which uses the - HNSW algorithm to perform approximate nearest neighbor search. This is the - fastest and most memory efficient algorithm. - - If you want to use the Brute force / Exact strategy for searching vectors, you - can pass in the ExactRetrievalStrategy to the ElasticsearchStore constructor. - - Use ExactRetrievalStrategy: - .. code-block:: python - - from langchain_elasticsearch.vectorstores import ElasticsearchStore - from langchain_openai import OpenAIEmbeddings - - store = ElasticsearchStore( - embedding=OpenAIEmbeddings(), - index_name="langchain-demo", - es_url="http://localhost:9200", - strategy=ElasticsearchStore.ExactRetrievalStrategy() - ) - - Both strategies require that you know the similarity metric you want to use - when creating the index. The default is cosine similarity, but you can also - use dot product or euclidean distance. - - Use dot product similarity: - .. code-block:: python - - from langchain_elasticsearch.vectorstores import ElasticsearchStore - from langchain_openai import OpenAIEmbeddings - from langchain_community.vectorstores.utils import DistanceStrategy - - store = ElasticsearchStore( - "langchain-demo", - embedding=OpenAIEmbeddings(), - es_url="http://localhost:9200", - distance_strategy="DOT_PRODUCT" - ) - - """ # noqa: E501 - - def __init__( - self, - index_name: str, - *, - embedding: Optional[Embeddings] = None, - es_connection: Optional[Elasticsearch] = None, - es_url: Optional[str] = None, - es_cloud_id: Optional[str] = None, - es_user: Optional[str] = None, - es_api_key: Optional[str] = None, - es_password: Optional[str] = None, - vector_query_field: str = "vector", - query_field: str = "text", - distance_strategy: Optional[ - Literal[ - DistanceStrategy.COSINE, - DistanceStrategy.DOT_PRODUCT, - DistanceStrategy.EUCLIDEAN_DISTANCE, - DistanceStrategy.MAX_INNER_PRODUCT, - ] - ] = None, - strategy: Union[ - BaseRetrievalStrategy, RetrievalStrategy - ] = ApproxRetrievalStrategy(), - es_params: Optional[Dict[str, Any]] = None, - ): - if isinstance(strategy, BaseRetrievalStrategy): - strategy = _convert_retrieval_strategy( - strategy, distance=distance_strategy or DistanceStrategy.COSINE - ) - - embedding_service = None - if embedding: - embedding_service = EmbeddingServiceAdapter(embedding) - - if not es_connection: - es_connection = create_elasticsearch_client( - url=es_url, - cloud_id=es_cloud_id, - api_key=es_api_key, - username=es_user, - password=es_password, - params=es_params, - ) - - self._store = EVectorStore( - client=es_connection, - index=index_name, - retrieval_strategy=strategy, - embedding_service=embedding_service, - text_field=query_field, - vector_field=vector_query_field, - user_agent=user_agent("langchain-py-vs"), - ) - - self.embedding = embedding - self.client = self._store.client - self._embedding_service = embedding_service - self.query_field = query_field - self.vector_query_field = vector_query_field - - def close(self) -> None: - self._store.close() - - @property - def embeddings(self) -> Optional[Embeddings]: - return self.embedding - - @staticmethod - def connect_to_elasticsearch( - *, - es_url: Optional[str] = None, - cloud_id: Optional[str] = None, - api_key: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - es_params: Optional[Dict[str, Any]] = None, - ) -> Elasticsearch: - return create_elasticsearch_client( - url=es_url, - cloud_id=cloud_id, - api_key=api_key, - username=username, - password=password, - params=es_params, - ) - - def similarity_search( - self, - query: str, - k: int = 4, - fetch_k: int = 50, - filter: Optional[List[dict]] = None, - *, - custom_query: Optional[ - Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] - ] = None, - doc_builder: Optional[Callable[[Dict], Document]] = None, - **kwargs: Any, - ) -> List[Document]: - """Return Elasticsearch documents most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to knn num_candidates. - filter: Array of Elasticsearch filter clauses to apply to the query. - - Returns: - List of Documents most similar to the query, - in descending order of similarity. - """ - hits = self._store.search( - query=query, - k=k, - num_candidates=fetch_k, - filter=filter, - custom_query=custom_query, - ) - docs = _hits_to_docs_scores( - hits=hits, - content_field=self.query_field, - doc_builder=doc_builder, - ) - return [doc for doc, _score in docs] - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - fields: Optional[List[str]] = None, - *, - custom_query: Optional[ - Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] - ] = None, - doc_builder: Optional[Callable[[Dict], Document]] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query (str): Text to look up documents similar to. - k (int): Number of Documents to return. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - lambda_mult (float): Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - fields: Other fields to get from elasticsearch source. These fields - will be added to the document metadata. - - Returns: - List[Document]: A list of Documents selected by maximal marginal relevance. - """ - if self._embedding_service is None: - raise ValueError( - "maximal marginal relevance search requires an embedding service." - ) - - hits = self._store.max_marginal_relevance_search( - embedding_service=self._embedding_service, - query=query, - vector_field=self.vector_query_field, - k=k, - num_candidates=fetch_k, - lambda_mult=lambda_mult, - fields=fields, - custom_query=custom_query, - ) - - docs_scores = _hits_to_docs_scores( - hits=hits, - content_field=self.query_field, - fields=fields, - doc_builder=doc_builder, - ) - - return [doc for doc, _score in docs_scores] - - @staticmethod - def _identity_fn(score: float) -> float: - return score - - def _select_relevance_score_fn(self) -> Callable[[float], float]: - """ - The 'correct' relevance function - may differ depending on a few things, including: - - the distance / similarity metric used by the VectorStore - - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) - - embedding dimensionality - - etc. - - Vectorstores should define their own selection based method of relevance. - """ - # All scores from Elasticsearch are already normalized similarities: - # https://www.elastic.co/guide/en/elasticsearch/reference/current/dense-vector.html#dense-vector-params - return self._identity_fn - - def similarity_search_with_score( - self, - query: str, - k: int = 4, - filter: Optional[List[dict]] = None, - *, - custom_query: Optional[ - Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] - ] = None, - doc_builder: Optional[Callable[[Dict], Document]] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return Elasticsearch documents most similar to query, along with scores. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Array of Elasticsearch filter clauses to apply to the query. - - Returns: - List of Documents most similar to the query and score for each - """ - if ( - isinstance(self._store.retrieval_strategy, DenseVectorStrategy) - and self._store.retrieval_strategy.hybrid - ): - raise ValueError("scores are currently not supported in hybrid mode") - - hits = self._store.search( - query=query, k=k, filter=filter, custom_query=custom_query - ) - return _hits_to_docs_scores( - hits=hits, - content_field=self.query_field, - doc_builder=doc_builder, - ) - - def similarity_search_by_vector_with_relevance_scores( - self, - embedding: List[float], - k: int = 4, - filter: Optional[List[Dict]] = None, - *, - custom_query: Optional[ - Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]] - ] = None, - doc_builder: Optional[Callable[[Dict], Document]] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return Elasticsearch documents most similar to query, along with scores. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: Array of Elasticsearch filter clauses to apply to the query. - - Returns: - List of Documents most similar to the embedding and score for each - """ - if ( - isinstance(self._store.retrieval_strategy, DenseVectorStrategy) - and self._store.retrieval_strategy.hybrid - ): - raise ValueError("scores are currently not supported in hybrid mode") - - hits = self._store.search( - query=None, - query_vector=embedding, - k=k, - filter=filter, - custom_query=custom_query, - ) - return _hits_to_docs_scores( - hits=hits, - content_field=self.query_field, - doc_builder=doc_builder, - ) - - def delete( - self, - ids: Optional[List[str]] = None, - refresh_indices: Optional[bool] = True, - **kwargs: Any, - ) -> Optional[bool]: - """Delete documents from the Elasticsearch index. - - Args: - ids: List of ids of documents to delete. - refresh_indices: Whether to refresh the index - after deleting documents. Defaults to True. - """ - if ids is None: - raise ValueError("please specify some IDs") - - return self._store.delete(ids=ids, refresh_indices=refresh_indices or False) - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[Dict[Any, Any]]] = None, - ids: Optional[List[str]] = None, - refresh_indices: bool = True, - create_index_if_not_exists: bool = True, - bulk_kwargs: Optional[Dict] = None, - **kwargs: Any, - ) -> List[str]: - """Run more texts through the embeddings and add to the store. - - Args: - texts: Iterable of strings to add to the store. - metadatas: Optional list of metadatas associated with the texts. - ids: Optional list of ids to associate with the texts. - refresh_indices: Whether to refresh the Elasticsearch indices - after adding the texts. - create_index_if_not_exists: Whether to create the Elasticsearch - index if it doesn't already exist. - *bulk_kwargs: Additional arguments to pass to Elasticsearch bulk. - - chunk_size: Optional. Number of texts to add to the - index at a time. Defaults to 500. - - Returns: - List of ids from adding the texts into the store. - """ - return self._store.add_texts( - texts=list(texts), - metadatas=metadatas, - ids=ids, - refresh_indices=refresh_indices, - create_index_if_not_exists=create_index_if_not_exists, - bulk_kwargs=bulk_kwargs, - ) - - def add_embeddings( - self, - text_embeddings: Iterable[Tuple[str, List[float]]], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - refresh_indices: bool = True, - create_index_if_not_exists: bool = True, - bulk_kwargs: Optional[Dict] = None, - **kwargs: Any, - ) -> List[str]: - """Add the given texts and embeddings to the store. - - Args: - text_embeddings: Iterable pairs of string and embedding to - add to the store. - metadatas: Optional list of metadatas associated with the texts. - ids: Optional list of unique IDs. - refresh_indices: Whether to refresh the Elasticsearch indices - after adding the texts. - create_index_if_not_exists: Whether to create the Elasticsearch - index if it doesn't already exist. - *bulk_kwargs: Additional arguments to pass to Elasticsearch bulk. - - chunk_size: Optional. Number of texts to add to the - index at a time. Defaults to 500. - - Returns: - List of ids from adding the texts into the store. - """ - texts, embeddings = zip(*text_embeddings) - return self._store.add_texts( - texts=list(texts), - metadatas=metadatas, - vectors=list(embeddings), - ids=ids, - refresh_indices=refresh_indices, - create_index_if_not_exists=create_index_if_not_exists, - bulk_kwargs=bulk_kwargs, - ) +# langchain defines some sync methods as abstract in its base class +# so we have to add dummy methods for them, even though we only use the async versions +class AsyncElasticsearchStore(_AsyncElasticsearchStore): @classmethod def from_texts( cls, - texts: List[str], - embedding: Optional[Embeddings] = None, - metadatas: Optional[List[Dict[str, Any]]] = None, - bulk_kwargs: Optional[Dict] = None, - **kwargs: Any, - ) -> "ElasticsearchStore": - """Construct ElasticsearchStore wrapper from raw documents. - - Example: - .. code-block:: python - - from langchain_elasticsearch.vectorstores import ElasticsearchStore - from langchain_openai import OpenAIEmbeddings - - db = ElasticsearchStore.from_texts( - texts, - // embeddings optional if using - // a strategy that doesn't require inference - embeddings, - index_name="langchain-demo", - es_url="http://localhost:9200" - ) - - Args: - texts: List of texts to add to the Elasticsearch index. - embedding: Embedding function to use to embed the texts. - metadatas: Optional list of metadatas associated with the texts. - index_name: Name of the Elasticsearch index to create. - es_url: URL of the Elasticsearch instance to connect to. - cloud_id: Cloud ID of the Elasticsearch instance to connect to. - es_user: Username to use when connecting to Elasticsearch. - es_password: Password to use when connecting to Elasticsearch. - es_api_key: API key to use when connecting to Elasticsearch. - es_connection: Optional pre-existing Elasticsearch connection. - vector_query_field: Optional. Name of the field to - store the embedding vectors in. - query_field: Optional. Name of the field to store the texts in. - distance_strategy: Optional. Name of the distance - strategy to use. Defaults to "COSINE". - can be one of "COSINE", - "EUCLIDEAN_DISTANCE", "DOT_PRODUCT", - "MAX_INNER_PRODUCT". - bulk_kwargs: Optional. Additional arguments to pass to - Elasticsearch bulk. - """ - - index_name = kwargs.get("index_name") - if index_name is None: - raise ValueError("Please provide an index_name.") - - elasticsearchStore = ElasticsearchStore(embedding=embedding, **kwargs) - - # Encode the provided texts and add them to the newly created index. - elasticsearchStore.add_texts( - texts=texts, metadatas=metadatas, bulk_kwargs=bulk_kwargs - ) - - return elasticsearchStore - - @classmethod - def from_documents( - cls, - documents: List[Document], - embedding: Optional[Embeddings] = None, - bulk_kwargs: Optional[Dict] = None, + texts: list[str], + embedding: Embeddings, + metadatas: Optional[list[dict]] = None, + *, + ids: Optional[list[str]] = None, **kwargs: Any, - ) -> "ElasticsearchStore": - """Construct ElasticsearchStore wrapper from documents. - - Example: - .. code-block:: python - - from langchain_elasticsearch.vectorstores import ElasticsearchStore - from langchain_openai import OpenAIEmbeddings - - db = ElasticsearchStore.from_documents( - texts, - embeddings, - index_name="langchain-demo", - es_url="http://localhost:9200" - ) - - Args: - texts: List of texts to add to the Elasticsearch index. - embedding: Embedding function to use to embed the texts. - Do not provide if using a strategy - that doesn't require inference. - metadatas: Optional list of metadatas associated with the texts. - index_name: Name of the Elasticsearch index to create. - es_url: URL of the Elasticsearch instance to connect to. - cloud_id: Cloud ID of the Elasticsearch instance to connect to. - es_user: Username to use when connecting to Elasticsearch. - es_password: Password to use when connecting to Elasticsearch. - es_api_key: API key to use when connecting to Elasticsearch. - es_connection: Optional pre-existing Elasticsearch connection. - vector_query_field: Optional. Name of the field - to store the embedding vectors in. - query_field: Optional. Name of the field to store the texts in. - bulk_kwargs: Optional. Additional arguments to pass to - Elasticsearch bulk. - """ - - index_name = kwargs.get("index_name") - if index_name is None: - raise ValueError("Please provide an index_name.") - - elasticsearchStore = ElasticsearchStore(embedding=embedding, **kwargs) + ) -> "AsyncElasticsearchStore": + raise NotImplementedError("This class is asynchronous, use afrom_texts()") - # Encode the provided texts and add them to the newly created index. - elasticsearchStore.add_documents(documents, bulk_kwargs=bulk_kwargs) - - return elasticsearchStore - - @staticmethod - def ExactRetrievalStrategy() -> "ExactRetrievalStrategy": - """Used to perform brute force / exact - nearest neighbor search via script_score.""" - return ExactRetrievalStrategy() - - @staticmethod - def ApproxRetrievalStrategy( - query_model_id: Optional[str] = None, - hybrid: Optional[bool] = False, - rrf: Optional[Union[dict, bool]] = True, - ) -> "ApproxRetrievalStrategy": - """Used to perform approximate nearest neighbor search - using the HNSW algorithm. - - At build index time, this strategy will create a - dense vector field in the index and store the - embedding vectors in the index. - - At query time, the text will either be embedded using the - provided embedding function or the query_model_id - will be used to embed the text using the model - deployed to Elasticsearch. - - if query_model_id is used, do not provide an embedding function. - - Args: - query_model_id: Optional. ID of the model to use to - embed the query text within the stack. Requires - embedding model to be deployed to Elasticsearch. - hybrid: Optional. If True, will perform a hybrid search - using both the knn query and a text query. - Defaults to False. - rrf: Optional. rrf is Reciprocal Rank Fusion. - When `hybrid` is True, - and `rrf` is True, then rrf: {}. - and `rrf` is False, then rrf is omitted. - and isinstance(rrf, dict) is True, then pass in the dict values. - rrf could be passed for adjusting 'rank_constant' and 'window_size'. - """ - return ApproxRetrievalStrategy( - query_model_id=query_model_id, hybrid=hybrid, rrf=rrf + def similarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> list[Document]: + raise NotImplementedError( + "This class is asynchronous, use asimilarity_search()" ) - - @staticmethod - def SparseVectorRetrievalStrategy( - model_id: Optional[str] = None, - ) -> "SparseRetrievalStrategy": - """Used to perform sparse vector search via text_expansion. - Used for when you want to use ELSER model to perform document search. - - At build index time, this strategy will create a pipeline that - will embed the text using the ELSER model and store the - resulting tokens in the index. - - At query time, the text will be embedded using the ELSER - model and the resulting tokens will be used to - perform a text_expansion query. - - Args: - model_id: Optional. Default is ".elser_model_1". - ID of the model to use to embed the query text - within the stack. Requires embedding model to be - deployed to Elasticsearch. - """ - return SparseRetrievalStrategy(model_id=model_id) - - @staticmethod - def BM25RetrievalStrategy( - k1: Union[float, None] = None, b: Union[float, None] = None - ) -> "BM25RetrievalStrategy": - """Used to apply BM25 without vector search. - - Args: - k1: Optional. This corresponds to the BM25 parameter, k1. Default is None, - which uses the default setting of Elasticsearch. - b: Optional. This corresponds to the BM25 parameter, b. Default is None, - which uses the default setting of Elasticsearch. - """ - return BM25RetrievalStrategy(k1=k1, b=b) diff --git a/libs/elasticsearch/pyproject.toml b/libs/elasticsearch/pyproject.toml index fef8e48..9838ff3 100644 --- a/libs/elasticsearch/pyproject.toml +++ b/libs/elasticsearch/pyproject.toml @@ -91,5 +91,6 @@ addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5 markers = [ "requires: mark tests as requiring a specific library", "asyncio: mark tests as requiring asyncio", + "sync: mark tests as performing I/O without asyncio", ] -asyncio_mode = "auto" \ No newline at end of file +asyncio_mode = "auto" diff --git a/libs/elasticsearch/scripts/run_unasync.py b/libs/elasticsearch/scripts/run_unasync.py new file mode 100644 index 0000000..d708b76 --- /dev/null +++ b/libs/elasticsearch/scripts/run_unasync.py @@ -0,0 +1,149 @@ +import os +import subprocess +import sys +from glob import glob +from pathlib import Path + +import unasync + + +def main(check=False): + # the list of directories that need to be processed with unasync + # each entry has two paths: + # - the source path with the async sources + # - the destination path where the sync sources should be written + source_dirs = [ + ( + "langchain_elasticsearch/_async/", + "langchain_elasticsearch/_sync/", + ), + ("tests/_async/", "tests/_sync/"), + ("tests/integration_tests/_async/", "tests/integration_tests/_sync/"), + ("tests/unit_tests/_async/", "tests/unit_tests/_sync/"), + ] + + # Unasync all the generated async code + additional_replacements = { + "_async": "_sync", + "AsyncElasticsearch": "Elasticsearch", + "AsyncTransport": "Transport", + "AsyncBM25Strategy": "BM25Strategy", + "AsyncDenseVectorScriptScoreStrategy": "DenseVectorScriptScoreStrategy", + "AsyncDenseVectorStrategy": "DenseVectorStrategy", + "AsyncRetrievalStrategy": "RetrievalStrategy", + "AsyncSparseVectorStrategy": "SparseVectorStrategy", + "AsyncVectorStore": "VectorStore", + "AsyncElasticsearchStore": "ElasticsearchStore", + "AsyncElasticsearchEmbeddings": "ElasticsearchEmbeddings", + "AsyncElasticsearchEmbeddingsCache": "ElasticsearchEmbeddingsCache", + "AsyncEmbeddingServiceAdapter": "EmbeddingServiceAdapter", + "AsyncEmbeddingService": "EmbeddingService", + "AsyncElasticsearchRetriever": "ElasticsearchRetriever", + "AsyncElasticsearchCache": "ElasticsearchCache", + "AsyncElasticsearchChatMessageHistory": "ElasticsearchChatMessageHistory", + "AsyncCallbackManagerForRetrieverRun": "CallbackManagerForRetrieverRun", + "AsyncFakeEmbeddings": "FakeEmbeddings", + "AsyncConsistentFakeEmbeddings": "ConsistentFakeEmbeddings", + "AsyncRequestSavingTransport": "RequestSavingTransport", + "AsyncMock": "Mock", + "Embeddings": "Embeddings", + "AsyncGenerator": "Generator", + "AsyncIterator": "Iterator", + "create_async_elasticsearch_client": "create_elasticsearch_client", + "aadd_texts": "add_texts", + "aadd_embeddings": "add_embeddings", + "aadd_documents": "add_documents", + "afrom_texts": "from_texts", + "afrom_documents": "from_documents", + "amax_marginal_relevance_search": "max_marginal_relevance_search", + "asimilarity_search": "similarity_search", + "asimilarity_search_by_vector_with_relevance_scores": "similarity_search_by_vector_with_relevance_scores", # noqa: E501 + "asimilarity_search_with_score": "similarity_search_with_score", + "asimilarity_search_with_relevance_scores": "similarity_search_with_relevance_scores", # noqa: E501 + "adelete": "delete", + "aclose": "close", + "ainvoke": "invoke", + "aembed_documents": "embed_documents", + "aembed_query": "embed_query", + "_aget_relevant_documents": "_get_relevant_documents", + "aget_relevant_documents": "get_relevant_documents", + "alookup": "lookup", + "aupdate": "update", + "aclear": "clear", + "amget": "mget", + "amset": "mset", + "amdelete": "mdelete", + "ayield_keys": "yield_keys", + "asearch": "search", + "aget_messages": "get_messages", + "aadd_messages": "add_messages", + "aadd_message": "add_message", + "aencode_vector": "encode_vector", + "assert_awaited_with": "assert_called_with", + "async_es_client_fx": "es_client_fx", + "async_es_embeddings_cache_fx": "es_embeddings_cache_fx", + "async_es_cache_fx": "es_cache_fx", + "async_bulk": "bulk", + "async_with_user_agent_header": "with_user_agent_header", + "asyncio": "sync", + } + rules = [ + unasync.Rule( + fromdir=dir[0], + todir=f"{dir[0]}_sync_check/" if check else dir[1], + additional_replacements=additional_replacements, + ) + for dir in source_dirs + ] + + filepaths = [] + for root, _, filenames in os.walk(Path(__file__).absolute().parent.parent): + if "/site-packages" in root or "/." in root or "__pycache__" in root: + continue + for filename in filenames: + if filename.rpartition(".")[-1] in ( + "py", + "pyi", + ) and not filename.startswith("utils.py"): + filepaths.append(os.path.join(root, filename)) + + unasync.unasync_files(filepaths, rules) + for dir in source_dirs: + output_dir = f"{dir[0]}_sync_check/" if check else dir[1] + subprocess.check_call(["ruff", "format", "--target-version=py38", output_dir]) + subprocess.check_call(["isort", output_dir]) + for file in glob("*.py", root_dir=dir[0]): + subprocess.check_call( + [ + "sed", + "-i.bak", + "s/pytest.mark.asyncio/pytest.mark.sync/", + f"{output_dir}{file}", + ] + ) + subprocess.check_call( + [ + "sed", + "-i.bak", + "s/get_messages()/messages/", + f"{output_dir}{file}", + ] + ) + subprocess.check_call(["rm", f"{output_dir}{file}.bak"]) + + if check: + # make sure there are no differences between _sync and _sync_check + subprocess.check_call( + [ + "diff", + f"{dir[1]}{file}", + f"{output_dir}{file}", + ] + ) + + if check: + subprocess.check_call(["rm", "-rf", output_dir]) + + +if __name__ == "__main__": + main(check="--check" in sys.argv) diff --git a/libs/elasticsearch/tests/_async/fake_embeddings.py b/libs/elasticsearch/tests/_async/fake_embeddings.py new file mode 100644 index 0000000..af5c1cf --- /dev/null +++ b/libs/elasticsearch/tests/_async/fake_embeddings.py @@ -0,0 +1,49 @@ +"""Fake Embedding class for testing purposes.""" + +from typing import List + +from langchain_core.embeddings import Embeddings + +fake_texts = ["foo", "bar", "baz"] + + +class AsyncFakeEmbeddings(Embeddings): + """Fake embeddings functionality for testing.""" + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """Return simple embeddings. + Embeddings encode each text as its index.""" + return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))] + + async def aembed_query(self, text: str) -> List[float]: + """Return constant query embeddings. + Embeddings are identical to embed_documents(texts)[0]. + Distance to each text will be that text's index, + as it was passed to embed_documents.""" + return [float(1.0)] * 9 + [float(0.0)] + + +class AsyncConsistentFakeEmbeddings(AsyncFakeEmbeddings): + """Fake embeddings which remember all the texts seen so far to return consistent + vectors for the same texts.""" + + def __init__(self, dimensionality: int = 10) -> None: + self.known_texts: List[str] = [] + self.dimensionality = dimensionality + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """Return consistent embeddings for each text seen so far.""" + out_vectors = [] + for text in texts: + if text not in self.known_texts: + self.known_texts.append(text) + vector = [float(1.0)] * (self.dimensionality - 1) + [ + float(self.known_texts.index(text)) + ] + out_vectors.append(vector) + return out_vectors + + async def aembed_query(self, text: str) -> List[float]: + """Return consistent embeddings for the text, if seen before, or a constant + one if the text is unknown.""" + return (await self.aembed_documents([text]))[0] diff --git a/libs/elasticsearch/tests/_sync/fake_embeddings.py b/libs/elasticsearch/tests/_sync/fake_embeddings.py new file mode 100644 index 0000000..832403e --- /dev/null +++ b/libs/elasticsearch/tests/_sync/fake_embeddings.py @@ -0,0 +1,49 @@ +"""Fake Embedding class for testing purposes.""" + +from typing import List + +from langchain_core.embeddings import Embeddings + +fake_texts = ["foo", "bar", "baz"] + + +class FakeEmbeddings(Embeddings): + """Fake embeddings functionality for testing.""" + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return simple embeddings. + Embeddings encode each text as its index.""" + return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))] + + def embed_query(self, text: str) -> List[float]: + """Return constant query embeddings. + Embeddings are identical to embed_documents(texts)[0]. + Distance to each text will be that text's index, + as it was passed to embed_documents.""" + return [float(1.0)] * 9 + [float(0.0)] + + +class ConsistentFakeEmbeddings(FakeEmbeddings): + """Fake embeddings which remember all the texts seen so far to return consistent + vectors for the same texts.""" + + def __init__(self, dimensionality: int = 10) -> None: + self.known_texts: List[str] = [] + self.dimensionality = dimensionality + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return consistent embeddings for each text seen so far.""" + out_vectors = [] + for text in texts: + if text not in self.known_texts: + self.known_texts.append(text) + vector = [float(1.0)] * (self.dimensionality - 1) + [ + float(self.known_texts.index(text)) + ] + out_vectors.append(vector) + return out_vectors + + def embed_query(self, text: str) -> List[float]: + """Return consistent embeddings for the text, if seen before, or a constant + one if the text is unknown.""" + return (self.embed_documents([text]))[0] diff --git a/libs/elasticsearch/tests/conftest.py b/libs/elasticsearch/tests/conftest.py index 6d91265..cb3febf 100644 --- a/libs/elasticsearch/tests/conftest.py +++ b/libs/elasticsearch/tests/conftest.py @@ -1,21 +1,39 @@ from typing import Generator from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest -from elasticsearch import Elasticsearch +from elasticsearch import AsyncElasticsearch, Elasticsearch +from elasticsearch._async.client import IndicesClient as AsyncIndicesClient from elasticsearch._sync.client import IndicesClient from langchain_community.chat_models.fake import FakeMessagesListChatModel from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage -from langchain_elasticsearch import ElasticsearchCache, ElasticsearchEmbeddingsCache +from langchain_elasticsearch import ( + AsyncElasticsearchCache, + AsyncElasticsearchEmbeddingsCache, + ElasticsearchCache, + ElasticsearchEmbeddingsCache, +) @pytest.fixture def es_client_fx() -> Generator[MagicMock, None, None]: client_mock = MagicMock(spec=Elasticsearch) - client_mock.indices = MagicMock(spec=IndicesClient) + client_mock.return_value.indices = MagicMock(spec=IndicesClient) + yield client_mock() + + +@pytest.fixture +def async_es_client_fx() -> Generator[MagicMock, None, None]: + client_mock = MagicMock(spec=AsyncElasticsearch) + client_mock.return_value.indices = MagicMock(spec=AsyncIndicesClient) + # coroutines need to be mocked explicitly + client_mock.return_value.indices.exists_alias = AsyncMock() + client_mock.return_value.indices.put_mapping = AsyncMock() + client_mock.return_value.indices.exists = AsyncMock() + client_mock.return_value.indices.create = AsyncMock() yield client_mock() @@ -24,7 +42,7 @@ def es_embeddings_cache_fx( es_client_fx: MagicMock, ) -> Generator[ElasticsearchEmbeddingsCache, None, None]: with mock.patch( - "langchain_elasticsearch.cache.create_elasticsearch_client", + "langchain_elasticsearch._sync.cache.create_elasticsearch_client", return_value=es_client_fx, ): yield ElasticsearchEmbeddingsCache( @@ -37,9 +55,28 @@ def es_embeddings_cache_fx( @pytest.fixture -def es_cache_fx(es_client_fx: MagicMock) -> Generator[ElasticsearchCache, None, None]: +def async_es_embeddings_cache_fx( + async_es_client_fx: MagicMock, +) -> Generator[AsyncElasticsearchEmbeddingsCache, None, None]: with mock.patch( - "langchain_elasticsearch.cache.create_elasticsearch_client", + "langchain_elasticsearch._async.cache.create_async_elasticsearch_client", + return_value=async_es_client_fx, + ): + yield AsyncElasticsearchEmbeddingsCache( + es_url="http://localhost:9200", + index_name="test_index", + store_input=True, + namespace="test", + metadata={"project": "test_project"}, + ) + + +@pytest.fixture +def es_cache_fx( + es_client_fx: MagicMock, +) -> Generator[ElasticsearchCache, None, None]: + with mock.patch( + "langchain_elasticsearch._sync.cache.create_elasticsearch_client", return_value=es_client_fx, ): yield ElasticsearchCache( @@ -51,6 +88,23 @@ def es_cache_fx(es_client_fx: MagicMock) -> Generator[ElasticsearchCache, None, ) +@pytest.fixture +def async_es_cache_fx( + async_es_client_fx: MagicMock, +) -> Generator[AsyncElasticsearchCache, None, None]: + with mock.patch( + "langchain_elasticsearch._async.cache.create_async_elasticsearch_client", + return_value=async_es_client_fx, + ): + yield AsyncElasticsearchCache( + es_url="http://localhost:30096", + index_name="test_index", + store_input=True, + store_input_params=True, + metadata={"project": "test_project"}, + ) + + @pytest.fixture def fake_chat_fx() -> Generator[BaseChatModel, None, None]: yield FakeMessagesListChatModel( diff --git a/libs/elasticsearch/tests/fake_embeddings.py b/libs/elasticsearch/tests/fake_embeddings.py index 9623dd7..099eb94 100644 --- a/libs/elasticsearch/tests/fake_embeddings.py +++ b/libs/elasticsearch/tests/fake_embeddings.py @@ -2,48 +2,29 @@ from typing import List -from langchain_core.embeddings import Embeddings - -fake_texts = ["foo", "bar", "baz"] - - -class FakeEmbeddings(Embeddings): - """Fake embeddings functionality for testing.""" - +from ._async.fake_embeddings import ( + AsyncConsistentFakeEmbeddings as _AsyncConsistentFakeEmbeddings, +) +from ._async.fake_embeddings import AsyncFakeEmbeddings as _AsyncFakeEmbeddings +from ._sync.fake_embeddings import ( # noqa: F401 + ConsistentFakeEmbeddings, + FakeEmbeddings, +) + + +# langchain defines embed_documents and embed_query as abstract in its base class +# so we have to add dummy methods for them, even though we only use the async versions +class AsyncFakeEmbeddings(_AsyncFakeEmbeddings): def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Return simple embeddings. - Embeddings encode each text as its index.""" - return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))] + raise NotImplementedError("This class is asynchronous, use aembed_documents()") def embed_query(self, text: str) -> List[float]: - """Return constant query embeddings. - Embeddings are identical to embed_documents(texts)[0]. - Distance to each text will be that text's index, - as it was passed to embed_documents.""" - return [float(1.0)] * 9 + [float(0.0)] - - -class ConsistentFakeEmbeddings(FakeEmbeddings): - """Fake embeddings which remember all the texts seen so far to return consistent - vectors for the same texts.""" + raise NotImplementedError("This class is asynchronous, use aembed_query()") - def __init__(self, dimensionality: int = 10) -> None: - self.known_texts: List[str] = [] - self.dimensionality = dimensionality +class AsyncConsistentFakeEmbeddings(_AsyncConsistentFakeEmbeddings): def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Return consistent embeddings for each text seen so far.""" - out_vectors = [] - for text in texts: - if text not in self.known_texts: - self.known_texts.append(text) - vector = [float(1.0)] * (self.dimensionality - 1) + [ - float(self.known_texts.index(text)) - ] - out_vectors.append(vector) - return out_vectors + raise NotImplementedError("This class is asynchronous, use aembed_documents()") def embed_query(self, text: str) -> List[float]: - """Return consistent embeddings for the text, if seen before, or a constant - one if the text is unknown.""" - return self.embed_documents([text])[0] + raise NotImplementedError("This class is asynchronous, use aembed_query()") diff --git a/libs/elasticsearch/tests/integration_tests/_async/__init__.py b/libs/elasticsearch/tests/integration_tests/_async/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/elasticsearch/tests/integration_tests/_async/_test_utilities.py b/libs/elasticsearch/tests/integration_tests/_async/_test_utilities.py new file mode 100644 index 0000000..d8aa634 --- /dev/null +++ b/libs/elasticsearch/tests/integration_tests/_async/_test_utilities.py @@ -0,0 +1,77 @@ +import os +from typing import Any, Dict, List, Optional + +from elastic_transport import AsyncTransport +from elasticsearch import ( + AsyncElasticsearch, + BadRequestError, + ConflictError, + NotFoundError, +) + + +def read_env() -> Dict: + url = os.environ.get("ES_URL", "http://localhost:9200") + cloud_id = os.environ.get("ES_CLOUD_ID") + api_key = os.environ.get("ES_API_KEY") + + if cloud_id: + return {"es_cloud_id": cloud_id, "es_api_key": api_key} + return {"es_url": url} + + +class AsyncRequestSavingTransport(AsyncTransport): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.requests: List[Dict] = [] + + async def perform_request(self, *args, **kwargs): # type: ignore + self.requests.append(kwargs) + return await super().perform_request(*args, **kwargs) + + +def create_es_client( + es_params: Optional[Dict[str, str]] = None, + es_kwargs: Dict = {}, +) -> AsyncElasticsearch: + if es_params is None: + es_params = read_env() + if not es_kwargs: + es_kwargs = {} + + if "es_cloud_id" in es_params: + return AsyncElasticsearch( + cloud_id=es_params["es_cloud_id"], + api_key=es_params["es_api_key"], + **es_kwargs, + ) + + return AsyncElasticsearch(hosts=[es_params["es_url"]], **es_kwargs) + + +def requests_saving_es_client() -> AsyncElasticsearch: + return create_es_client(es_kwargs={"transport_class": AsyncRequestSavingTransport}) + + +async def clear_test_indices(es: AsyncElasticsearch) -> None: + index_names_response = await es.indices.get(index="_all") + index_names = index_names_response.keys() + for index_name in index_names: + if index_name.startswith("test_"): + await es.indices.delete(index=index_name) + await es.indices.refresh(index="_all") + + +async def model_is_deployed(client: AsyncElasticsearch, model_id: str) -> bool: + try: + dummy = {"x": "y"} + await client.ml.infer_trained_model(model_id=model_id, docs=[dummy]) + return True + except NotFoundError: + return False + except ConflictError: + return False + except BadRequestError: + # This error is expected because we do not know the expected document + # shape and just use a dummy doc above. + return True diff --git a/libs/elasticsearch/tests/integration_tests/_async/test_cache.py b/libs/elasticsearch/tests/integration_tests/_async/test_cache.py new file mode 100644 index 0000000..8feda80 --- /dev/null +++ b/libs/elasticsearch/tests/integration_tests/_async/test_cache.py @@ -0,0 +1,299 @@ +from typing import AsyncGenerator, Dict, Union + +import pytest +from elasticsearch.helpers import BulkIndexError +from langchain.embeddings.cache import _value_serializer +from langchain.globals import set_llm_cache +from langchain_core.language_models import BaseChatModel + +from langchain_elasticsearch import ( + AsyncElasticsearchCache, + AsyncElasticsearchEmbeddingsCache, +) + +from ._test_utilities import clear_test_indices, create_es_client, read_env + + +@pytest.fixture +async def es_env_fx() -> Union[dict, AsyncGenerator]: + params = read_env() + es = create_es_client(params) + await es.options(ignore_status=404).indices.delete(index="test_index1") + await es.options(ignore_status=404).indices.delete(index="test_index2") + await es.indices.create(index="test_index1") + await es.indices.create(index="test_index2") + await es.indices.put_alias(index="test_index1", name="test_alias") + await es.indices.put_alias( + index="test_index2", name="test_alias", is_write_index=True + ) + yield params + await es.options(ignore_status=404).indices.delete_alias( + index="test_index1,test_index2", name="test_alias" + ) + await clear_test_indices(es) + await es.close() + + +@pytest.mark.asyncio +async def test_index_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: + cache = AsyncElasticsearchCache( + **es_env_fx, index_name="test_index1", metadata={"project": "test"} + ) + es_client = cache._es_client + set_llm_cache(cache) + await fake_chat_fx.ainvoke("test") + assert (await es_client.count(index="test_index1"))["count"] == 1 + await fake_chat_fx.ainvoke("test") + assert (await es_client.count(index="test_index1"))["count"] == 1 + record = (await es_client.search(index="test_index1"))["hits"]["hits"][0]["_source"] + assert "test output" in record.get("llm_output", [""])[0] + assert record.get("llm_input") + assert record.get("timestamp") + assert record.get("llm_params") + assert record.get("metadata") == {"project": "test"} + cache2 = AsyncElasticsearchCache( + **es_env_fx, + index_name="test_index1", + metadata={"project": "test"}, + store_input=False, + store_input_params=False, + ) + set_llm_cache(cache2) + await fake_chat_fx.ainvoke("test") + assert (await es_client.count(index="test_index1"))["count"] == 1 + await fake_chat_fx.ainvoke("test2") + assert (await es_client.count(index="test_index1"))["count"] == 2 + await fake_chat_fx.ainvoke("test2") + records = [ + record["_source"] + for record in (await es_client.search(index="test_index1"))["hits"]["hits"] + ] + assert all("test output" in record.get("llm_output", [""])[0] for record in records) + assert not all(record.get("llm_input", "") for record in records) + assert all(record.get("timestamp", "") for record in records) + assert not all(record.get("llm_params", "") for record in records) + assert all(record.get("metadata") == {"project": "test"} for record in records) + + +@pytest.mark.asyncio +async def test_alias_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: + cache = AsyncElasticsearchCache( + **es_env_fx, index_name="test_alias", metadata={"project": "test"} + ) + es_client = cache._es_client + set_llm_cache(cache) + await fake_chat_fx.ainvoke("test") + assert (await es_client.count(index="test_index2"))["count"] == 1 + await fake_chat_fx.ainvoke("test2") + assert (await es_client.count(index="test_index2"))["count"] == 2 + await es_client.indices.put_alias( + index="test_index2", name="test_alias", is_write_index=False + ) + await es_client.indices.put_alias( + index="test_index1", name="test_alias", is_write_index=True + ) + await fake_chat_fx.ainvoke("test3") + assert (await es_client.count(index="test_index1"))["count"] == 1 + await fake_chat_fx.ainvoke("test2") + assert (await es_client.count(index="test_index1"))["count"] == 1 + await es_client.indices.delete_alias(index="test_index2", name="test_alias") + # we cache the response for prompt "test2" on both test_index1 and test_index2 + await fake_chat_fx.ainvoke("test2") + assert (await es_client.count(index="test_index1"))["count"] == 2 + await es_client.indices.put_alias(index="test_index2", name="test_alias") + # we just test the latter scenario is working + assert await fake_chat_fx.ainvoke("test2") + + +@pytest.mark.asyncio +async def test_clear_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: + cache = AsyncElasticsearchCache( + **es_env_fx, index_name="test_alias", metadata={"project": "test"} + ) + es_client = cache._es_client + set_llm_cache(cache) + await fake_chat_fx.ainvoke("test") + await fake_chat_fx.ainvoke("test2") + await es_client.indices.put_alias( + index="test_index2", name="test_alias", is_write_index=False + ) + await es_client.indices.put_alias( + index="test_index1", name="test_alias", is_write_index=True + ) + await fake_chat_fx.ainvoke("test3") + assert (await es_client.count(index="test_alias"))["count"] == 3 + await cache.aclear() + assert (await es_client.count(index="test_alias"))["count"] == 0 + + +@pytest.mark.asyncio +async def test_mdelete_cache_store(es_env_fx: Dict) -> None: + store = AsyncElasticsearchEmbeddingsCache( + **es_env_fx, index_name="test_alias", metadata={"project": "test"} + ) + + recors = ["my little tests", "my little tests2", "my little tests3"] + await store.amset( + [ + (recors[0], _value_serializer([1, 2, 3])), + (recors[1], _value_serializer([1, 2, 3])), + (recors[2], _value_serializer([1, 2, 3])), + ] + ) + + assert (await store._es_client.count(index="test_alias"))["count"] == 3 + + await store.amdelete(recors[:2]) + assert (await store._es_client.count(index="test_alias"))["count"] == 1 + + await store.amdelete(recors[2:]) + assert (await store._es_client.count(index="test_alias"))["count"] == 0 + + with pytest.raises(BulkIndexError): + await store.amdelete(recors) + + +@pytest.mark.asyncio +async def test_mset_cache_store(es_env_fx: Dict) -> None: + store = AsyncElasticsearchEmbeddingsCache( + **es_env_fx, index_name="test_alias", metadata={"project": "test"} + ) + + records = ["my little tests", "my little tests2", "my little tests3"] + + await store.amset([(records[0], _value_serializer([1, 2, 3]))]) + assert (await store._es_client.count(index="test_alias"))["count"] == 1 + await store.amset([(records[0], _value_serializer([1, 2, 3]))]) + assert (await store._es_client.count(index="test_alias"))["count"] == 1 + await store.amset( + [ + (records[1], _value_serializer([1, 2, 3])), + (records[2], _value_serializer([1, 2, 3])), + ] + ) + assert (await store._es_client.count(index="test_alias"))["count"] == 3 + + +@pytest.mark.asyncio +async def test_mget_cache_store(es_env_fx: Dict) -> None: + store_no_alias = AsyncElasticsearchEmbeddingsCache( + **es_env_fx, + index_name="test_index3", + metadata={"project": "test"}, + namespace="test", + ) + + records = ["my little tests", "my little tests2", "my little tests3"] + docs = [(r, _value_serializer([0.1, 2, i])) for i, r in enumerate(records)] + + await store_no_alias.amset(docs) + assert (await store_no_alias._es_client.count(index="test_index3"))["count"] == 3 + + cached_records = await store_no_alias.amget([d[0] for d in docs]) + assert all(cached_records) + assert all([r == d[1] for r, d in zip(cached_records, docs)]) + + store_alias = AsyncElasticsearchEmbeddingsCache( + **es_env_fx, + index_name="test_alias", + metadata={"project": "test"}, + namespace="test", + maximum_duplicates_allowed=1, + ) + + await store_alias.amset(docs) + assert (await store_alias._es_client.count(index="test_alias"))["count"] == 3 + + cached_records = await store_alias.amget([d[0] for d in docs]) + assert all(cached_records) + assert all([r == d[1] for r, d in zip(cached_records, docs)]) + + +@pytest.mark.asyncio +async def test_mget_cache_store_multiple_keys(es_env_fx: Dict) -> None: + """verify the logic of deduplication of keys in the cache store""" + + store_alias = AsyncElasticsearchEmbeddingsCache( + **es_env_fx, + index_name="test_alias", + metadata={"project": "test"}, + namespace="test", + maximum_duplicates_allowed=2, + ) + + es_client = store_alias._es_client + + records = ["my little tests", "my little tests2", "my little tests3"] + docs = [(r, _value_serializer([0.1, 2, i])) for i, r in enumerate(records)] + + await store_alias.amset(docs) + assert (await es_client.count(index="test_alias"))["count"] == 3 + + store_no_alias = AsyncElasticsearchEmbeddingsCache( + **es_env_fx, + index_name="test_index3", + metadata={"project": "test"}, + namespace="test", + maximum_duplicates_allowed=1, + ) + + new_records = records + ["my little tests4", "my little tests5"] + new_docs = [ + (r, _value_serializer([0.1, 2, i + 100])) for i, r in enumerate(new_records) + ] + + # store the same 3 previous records and 2 more in a fresh index + await store_no_alias.amset(new_docs) + assert (await es_client.count(index="test_index3"))["count"] == 5 + + # update the alias to point to the new index and verify the cache + await es_client.indices.update_aliases( + actions=[ + { + "add": { + "index": "test_index3", + "alias": "test_alias", + } + } + ] + ) + + # the alias now point to two indices that contains multiple records + # of the same keys, the cache store should return the latest records. + cached_records = await store_alias.amget([d[0] for d in new_docs]) + assert all(cached_records) + assert len(cached_records) == 5 + assert (await es_client.count(index="test_alias"))["count"] == 8 + assert cached_records[:3] != [ + d[1] for d in docs + ], "the first 3 records should be updated" + assert cached_records == [ + d[1] for d in new_docs + ], "new records should be returned and the updated ones" + assert all([r == d[1] for r, d in zip(cached_records, new_docs)]) + await es_client.options(ignore_status=404).indices.delete_alias( + index="test_index3", name="test_alias" + ) + + +@pytest.mark.asyncio +async def test_build_document_cache_store(es_env_fx: Dict) -> None: + store = AsyncElasticsearchEmbeddingsCache( + **es_env_fx, + index_name="test_alias", + metadata={"project": "test"}, + namespace="test", + ) + + await store.amset([("my little tests", _value_serializer([0.1, 2, 3]))]) + record = (await store._es_client.search(index="test_alias"))["hits"]["hits"][0][ + "_source" + ] + + assert record.get("metadata") == {"project": "test"} + assert record.get("namespace") == "test" + assert record.get("timestamp") + assert record.get("text_input") == "my little tests" + assert record.get("vector_dump") == AsyncElasticsearchEmbeddingsCache.encode_vector( + _value_serializer([0.1, 2, 3]) + ) diff --git a/libs/elasticsearch/tests/integration_tests/_async/test_chat_history.py b/libs/elasticsearch/tests/integration_tests/_async/test_chat_history.py new file mode 100644 index 0000000..100bb1a --- /dev/null +++ b/libs/elasticsearch/tests/integration_tests/_async/test_chat_history.py @@ -0,0 +1,72 @@ +import json +import uuid +from typing import AsyncIterator + +import pytest +from langchain.memory import ConversationBufferMemory +from langchain_core.messages import AIMessage, HumanMessage, message_to_dict + +from langchain_elasticsearch.chat_history import AsyncElasticsearchChatMessageHistory + +from ._test_utilities import clear_test_indices, create_es_client, read_env + +""" +cd tests/integration_tests +docker-compose up elasticsearch + +By default runs against local docker instance of Elasticsearch. +To run against Elastic Cloud, set the following environment variables: +- ES_CLOUD_ID +- ES_USERNAME +- ES_PASSWORD +""" + + +class TestElasticsearch: + @pytest.fixture + async def elasticsearch_connection(self) -> AsyncIterator[dict]: + params = read_env() + es = create_es_client(params) + + yield params + + await clear_test_indices(es) + await es.close() + + @pytest.fixture(scope="function") + def index_name(self) -> str: + """Return the index name.""" + return f"test_{uuid.uuid4().hex}" + + async def test_memory_with_message_store( + self, elasticsearch_connection: dict, index_name: str + ) -> None: + """Test the memory with a message store.""" + # setup Elasticsearch as a message store + message_history = AsyncElasticsearchChatMessageHistory( + **elasticsearch_connection, index=index_name, session_id="test-session" + ) + + memory = ConversationBufferMemory( + memory_key="baz", chat_memory=message_history, return_messages=True + ) + + # add some messages + await memory.chat_memory.aadd_messages( + [ + AIMessage("This is me, the AI"), + HumanMessage("This is me, the human"), + ] + ) + + # get the message history from the memory store and turn it into a json + messages = await memory.chat_memory.aget_messages() + messages_json = json.dumps([message_to_dict(msg) for msg in messages]) + + assert "This is me, the AI" in messages_json + assert "This is me, the human" in messages_json + + # remove the record from Elasticsearch, so the next test run won't pick it up + await memory.chat_memory.aclear() + + assert await memory.chat_memory.aget_messages() == [] diff --git a/libs/elasticsearch/tests/integration_tests/_async/test_embeddings.py b/libs/elasticsearch/tests/integration_tests/_async/test_embeddings.py new file mode 100644 index 0000000..799ff00 --- /dev/null +++ b/libs/elasticsearch/tests/integration_tests/_async/test_embeddings.py @@ -0,0 +1,54 @@ +"""Test elasticsearch_embeddings embeddings.""" + +import os + +import pytest +from elasticsearch import AsyncElasticsearch + +from langchain_elasticsearch.embeddings import AsyncElasticsearchEmbeddings + +from ._test_utilities import model_is_deployed + +# deployed with +# https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html +MODEL_ID = os.getenv("MODEL_ID", "sentence-transformers__msmarco-minilm-l-12-v3") +NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384")) + +ES_URL = os.environ.get("ES_URL", "http://localhost:9200") + + +@pytest.mark.asyncio +async def test_elasticsearch_embedding_documents() -> None: + """Test Elasticsearch embedding documents.""" + client = AsyncElasticsearch(hosts=[ES_URL]) + if not (await model_is_deployed(client, MODEL_ID)): + await client.close() + pytest.skip( + reason=f"{MODEL_ID} model is not deployed in ML Node, skipping test" + ) + + documents = ["foo bar", "bar foo", "foo"] + embedding = AsyncElasticsearchEmbeddings.from_es_connection(MODEL_ID, client) + output = await embedding.aembed_documents(documents) + await client.close() + assert len(output) == 3 + assert len(output[0]) == NUM_DIMENSIONS + assert len(output[1]) == NUM_DIMENSIONS + assert len(output[2]) == NUM_DIMENSIONS + + +@pytest.mark.asyncio +async def test_elasticsearch_embedding_query() -> None: + """Test Elasticsearch embedding query.""" + client = AsyncElasticsearch(hosts=[ES_URL]) + if not (await model_is_deployed(client, MODEL_ID)): + await client.close() + pytest.skip( + reason=f"{MODEL_ID} model is not deployed in ML Node, skipping test" + ) + + document = "foo bar" + embedding = AsyncElasticsearchEmbeddings.from_es_connection(MODEL_ID, client) + output = await embedding.aembed_query(document) + await client.close() + assert len(output) == NUM_DIMENSIONS diff --git a/libs/elasticsearch/tests/integration_tests/_async/test_retrievers.py b/libs/elasticsearch/tests/integration_tests/_async/test_retrievers.py new file mode 100644 index 0000000..bd88e24 --- /dev/null +++ b/libs/elasticsearch/tests/integration_tests/_async/test_retrievers.py @@ -0,0 +1,239 @@ +"""Test ElasticsearchRetriever functionality.""" + +import os +import re +import uuid +from typing import Any, Dict + +import pytest +from elasticsearch import AsyncElasticsearch +from langchain_core.documents import Document + +from langchain_elasticsearch.retrievers import AsyncElasticsearchRetriever + +from ._test_utilities import requests_saving_es_client + +""" +cd tests/integration_tests +docker-compose up elasticsearch + +By default runs against local docker instance of Elasticsearch. +To run against Elastic Cloud, set the following environment variables: +- ES_CLOUD_ID +- ES_API_KEY +""" + + +async def index_test_data( + es_client: AsyncElasticsearch, index_name: str, field_name: str +) -> None: + docs = [(1, "foo bar"), (2, "bar"), (3, "foo"), (4, "baz"), (5, "foo baz")] + for identifier, text in docs: + await es_client.index( + index=index_name, + document={field_name: text, "another_field": 1}, + id=str(identifier), + refresh=True, + ) + + +class TestElasticsearchRetriever: + @pytest.fixture(scope="function") + async def es_client(self) -> Any: + client = requests_saving_es_client() + yield client + await client.close() + + @pytest.fixture(scope="function") + def index_name(self) -> str: + """Return the index name.""" + return f"test_{uuid.uuid4().hex}" + + @pytest.mark.asyncio + async def test_user_agent_header( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """Test that the user agent header is set correctly.""" + + retriever = AsyncElasticsearchRetriever( + index_name=index_name, + body_func=lambda _: {"query": {"match_all": {}}}, + content_field="text", + es_client=es_client, + ) + + assert retriever.es_client + user_agent = retriever.es_client._headers["User-Agent"] + assert ( + re.match(r"^langchain-py-r/\d+\.\d+\.\d+(?:rc\d+)?$", user_agent) + is not None + ), f"The string '{user_agent}' does not match the expected pattern." + + await index_test_data(es_client, index_name, "text") + await retriever.aget_relevant_documents("foo") + + search_request = es_client.transport.requests[-1] # type: ignore[attr-defined] + user_agent = search_request["headers"]["User-Agent"] + assert ( + re.match(r"^langchain-py-r/\d+\.\d+\.\d+(?:rc\d+)?$", user_agent) + is not None + ), f"The string '{user_agent}' does not match the expected pattern." + + @pytest.mark.asyncio + async def test_init_url(self, index_name: str) -> None: + """Test end-to-end indexing and search.""" + + text_field = "text" + + def body_func(query: str) -> Dict: + return {"query": {"match": {text_field: {"query": query}}}} + + es_url = os.environ.get("ES_URL", "http://localhost:9200") + cloud_id = os.environ.get("ES_CLOUD_ID") + api_key = os.environ.get("ES_API_KEY") + + config = ( + {"cloud_id": cloud_id, "api_key": api_key} if cloud_id else {"url": es_url} + ) + + retriever = AsyncElasticsearchRetriever.from_es_params( + index_name=index_name, + body_func=body_func, + content_field=text_field, + **config, # type: ignore[arg-type] + ) + + await index_test_data(retriever.es_client, index_name, text_field) + result = await retriever.aget_relevant_documents("foo") + + assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"} + assert {r.metadata["_id"] for r in result} == {"3", "1", "5"} + for r in result: + assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"} + assert text_field not in r.metadata["_source"] + assert "another_field" in r.metadata["_source"] + + @pytest.mark.asyncio + async def test_init_client( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """Test end-to-end indexing and search.""" + + text_field = "text" + + def body_func(query: str) -> Dict: + return {"query": {"match": {text_field: {"query": query}}}} + + retriever = AsyncElasticsearchRetriever( + index_name=index_name, + body_func=body_func, + content_field=text_field, + es_client=es_client, + ) + + await index_test_data(es_client, index_name, text_field) + result = await retriever.aget_relevant_documents("foo") + + assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"} + assert {r.metadata["_id"] for r in result} == {"3", "1", "5"} + for r in result: + assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"} + assert text_field not in r.metadata["_source"] + assert "another_field" in r.metadata["_source"] + + @pytest.mark.asyncio + async def test_multiple_index_and_content_fields( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """Test multiple content fields""" + index_name_1 = f"{index_name}_1" + index_name_2 = f"{index_name}_2" + text_field_1 = "text_1" + text_field_2 = "text_2" + + def body_func(query: str) -> Dict: + return { + "query": { + "multi_match": { + "query": query, + "fields": [text_field_1, text_field_2], + } + } + } + + retriever = AsyncElasticsearchRetriever( + index_name=[index_name_1, index_name_2], + content_field={index_name_1: text_field_1, index_name_2: text_field_2}, + body_func=body_func, + es_client=es_client, + ) + + await index_test_data(es_client, index_name_1, text_field_1) + await index_test_data(es_client, index_name_2, text_field_2) + result = await retriever.aget_relevant_documents("foo") + + # matches from both indices + assert sorted([(r.page_content, r.metadata["_index"]) for r in result]) == [ + ("foo", index_name_1), + ("foo", index_name_2), + ("foo bar", index_name_1), + ("foo bar", index_name_2), + ("foo baz", index_name_1), + ("foo baz", index_name_2), + ] + + @pytest.mark.asyncio + async def test_custom_mapper( + self, es_client: AsyncElasticsearch, index_name: str + ) -> None: + """Test custom document maper""" + + text_field = "text" + meta = {"some_field": 12} + + def body_func(query: str) -> Dict: + return {"query": {"match": {text_field: {"query": query}}}} + + def id_as_content(hit: Dict) -> Document: + return Document(page_content=hit["_id"], metadata=meta) + + retriever = AsyncElasticsearchRetriever( + index_name=index_name, + body_func=body_func, + document_mapper=id_as_content, + es_client=es_client, + ) + + await index_test_data(es_client, index_name, text_field) + result = await retriever.aget_relevant_documents("foo") + + assert [r.page_content for r in result] == ["3", "1", "5"] + assert [r.metadata for r in result] == [meta, meta, meta] + + @pytest.mark.asyncio + async def test_fail_content_field_and_mapper( + self, es_client: AsyncElasticsearch + ) -> None: + """Raise exception if both content_field and document_mapper are specified.""" + + with pytest.raises(ValueError): + AsyncElasticsearchRetriever( + content_field="text", + document_mapper=lambda x: x, + index_name="foo", + body_func=lambda x: x, + es_client=es_client, + ) + + @pytest.mark.asyncio + async def test_fail_neither_content_field_nor_mapper( + self, es_client: AsyncElasticsearch + ) -> None: + """Raise exception if neither content_field nor document_mapper are specified""" + + with pytest.raises(ValueError): + AsyncElasticsearchRetriever( + index_name="foo", + body_func=lambda x: x, + es_client=es_client, + ) diff --git a/libs/elasticsearch/tests/integration_tests/_async/test_vectorstores.py b/libs/elasticsearch/tests/integration_tests/_async/test_vectorstores.py new file mode 100644 index 0000000..d3c2239 --- /dev/null +++ b/libs/elasticsearch/tests/integration_tests/_async/test_vectorstores.py @@ -0,0 +1,936 @@ +"""Test AsyncElasticsearchStore functionality.""" + +import logging +import uuid +from typing import Any, AsyncIterator, Dict, List, Optional, Union + +import pytest +from elasticsearch import NotFoundError +from langchain_core.documents import Document + +from langchain_elasticsearch.vectorstores import AsyncElasticsearchStore + +from ...fake_embeddings import AsyncConsistentFakeEmbeddings, AsyncFakeEmbeddings +from ._test_utilities import clear_test_indices, create_es_client, read_env + +logging.basicConfig(level=logging.DEBUG) + +""" +cd tests/integration_tests +docker-compose up elasticsearch + +By default runs against local docker instance of Elasticsearch. +To run against Elastic Cloud, set the following environment variables: +- ES_CLOUD_ID +- ES_API_KEY +""" + + +class TestElasticsearch: + @pytest.fixture + async def es_params(self) -> AsyncIterator[dict]: + params = read_env() + es = create_es_client(params) + + yield params + + await clear_test_indices(es) + await es.close() + + @pytest.fixture(scope="function") + def index_name(self) -> str: + """Return the index name.""" + return f"test_{uuid.uuid4().hex}" + + @pytest.mark.asyncio + async def test_from_texts_similarity_search_with_doc_builder( + self, es_params: dict, index_name: str + ) -> None: + texts = ["foo", "foo", "foo"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = await AsyncElasticsearchStore.afrom_texts( + texts, + AsyncFakeEmbeddings(), + metadatas=metadatas, + **es_params, + index_name=index_name, + ) + + def custom_document_builder(_: Dict) -> Document: + return Document( + page_content="Mock content!", + metadata={ + "page_number": -1, + "original_filename": "Mock filename!", + }, + ) + + output = await docsearch.asimilarity_search( + query="foo", k=1, doc_builder=custom_document_builder + ) + assert output[0].page_content == "Mock content!" + assert output[0].metadata["page_number"] == -1 + assert output[0].metadata["original_filename"] == "Mock filename!" + + await docsearch.aclose() + + @pytest.mark.asyncio + async def test_search_with_relevance_threshold( + self, es_params: dict, index_name: str + ) -> None: + """Test to make sure the relevance threshold is respected.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + embeddings = AsyncConsistentFakeEmbeddings() + + docsearch = await AsyncElasticsearchStore.afrom_texts( + index_name=index_name, + texts=texts, + embedding=embeddings, + metadatas=metadatas, + **es_params, + ) + + # Find a good threshold for testing + query_string = "foo" + top3 = await docsearch.asimilarity_search_with_relevance_scores( + query=query_string, k=3 + ) + similarity_of_second_ranked = top3[1][1] + assert len(top3) == 3 + + # Test threshold + retriever = docsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"score_threshold": similarity_of_second_ranked}, + ) + output = await retriever.aget_relevant_documents(query=query_string) + + assert output == [ + top3[0][0], + top3[1][0], + # third ranked is out + ] + + await docsearch.aclose() + + @pytest.mark.asyncio + async def test_search_by_vector_with_relevance_threshold( + self, es_params: dict, index_name: str + ) -> None: + """Test to make sure the relevance threshold is respected.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + embeddings = AsyncConsistentFakeEmbeddings() + + docsearch = await AsyncElasticsearchStore.afrom_texts( + index_name=index_name, + texts=texts, + embedding=embeddings, + metadatas=metadatas, + **es_params, + ) + + # Find a good threshold for testing + query_string = "foo" + embedded_query = await embeddings.aembed_query(query_string) + top3 = await docsearch.asimilarity_search_by_vector_with_relevance_scores( + embedding=embedded_query, k=3 + ) + similarity_of_second_ranked = top3[1][1] + assert len(top3) == 3 + + # Test threshold + retriever = docsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"score_threshold": similarity_of_second_ranked}, + ) + output = await retriever.aget_relevant_documents(query=query_string) + + assert output == [ + top3[0][0], + top3[1][0], + # third ranked is out + ] + + await docsearch.aclose() + + # Also tested in elasticsearch.helpers.vectorstore + + @pytest.mark.asyncio + async def test_similarity_search_without_metadata( + self, es_params: dict, index_name: str + ) -> None: + """Test end to end construction and search without metadata.""" + + def assert_query( + query_body: Dict[str, Any], query: Optional[str] + ) -> Dict[str, Any]: + assert query_body == { + "knn": { + "field": "vector", + "filter": [], + "k": 1, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + } + } + return query_body + + texts = ["foo", "bar", "baz"] + docsearch = await AsyncElasticsearchStore.afrom_texts( + texts, + AsyncFakeEmbeddings(), + **es_params, + index_name=index_name, + ) + output = await docsearch.asimilarity_search( + "foo", k=1, custom_query=assert_query + ) + assert output == [Document(page_content="foo")] + + @pytest.mark.asyncio + async def test_add_embeddings(self, es_params: dict, index_name: str) -> None: + """ + Test add_embeddings, which accepts pre-built embeddings instead of + using inference for the texts. + This allows you to separate the embeddings text and the page_content + for better proximity between user's question and embedded text. + For example, your embedding text can be a question, whereas page_content + is the answer. + """ + embeddings = AsyncConsistentFakeEmbeddings() + text_input = ["foo1", "foo2", "foo3"] + metadatas = [{"page": i} for i in range(len(text_input))] + + """In real use case, embedding_input can be questions for each text""" + embedding_input = ["foo2", "foo3", "foo1"] + embedding_vectors = await embeddings.aembed_documents(embedding_input) + + docsearch = AsyncElasticsearchStore( + embedding=embeddings, + **es_params, + index_name=index_name, + ) + await docsearch.aadd_embeddings( + list(zip(text_input, embedding_vectors)), metadatas + ) + output = await docsearch.asimilarity_search("foo1", k=1) + assert output == [Document(page_content="foo3", metadata={"page": 2})] + + @pytest.mark.asyncio + async def test_similarity_search_with_metadata( + self, es_params: dict, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = await AsyncElasticsearchStore.afrom_texts( + texts, + AsyncConsistentFakeEmbeddings(), + metadatas=metadatas, + **es_params, + index_name=index_name, + ) + + output = await docsearch.asimilarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={"page": 0})] + + output = await docsearch.asimilarity_search("bar", k=1) + assert output == [Document(page_content="bar", metadata={"page": 1})] + + @pytest.mark.asyncio + async def test_similarity_search_with_filter( + self, es_params: dict, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + texts = ["foo", "foo", "foo"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = await AsyncElasticsearchStore.afrom_texts( + texts, + AsyncFakeEmbeddings(), + metadatas=metadatas, + **es_params, + index_name=index_name, + ) + + def assert_query( + query_body: Dict[str, Any], query: Optional[str] + ) -> Dict[str, Any]: + assert query_body == { + "knn": { + "field": "vector", + "filter": [{"term": {"metadata.page": "1"}}], + "k": 3, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + } + } + return query_body + + output = await docsearch.asimilarity_search( + query="foo", + k=3, + filter=[{"term": {"metadata.page": "1"}}], + custom_query=assert_query, + ) + assert output == [Document(page_content="foo", metadata={"page": 1})] + + @pytest.mark.asyncio + async def test_similarity_search_with_doc_builder( + self, es_params: dict, index_name: str + ) -> None: + texts = ["foo", "foo", "foo"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = await AsyncElasticsearchStore.afrom_texts( + texts, + AsyncFakeEmbeddings(), + metadatas=metadatas, + **es_params, + index_name=index_name, + ) + + def custom_document_builder(_: Dict) -> Document: + return Document( + page_content="Mock content!", + metadata={ + "page_number": -1, + "original_filename": "Mock filename!", + }, + ) + + output = await docsearch.asimilarity_search( + query="foo", k=1, doc_builder=custom_document_builder + ) + assert output[0].page_content == "Mock content!" + assert output[0].metadata["page_number"] == -1 + assert output[0].metadata["original_filename"] == "Mock filename!" + + @pytest.mark.asyncio + async def test_similarity_search_exact_search( + self, es_params: dict, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + texts = ["foo", "bar", "baz"] + docsearch = await AsyncElasticsearchStore.afrom_texts( + texts, + AsyncFakeEmbeddings(), + **es_params, + index_name=index_name, + strategy=AsyncElasticsearchStore.ExactRetrievalStrategy(), + ) + + expected_query = { + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": "cosineSimilarity(params.query_vector, 'vector') + 1.0", # noqa: E501 + "params": { + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ] + }, + }, + } + } + } + + def assert_query( + query_body: Dict[str, Any], query: Optional[str] + ) -> Dict[str, Any]: + assert query_body == expected_query + return query_body + + output = await docsearch.asimilarity_search( + "foo", k=1, custom_query=assert_query + ) + assert output == [Document(page_content="foo")] + + @pytest.mark.asyncio + async def test_similarity_search_exact_search_with_filter( + self, es_params: dict, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = await AsyncElasticsearchStore.afrom_texts( + texts, + AsyncFakeEmbeddings(), + **es_params, + index_name=index_name, + metadatas=metadatas, + strategy=AsyncElasticsearchStore.ExactRetrievalStrategy(), + ) + + def assert_query( + query_body: Dict[str, Any], query: Optional[str] + ) -> Dict[str, Any]: + expected_query = { + "query": { + "script_score": { + "query": {"bool": {"filter": [{"term": {"metadata.page": 0}}]}}, + "script": { + "source": "cosineSimilarity(params.query_vector, 'vector') + 1.0", # noqa: E501 + "params": { + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ] + }, + }, + } + } + } + assert query_body == expected_query + return query_body + + output = await docsearch.asimilarity_search( + "foo", + k=1, + custom_query=assert_query, + filter=[{"term": {"metadata.page": 0}}], + ) + assert output == [Document(page_content="foo", metadata={"page": 0})] + + @pytest.mark.asyncio + async def test_similarity_search_exact_search_distance_dot_product( + self, es_params: dict, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + texts = ["foo", "bar", "baz"] + docsearch = await AsyncElasticsearchStore.afrom_texts( + texts, + AsyncFakeEmbeddings(), + **es_params, + index_name=index_name, + strategy=AsyncElasticsearchStore.ExactRetrievalStrategy(), + distance_strategy="DOT_PRODUCT", + ) + + def assert_query( + query_body: Dict[str, Any], query: Optional[str] + ) -> Dict[str, Any]: + assert query_body == { + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": """ + double value = dotProduct(params.query_vector, 'vector'); + return sigmoid(1, Math.E, -value); + """, + "params": { + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ] + }, + }, + } + } + } + return query_body + + output = await docsearch.asimilarity_search( + "foo", k=1, custom_query=assert_query + ) + assert output == [Document(page_content="foo")] + + @pytest.mark.asyncio + async def test_similarity_search_exact_search_unknown_distance_strategy( + self, es_params: dict, index_name: str + ) -> None: + """Test end to end construction and search with unknown distance strategy.""" + + with pytest.raises(KeyError): + texts = ["foo", "bar", "baz"] + await AsyncElasticsearchStore.afrom_texts( + texts, + AsyncFakeEmbeddings(), + **es_params, + index_name=index_name, + strategy=AsyncElasticsearchStore.ExactRetrievalStrategy(), + distance_strategy="NOT_A_STRATEGY", + ) + + @pytest.mark.asyncio + async def test_max_marginal_relevance_search( + self, es_params: dict, index_name: str + ) -> None: + """Test max marginal relevance search.""" + texts = ["foo", "bar", "baz"] + docsearch = await AsyncElasticsearchStore.afrom_texts( + texts, + AsyncFakeEmbeddings(), + **es_params, + index_name=index_name, + strategy=AsyncElasticsearchStore.ExactRetrievalStrategy(), + ) + + mmr_output = await docsearch.amax_marginal_relevance_search( + texts[0], k=3, fetch_k=3 + ) + sim_output = await docsearch.asimilarity_search(texts[0], k=3) + assert mmr_output == sim_output + + mmr_output = await docsearch.amax_marginal_relevance_search( + texts[0], k=2, fetch_k=3 + ) + assert len(mmr_output) == 2 + assert mmr_output[0].page_content == texts[0] + assert mmr_output[1].page_content == texts[1] + + mmr_output = await docsearch.amax_marginal_relevance_search( + texts[0], + k=2, + fetch_k=3, + lambda_mult=0.1, # more diversity + ) + assert len(mmr_output) == 2 + assert mmr_output[0].page_content == texts[0] + assert mmr_output[1].page_content == texts[2] + + # if fetch_k < k, then the output will be less than k + mmr_output = await docsearch.amax_marginal_relevance_search( + texts[0], k=3, fetch_k=2 + ) + assert len(mmr_output) == 2 + + @pytest.mark.asyncio + async def test_similarity_search_approx_with_hybrid_search( + self, es_params: dict, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + texts = ["foo", "bar", "baz"] + docsearch = await AsyncElasticsearchStore.afrom_texts( + texts, + AsyncFakeEmbeddings(), + **es_params, + index_name=index_name, + strategy=AsyncElasticsearchStore.ApproxRetrievalStrategy(hybrid=True), + ) + + def assert_query( + query_body: Dict[str, Any], query: Optional[str] + ) -> Dict[str, Any]: + assert query_body == { + "knn": { + "field": "vector", + "filter": [], + "k": 1, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + }, + "query": { + "bool": { + "filter": [], + "must": [{"match": {"text": {"query": "foo"}}}], + } + }, + "rank": {"rrf": {}}, + } + return query_body + + output = await docsearch.asimilarity_search( + "foo", k=1, custom_query=assert_query + ) + assert output == [Document(page_content="foo")] + + @pytest.mark.asyncio + async def test_similarity_search_approx_by_vector( + self, es_params: dict, index_name: str + ) -> None: + """Test end to end construction and search with metadata.""" + texts = ["foo", "bar", "baz"] + embeddings = AsyncConsistentFakeEmbeddings() + docsearch = await AsyncElasticsearchStore.afrom_texts( + texts, + embedding=embeddings, + **es_params, + index_name=index_name, + ) + query_vector = await embeddings.aembed_query("foo") + + def assert_query( + query_body: Dict[str, Any], query: Optional[str] + ) -> Dict[str, Any]: + assert query_body == { + "knn": { + "field": "vector", + "filter": [], + "k": 1, + "num_candidates": 50, + "query_vector": query_vector, + }, + } + return query_body + + # accept ndarray as query vector + output = await docsearch.asimilarity_search_by_vector_with_relevance_scores( + query_vector, + k=1, + custom_query=assert_query, + ) + assert output == [(Document(page_content="foo"), 1.0)] + + @pytest.mark.asyncio + async def test_similarity_search_approx_with_hybrid_search_rrf( + self, es_params: dict, index_name: str + ) -> None: + """Test end to end construction and rrf hybrid search with metadata.""" + from functools import partial + + # 1. check query_body is okay + rrf_test_cases: List[Optional[Union[dict, bool]]] = [ + True, + False, + {"rank_constant": 1, "window_size": 5}, + ] + for rrf_test_case in rrf_test_cases: + texts = ["foo", "bar", "baz"] + docsearch = await AsyncElasticsearchStore.afrom_texts( + texts, + AsyncFakeEmbeddings(), + **es_params, + index_name=index_name, + strategy=AsyncElasticsearchStore.ApproxRetrievalStrategy( + hybrid=True, rrf=rrf_test_case + ), + ) + + def assert_query( + query_body: Dict[str, Any], + query: Optional[str], + rrf: Optional[Union[dict, bool]] = True, + ) -> dict: + cmp_query_body = { + "knn": { + "field": "vector", + "filter": [], + "k": 3, + "num_candidates": 50, + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ], + }, + "query": { + "bool": { + "filter": [], + "must": [{"match": {"text": {"query": "foo"}}}], + } + }, + } + + if isinstance(rrf, dict): + cmp_query_body["rank"] = {"rrf": rrf} + elif isinstance(rrf, bool) and rrf is True: + cmp_query_body["rank"] = {"rrf": {}} + + assert query_body == cmp_query_body + + return query_body + + ## without fetch_k parameter + output = await docsearch.asimilarity_search( + "foo", k=3, custom_query=partial(assert_query, rrf=rrf_test_case) + ) + + # 2. check query result is okay + es_output = await docsearch.client.search( + index=index_name, + query={ + "bool": { + "filter": [], + "must": [{"match": {"text": {"query": "foo"}}}], + } + }, + knn={ + "field": "vector", + "filter": [], + "k": 3, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + }, + size=3, + rank={"rrf": {"rank_constant": 1, "window_size": 5}}, + ) + + assert [o.page_content for o in output] == [ + e["_source"]["text"] for e in es_output["hits"]["hits"] + ] + + # 3. check rrf default option is okay + docsearch = await AsyncElasticsearchStore.afrom_texts( + texts, + AsyncFakeEmbeddings(), + **es_params, + index_name=index_name, + strategy=AsyncElasticsearchStore.ApproxRetrievalStrategy(hybrid=True), + ) + + ## with fetch_k parameter + output = await docsearch.asimilarity_search( + "foo", k=3, fetch_k=50, custom_query=assert_query + ) + + @pytest.mark.asyncio + async def test_similarity_search_approx_with_custom_query_fn( + self, es_params: dict, index_name: str + ) -> None: + """test that custom query function is called + with the query string and query body""" + + def my_custom_query( + query_body: Dict[str, Any], query: Optional[str] + ) -> Dict[str, Any]: + assert query == "foo" + assert query_body == { + "knn": { + "field": "vector", + "filter": [], + "k": 1, + "num_candidates": 50, + "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], + } + } + return {"query": {"match": {"text": {"query": "bar"}}}} + + """Test end to end construction and search with metadata.""" + texts = ["foo", "bar", "baz"] + docsearch = await AsyncElasticsearchStore.afrom_texts( + texts, AsyncFakeEmbeddings(), **es_params, index_name=index_name + ) + output = await docsearch.asimilarity_search( + "foo", k=1, custom_query=my_custom_query + ) + assert output == [Document(page_content="bar")] + + @pytest.mark.asyncio + async def test_deployed_model_check_fails_approx( + self, es_params: dict, index_name: str + ) -> None: + """test that exceptions are raised if a specified model is not deployed""" + with pytest.raises(NotFoundError): + await AsyncElasticsearchStore.afrom_texts( + texts=["foo", "bar", "baz"], + embedding=AsyncConsistentFakeEmbeddings(10), + **es_params, + index_name=index_name, + strategy=AsyncElasticsearchStore.ApproxRetrievalStrategy( + query_model_id="non-existing model ID", + ), + ) + + @pytest.mark.asyncio + async def test_deployed_model_check_fails_sparse( + self, es_params: dict, index_name: str + ) -> None: + """test that exceptions are raised if a specified model is not deployed""" + with pytest.raises(NotFoundError): + await AsyncElasticsearchStore.afrom_texts( + texts=["foo", "bar", "baz"], + **es_params, + index_name=index_name, + strategy=AsyncElasticsearchStore.SparseVectorRetrievalStrategy( + model_id="non-existing model ID" + ), + ) + + @pytest.mark.asyncio + async def test_elasticsearch_with_relevance_score( + self, es_params: dict, index_name: str + ) -> None: + """Test to make sure the relevance score is scaled to 0-1.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + embeddings = AsyncFakeEmbeddings() + + docsearch = await AsyncElasticsearchStore.afrom_texts( + index_name=index_name, + texts=texts, + embedding=embeddings, + metadatas=metadatas, + **es_params, + ) + + embedded_query = await embeddings.aembed_query("foo") + output = await docsearch.asimilarity_search_by_vector_with_relevance_scores( + embedding=embedded_query, k=1 + ) + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)] + + @pytest.mark.asyncio + async def test_similarity_search_bm25_search( + self, es_params: dict, index_name: str + ) -> None: + """Test end to end using the BM25 retrieval strategy.""" + texts = ["foo", "bar", "baz"] + docsearch = await AsyncElasticsearchStore.afrom_texts( + texts, + None, + **es_params, + index_name=index_name, + strategy=AsyncElasticsearchStore.BM25RetrievalStrategy(), + ) + + def assert_query( + query_body: Dict[str, Any], query: Optional[str] + ) -> Dict[str, Any]: + assert query_body == { + "query": { + "bool": { + "must": [{"match": {"text": {"query": "foo"}}}], + "filter": [], + } + } + } + return query_body + + output = await docsearch.asimilarity_search( + "foo", k=1, custom_query=assert_query + ) + assert output == [Document(page_content="foo")] + + @pytest.mark.asyncio + async def test_similarity_search_bm25_search_with_filter( + self, es_params: dict, index_name: str + ) -> None: + """Test end to using the BM25 retrieval strategy with metadata.""" + texts = ["foo", "foo", "foo"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = await AsyncElasticsearchStore.afrom_texts( + texts, + None, + **es_params, + index_name=index_name, + metadatas=metadatas, + strategy=AsyncElasticsearchStore.BM25RetrievalStrategy(), + ) + + def assert_query( + query_body: Dict[str, Any], query: Optional[str] + ) -> Dict[str, Any]: + assert query_body == { + "query": { + "bool": { + "must": [{"match": {"text": {"query": "foo"}}}], + "filter": [{"term": {"metadata.page": 1}}], + } + } + } + return query_body + + output = await docsearch.asimilarity_search( + "foo", + k=3, + custom_query=assert_query, + filter=[{"term": {"metadata.page": 1}}], + ) + assert output == [Document(page_content="foo", metadata={"page": 1})] + + @pytest.mark.asyncio + async def test_elasticsearch_with_relevance_threshold( + self, es_params: dict, index_name: str + ) -> None: + """Test to make sure the relevance threshold is respected.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + embeddings = AsyncFakeEmbeddings() + + docsearch = await AsyncElasticsearchStore.afrom_texts( + index_name=index_name, + texts=texts, + embedding=embeddings, + metadatas=metadatas, + **es_params, + ) + + # Find a good threshold for testing + query_string = "foo" + embedded_query = await embeddings.aembed_query(query_string) + top3 = await docsearch.asimilarity_search_by_vector_with_relevance_scores( + embedding=embedded_query, k=3 + ) + similarity_of_second_ranked = top3[1][1] + assert len(top3) == 3 + + # Test threshold + retriever = docsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"score_threshold": similarity_of_second_ranked}, + ) + output = await retriever.aget_relevant_documents(query=query_string) + + assert output == [ + top3[0][0], + top3[1][0], + # third ranked is out + ] + + @pytest.mark.asyncio + async def test_elasticsearch_delete_ids( + self, es_params: dict, index_name: str + ) -> None: + """Test delete methods from vector store.""" + texts = ["foo", "bar", "baz", "gni"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = AsyncElasticsearchStore( + embedding=AsyncConsistentFakeEmbeddings(), + **es_params, + index_name=index_name, + ) + + ids = await docsearch.aadd_texts(texts, metadatas) + output = await docsearch.asimilarity_search("foo", k=10) + assert len(output) == 4 + + await docsearch.adelete(ids[1:3]) + output = await docsearch.asimilarity_search("foo", k=10) + assert len(output) == 2 + + await docsearch.adelete(["not-existing"]) + output = await docsearch.asimilarity_search("foo", k=10) + assert len(output) == 2 + + await docsearch.adelete([ids[0]]) + output = await docsearch.asimilarity_search("foo", k=10) + assert len(output) == 1 + + await docsearch.adelete([ids[3]]) + output = await docsearch.asimilarity_search("gni", k=10) + assert len(output) == 0 diff --git a/libs/elasticsearch/tests/integration_tests/_sync/__init__.py b/libs/elasticsearch/tests/integration_tests/_sync/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/elasticsearch/tests/integration_tests/_test_utilities.py b/libs/elasticsearch/tests/integration_tests/_sync/_test_utilities.py similarity index 100% rename from libs/elasticsearch/tests/integration_tests/_test_utilities.py rename to libs/elasticsearch/tests/integration_tests/_sync/_test_utilities.py diff --git a/libs/elasticsearch/tests/integration_tests/test_cache.py b/libs/elasticsearch/tests/integration_tests/_sync/test_cache.py similarity index 81% rename from libs/elasticsearch/tests/integration_tests/test_cache.py rename to libs/elasticsearch/tests/integration_tests/_sync/test_cache.py index 4271210..ebf0591 100644 --- a/libs/elasticsearch/tests/integration_tests/test_cache.py +++ b/libs/elasticsearch/tests/integration_tests/_sync/test_cache.py @@ -7,15 +7,12 @@ from langchain_core.language_models import BaseChatModel from langchain_elasticsearch import ElasticsearchCache, ElasticsearchEmbeddingsCache -from tests.integration_tests._test_utilities import ( - clear_test_indices, - create_es_client, - read_env, -) + +from ._test_utilities import clear_test_indices, create_es_client, read_env @pytest.fixture -def es_env_fx() -> Union[dict, Generator[dict, None, None]]: +def es_env_fx() -> Union[dict, Generator]: params = read_env() es = create_es_client(params) es.options(ignore_status=404).indices.delete(index="test_index1") @@ -29,9 +26,10 @@ def es_env_fx() -> Union[dict, Generator[dict, None, None]]: index="test_index1,test_index2", name="test_alias" ) clear_test_indices(es) - return None + es.close() +@pytest.mark.sync def test_index_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: cache = ElasticsearchCache( **es_env_fx, index_name="test_index1", metadata={"project": "test"} @@ -39,10 +37,10 @@ def test_index_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: es_client = cache._es_client set_llm_cache(cache) fake_chat_fx.invoke("test") - assert es_client.count(index="test_index1")["count"] == 1 + assert (es_client.count(index="test_index1"))["count"] == 1 fake_chat_fx.invoke("test") - assert es_client.count(index="test_index1")["count"] == 1 - record = es_client.search(index="test_index1")["hits"]["hits"][0]["_source"] + assert (es_client.count(index="test_index1"))["count"] == 1 + record = (es_client.search(index="test_index1"))["hits"]["hits"][0]["_source"] assert "test output" in record.get("llm_output", [""])[0] assert record.get("llm_input") assert record.get("timestamp") @@ -57,13 +55,13 @@ def test_index_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: ) set_llm_cache(cache2) fake_chat_fx.invoke("test") - assert es_client.count(index="test_index1")["count"] == 1 + assert (es_client.count(index="test_index1"))["count"] == 1 fake_chat_fx.invoke("test2") - assert es_client.count(index="test_index1")["count"] == 2 + assert (es_client.count(index="test_index1"))["count"] == 2 fake_chat_fx.invoke("test2") records = [ record["_source"] - for record in es_client.search(index="test_index1")["hits"]["hits"] + for record in (es_client.search(index="test_index1"))["hits"]["hits"] ] assert all("test output" in record.get("llm_output", [""])[0] for record in records) assert not all(record.get("llm_input", "") for record in records) @@ -72,6 +70,7 @@ def test_index_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: assert all(record.get("metadata") == {"project": "test"} for record in records) +@pytest.mark.sync def test_alias_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: cache = ElasticsearchCache( **es_env_fx, index_name="test_alias", metadata={"project": "test"} @@ -79,9 +78,9 @@ def test_alias_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: es_client = cache._es_client set_llm_cache(cache) fake_chat_fx.invoke("test") - assert es_client.count(index="test_index2")["count"] == 1 + assert (es_client.count(index="test_index2"))["count"] == 1 fake_chat_fx.invoke("test2") - assert es_client.count(index="test_index2")["count"] == 2 + assert (es_client.count(index="test_index2"))["count"] == 2 es_client.indices.put_alias( index="test_index2", name="test_alias", is_write_index=False ) @@ -89,18 +88,19 @@ def test_alias_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: index="test_index1", name="test_alias", is_write_index=True ) fake_chat_fx.invoke("test3") - assert es_client.count(index="test_index1")["count"] == 1 + assert (es_client.count(index="test_index1"))["count"] == 1 fake_chat_fx.invoke("test2") - assert es_client.count(index="test_index1")["count"] == 1 + assert (es_client.count(index="test_index1"))["count"] == 1 es_client.indices.delete_alias(index="test_index2", name="test_alias") # we cache the response for prompt "test2" on both test_index1 and test_index2 fake_chat_fx.invoke("test2") - assert es_client.count(index="test_index1")["count"] == 2 + assert (es_client.count(index="test_index1"))["count"] == 2 es_client.indices.put_alias(index="test_index2", name="test_alias") # we just test the latter scenario is working assert fake_chat_fx.invoke("test2") +@pytest.mark.sync def test_clear_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: cache = ElasticsearchCache( **es_env_fx, index_name="test_alias", metadata={"project": "test"} @@ -116,11 +116,12 @@ def test_clear_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: index="test_index1", name="test_alias", is_write_index=True ) fake_chat_fx.invoke("test3") - assert es_client.count(index="test_alias")["count"] == 3 + assert (es_client.count(index="test_alias"))["count"] == 3 cache.clear() - assert es_client.count(index="test_alias")["count"] == 0 + assert (es_client.count(index="test_alias"))["count"] == 0 +@pytest.mark.sync def test_mdelete_cache_store(es_env_fx: Dict) -> None: store = ElasticsearchEmbeddingsCache( **es_env_fx, index_name="test_alias", metadata={"project": "test"} @@ -135,18 +136,19 @@ def test_mdelete_cache_store(es_env_fx: Dict) -> None: ] ) - assert store._es_client.count(index="test_alias")["count"] == 3 + assert (store._es_client.count(index="test_alias"))["count"] == 3 store.mdelete(recors[:2]) - assert store._es_client.count(index="test_alias")["count"] == 1 + assert (store._es_client.count(index="test_alias"))["count"] == 1 store.mdelete(recors[2:]) - assert store._es_client.count(index="test_alias")["count"] == 0 + assert (store._es_client.count(index="test_alias"))["count"] == 0 with pytest.raises(BulkIndexError): store.mdelete(recors) +@pytest.mark.sync def test_mset_cache_store(es_env_fx: Dict) -> None: store = ElasticsearchEmbeddingsCache( **es_env_fx, index_name="test_alias", metadata={"project": "test"} @@ -155,18 +157,19 @@ def test_mset_cache_store(es_env_fx: Dict) -> None: records = ["my little tests", "my little tests2", "my little tests3"] store.mset([(records[0], _value_serializer([1, 2, 3]))]) - assert store._es_client.count(index="test_alias")["count"] == 1 + assert (store._es_client.count(index="test_alias"))["count"] == 1 store.mset([(records[0], _value_serializer([1, 2, 3]))]) - assert store._es_client.count(index="test_alias")["count"] == 1 + assert (store._es_client.count(index="test_alias"))["count"] == 1 store.mset( [ (records[1], _value_serializer([1, 2, 3])), (records[2], _value_serializer([1, 2, 3])), ] ) - assert store._es_client.count(index="test_alias")["count"] == 3 + assert (store._es_client.count(index="test_alias"))["count"] == 3 +@pytest.mark.sync def test_mget_cache_store(es_env_fx: Dict) -> None: store_no_alias = ElasticsearchEmbeddingsCache( **es_env_fx, @@ -179,7 +182,7 @@ def test_mget_cache_store(es_env_fx: Dict) -> None: docs = [(r, _value_serializer([0.1, 2, i])) for i, r in enumerate(records)] store_no_alias.mset(docs) - assert store_no_alias._es_client.count(index="test_index3")["count"] == 3 + assert (store_no_alias._es_client.count(index="test_index3"))["count"] == 3 cached_records = store_no_alias.mget([d[0] for d in docs]) assert all(cached_records) @@ -194,13 +197,14 @@ def test_mget_cache_store(es_env_fx: Dict) -> None: ) store_alias.mset(docs) - assert store_alias._es_client.count(index="test_alias")["count"] == 3 + assert (store_alias._es_client.count(index="test_alias"))["count"] == 3 cached_records = store_alias.mget([d[0] for d in docs]) assert all(cached_records) assert all([r == d[1] for r, d in zip(cached_records, docs)]) +@pytest.mark.sync def test_mget_cache_store_multiple_keys(es_env_fx: Dict) -> None: """verify the logic of deduplication of keys in the cache store""" @@ -218,7 +222,7 @@ def test_mget_cache_store_multiple_keys(es_env_fx: Dict) -> None: docs = [(r, _value_serializer([0.1, 2, i])) for i, r in enumerate(records)] store_alias.mset(docs) - assert es_client.count(index="test_alias")["count"] == 3 + assert (es_client.count(index="test_alias"))["count"] == 3 store_no_alias = ElasticsearchEmbeddingsCache( **es_env_fx, @@ -235,7 +239,7 @@ def test_mget_cache_store_multiple_keys(es_env_fx: Dict) -> None: # store the same 3 previous records and 2 more in a fresh index store_no_alias.mset(new_docs) - assert es_client.count(index="test_index3")["count"] == 5 + assert (es_client.count(index="test_index3"))["count"] == 5 # update the alias to point to the new index and verify the cache es_client.indices.update_aliases( @@ -254,7 +258,7 @@ def test_mget_cache_store_multiple_keys(es_env_fx: Dict) -> None: cached_records = store_alias.mget([d[0] for d in new_docs]) assert all(cached_records) assert len(cached_records) == 5 - assert es_client.count(index="test_alias")["count"] == 8 + assert (es_client.count(index="test_alias"))["count"] == 8 assert cached_records[:3] != [ d[1] for d in docs ], "the first 3 records should be updated" @@ -267,6 +271,7 @@ def test_mget_cache_store_multiple_keys(es_env_fx: Dict) -> None: ) +@pytest.mark.sync def test_build_document_cache_store(es_env_fx: Dict) -> None: store = ElasticsearchEmbeddingsCache( **es_env_fx, @@ -276,7 +281,7 @@ def test_build_document_cache_store(es_env_fx: Dict) -> None: ) store.mset([("my little tests", _value_serializer([0.1, 2, 3]))]) - record = store._es_client.search(index="test_alias")["hits"]["hits"][0]["_source"] + record = (store._es_client.search(index="test_alias"))["hits"]["hits"][0]["_source"] assert record.get("metadata") == {"project": "test"} assert record.get("namespace") == "test" diff --git a/libs/elasticsearch/tests/integration_tests/test_chat_history.py b/libs/elasticsearch/tests/integration_tests/_sync/test_chat_history.py similarity index 87% rename from libs/elasticsearch/tests/integration_tests/test_chat_history.py rename to libs/elasticsearch/tests/integration_tests/_sync/test_chat_history.py index 56db550..8207e1a 100644 --- a/libs/elasticsearch/tests/integration_tests/test_chat_history.py +++ b/libs/elasticsearch/tests/integration_tests/_sync/test_chat_history.py @@ -4,7 +4,7 @@ import pytest from langchain.memory import ConversationBufferMemory -from langchain_core.messages import message_to_dict +from langchain_core.messages import AIMessage, HumanMessage, message_to_dict from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory @@ -31,6 +31,7 @@ def elasticsearch_connection(self) -> Iterator[dict]: yield params clear_test_indices(es) + es.close() @pytest.fixture(scope="function") def index_name(self) -> str: @@ -51,8 +52,12 @@ def test_memory_with_message_store( ) # add some messages - memory.chat_memory.add_ai_message("This is me, the AI") - memory.chat_memory.add_user_message("This is me, the human") + memory.chat_memory.add_messages( + [ + AIMessage("This is me, the AI"), + HumanMessage("This is me, the human"), + ] + ) # get the message history from the memory store and turn it into a json messages = memory.chat_memory.messages diff --git a/libs/elasticsearch/tests/integration_tests/test_embeddings.py b/libs/elasticsearch/tests/integration_tests/_sync/test_embeddings.py similarity index 62% rename from libs/elasticsearch/tests/integration_tests/test_embeddings.py rename to libs/elasticsearch/tests/integration_tests/_sync/test_embeddings.py index 87d31b6..7420f8c 100644 --- a/libs/elasticsearch/tests/integration_tests/test_embeddings.py +++ b/libs/elasticsearch/tests/integration_tests/_sync/test_embeddings.py @@ -15,31 +15,40 @@ NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384")) ES_URL = os.environ.get("ES_URL", "http://localhost:9200") -ES_CLIENT = Elasticsearch(hosts=[ES_URL]) -@pytest.mark.skipif( - not model_is_deployed(ES_CLIENT, MODEL_ID), - reason=f"{MODEL_ID} model is not deployed in ML Node, skipping test", -) +@pytest.mark.sync def test_elasticsearch_embedding_documents() -> None: """Test Elasticsearch embedding documents.""" + client = Elasticsearch(hosts=[ES_URL]) + if not (model_is_deployed(client, MODEL_ID)): + client.close() + pytest.skip( + reason=f"{MODEL_ID} model is not deployed in ML Node, skipping test" + ) + documents = ["foo bar", "bar foo", "foo"] - embedding = ElasticsearchEmbeddings.from_credentials(MODEL_ID) + embedding = ElasticsearchEmbeddings.from_es_connection(MODEL_ID, client) output = embedding.embed_documents(documents) + client.close() assert len(output) == 3 assert len(output[0]) == NUM_DIMENSIONS assert len(output[1]) == NUM_DIMENSIONS assert len(output[2]) == NUM_DIMENSIONS -@pytest.mark.skipif( - not model_is_deployed(ES_CLIENT, MODEL_ID), - reason=f"{MODEL_ID} model is not deployed in ML Node, skipping test", -) +@pytest.mark.sync def test_elasticsearch_embedding_query() -> None: """Test Elasticsearch embedding query.""" + client = Elasticsearch(hosts=[ES_URL]) + if not (model_is_deployed(client, MODEL_ID)): + client.close() + pytest.skip( + reason=f"{MODEL_ID} model is not deployed in ML Node, skipping test" + ) + document = "foo bar" - embedding = ElasticsearchEmbeddings.from_credentials(MODEL_ID) + embedding = ElasticsearchEmbeddings.from_es_connection(MODEL_ID, client) output = embedding.embed_query(document) + client.close() assert len(output) == NUM_DIMENSIONS diff --git a/libs/elasticsearch/tests/integration_tests/test_retrievers.py b/libs/elasticsearch/tests/integration_tests/_sync/test_retrievers.py similarity index 94% rename from libs/elasticsearch/tests/integration_tests/test_retrievers.py rename to libs/elasticsearch/tests/integration_tests/_sync/test_retrievers.py index 15d0307..24c968c 100644 --- a/libs/elasticsearch/tests/integration_tests/test_retrievers.py +++ b/libs/elasticsearch/tests/integration_tests/_sync/test_retrievers.py @@ -38,13 +38,16 @@ def index_test_data(es_client: Elasticsearch, index_name: str, field_name: str) class TestElasticsearchRetriever: @pytest.fixture(scope="function") def es_client(self) -> Any: - return requests_saving_es_client() + client = requests_saving_es_client() + yield client + client.close() @pytest.fixture(scope="function") def index_name(self) -> str: """Return the index name.""" return f"test_{uuid.uuid4().hex}" + @pytest.mark.sync def test_user_agent_header(self, es_client: Elasticsearch, index_name: str) -> None: """Test that the user agent header is set correctly.""" @@ -58,9 +61,7 @@ def test_user_agent_header(self, es_client: Elasticsearch, index_name: str) -> N assert retriever.es_client user_agent = retriever.es_client._headers["User-Agent"] assert ( - re.match( - r"^langchain-py-r/\d+\.\d+\.\d+(?:rc\d+)?(?:\.dev\d+)?$", user_agent - ) + re.match(r"^langchain-py-r/\d+\.\d+\.\d+(?:rc\d+)?$", user_agent) is not None ), f"The string '{user_agent}' does not match the expected pattern." @@ -70,12 +71,11 @@ def test_user_agent_header(self, es_client: Elasticsearch, index_name: str) -> N search_request = es_client.transport.requests[-1] # type: ignore[attr-defined] user_agent = search_request["headers"]["User-Agent"] assert ( - re.match( - r"^langchain-py-r/\d+\.\d+\.\d+(?:rc\d+)?(?:\.dev\d+)?$", user_agent - ) + re.match(r"^langchain-py-r/\d+\.\d+\.\d+(?:rc\d+)?$", user_agent) is not None ), f"The string '{user_agent}' does not match the expected pattern." + @pytest.mark.sync def test_init_url(self, index_name: str) -> None: """Test end-to-end indexing and search.""" @@ -109,6 +109,7 @@ def body_func(query: str) -> Dict: assert text_field not in r.metadata["_source"] assert "another_field" in r.metadata["_source"] + @pytest.mark.sync def test_init_client(self, es_client: Elasticsearch, index_name: str) -> None: """Test end-to-end indexing and search.""" @@ -134,6 +135,7 @@ def body_func(query: str) -> Dict: assert text_field not in r.metadata["_source"] assert "another_field" in r.metadata["_source"] + @pytest.mark.sync def test_multiple_index_and_content_fields( self, es_client: Elasticsearch, index_name: str ) -> None: @@ -174,6 +176,7 @@ def body_func(query: str) -> Dict: ("foo baz", index_name_2), ] + @pytest.mark.sync def test_custom_mapper(self, es_client: Elasticsearch, index_name: str) -> None: """Test custom document maper""" @@ -199,6 +202,7 @@ def id_as_content(hit: Dict) -> Document: assert [r.page_content for r in result] == ["3", "1", "5"] assert [r.metadata for r in result] == [meta, meta, meta] + @pytest.mark.sync def test_fail_content_field_and_mapper(self, es_client: Elasticsearch) -> None: """Raise exception if both content_field and document_mapper are specified.""" @@ -211,6 +215,7 @@ def test_fail_content_field_and_mapper(self, es_client: Elasticsearch) -> None: es_client=es_client, ) + @pytest.mark.sync def test_fail_neither_content_field_nor_mapper( self, es_client: Elasticsearch ) -> None: diff --git a/libs/elasticsearch/tests/integration_tests/test_vectorstores.py b/libs/elasticsearch/tests/integration_tests/_sync/test_vectorstores.py similarity index 97% rename from libs/elasticsearch/tests/integration_tests/test_vectorstores.py rename to libs/elasticsearch/tests/integration_tests/_sync/test_vectorstores.py index 39219f0..463f6dd 100644 --- a/libs/elasticsearch/tests/integration_tests/test_vectorstores.py +++ b/libs/elasticsearch/tests/integration_tests/_sync/test_vectorstores.py @@ -1,4 +1,4 @@ -"""Test ElasticsearchStore functionality.""" +"""Test AsyncElasticsearchStore functionality.""" import logging import uuid @@ -10,7 +10,7 @@ from langchain_elasticsearch.vectorstores import ElasticsearchStore -from ..fake_embeddings import ConsistentFakeEmbeddings, FakeEmbeddings +from ...fake_embeddings import ConsistentFakeEmbeddings, FakeEmbeddings from ._test_utilities import clear_test_indices, create_es_client, read_env logging.basicConfig(level=logging.DEBUG) @@ -42,6 +42,7 @@ def index_name(self) -> str: """Return the index name.""" return f"test_{uuid.uuid4().hex}" + @pytest.mark.sync def test_from_texts_similarity_search_with_doc_builder( self, es_params: dict, index_name: str ) -> None: @@ -73,6 +74,7 @@ def custom_document_builder(_: Dict) -> Document: docsearch.close() + @pytest.mark.sync def test_search_with_relevance_threshold( self, es_params: dict, index_name: str ) -> None: @@ -112,6 +114,7 @@ def test_search_with_relevance_threshold( docsearch.close() + @pytest.mark.sync def test_search_by_vector_with_relevance_threshold( self, es_params: dict, index_name: str ) -> None: @@ -154,6 +157,7 @@ def test_search_by_vector_with_relevance_threshold( # Also tested in elasticsearch.helpers.vectorstore + @pytest.mark.sync def test_similarity_search_without_metadata( self, es_params: dict, index_name: str ) -> None: @@ -183,20 +187,7 @@ def assert_query( output = docsearch.similarity_search("foo", k=1, custom_query=assert_query) assert output == [Document(page_content="foo")] - async def test_similarity_search_without_metadata_async( - self, es_params: dict, index_name: str - ) -> None: - """Test end to end construction and search without metadata.""" - texts = ["foo", "bar", "baz"] - docsearch = ElasticsearchStore.from_texts( - texts, - FakeEmbeddings(), - **es_params, - index_name=index_name, - ) - output = await docsearch.asimilarity_search("foo", k=1) - assert output == [Document(page_content="foo")] - + @pytest.mark.sync def test_add_embeddings(self, es_params: dict, index_name: str) -> None: """ Test add_embeddings, which accepts pre-built embeddings instead of @@ -223,6 +214,7 @@ def test_add_embeddings(self, es_params: dict, index_name: str) -> None: output = docsearch.similarity_search("foo1", k=1) assert output == [Document(page_content="foo3", metadata={"page": 2})] + @pytest.mark.sync def test_similarity_search_with_metadata( self, es_params: dict, index_name: str ) -> None: @@ -243,6 +235,7 @@ def test_similarity_search_with_metadata( output = docsearch.similarity_search("bar", k=1) assert output == [Document(page_content="bar", metadata={"page": 1})] + @pytest.mark.sync def test_similarity_search_with_filter( self, es_params: dict, index_name: str ) -> None: @@ -279,6 +272,7 @@ def assert_query( ) assert output == [Document(page_content="foo", metadata={"page": 1})] + @pytest.mark.sync def test_similarity_search_with_doc_builder( self, es_params: dict, index_name: str ) -> None: @@ -308,6 +302,7 @@ def custom_document_builder(_: Dict) -> Document: assert output[0].metadata["page_number"] == -1 assert output[0].metadata["original_filename"] == "Mock filename!" + @pytest.mark.sync def test_similarity_search_exact_search( self, es_params: dict, index_name: str ) -> None: @@ -355,6 +350,7 @@ def assert_query( output = docsearch.similarity_search("foo", k=1, custom_query=assert_query) assert output == [Document(page_content="foo")] + @pytest.mark.sync def test_similarity_search_exact_search_with_filter( self, es_params: dict, index_name: str ) -> None: @@ -408,6 +404,7 @@ def assert_query( ) assert output == [Document(page_content="foo", metadata={"page": 0})] + @pytest.mark.sync def test_similarity_search_exact_search_distance_dot_product( self, es_params: dict, index_name: str ) -> None: @@ -457,6 +454,7 @@ def assert_query( output = docsearch.similarity_search("foo", k=1, custom_query=assert_query) assert output == [Document(page_content="foo")] + @pytest.mark.sync def test_similarity_search_exact_search_unknown_distance_strategy( self, es_params: dict, index_name: str ) -> None: @@ -473,6 +471,7 @@ def test_similarity_search_exact_search_unknown_distance_strategy( distance_strategy="NOT_A_STRATEGY", ) + @pytest.mark.sync def test_max_marginal_relevance_search( self, es_params: dict, index_name: str ) -> None: @@ -509,6 +508,7 @@ def test_max_marginal_relevance_search( mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=3, fetch_k=2) assert len(mmr_output) == 2 + @pytest.mark.sync def test_similarity_search_approx_with_hybrid_search( self, es_params: dict, index_name: str ) -> None: @@ -546,6 +546,7 @@ def assert_query( output = docsearch.similarity_search("foo", k=1, custom_query=assert_query) assert output == [Document(page_content="foo")] + @pytest.mark.sync def test_similarity_search_approx_by_vector( self, es_params: dict, index_name: str ) -> None: @@ -582,6 +583,7 @@ def assert_query( ) assert output == [(Document(page_content="foo"), 1.0)] + @pytest.mark.sync def test_similarity_search_approx_with_hybrid_search_rrf( self, es_params: dict, index_name: str ) -> None: @@ -690,6 +692,7 @@ def assert_query( "foo", k=3, fetch_k=50, custom_query=assert_query ) + @pytest.mark.sync def test_similarity_search_approx_with_custom_query_fn( self, es_params: dict, index_name: str ) -> None: @@ -719,6 +722,7 @@ def my_custom_query( output = docsearch.similarity_search("foo", k=1, custom_query=my_custom_query) assert output == [Document(page_content="bar")] + @pytest.mark.sync def test_deployed_model_check_fails_approx( self, es_params: dict, index_name: str ) -> None: @@ -734,6 +738,7 @@ def test_deployed_model_check_fails_approx( ), ) + @pytest.mark.sync def test_deployed_model_check_fails_sparse( self, es_params: dict, index_name: str ) -> None: @@ -748,6 +753,7 @@ def test_deployed_model_check_fails_sparse( ), ) + @pytest.mark.sync def test_elasticsearch_with_relevance_score( self, es_params: dict, index_name: str ) -> None: @@ -770,6 +776,7 @@ def test_elasticsearch_with_relevance_score( ) assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)] + @pytest.mark.sync def test_similarity_search_bm25_search( self, es_params: dict, index_name: str ) -> None: @@ -799,6 +806,7 @@ def assert_query( output = docsearch.similarity_search("foo", k=1, custom_query=assert_query) assert output == [Document(page_content="foo")] + @pytest.mark.sync def test_similarity_search_bm25_search_with_filter( self, es_params: dict, index_name: str ) -> None: @@ -835,6 +843,7 @@ def assert_query( ) assert output == [Document(page_content="foo", metadata={"page": 1})] + @pytest.mark.sync def test_elasticsearch_with_relevance_threshold( self, es_params: dict, index_name: str ) -> None: @@ -873,6 +882,7 @@ def test_elasticsearch_with_relevance_threshold( # third ranked is out ] + @pytest.mark.sync def test_elasticsearch_delete_ids(self, es_params: dict, index_name: str) -> None: """Test delete methods from vector store.""" texts = ["foo", "bar", "baz", "gni"] diff --git a/libs/elasticsearch/tests/unit_tests/_async/__init__.py b/libs/elasticsearch/tests/unit_tests/_async/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/elasticsearch/tests/unit_tests/_async/test_cache.py b/libs/elasticsearch/tests/unit_tests/_async/test_cache.py new file mode 100644 index 0000000..fa83ceb --- /dev/null +++ b/libs/elasticsearch/tests/unit_tests/_async/test_cache.py @@ -0,0 +1,405 @@ +from datetime import datetime +from typing import Any, Dict +from unittest import mock +from unittest.mock import ANY, MagicMock, patch + +import pytest +from _pytest.fixtures import FixtureRequest +from elastic_transport import ApiResponseMeta, HttpHeaders, NodeConfig +from elasticsearch import NotFoundError +from langchain.embeddings.cache import _value_serializer +from langchain_core.load import dumps +from langchain_core.outputs import Generation + +from langchain_elasticsearch import ( + AsyncElasticsearchCache, + AsyncElasticsearchEmbeddingsCache, +) + + +def serialize_encode_vector(vector: Any) -> str: + return AsyncElasticsearchEmbeddingsCache.encode_vector(_value_serializer(vector)) + + +@pytest.mark.asyncio +async def test_initialization_llm_cache(async_es_client_fx: MagicMock) -> None: + async_es_client_fx.ping.return_value = True + async_es_client_fx.indices.exists_alias.return_value = True + with mock.patch( + "langchain_elasticsearch._sync.cache.create_elasticsearch_client", + return_value=async_es_client_fx, + ): + with mock.patch( + "langchain_elasticsearch._async.cache.create_async_elasticsearch_client", + return_value=async_es_client_fx, + ): + cache = AsyncElasticsearchCache( + es_url="http://localhost:9200", index_name="test_index" + ) + assert await cache.is_alias() + async_es_client_fx.indices.exists_alias.assert_awaited_with( + name="test_index" + ) + async_es_client_fx.indices.put_mapping.assert_awaited_with( + index="test_index", body=cache.mapping["mappings"] + ) + async_es_client_fx.indices.exists_alias.return_value = False + async_es_client_fx.indices.exists.return_value = False + cache = AsyncElasticsearchCache( + es_url="http://localhost:9200", index_name="test_index" + ) + assert not (await cache.is_alias()) + async_es_client_fx.indices.create.assert_awaited_with( + index="test_index", body=cache.mapping + ) + + +def test_mapping_llm_cache( + async_es_cache_fx: AsyncElasticsearchCache, request: FixtureRequest +) -> None: + mapping = request.getfixturevalue("es_cache_fx").mapping + assert mapping.get("mappings") + assert mapping["mappings"].get("properties") + + +def test_key_generation_llm_cache(es_cache_fx: AsyncElasticsearchCache) -> None: + key1 = es_cache_fx._key("test_prompt", "test_llm_string") + assert key1 and isinstance(key1, str) + key2 = es_cache_fx._key("test_prompt", "test_llm_string1") + assert key2 and key1 != key2 + key3 = es_cache_fx._key("test_prompt1", "test_llm_string") + assert key3 and key1 != key3 + + +def test_clear_llm_cache( + es_client_fx: MagicMock, es_cache_fx: AsyncElasticsearchCache +) -> None: + es_cache_fx.clear() + es_client_fx.delete_by_query.assert_called_once_with( + index="test_index", + body={"query": {"match_all": {}}}, + refresh=True, + wait_for_completion=True, + ) + + +def test_build_document_llm_cache(es_cache_fx: AsyncElasticsearchCache) -> None: + doc = es_cache_fx.build_document( + "test_prompt", "test_llm_string", [Generation(text="test_prompt")] + ) + assert doc["llm_input"] == "test_prompt" + assert doc["llm_params"] == "test_llm_string" + assert isinstance(doc["llm_output"], list) + assert all(isinstance(gen, str) for gen in doc["llm_output"]) + assert datetime.fromisoformat(str(doc["timestamp"])) + assert doc["metadata"] == es_cache_fx._metadata + + +def test_update_llm_cache( + es_client_fx: MagicMock, es_cache_fx: AsyncElasticsearchCache +) -> None: + es_cache_fx.update("test_prompt", "test_llm_string", [Generation(text="test")]) + timestamp = es_client_fx.index.call_args.kwargs["body"]["timestamp"] + doc = es_cache_fx.build_document( + "test_prompt", "test_llm_string", [Generation(text="test")] + ) + doc["timestamp"] = timestamp + es_client_fx.index.assert_called_once_with( + index=es_cache_fx._index_name, + id=es_cache_fx._key("test_prompt", "test_llm_string"), + body=doc, + require_alias=es_cache_fx._is_alias, + refresh=True, + ) + + +def test_lookup_llm_cache( + es_client_fx: MagicMock, es_cache_fx: AsyncElasticsearchCache +) -> None: + cache_key = es_cache_fx._key("test_prompt", "test_llm_string") + doc: Dict[str, Any] = { + "_source": { + "llm_output": [dumps(Generation(text="test"))], + "timestamp": "2024-03-07T13:25:36.410756", + } + } + es_cache_fx._is_alias = False + es_client_fx.get.side_effect = NotFoundError( + "not found", + ApiResponseMeta(404, "0", HttpHeaders(), 0, NodeConfig("http", "xxx", 80)), + "", + ) + assert es_cache_fx.lookup("test_prompt", "test_llm_string") is None + es_client_fx.get.assert_called_once_with( + index="test_index", id=cache_key, source=["llm_output"] + ) + es_client_fx.get.side_effect = None + es_client_fx.get.return_value = doc + assert es_cache_fx.lookup("test_prompt", "test_llm_string") == [ + Generation(text="test") + ] + es_cache_fx._is_alias = True + es_client_fx.search.return_value = {"hits": {"total": {"value": 0}, "hits": []}} + assert es_cache_fx.lookup("test_prompt", "test_llm_string") is None + es_client_fx.search.assert_called_once_with( + index="test_index", + body={ + "query": {"term": {"_id": cache_key}}, + "sort": {"timestamp": {"order": "asc"}}, + }, + source_includes=["llm_output"], + ) + doc2 = { + "_source": { + "llm_output": [dumps(Generation(text="test2"))], + "timestamp": "2024-03-08T13:25:36.410756", + }, + } + es_client_fx.search.return_value = { + "hits": {"total": {"value": 2}, "hits": [doc2, doc]} + } + assert es_cache_fx.lookup("test_prompt", "test_llm_string") == [ + Generation(text="test2") + ] + + +def test_key_generation_cache_store( + es_embeddings_cache_fx: AsyncElasticsearchEmbeddingsCache, +) -> None: + key1 = es_embeddings_cache_fx._key("test_text") + assert key1 and isinstance(key1, str) + key2 = es_embeddings_cache_fx._key("test_text2") + assert key2 and key1 != key2 + es_embeddings_cache_fx._namespace = "other" + key3 = es_embeddings_cache_fx._key("test_text") + assert key3 and key1 != key3 + es_embeddings_cache_fx._namespace = None + key4 = es_embeddings_cache_fx._key("test_text") + assert key4 and key1 != key4 and key3 != key4 + + +def test_build_document_cache_store( + es_embeddings_cache_fx: AsyncElasticsearchEmbeddingsCache, +) -> None: + doc = es_embeddings_cache_fx.build_document( + "test_text", _value_serializer([1.5, 2, 3.6]) + ) + assert doc["text_input"] == "test_text" + assert doc["vector_dump"] == serialize_encode_vector([1.5, 2, 3.6]) + assert datetime.fromisoformat(str(doc["timestamp"])) + assert doc["metadata"] == es_embeddings_cache_fx._metadata + + +def test_mget_cache_store( + es_client_fx: MagicMock, es_embeddings_cache_fx: AsyncElasticsearchEmbeddingsCache +) -> None: + cache_keys = [ + es_embeddings_cache_fx._key("test_text1"), + es_embeddings_cache_fx._key("test_text2"), + es_embeddings_cache_fx._key("test_text3"), + ] + docs = { + "docs": [ + {"_index": "test_index", "_id": cache_keys[0], "found": False}, + { + "_index": "test_index", + "_id": cache_keys[1], + "found": True, + "_source": {"vector_dump": serialize_encode_vector([1.5, 2, 3.6])}, + }, + { + "_index": "test_index", + "_id": cache_keys[2], + "found": True, + "_source": {"vector_dump": serialize_encode_vector([5, 6, 7.1])}, + }, + ] + } + es_embeddings_cache_fx._is_alias = False + es_client_fx.mget.return_value = docs + assert es_embeddings_cache_fx.mget([]) == [] + assert es_embeddings_cache_fx.mget(["test_text1", "test_text2", "test_text3"]) == [ + None, + _value_serializer([1.5, 2, 3.6]), + _value_serializer([5, 6, 7.1]), + ] + es_client_fx.mget.assert_called_with( + index="test_index", ids=cache_keys, source_includes=["vector_dump"] + ) + es_embeddings_cache_fx._is_alias = True + es_client_fx.search.return_value = {"hits": {"total": {"value": 0}, "hits": []}} + assert es_embeddings_cache_fx.mget([]) == [] + assert es_embeddings_cache_fx.mget(["test_text1", "test_text2", "test_text3"]) == [ + None, + None, + None, + ] + es_client_fx.search.assert_called_with( + index="test_index", + body={ + "query": {"ids": {"values": cache_keys}}, + "size": 3, + }, + source_includes=["vector_dump", "timestamp"], + ) + resp = { + "hits": {"total": {"value": 3}, "hits": [d for d in docs["docs"] if d["found"]]} + } + es_client_fx.search.return_value = resp + assert es_embeddings_cache_fx.mget(["test_text1", "test_text2", "test_text3"]) == [ + None, + _value_serializer([1.5, 2, 3.6]), + _value_serializer([5, 6, 7.1]), + ] + + +def test_deduplicate_hits( + es_embeddings_cache_fx: AsyncElasticsearchEmbeddingsCache, +) -> None: + hits = [ + { + "_id": "1", + "_source": { + "timestamp": "2022-01-01T00:00:00", + "vector_dump": serialize_encode_vector([1, 2, 3]), + }, + }, + { + "_id": "1", + "_source": { + "timestamp": "2022-01-02T00:00:00", + "vector_dump": serialize_encode_vector([4, 5, 6]), + }, + }, + { + "_id": "2", + "_source": { + "timestamp": "2022-01-01T00:00:00", + "vector_dump": serialize_encode_vector([7, 8, 9]), + }, + }, + ] + + result = es_embeddings_cache_fx._deduplicate_hits(hits) + + assert len(result) == 2 + assert result["1"] == _value_serializer([4, 5, 6]) + assert result["2"] == _value_serializer([7, 8, 9]) + + +def test_mget_duplicate_keys_cache_store( + es_client_fx: MagicMock, es_embeddings_cache_fx: AsyncElasticsearchEmbeddingsCache +) -> None: + cache_keys = [ + es_embeddings_cache_fx._key("test_text1"), + es_embeddings_cache_fx._key("test_text2"), + ] + + resp = { + "hits": { + "total": {"value": 3}, + "hits": [ + { + "_index": "test_index", + "_id": cache_keys[1], + "found": True, + "_source": { + "vector_dump": serialize_encode_vector([1.5, 2, 3.6]), + "timestamp": "2024-03-07T13:25:36.410756", + }, + }, + { + "_index": "test_index", + "_id": cache_keys[0], + "found": True, + "_source": { + "vector_dump": serialize_encode_vector([1, 6, 7.1]), + "timestamp": "2024-03-07T13:25:46.410756", + }, + }, + { + "_index": "test_index", + "_id": cache_keys[0], + "found": True, + "_source": { + "vector_dump": serialize_encode_vector([2, 6, 7.1]), + "timestamp": "2024-03-07T13:27:46.410756", + }, + }, + ], + } + } + + es_embeddings_cache_fx._is_alias = True + es_client_fx.search.return_value = resp + assert es_embeddings_cache_fx.mget(["test_text1", "test_text2"]) == [ + _value_serializer([2, 6, 7.1]), + _value_serializer([1.5, 2, 3.6]), + ] + es_client_fx.search.assert_called_with( + index="test_index", + body={ + "query": {"ids": {"values": cache_keys}}, + "size": len(cache_keys), + }, + source_includes=["vector_dump", "timestamp"], + ) + + +def _del_timestamp(doc: Dict[str, Any]) -> Dict[str, Any]: + del doc["_source"]["timestamp"] + return doc + + +def test_mset_cache_store( + es_embeddings_cache_fx: AsyncElasticsearchEmbeddingsCache, +) -> None: + input = [ + ("test_text1", _value_serializer([1.5, 2, 3.6])), + ("test_text2", _value_serializer([5, 6, 7.1])), + ] + actions = [ + { + "_op_type": "index", + "_id": es_embeddings_cache_fx._key(k), + "_source": es_embeddings_cache_fx.build_document(k, v), + } + for k, v in input + ] + es_embeddings_cache_fx._is_alias = False + with patch("elasticsearch.helpers.bulk") as bulk_mock: + es_embeddings_cache_fx.mset([]) + bulk_mock.assert_called_once() + es_embeddings_cache_fx.mset(input) + bulk_mock.assert_called_with( + client=es_embeddings_cache_fx._es_client, + actions=ANY, + index="test_index", + require_alias=False, + refresh=True, + ) + assert [_del_timestamp(d) for d in bulk_mock.call_args.kwargs["actions"]] == [ + _del_timestamp(d) for d in actions + ] + + +def test_mdelete_cache_store( + es_embeddings_cache_fx: AsyncElasticsearchEmbeddingsCache, +) -> None: + input = ["test_text1", "test_text2"] + actions = [ + {"_op_type": "delete", "_id": es_embeddings_cache_fx._key(k)} for k in input + ] + es_embeddings_cache_fx._is_alias = False + with patch("elasticsearch.helpers.bulk") as bulk_mock: + es_embeddings_cache_fx.mdelete([]) + bulk_mock.assert_called_once() + es_embeddings_cache_fx.mdelete(input) + bulk_mock.assert_called_with( + client=es_embeddings_cache_fx._es_client, + actions=ANY, + index="test_index", + require_alias=False, + refresh=True, + ) + assert list(bulk_mock.call_args.kwargs["actions"]) == actions diff --git a/libs/elasticsearch/tests/unit_tests/_async/test_vectorstores.py b/libs/elasticsearch/tests/unit_tests/_async/test_vectorstores.py new file mode 100644 index 0000000..f2a2358 --- /dev/null +++ b/libs/elasticsearch/tests/unit_tests/_async/test_vectorstores.py @@ -0,0 +1,417 @@ +"""Test Elasticsearch functionality.""" + +import re +from typing import Any, AsyncGenerator, Dict, List, Optional +from unittest.mock import AsyncMock + +import pytest +from elasticsearch import AsyncElasticsearch +from langchain_core.documents import Document + +from langchain_elasticsearch._async.vectorstores import _convert_retrieval_strategy +from langchain_elasticsearch._utilities import _hits_to_docs_scores +from langchain_elasticsearch.embeddings import AsyncEmbeddingServiceAdapter, Embeddings +from langchain_elasticsearch.vectorstores import ( + ApproxRetrievalStrategy, + AsyncBM25Strategy, + AsyncDenseVectorScriptScoreStrategy, + AsyncDenseVectorStrategy, + AsyncElasticsearchStore, + AsyncSparseVectorStrategy, + BM25RetrievalStrategy, + DistanceMetric, + DistanceStrategy, + ExactRetrievalStrategy, + SparseRetrievalStrategy, +) + +from ...fake_embeddings import AsyncConsistentFakeEmbeddings + + +class TestHitsToDocsScores: + def test_basic(self) -> None: + content_field = "content" + hits = [ + { + "_score": 11, + "_source": {content_field: "abc", "metadata": {"meta1": "one"}}, + }, + { + "_score": 22, + "_source": {content_field: "def", "metadata": {"meta2": "two"}}, + }, + ] + expected = [ + (Document("abc", metadata={"meta1": "one"}), 11), + (Document("def", metadata={"meta2": "two"}), 22), + ] + actual = _hits_to_docs_scores(hits, content_field) + assert actual == expected + + def test_custom_builder(self) -> None: + content_field = "content" + hits = [ + { + "_score": 11, + "_source": {content_field: "abc", "metadata": {"meta1": "one"}}, + }, + { + "_score": 22, + "_source": {content_field: "def", "metadata": {"meta2": "two"}}, + }, + ] + + def custom_builder(hit: Dict) -> Document: + return Document("static", metadata={"score": hit["_score"]}) + + expected = [ + (Document("static", metadata={"score": 11}), 11), + (Document("static", metadata={"score": 22}), 22), + ] + actual = _hits_to_docs_scores(hits, content_field, doc_builder=custom_builder) + assert actual == expected + + def test_fields(self) -> None: + content_field = "content" + extra_field = "extra" + hits = [ + { + "_score": 11, + "_source": { + content_field: "abc", + extra_field: "extra1", + "ignore_me": "please", + }, + }, + {"_score": 22, "_source": {content_field: "def", extra_field: "extra2"}}, + ] + expected = [ + (Document("abc", metadata={extra_field: "extra1"}), 11), + (Document("def", metadata={extra_field: "extra2"}), 22), + ] + actual = _hits_to_docs_scores(hits, content_field, fields=[extra_field]) + assert actual == expected + + def test_missing_content_field(self) -> None: + content_field = "content" + hits = [ + { + "_score": 11, + "_source": {content_field: "abc", "metadata": {"meta1": "one"}}, + }, + { + "_score": 22, + "_source": {content_field: "def", "metadata": {"meta2": "two"}}, + }, + ] + expected = [ + (Document("", metadata={"meta1": "one"}), 11), + (Document("", metadata={"meta2": "two"}), 22), + ] + actual = _hits_to_docs_scores(hits, "missing_content_field") + assert actual == expected + + def test_missing_metadata_field(self) -> None: + content_field = "content" + hits = [ + {"_score": 11, "_source": {content_field: "abc"}}, # missing metadata + ] + expected = [ + (Document("abc", metadata={}), 11), # empty metadata + ] + actual = _hits_to_docs_scores(hits, content_field) + assert actual == expected + + def test_doc_field_to_metadata(self) -> None: + content_field = "content" + other_field = "other" + hits = [ + { + "_score": 11, + "_source": { + content_field: "abc", + other_field: "foo", + "metadata": {"meta1": "one"}, + }, + }, + { + "_score": 22, + "_source": { + content_field: "def", + other_field: "bar", + "metadata": {"meta2": "two"}, + }, + }, + ] + expected = [ + (Document("abc", metadata={"meta1": "one", other_field: "foo"}), 11), + (Document("def", metadata={"meta2": "two", other_field: "bar"}), 22), + ] + actual = _hits_to_docs_scores( + hits, content_field=content_field, fields=[other_field] + ) + assert actual == expected + + +class TestConvertStrategy: + def test_dense_approx(self) -> None: + actual = _convert_retrieval_strategy( + ApproxRetrievalStrategy(query_model_id="my model", hybrid=True, rrf=False), + distance=DistanceStrategy.DOT_PRODUCT, + ) + assert isinstance(actual, AsyncDenseVectorStrategy) + assert actual.distance == DistanceMetric.DOT_PRODUCT + assert actual.model_id == "my model" + assert actual.hybrid is True + assert actual.rrf is False + + def test_dense_exact(self) -> None: + actual = _convert_retrieval_strategy( + ExactRetrievalStrategy(), distance=DistanceStrategy.EUCLIDEAN_DISTANCE + ) + assert isinstance(actual, AsyncDenseVectorScriptScoreStrategy) + assert actual.distance == DistanceMetric.EUCLIDEAN_DISTANCE + + def test_sparse(self) -> None: + actual = _convert_retrieval_strategy( + SparseRetrievalStrategy(model_id="my model ID") + ) + assert isinstance(actual, AsyncSparseVectorStrategy) + assert actual.model_id == "my model ID" + + def test_bm25(self) -> None: + actual = _convert_retrieval_strategy(BM25RetrievalStrategy(k1=1.7, b=5.4)) + assert isinstance(actual, AsyncBM25Strategy) + assert actual.k1 == 1.7 + assert actual.b == 5.4 + + +class TestVectorStore: + @pytest.fixture + def embeddings(self) -> Embeddings: + return AsyncConsistentFakeEmbeddings() + + @pytest.fixture + async def store(self) -> AsyncGenerator: + client = AsyncElasticsearch(hosts=["http://dummy:9200"]) # never connected to + store = AsyncElasticsearchStore(index_name="test_index", es_connection=client) + try: + yield store + finally: + await store.aclose() + + @pytest.fixture + async def hybrid_store(self, embeddings: Embeddings) -> AsyncGenerator: + client = AsyncElasticsearch(hosts=["http://dummy:9200"]) # never connected to + store = AsyncElasticsearchStore( + index_name="test_index", + embedding=embeddings, + strategy=ApproxRetrievalStrategy(hybrid=True), + es_connection=client, + ) + try: + yield store + finally: + await store.aclose() + + @pytest.fixture + def static_hits(self) -> List[Dict[str, Any]]: + default_content_field = "text" + return [ + {"_score": 1, "_source": {default_content_field: "test", "metadata": {}}} + ] + + @staticmethod + def dummy_custom_query(query_body: dict, query: Optional[str]) -> Dict[str, Any]: + return {"dummy": "query"} + + def test_agent_header(self, store: AsyncElasticsearchStore) -> None: + agent = store.client._headers["User-Agent"] + assert ( + re.match(r"^langchain-py-vs/\d+\.\d+\.\d+(?:rc\d+)?$", agent) is not None + ), f"The string '{agent}' does not match the expected pattern." + + @pytest.mark.asyncio + async def test_similarity_search( + self, store: AsyncElasticsearchStore, static_hits: List[Dict] + ) -> None: + store._store.search = AsyncMock(return_value=static_hits) # type: ignore[assignment] + actual1 = await store.asimilarity_search( + query="test", + k=7, + fetch_k=34, + filter=[{"f": 1}], + custom_query=self.dummy_custom_query, + ) + assert actual1 == [Document("test")] + store._store.search.assert_awaited_with( + query="test", + k=7, + num_candidates=34, + filter=[{"f": 1}], + custom_query=self.dummy_custom_query, + ) + + store._store.search = AsyncMock(return_value=static_hits) # type: ignore[assignment] + + actual2 = await store.asimilarity_search_with_score( + query="test", + k=7, + fetch_k=34, + filter=[{"f": 1}], + custom_query=self.dummy_custom_query, + ) + assert actual2 == [(Document("test"), 1)] + store._store.search.assert_awaited_with( + query="test", + k=7, + filter=[{"f": 1}], + custom_query=self.dummy_custom_query, + ) + + @pytest.mark.asyncio + async def test_similarity_search_by_vector_with_relevance_scores( + self, store: AsyncElasticsearchStore, static_hits: List[Dict] + ) -> None: + store._store.search = AsyncMock(return_value=static_hits) # type: ignore[assignment] + actual = await store.asimilarity_search_by_vector_with_relevance_scores( + embedding=[1, 2, 3], + k=7, + fetch_k=34, + filter=[{"f": 1}], + custom_query=self.dummy_custom_query, + ) + assert actual == [(Document("test"), 1)] + store._store.search.assert_awaited_with( + query=None, + query_vector=[1, 2, 3], + k=7, + filter=[{"f": 1}], + custom_query=self.dummy_custom_query, + ) + + @pytest.mark.asyncio + async def test_delete(self, store: AsyncElasticsearchStore) -> None: + store._store.delete = AsyncMock(return_value=True) # type: ignore[assignment] + actual = await store.adelete( + ids=["10", "20"], + refresh_indices=True, + ) + assert actual is True + store._store.delete.assert_awaited_with( + ids=["10", "20"], + refresh_indices=True, + ) + + @pytest.mark.asyncio + async def test_add_texts(self, store: AsyncElasticsearchStore) -> None: + store._store.add_texts = AsyncMock(return_value=["10", "20"]) # type: ignore[assignment] + actual = await store.aadd_texts( + texts=["t1", "t2"], + ) + assert actual == ["10", "20"] + store._store.add_texts.assert_awaited_with( + texts=["t1", "t2"], + metadatas=None, + ids=None, + refresh_indices=True, + create_index_if_not_exists=True, + bulk_kwargs=None, + ) + + store._store.add_texts = AsyncMock(return_value=["10", "20"]) # type: ignore[assignment] + await store.aadd_texts( + texts=["t1", "t2"], + metadatas=[{1: 2}, {3: 4}], + ids=["10", "20"], + refresh_indices=False, + create_index_if_not_exists=False, + bulk_kwargs={"x": "y"}, + ) + store._store.add_texts.assert_awaited_with( + texts=["t1", "t2"], + metadatas=[{1: 2}, {3: 4}], + ids=["10", "20"], + refresh_indices=False, + create_index_if_not_exists=False, + bulk_kwargs={"x": "y"}, + ) + + @pytest.mark.asyncio + async def test_add_embeddings(self, store: AsyncElasticsearchStore) -> None: + store._store.add_texts = AsyncMock(return_value=["10", "20"]) # type: ignore[assignment] + actual = await store.aadd_embeddings( + text_embeddings=[("t1", [1, 2, 3]), ("t2", [4, 5, 6])], + ) + assert actual == ["10", "20"] + store._store.add_texts.assert_awaited_with( + texts=["t1", "t2"], + metadatas=None, + vectors=[[1, 2, 3], [4, 5, 6]], + ids=None, + refresh_indices=True, + create_index_if_not_exists=True, + bulk_kwargs=None, + ) + + store._store.add_texts = AsyncMock(return_value=["10", "20"]) # type: ignore[assignment] + await store.aadd_embeddings( + text_embeddings=[("t1", [1, 2, 3]), ("t2", [4, 5, 6])], + metadatas=[{1: 2}, {3: 4}], + ids=["10", "20"], + refresh_indices=False, + create_index_if_not_exists=False, + bulk_kwargs={"x": "y"}, + ) + store._store.add_texts.assert_awaited_with( + texts=["t1", "t2"], + metadatas=[{1: 2}, {3: 4}], + vectors=[[1, 2, 3], [4, 5, 6]], + ids=["10", "20"], + refresh_indices=False, + create_index_if_not_exists=False, + bulk_kwargs={"x": "y"}, + ) + + @pytest.mark.asyncio + async def test_max_marginal_relevance_search( + self, + hybrid_store: AsyncElasticsearchStore, + embeddings: Embeddings, + static_hits: List[Dict], + ) -> None: + hybrid_store._store.max_marginal_relevance_search = AsyncMock( # type: ignore[assignment] + return_value=static_hits + ) + actual = await hybrid_store.amax_marginal_relevance_search( + query="qqq", + k=8, + fetch_k=19, + lambda_mult=0.3, + ) + assert actual == [Document("test")] + hybrid_store._store.max_marginal_relevance_search.assert_awaited_with( + embedding_service=AsyncEmbeddingServiceAdapter(embeddings), + query="qqq", + vector_field="vector", + k=8, + num_candidates=19, + lambda_mult=0.3, + fields=None, + custom_query=None, + ) + + @pytest.mark.asyncio + async def test_elasticsearch_hybrid_scores_guard( + self, hybrid_store: AsyncElasticsearchStore + ) -> None: + """Ensure an error is raised when search with score in hybrid mode + because in this case Elasticsearch does not return any score. + """ + with pytest.raises(ValueError): + await hybrid_store.asimilarity_search_with_score("foo") + + with pytest.raises(ValueError): + await hybrid_store.asimilarity_search_by_vector_with_relevance_scores( + [1, 2, 3] + ) diff --git a/libs/elasticsearch/tests/unit_tests/_sync/__init__.py b/libs/elasticsearch/tests/unit_tests/_sync/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/elasticsearch/tests/unit_tests/test_cache.py b/libs/elasticsearch/tests/unit_tests/_sync/test_cache.py similarity index 90% rename from libs/elasticsearch/tests/unit_tests/test_cache.py rename to libs/elasticsearch/tests/unit_tests/_sync/test_cache.py index fbbe3cc..079d876 100644 --- a/libs/elasticsearch/tests/unit_tests/test_cache.py +++ b/libs/elasticsearch/tests/unit_tests/_sync/test_cache.py @@ -3,6 +3,7 @@ from unittest import mock from unittest.mock import ANY, MagicMock, patch +import pytest from _pytest.fixtures import FixtureRequest from elastic_transport import ApiResponseMeta, HttpHeaders, NodeConfig from elasticsearch import NotFoundError @@ -17,30 +18,35 @@ def serialize_encode_vector(vector: Any) -> str: return ElasticsearchEmbeddingsCache.encode_vector(_value_serializer(vector)) +@pytest.mark.sync def test_initialization_llm_cache(es_client_fx: MagicMock) -> None: es_client_fx.ping.return_value = True es_client_fx.indices.exists_alias.return_value = True with mock.patch( - "langchain_elasticsearch.cache.create_elasticsearch_client", + "langchain_elasticsearch._sync.cache.create_elasticsearch_client", return_value=es_client_fx, ): - cache = ElasticsearchCache( - es_url="http://localhost:9200", index_name="test_index" - ) - es_client_fx.indices.exists_alias.assert_called_with(name="test_index") - assert cache._is_alias - es_client_fx.indices.put_mapping.assert_called_with( - index="test_index", body=cache.mapping["mappings"] - ) - es_client_fx.indices.exists_alias.return_value = False - es_client_fx.indices.exists.return_value = False - cache = ElasticsearchCache( - es_url="http://localhost:9200", index_name="test_index" - ) - assert not cache._is_alias - es_client_fx.indices.create.assert_called_with( - index="test_index", body=cache.mapping - ) + with mock.patch( + "langchain_elasticsearch._async.cache.create_async_elasticsearch_client", + return_value=es_client_fx, + ): + cache = ElasticsearchCache( + es_url="http://localhost:9200", index_name="test_index" + ) + assert cache.is_alias() + es_client_fx.indices.exists_alias.assert_called_with(name="test_index") + es_client_fx.indices.put_mapping.assert_called_with( + index="test_index", body=cache.mapping["mappings"] + ) + es_client_fx.indices.exists_alias.return_value = False + es_client_fx.indices.exists.return_value = False + cache = ElasticsearchCache( + es_url="http://localhost:9200", index_name="test_index" + ) + assert not (cache.is_alias()) + es_client_fx.indices.create.assert_called_with( + index="test_index", body=cache.mapping + ) def test_mapping_llm_cache( @@ -242,7 +248,9 @@ def test_mget_cache_store( ] -def test_deduplicate_hits(es_embeddings_cache_fx: ElasticsearchEmbeddingsCache) -> None: +def test_deduplicate_hits( + es_embeddings_cache_fx: ElasticsearchEmbeddingsCache, +) -> None: hits = [ { "_id": "1", @@ -338,7 +346,9 @@ def _del_timestamp(doc: Dict[str, Any]) -> Dict[str, Any]: return doc -def test_mset_cache_store(es_embeddings_cache_fx: ElasticsearchEmbeddingsCache) -> None: +def test_mset_cache_store( + es_embeddings_cache_fx: ElasticsearchEmbeddingsCache, +) -> None: input = [ ("test_text1", _value_serializer([1.5, 2, 3.6])), ("test_text2", _value_serializer([5, 6, 7.1])), diff --git a/libs/elasticsearch/tests/unit_tests/test_vectorstores.py b/libs/elasticsearch/tests/unit_tests/_sync/test_vectorstores.py similarity index 96% rename from libs/elasticsearch/tests/unit_tests/test_vectorstores.py rename to libs/elasticsearch/tests/unit_tests/_sync/test_vectorstores.py index eac78d0..4645185 100644 --- a/libs/elasticsearch/tests/unit_tests/test_vectorstores.py +++ b/libs/elasticsearch/tests/unit_tests/_sync/test_vectorstores.py @@ -8,6 +8,8 @@ from elasticsearch import Elasticsearch from langchain_core.documents import Document +from langchain_elasticsearch._sync.vectorstores import _convert_retrieval_strategy +from langchain_elasticsearch._utilities import _hits_to_docs_scores from langchain_elasticsearch.embeddings import Embeddings, EmbeddingServiceAdapter from langchain_elasticsearch.vectorstores import ( ApproxRetrievalStrategy, @@ -21,11 +23,9 @@ ExactRetrievalStrategy, SparseRetrievalStrategy, SparseVectorStrategy, - _convert_retrieval_strategy, - _hits_to_docs_scores, ) -from ..fake_embeddings import ConsistentFakeEmbeddings +from ...fake_embeddings import ConsistentFakeEmbeddings class TestHitsToDocsScores: @@ -192,7 +192,7 @@ def embeddings(self) -> Embeddings: return ConsistentFakeEmbeddings() @pytest.fixture - def store(self) -> Generator[ElasticsearchStore, None, None]: + def store(self) -> Generator: client = Elasticsearch(hosts=["http://dummy:9200"]) # never connected to store = ElasticsearchStore(index_name="test_index", es_connection=client) try: @@ -201,9 +201,7 @@ def store(self) -> Generator[ElasticsearchStore, None, None]: store.close() @pytest.fixture - def hybrid_store( - self, embeddings: Embeddings - ) -> Generator[ElasticsearchStore, None, None]: + def hybrid_store(self, embeddings: Embeddings) -> Generator: client = Elasticsearch(hosts=["http://dummy:9200"]) # never connected to store = ElasticsearchStore( index_name="test_index", @@ -230,10 +228,10 @@ def dummy_custom_query(query_body: dict, query: Optional[str]) -> Dict[str, Any] def test_agent_header(self, store: ElasticsearchStore) -> None: agent = store.client._headers["User-Agent"] assert ( - re.match(r"^langchain-py-vs/\d+\.\d+\.\d+(?:rc\d+)?(?:\.dev\d+)?$", agent) - is not None + re.match(r"^langchain-py-vs/\d+\.\d+\.\d+(?:rc\d+)?$", agent) is not None ), f"The string '{agent}' does not match the expected pattern." + @pytest.mark.sync def test_similarity_search( self, store: ElasticsearchStore, static_hits: List[Dict] ) -> None: @@ -271,6 +269,7 @@ def test_similarity_search( custom_query=self.dummy_custom_query, ) + @pytest.mark.sync def test_similarity_search_by_vector_with_relevance_scores( self, store: ElasticsearchStore, static_hits: List[Dict] ) -> None: @@ -291,6 +290,7 @@ def test_similarity_search_by_vector_with_relevance_scores( custom_query=self.dummy_custom_query, ) + @pytest.mark.sync def test_delete(self, store: ElasticsearchStore) -> None: store._store.delete = Mock(return_value=True) # type: ignore[assignment] actual = store.delete( @@ -303,6 +303,7 @@ def test_delete(self, store: ElasticsearchStore) -> None: refresh_indices=True, ) + @pytest.mark.sync def test_add_texts(self, store: ElasticsearchStore) -> None: store._store.add_texts = Mock(return_value=["10", "20"]) # type: ignore[assignment] actual = store.add_texts( @@ -336,6 +337,7 @@ def test_add_texts(self, store: ElasticsearchStore) -> None: bulk_kwargs={"x": "y"}, ) + @pytest.mark.sync def test_add_embeddings(self, store: ElasticsearchStore) -> None: store._store.add_texts = Mock(return_value=["10", "20"]) # type: ignore[assignment] actual = store.add_embeddings( @@ -371,6 +373,7 @@ def test_add_embeddings(self, store: ElasticsearchStore) -> None: bulk_kwargs={"x": "y"}, ) + @pytest.mark.sync def test_max_marginal_relevance_search( self, hybrid_store: ElasticsearchStore, @@ -398,6 +401,7 @@ def test_max_marginal_relevance_search( custom_query=None, ) + @pytest.mark.sync def test_elasticsearch_hybrid_scores_guard( self, hybrid_store: ElasticsearchStore ) -> None: diff --git a/libs/elasticsearch/tests/unit_tests/test_imports.py b/libs/elasticsearch/tests/unit_tests/test_imports.py index 5cd5fc2..28c6296 100644 --- a/libs/elasticsearch/tests/unit_tests/test_imports.py +++ b/libs/elasticsearch/tests/unit_tests/test_imports.py @@ -1,25 +1,38 @@ from langchain_elasticsearch import __all__ -EXPECTED_ALL = [ - "ElasticsearchCache", - "ElasticsearchChatMessageHistory", - "ElasticsearchEmbeddings", - "ElasticsearchEmbeddingsCache", - "ElasticsearchRetriever", - "ElasticsearchStore", - # retrieval strategies - "BM25Strategy", - "DenseVectorScriptScoreStrategy", - "DenseVectorStrategy", - "DistanceMetric", - "RetrievalStrategy", - "SparseVectorStrategy", - # deprecated retrieval strategies - "ApproxRetrievalStrategy", - "BM25RetrievalStrategy", - "ExactRetrievalStrategy", - "SparseRetrievalStrategy", -] +EXPECTED_ALL = sorted( + [ + "ElasticsearchCache", + "ElasticsearchChatMessageHistory", + "ElasticsearchEmbeddings", + "ElasticsearchEmbeddingsCache", + "ElasticsearchRetriever", + "ElasticsearchStore", + "AsyncElasticsearchCache", + "AsyncElasticsearchChatMessageHistory", + "AsyncElasticsearchEmbeddings", + "AsyncElasticsearchEmbeddingsCache", + "AsyncElasticsearchRetriever", + "AsyncElasticsearchStore", + # retrieval strategies + "BM25Strategy", + "DenseVectorScriptScoreStrategy", + "DenseVectorStrategy", + "DistanceMetric", + "RetrievalStrategy", + "SparseVectorStrategy", + "AsyncBM25Strategy", + "AsyncDenseVectorScriptScoreStrategy", + "AsyncDenseVectorStrategy", + "AsyncRetrievalStrategy", + "AsyncSparseVectorStrategy", + # deprecated retrieval strategies + "ApproxRetrievalStrategy", + "BM25RetrievalStrategy", + "ExactRetrievalStrategy", + "SparseRetrievalStrategy", + ] +) def test_all_imports() -> None: