Skip to content

Commit

Permalink
[#169] Fix the broken in operator in bedrock knowledge base retriever. (
Browse files Browse the repository at this point in the history
#233)

Fixed the issue mentioned in
#169
Added unit test that would expose the bug in source code
Tested with my service code and it works with the change
  • Loading branch information
renjiexu-amzn authored Oct 15, 2024
1 parent adcfc34 commit 48535f0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
4 changes: 3 additions & 1 deletion libs/aws/langchain_aws/retrievers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ def _get_relevant_documents(
response = self.client.retrieve(
retrievalQuery={"text": query.strip()},
knowledgeBaseId=self.knowledge_base_id,
retrievalConfiguration=self.retrieval_config.model_dump(exclude_none=True),
retrievalConfiguration=self.retrieval_config.model_dump(
exclude_none=True, by_alias=True
),
)
results = response["retrievalResults"]
documents = []
Expand Down
23 changes: 22 additions & 1 deletion libs/aws/tests/unit_tests/retrievers/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from langchain_core.documents import Document

from langchain_aws.retrievers import AmazonKnowledgeBasesRetriever
from langchain_aws.retrievers.bedrock import (
RetrievalConfig,
SearchFilter,
VectorSearchConfig,
)


@pytest.fixture
Expand All @@ -15,7 +20,12 @@ def mock_client():

@pytest.fixture
def mock_retriever_config():
return {"vectorSearchConfiguration": {"numberOfResults": 4}}
return RetrievalConfig(
vectorSearchConfiguration=VectorSearchConfig(
numberOfResults=5,
filter=SearchFilter(in_={"key": "key", "value": ["value1", "value2"]}),
),
)


@pytest.fixture
Expand Down Expand Up @@ -43,6 +53,17 @@ def test_retriever_invoke(amazon_retriever, mock_client):
}
documents = amazon_retriever.invoke(query, run_manager=None)

mock_client.retrieve.assert_called_once_with(
retrievalQuery={"text": "test query"},
knowledgeBaseId="test_kb_id",
retrievalConfiguration={
"vectorSearchConfiguration": {
"numberOfResults": 5,
# Expecting to be called with correct "in" operatior instead of "in_"
"filter": {"in": {"key": "key", "value": ["value1", "value2"]}},
}
},
)
assert len(documents) == 3
assert isinstance(documents[0], Document)
assert documents[0].page_content == "result1"
Expand Down

0 comments on commit 48535f0

Please sign in to comment.