diff --git a/libs/elasticsearch/langchain_elasticsearch/retrievers.py b/libs/elasticsearch/langchain_elasticsearch/retrievers.py index 1e58e7c..a5a12f5 100644 --- a/libs/elasticsearch/langchain_elasticsearch/retrievers.py +++ b/libs/elasticsearch/langchain_elasticsearch/retrievers.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast from elasticsearch import Elasticsearch from langchain_core.callbacks import CallbackManagerForRetrieverRun @@ -19,20 +19,21 @@ class ElasticsearchRetriever(BaseRetriever): Args: es_client: Elasticsearch client connection. Alternatively you can use the `from_es_params` method with parameters to initialize the client. - index_name: The name of the index to query. + index_name: The name of the index to query. Can also be a list of names. body_func: Function to create an Elasticsearch DSL query body from a search string. The returned query body must fit what you would normally send in a POST request the the _search endpoint. If applicable, it also includes parameters the `size` parameter etc. - content_field: The document field name that contains the page content. + content_field: The document field name that contains the page content. If + multiple indices are queried, specify a dict {index_name: field_name} here. document_mapper: Function to map Elasticsearch hits to LangChain Documents. """ es_client: Elasticsearch - index_name: str + index_name: Union[str, Sequence[str]] body_func: Callable[[str], Dict] - content_field: Optional[str] = None - document_mapper: Optional[Callable[[Dict], Document]] = None + content_field: Optional[Union[str, Mapping[str, str]]] = None + document_mapper: Optional[Callable[[Mapping], Document]] = None def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -45,15 +46,24 @@ def __init__(self, **kwargs: Any) -> None: "Please provide only one." ) - self.document_mapper = self.document_mapper or self._field_mapper + if not self.document_mapper: + if isinstance(self.content_field, str): + self.document_mapper = self._single_field_mapper + elif isinstance(self.content_field, Mapping): + self.document_mapper = self._multi_field_mapper + else: + raise ValueError( + "unknown type for content_field, expected string or dict." + ) + self.es_client = with_user_agent_header(self.es_client, "langchain-py-r") @staticmethod def from_es_params( - index_name: str, + index_name: Union[str, Sequence[str]], body_func: Callable[[str], Dict], - content_field: Optional[str] = None, - document_mapper: Optional[Callable[[Dict], Document]] = None, + content_field: Optional[Union[str, Mapping[str, str]]] = None, + document_mapper: Optional[Callable[[Mapping], Document]] = None, url: Optional[str] = None, cloud_id: Optional[str] = None, api_key: Optional[str] = None, @@ -93,6 +103,12 @@ def _get_relevant_documents( results = self.es_client.search(index=self.index_name, body=body) return [self.document_mapper(hit) for hit in results["hits"]["hits"]] - def _field_mapper(self, hit: Dict[str, Any]) -> Document: + def _single_field_mapper(self, hit: Mapping[str, Any]) -> Document: content = hit["_source"].pop(self.content_field) return Document(page_content=content, metadata=hit) + + def _multi_field_mapper(self, hit: Mapping[str, Any]) -> Document: + self.content_field = cast(Mapping, self.content_field) + field = self.content_field[hit["_index"]] + content = hit["_source"].pop(field) + return Document(page_content=content, metadata=hit) diff --git a/libs/elasticsearch/tests/integration_tests/test_retrievers.py b/libs/elasticsearch/tests/integration_tests/test_retrievers.py index 59c0244..5dd3121 100644 --- a/libs/elasticsearch/tests/integration_tests/test_retrievers.py +++ b/libs/elasticsearch/tests/integration_tests/test_retrievers.py @@ -128,6 +128,46 @@ def body_func(query: str) -> Dict: assert text_field not in r.metadata["_source"] assert "another_field" in r.metadata["_source"] + def test_multiple_index_and_content_fields( + self, es_client: Elasticsearch, index_name: str + ) -> None: + """Test multiple content fields""" + index_name_1 = f"{index_name}_1" + index_name_2 = f"{index_name}_2" + text_field_1 = "text_1" + text_field_2 = "text_2" + + def body_func(query: str) -> Dict: + return { + "query": { + "multi_match": { + "query": query, + "fields": [text_field_1, text_field_2], + } + } + } + + retriever = ElasticsearchRetriever( + index_name=[index_name_1, index_name_2], + content_field={index_name_1: text_field_1, index_name_2: text_field_2}, + body_func=body_func, + es_client=es_client, + ) + + index_test_data(es_client, index_name_1, text_field_1) + index_test_data(es_client, index_name_2, text_field_2) + result = retriever.get_relevant_documents("foo") + + # matches from both indices + assert sorted([(r.page_content, r.metadata["_index"]) for r in result]) == [ + ("foo", index_name_1), + ("foo", index_name_2), + ("foo bar", index_name_1), + ("foo bar", index_name_2), + ("foo baz", index_name_1), + ("foo baz", index_name_2), + ] + def test_custom_mapper(self, es_client: Elasticsearch, index_name: str) -> None: """Test custom document maper"""