diff --git a/libs/aws/README.md b/libs/aws/README.md index 4e391286..baaee44c 100644 --- a/libs/aws/README.md +++ b/libs/aws/README.md @@ -54,7 +54,7 @@ retriever = AmazonKendraRetriever( retriever.get_relevant_documents(query="What is the meaning of life?") ``` -`AmazonKnowlegeBasesRetriever` class provides a retriever to connect with Amazon Knowlege Bases. +`AmazonKnowledgeBasesRetriever` class provides a retriever to connect with Amazon Knowledge Bases. ```python from langchain_aws import AmazonKnowledgeBasesRetriever diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index 068b904c..0fed1548 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -120,6 +120,9 @@ def _get_relevant_documents( metadata={ "location": result["location"], "score": result["score"] if "score" in result else 0, + "source_metadata": ( + result["metadata"] if "metadata" in result else None + ), }, ) ) diff --git a/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py b/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py index f66a158f..236cd5f0 100644 --- a/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py +++ b/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py @@ -34,6 +34,12 @@ def test_get_relevant_documents(retriever, mock_client) -> None: # type: ignore "score": 0.8, }, {"content": {"text": "This is the third result."}, "location": "location3"}, + { + "content": {"text": "This is the fourth result."}, + "location": "location4", + "score": 0.4, + "metadata": {"url": "http://example.com", "title": "Example Title"}, + }, ] } mock_client.retrieve.return_value = response @@ -43,19 +49,30 @@ def test_get_relevant_documents(retriever, mock_client) -> None: # type: ignore expected_documents = [ Document( page_content="This is the first result.", - metadata={"location": "location1", "score": 0.9}, + metadata={"location": "location1", "score": 0.9, "source_metadata": None}, ), Document( page_content="This is the second result.", - metadata={"location": "location2", "score": 0.8}, + metadata={"location": "location2", "score": 0.8, "source_metadata": None}, ), Document( page_content="This is the third result.", - metadata={"location": "location3", "score": 0.0}, + metadata={"location": "location3", "score": 0.0, "source_metadata": None}, + ), + Document( + page_content="This is the fourth result.", + metadata={ + "location": "location4", + "score": 0.4, + "source_metadata": { + "url": "http://example.com", + "title": "Example Title", + }, + }, ), ] - documents = retriever.get_relevant_documents(query) + documents = retriever.invoke(query) assert documents == expected_documents