diff --git a/libs/elasticsearch/langchain_elasticsearch/_async/chat_history.py b/libs/elasticsearch/langchain_elasticsearch/_async/chat_history.py index fb98458..2de6060 100644 --- a/libs/elasticsearch/langchain_elasticsearch/_async/chat_history.py +++ b/libs/elasticsearch/langchain_elasticsearch/_async/chat_history.py @@ -1,7 +1,7 @@ import json import logging from time import time -from typing import TYPE_CHECKING, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict @@ -101,26 +101,33 @@ async def create_if_missing(self) -> None: async def aget_messages(self) -> List[BaseMessage]: # type: ignore[override] """Retrieve the messages from Elasticsearch""" - try: - from elasticsearch import ApiError + from elasticsearch import ApiError - await self.create_if_missing() - result = await self.client.search( - index=self.index, - query={"term": {"session_id": self.session_id}}, - sort="created_at:asc", - ) - except ApiError as err: - logger.error(f"Could not retrieve messages from Elasticsearch: {err}") - raise err + await self.create_if_missing() - if result and len(result["hits"]["hits"]) > 0: - items = [ - json.loads(document["_source"]["history"]) - for document in result["hits"]["hits"] - ] - else: - items = [] + search_after: Dict[str, Any] = {} + items = [] + while True: + try: + result = await self.client.search( + index=self.index, + query={"term": {"session_id": self.session_id}}, + sort="created_at:asc", + size=100, + **search_after, + ) + except ApiError as err: + logger.error(f"Could not retrieve messages from Elasticsearch: {err}") + raise err + + if result and len(result["hits"]["hits"]) > 0: + items += [ + json.loads(document["_source"]["history"]) + for document in result["hits"]["hits"] + ] + search_after = {"search_after": result["hits"]["hits"][-1]["sort"]} + else: + break return messages_from_dict(items) diff --git a/libs/elasticsearch/langchain_elasticsearch/_sync/chat_history.py b/libs/elasticsearch/langchain_elasticsearch/_sync/chat_history.py index 5c5614d..2d7d1cf 100644 --- a/libs/elasticsearch/langchain_elasticsearch/_sync/chat_history.py +++ b/libs/elasticsearch/langchain_elasticsearch/_sync/chat_history.py @@ -1,7 +1,7 @@ import json import logging from time import time -from typing import TYPE_CHECKING, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict @@ -101,26 +101,33 @@ def create_if_missing(self) -> None: def get_messages(self) -> List[BaseMessage]: # type: ignore[override] """Retrieve the messages from Elasticsearch""" - try: - from elasticsearch import ApiError + from elasticsearch import ApiError - self.create_if_missing() - result = self.client.search( - index=self.index, - query={"term": {"session_id": self.session_id}}, - sort="created_at:asc", - ) - except ApiError as err: - logger.error(f"Could not retrieve messages from Elasticsearch: {err}") - raise err + self.create_if_missing() - if result and len(result["hits"]["hits"]) > 0: - items = [ - json.loads(document["_source"]["history"]) - for document in result["hits"]["hits"] - ] - else: - items = [] + search_after: Dict[str, Any] = {} + items = [] + while True: + try: + result = self.client.search( + index=self.index, + query={"term": {"session_id": self.session_id}}, + sort="created_at:asc", + size=100, + **search_after, + ) + except ApiError as err: + logger.error(f"Could not retrieve messages from Elasticsearch: {err}") + raise err + + if result and len(result["hits"]["hits"]) > 0: + items += [ + json.loads(document["_source"]["history"]) + for document in result["hits"]["hits"] + ] + search_after = {"search_after": result["hits"]["hits"][-1]["sort"]} + else: + break return messages_from_dict(items) diff --git a/libs/elasticsearch/tests/integration_tests/_async/test_chat_history.py b/libs/elasticsearch/tests/integration_tests/_async/test_chat_history.py index 100bb1a..91e5dbe 100644 --- a/libs/elasticsearch/tests/integration_tests/_async/test_chat_history.py +++ b/libs/elasticsearch/tests/integration_tests/_async/test_chat_history.py @@ -1,4 +1,3 @@ -import json import uuid from typing import AsyncIterator @@ -54,17 +53,35 @@ async def test_memory_with_message_store( # add some messages await memory.chat_memory.aadd_messages( [ - AIMessage("This is me, the AI"), - HumanMessage("This is me, the human"), + AIMessage("This is me, the AI (1)"), + HumanMessage("This is me, the human (1)"), + AIMessage("This is me, the AI (2)"), + HumanMessage("This is me, the human (2)"), + AIMessage("This is me, the AI (3)"), + HumanMessage("This is me, the human (3)"), + AIMessage("This is me, the AI (4)"), + HumanMessage("This is me, the human (4)"), + AIMessage("This is me, the AI (5)"), + HumanMessage("This is me, the human (5)"), + AIMessage("This is me, the AI (6)"), + HumanMessage("This is me, the human (6)"), + AIMessage("This is me, the AI (7)"), + HumanMessage("This is me, the human (7)"), ] ) # get the message history from the memory store and turn it into a json - messages = await memory.chat_memory.aget_messages() - messages_json = json.dumps([message_to_dict(msg) for msg in messages]) - - assert "This is me, the AI" in messages_json - assert "This is me, the human" in messages_json + messages = [ + message_to_dict(msg) for msg in (await memory.chat_memory.aget_messages()) + ] + + assert len(messages) == 14 + for i in range(7): + assert messages[i * 2]["data"]["content"] == f"This is me, the AI ({i+1})" + assert ( + messages[i * 2 + 1]["data"]["content"] + == f"This is me, the human ({i+1})" + ) # remove the record from Elasticsearch, so the next test run won't pick it up await memory.chat_memory.aclear() diff --git a/libs/elasticsearch/tests/integration_tests/_sync/test_chat_history.py b/libs/elasticsearch/tests/integration_tests/_sync/test_chat_history.py index 8207e1a..ec9793b 100644 --- a/libs/elasticsearch/tests/integration_tests/_sync/test_chat_history.py +++ b/libs/elasticsearch/tests/integration_tests/_sync/test_chat_history.py @@ -1,4 +1,3 @@ -import json import uuid from typing import Iterator @@ -54,17 +53,33 @@ def test_memory_with_message_store( # add some messages memory.chat_memory.add_messages( [ - AIMessage("This is me, the AI"), - HumanMessage("This is me, the human"), + AIMessage("This is me, the AI (1)"), + HumanMessage("This is me, the human (1)"), + AIMessage("This is me, the AI (2)"), + HumanMessage("This is me, the human (2)"), + AIMessage("This is me, the AI (3)"), + HumanMessage("This is me, the human (3)"), + AIMessage("This is me, the AI (4)"), + HumanMessage("This is me, the human (4)"), + AIMessage("This is me, the AI (5)"), + HumanMessage("This is me, the human (5)"), + AIMessage("This is me, the AI (6)"), + HumanMessage("This is me, the human (6)"), + AIMessage("This is me, the AI (7)"), + HumanMessage("This is me, the human (7)"), ] ) # get the message history from the memory store and turn it into a json - messages = memory.chat_memory.messages - messages_json = json.dumps([message_to_dict(msg) for msg in messages]) - - assert "This is me, the AI" in messages_json - assert "This is me, the human" in messages_json + messages = [message_to_dict(msg) for msg in (memory.chat_memory.messages)] + + assert len(messages) == 14 + for i in range(7): + assert messages[i * 2]["data"]["content"] == f"This is me, the AI ({i+1})" + assert ( + messages[i * 2 + 1]["data"]["content"] + == f"This is me, the human ({i+1})" + ) # remove the record from Elasticsearch, so the next test run won't pick it up memory.chat_memory.clear()