diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index d0960002..55d32837 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -162,7 +162,9 @@ def _get_relevant_documents( response = self.client.retrieve( retrievalQuery={"text": query.strip()}, knowledgeBaseId=self.knowledge_base_id, - retrievalConfiguration=self.retrieval_config.model_dump(exclude_none=True), + retrievalConfiguration=self.retrieval_config.model_dump( + exclude_none=True, by_alias=True + ), ) results = response["retrievalResults"] documents = [] diff --git a/libs/aws/tests/unit_tests/retrievers/test_bedrock.py b/libs/aws/tests/unit_tests/retrievers/test_bedrock.py index b243db10..6d5a84de 100644 --- a/libs/aws/tests/unit_tests/retrievers/test_bedrock.py +++ b/libs/aws/tests/unit_tests/retrievers/test_bedrock.py @@ -6,6 +6,11 @@ from langchain_core.documents import Document from langchain_aws.retrievers import AmazonKnowledgeBasesRetriever +from langchain_aws.retrievers.bedrock import ( + RetrievalConfig, + SearchFilter, + VectorSearchConfig, +) @pytest.fixture @@ -15,7 +20,12 @@ def mock_client(): @pytest.fixture def mock_retriever_config(): - return {"vectorSearchConfiguration": {"numberOfResults": 4}} + return RetrievalConfig( + vectorSearchConfiguration=VectorSearchConfig( + numberOfResults=5, + filter=SearchFilter(in_={"key": "key", "value": ["value1", "value2"]}), + ), + ) @pytest.fixture @@ -43,6 +53,17 @@ def test_retriever_invoke(amazon_retriever, mock_client): } documents = amazon_retriever.invoke(query, run_manager=None) + mock_client.retrieve.assert_called_once_with( + retrievalQuery={"text": "test query"}, + knowledgeBaseId="test_kb_id", + retrievalConfiguration={ + "vectorSearchConfiguration": { + "numberOfResults": 5, + # Expecting to be called with correct "in" operatior instead of "in_" + "filter": {"in": {"key": "key", "value": ["value1", "value2"]}}, + } + }, + ) assert len(documents) == 3 assert isinstance(documents[0], Document) assert documents[0].page_content == "result1"