Skip to content

Commit

Permalink
retrieve the complete chat history with pagination (#54)
Browse files Browse the repository at this point in the history
* retrieve the complete chat history with pagination (Fixes #48)

* extend chat history unit test to also test pagination
  • Loading branch information
miguelgrinberg authored Dec 4, 2024
1 parent 9c3bccd commit 8470936
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 54 deletions.
45 changes: 26 additions & 19 deletions libs/elasticsearch/langchain_elasticsearch/_async/chat_history.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
from time import time
from typing import TYPE_CHECKING, List, Optional, Sequence
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict
Expand Down Expand Up @@ -101,26 +101,33 @@ async def create_if_missing(self) -> None:

async def aget_messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from Elasticsearch"""
try:
from elasticsearch import ApiError
from elasticsearch import ApiError

await self.create_if_missing()
result = await self.client.search(
index=self.index,
query={"term": {"session_id": self.session_id}},
sort="created_at:asc",
)
except ApiError as err:
logger.error(f"Could not retrieve messages from Elasticsearch: {err}")
raise err
await self.create_if_missing()

if result and len(result["hits"]["hits"]) > 0:
items = [
json.loads(document["_source"]["history"])
for document in result["hits"]["hits"]
]
else:
items = []
search_after: Dict[str, Any] = {}
items = []
while True:
try:
result = await self.client.search(
index=self.index,
query={"term": {"session_id": self.session_id}},
sort="created_at:asc",
size=100,
**search_after,
)
except ApiError as err:
logger.error(f"Could not retrieve messages from Elasticsearch: {err}")
raise err

if result and len(result["hits"]["hits"]) > 0:
items += [
json.loads(document["_source"]["history"])
for document in result["hits"]["hits"]
]
search_after = {"search_after": result["hits"]["hits"][-1]["sort"]}
else:
break

return messages_from_dict(items)

Expand Down
45 changes: 26 additions & 19 deletions libs/elasticsearch/langchain_elasticsearch/_sync/chat_history.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
from time import time
from typing import TYPE_CHECKING, List, Optional, Sequence
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict
Expand Down Expand Up @@ -101,26 +101,33 @@ def create_if_missing(self) -> None:

def get_messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from Elasticsearch"""
try:
from elasticsearch import ApiError
from elasticsearch import ApiError

self.create_if_missing()
result = self.client.search(
index=self.index,
query={"term": {"session_id": self.session_id}},
sort="created_at:asc",
)
except ApiError as err:
logger.error(f"Could not retrieve messages from Elasticsearch: {err}")
raise err
self.create_if_missing()

if result and len(result["hits"]["hits"]) > 0:
items = [
json.loads(document["_source"]["history"])
for document in result["hits"]["hits"]
]
else:
items = []
search_after: Dict[str, Any] = {}
items = []
while True:
try:
result = self.client.search(
index=self.index,
query={"term": {"session_id": self.session_id}},
sort="created_at:asc",
size=100,
**search_after,
)
except ApiError as err:
logger.error(f"Could not retrieve messages from Elasticsearch: {err}")
raise err

if result and len(result["hits"]["hits"]) > 0:
items += [
json.loads(document["_source"]["history"])
for document in result["hits"]["hits"]
]
search_after = {"search_after": result["hits"]["hits"][-1]["sort"]}
else:
break

return messages_from_dict(items)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import uuid
from typing import AsyncIterator

Expand Down Expand Up @@ -54,17 +53,35 @@ async def test_memory_with_message_store(
# add some messages
await memory.chat_memory.aadd_messages(
[
AIMessage("This is me, the AI"),
HumanMessage("This is me, the human"),
AIMessage("This is me, the AI (1)"),
HumanMessage("This is me, the human (1)"),
AIMessage("This is me, the AI (2)"),
HumanMessage("This is me, the human (2)"),
AIMessage("This is me, the AI (3)"),
HumanMessage("This is me, the human (3)"),
AIMessage("This is me, the AI (4)"),
HumanMessage("This is me, the human (4)"),
AIMessage("This is me, the AI (5)"),
HumanMessage("This is me, the human (5)"),
AIMessage("This is me, the AI (6)"),
HumanMessage("This is me, the human (6)"),
AIMessage("This is me, the AI (7)"),
HumanMessage("This is me, the human (7)"),
]
)

# get the message history from the memory store and turn it into a json
messages = await memory.chat_memory.aget_messages()
messages_json = json.dumps([message_to_dict(msg) for msg in messages])

assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
messages = [
message_to_dict(msg) for msg in (await memory.chat_memory.aget_messages())
]

assert len(messages) == 14
for i in range(7):
assert messages[i * 2]["data"]["content"] == f"This is me, the AI ({i+1})"
assert (
messages[i * 2 + 1]["data"]["content"]
== f"This is me, the human ({i+1})"
)

# remove the record from Elasticsearch, so the next test run won't pick it up
await memory.chat_memory.aclear()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import uuid
from typing import Iterator

Expand Down Expand Up @@ -54,17 +53,33 @@ def test_memory_with_message_store(
# add some messages
memory.chat_memory.add_messages(
[
AIMessage("This is me, the AI"),
HumanMessage("This is me, the human"),
AIMessage("This is me, the AI (1)"),
HumanMessage("This is me, the human (1)"),
AIMessage("This is me, the AI (2)"),
HumanMessage("This is me, the human (2)"),
AIMessage("This is me, the AI (3)"),
HumanMessage("This is me, the human (3)"),
AIMessage("This is me, the AI (4)"),
HumanMessage("This is me, the human (4)"),
AIMessage("This is me, the AI (5)"),
HumanMessage("This is me, the human (5)"),
AIMessage("This is me, the AI (6)"),
HumanMessage("This is me, the human (6)"),
AIMessage("This is me, the AI (7)"),
HumanMessage("This is me, the human (7)"),
]
)

# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([message_to_dict(msg) for msg in messages])

assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
messages = [message_to_dict(msg) for msg in (memory.chat_memory.messages)]

assert len(messages) == 14
for i in range(7):
assert messages[i * 2]["data"]["content"] == f"This is me, the AI ({i+1})"
assert (
messages[i * 2 + 1]["data"]["content"]
== f"This is me, the human ({i+1})"
)

# remove the record from Elasticsearch, so the next test run won't pick it up
memory.chat_memory.clear()
Expand Down

0 comments on commit 8470936

Please sign in to comment.