From 867277e9c7ce69230bd157989203f3ea83753396 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Tue, 2 Apr 2024 12:30:50 -0700 Subject: [PATCH 1/5] Added retrievers, moved gitignore to root. --- libs/aws/.gitignore => .gitignore | 0 libs/aws/langchain_aws/__init__.py | 6 +- libs/aws/langchain_aws/retrievers/__init__.py | 8 + libs/aws/langchain_aws/retrievers/bedrock.py | 128 +++++ libs/aws/langchain_aws/retrievers/kendra.py | 479 ++++++++++++++++++ libs/aws/tests/unit_tests/test_imports.py | 20 +- 6 files changed, 633 insertions(+), 8 deletions(-) rename libs/aws/.gitignore => .gitignore (100%) create mode 100644 libs/aws/langchain_aws/retrievers/__init__.py create mode 100644 libs/aws/langchain_aws/retrievers/bedrock.py create mode 100644 libs/aws/langchain_aws/retrievers/kendra.py diff --git a/libs/aws/.gitignore b/.gitignore similarity index 100% rename from libs/aws/.gitignore rename to .gitignore diff --git a/libs/aws/langchain_aws/__init__.py b/libs/aws/langchain_aws/__init__.py index 4c8cc796..c98359f9 100644 --- a/libs/aws/langchain_aws/__init__.py +++ b/libs/aws/langchain_aws/__init__.py @@ -1,3 +1,7 @@ from langchain_aws.llms import SagemakerEndpoint +from langchain_aws.retrievers import AmazonKendraRetriever -__all__ = ["SagemakerEndpoint"] +__all__ = [ + "SagemakerEndpoint", + "AmazonKendraRetriever" +] diff --git a/libs/aws/langchain_aws/retrievers/__init__.py b/libs/aws/langchain_aws/retrievers/__init__.py new file mode 100644 index 00000000..2affadec --- /dev/null +++ b/libs/aws/langchain_aws/retrievers/__init__.py @@ -0,0 +1,8 @@ +from langchain_aws.retrievers.kendra import AmazonKendraRetriever +from langchain_aws.retrievers.bedrock import AmazonKnowledgeBasesRetriever + +__all__ = [ + "AmazonKendraRetriever" + "AmazonKnowledgeBasesRetriever" +] + diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py new file mode 100644 index 00000000..0c3d1d66 --- /dev/null +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -0,0 +1,128 @@ +from typing import Any, Dict, List, Optional + +from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.documents import Document +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.retrievers import BaseRetriever + + +class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg] + """Configuration for vector search.""" + + numberOfResults: int = 4 + + +class RetrievalConfig(BaseModel, extra="allow"): # type: ignore[call-arg] + """Configuration for retrieval.""" + + vectorSearchConfiguration: VectorSearchConfig + + +class AmazonKnowledgeBasesRetriever(BaseRetriever): + """`Amazon Bedrock Knowledge Bases` retrieval. + + See https://aws.amazon.com/bedrock/knowledge-bases for more info. + + Args: + knowledge_base_id: Knowledge Base ID. + region_name: The aws region e.g., `us-west-2`. + Fallback to AWS_DEFAULT_REGION env variable or region specified in + ~/.aws/config. + credentials_profile_name: The name of the profile in the ~/.aws/credentials + or ~/.aws/config files, which has either access keys or role information + specified. If not specified, the default credential profile or, if on an + EC2 instance, credentials from IMDS will be used. + client: boto3 client for bedrock agent runtime. + retrieval_config: Configuration for retrieval. + + Example: + .. code-block:: python + + from langchain_community.retrievers import AmazonKnowledgeBasesRetriever + + retriever = AmazonKnowledgeBasesRetriever( + knowledge_base_id="", + retrieval_config={ + "vectorSearchConfiguration": { + "numberOfResults": 4 + } + }, + ) + """ + + knowledge_base_id: str + region_name: Optional[str] = None + credentials_profile_name: Optional[str] = None + endpoint_url: Optional[str] = None + client: Any + retrieval_config: RetrievalConfig + + @root_validator(pre=True) + def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if values.get("client") is not None: + return values + + try: + import boto3 + from botocore.client import Config + from botocore.exceptions import UnknownServiceError + + if values.get("credentials_profile_name"): + session = boto3.Session(profile_name=values["credentials_profile_name"]) + else: + # use default credentials + session = boto3.Session() + + client_params = { + "config": Config( + connect_timeout=120, read_timeout=120, retries={"max_attempts": 0} + ) + } + if values.get("region_name"): + client_params["region_name"] = values["region_name"] + + if values.get("endpoint_url"): + client_params["endpoint_url"] = values["endpoint_url"] + + values["client"] = session.client("bedrock-agent-runtime", **client_params) + + return values + except ImportError: + raise ModuleNotFoundError( + "Could not import boto3 python package. " + "Please install it with `pip install boto3`." + ) + except UnknownServiceError as e: + raise ModuleNotFoundError( + "Ensure that you have installed the latest boto3 package " + "that contains the API for `bedrock-runtime-agent`." + ) from e + except Exception as e: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + "profile name are valid." + ) from e + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + response = self.client.retrieve( + retrievalQuery={"text": query.strip()}, + knowledgeBaseId=self.knowledge_base_id, + retrievalConfiguration=self.retrieval_config.dict(), + ) + results = response["retrievalResults"] + documents = [] + for result in results: + documents.append( + Document( + page_content=result["content"]["text"], + metadata={ + "location": result["location"], + "score": result["score"] if "score" in result else 0, + }, + ) + ) + + return documents diff --git a/libs/aws/langchain_aws/retrievers/kendra.py b/libs/aws/langchain_aws/retrievers/kendra.py new file mode 100644 index 00000000..b4480cae --- /dev/null +++ b/libs/aws/langchain_aws/retrievers/kendra.py @@ -0,0 +1,479 @@ +import re +from abc import ABC, abstractmethod +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Union, +) + +from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.documents import Document +from langchain_core.pydantic_v1 import ( + BaseModel, + Extra, + Field, + root_validator, + validator, +) +from langchain_core.retrievers import BaseRetriever +from typing_extensions import Annotated + + +def clean_excerpt(excerpt: str) -> str: + """Clean an excerpt from Kendra. + + Args: + excerpt: The excerpt to clean. + + Returns: + The cleaned excerpt. + + """ + if not excerpt: + return excerpt + res = re.sub(r"\s+", " ", excerpt).replace("...", "") + return res + + +def combined_text(item: "ResultItem") -> str: + """Combine a ResultItem title and excerpt into a single string. + + Args: + item: the ResultItem of a Kendra search. + + Returns: + A combined text of the title and excerpt of the given item. + + """ + text = "" + title = item.get_title() + if title: + text += f"Document Title: {title}\n" + excerpt = clean_excerpt(item.get_excerpt()) + if excerpt: + text += f"Document Excerpt: \n{excerpt}\n" + return text + + +DocumentAttributeValueType = Union[str, int, List[str], None] +"""Possible types of a DocumentAttributeValue. + +Dates are also represented as str. +""" + + +# Unexpected keyword argument "extra" for "__init_subclass__" of "object" +class Highlight(BaseModel, extra=Extra.allow): # type: ignore[call-arg] + """Information that highlights the keywords in the excerpt.""" + + BeginOffset: int + """The zero-based location in the excerpt where the highlight starts.""" + EndOffset: int + """The zero-based location in the excerpt where the highlight ends.""" + TopAnswer: Optional[bool] + """Indicates whether the result is the best one.""" + Type: Optional[str] + """The highlight type: STANDARD or THESAURUS_SYNONYM.""" + + +# Unexpected keyword argument "extra" for "__init_subclass__" of "object" +class TextWithHighLights(BaseModel, extra=Extra.allow): # type: ignore[call-arg] + """Text with highlights.""" + + Text: str + """The text.""" + Highlights: Optional[Any] + """The highlights.""" + + +# Unexpected keyword argument "extra" for "__init_subclass__" of "object" +class AdditionalResultAttributeValue( # type: ignore[call-arg] + BaseModel, extra=Extra.allow +): + """Value of an additional result attribute.""" + + TextWithHighlightsValue: TextWithHighLights + """The text with highlights value.""" + + +# Unexpected keyword argument "extra" for "__init_subclass__" of "object" +class AdditionalResultAttribute(BaseModel, extra=Extra.allow): # type: ignore[call-arg] + """Additional result attribute.""" + + Key: str + """The key of the attribute.""" + ValueType: Literal["TEXT_WITH_HIGHLIGHTS_VALUE"] + """The type of the value.""" + Value: AdditionalResultAttributeValue + """The value of the attribute.""" + + def get_value_text(self) -> str: + return self.Value.TextWithHighlightsValue.Text + + +# Unexpected keyword argument "extra" for "__init_subclass__" of "object" +class DocumentAttributeValue(BaseModel, extra=Extra.allow): # type: ignore[call-arg] + """Value of a document attribute.""" + + DateValue: Optional[str] + """The date expressed as an ISO 8601 string.""" + LongValue: Optional[int] + """The long value.""" + StringListValue: Optional[List[str]] + """The string list value.""" + StringValue: Optional[str] + """The string value.""" + + @property + def value(self) -> DocumentAttributeValueType: + """The only defined document attribute value or None. + According to Amazon Kendra, you can only provide one + value for a document attribute. + """ + if self.DateValue: + return self.DateValue + if self.LongValue: + return self.LongValue + if self.StringListValue: + return self.StringListValue + if self.StringValue: + return self.StringValue + + return None + + +# Unexpected keyword argument "extra" for "__init_subclass__" of "object" +class DocumentAttribute(BaseModel, extra=Extra.allow): # type: ignore[call-arg] + """Document attribute.""" + + Key: str + """The key of the attribute.""" + Value: DocumentAttributeValue + """The value of the attribute.""" + + +# Unexpected keyword argument "extra" for "__init_subclass__" of "object" +class ResultItem(BaseModel, ABC, extra=Extra.allow): # type: ignore[call-arg] + """Base class of a result item.""" + + Id: Optional[str] + """The ID of the relevant result item.""" + DocumentId: Optional[str] + """The document ID.""" + DocumentURI: Optional[str] + """The document URI.""" + DocumentAttributes: Optional[List[DocumentAttribute]] = [] + """The document attributes.""" + ScoreAttributes: Optional[dict] + """The kendra score confidence""" + + @abstractmethod + def get_title(self) -> str: + """Document title.""" + + @abstractmethod + def get_excerpt(self) -> str: + """Document excerpt or passage original content as retrieved by Kendra.""" + + def get_additional_metadata(self) -> dict: + """Document additional metadata dict. + This returns any extra metadata except these: + * result_id + * document_id + * source + * title + * excerpt + * document_attributes + """ + return {} + + def get_document_attributes_dict(self) -> Dict[str, DocumentAttributeValueType]: + """Document attributes dict.""" + return {attr.Key: attr.Value.value for attr in (self.DocumentAttributes or [])} + + def get_score_attribute(self) -> str: + """Document Score Confidence""" + if self.ScoreAttributes is not None: + return self.ScoreAttributes["ScoreConfidence"] + else: + return "NOT_AVAILABLE" + + def to_doc( + self, page_content_formatter: Callable[["ResultItem"], str] = combined_text + ) -> Document: + """Converts this item to a Document.""" + page_content = page_content_formatter(self) + metadata = self.get_additional_metadata() + metadata.update( + { + "result_id": self.Id, + "document_id": self.DocumentId, + "source": self.DocumentURI, + "title": self.get_title(), + "excerpt": self.get_excerpt(), + "document_attributes": self.get_document_attributes_dict(), + "score": self.get_score_attribute(), + } + ) + return Document(page_content=page_content, metadata=metadata) + + +class QueryResultItem(ResultItem): + """Query API result item.""" + + DocumentTitle: TextWithHighLights + """The document title.""" + FeedbackToken: Optional[str] + """Identifies a particular result from a particular query.""" + Format: Optional[str] + """ + If the Type is ANSWER, then format is either: + * TABLE: a table excerpt is returned in TableExcerpt; + * TEXT: a text excerpt is returned in DocumentExcerpt. + """ + Type: Optional[str] + """Type of result: DOCUMENT or QUESTION_ANSWER or ANSWER""" + AdditionalAttributes: Optional[List[AdditionalResultAttribute]] = [] + """One or more additional attributes associated with the result.""" + DocumentExcerpt: Optional[TextWithHighLights] + """Excerpt of the document text.""" + + def get_title(self) -> str: + return self.DocumentTitle.Text + + def get_attribute_value(self) -> str: + if not self.AdditionalAttributes: + return "" + if not self.AdditionalAttributes[0]: + return "" + else: + return self.AdditionalAttributes[0].get_value_text() + + def get_excerpt(self) -> str: + if ( + self.AdditionalAttributes + and self.AdditionalAttributes[0].Key == "AnswerText" + ): + excerpt = self.get_attribute_value() + elif self.DocumentExcerpt: + excerpt = self.DocumentExcerpt.Text + else: + excerpt = "" + + return excerpt + + def get_additional_metadata(self) -> dict: + additional_metadata = {"type": self.Type} + return additional_metadata + + +class RetrieveResultItem(ResultItem): + """Retrieve API result item.""" + + DocumentTitle: Optional[str] + """The document title.""" + Content: Optional[str] + """The content of the item.""" + + def get_title(self) -> str: + return self.DocumentTitle or "" + + def get_excerpt(self) -> str: + return self.Content or "" + + +# Unexpected keyword argument "extra" for "__init_subclass__" of "object" +class QueryResult(BaseModel, extra=Extra.allow): # type: ignore[call-arg] + """`Amazon Kendra Query API` search result. + + It is composed of: + * Relevant suggested answers: either a text excerpt or table excerpt. + * Matching FAQs or questions-answer from your FAQ file. + * Documents including an excerpt of each document with its title. + """ + + ResultItems: List[QueryResultItem] + """The result items.""" + + +# Unexpected keyword argument "extra" for "__init_subclass__" of "object" +class RetrieveResult(BaseModel, extra=Extra.allow): # type: ignore[call-arg] + """`Amazon Kendra Retrieve API` search result. + + It is composed of: + * relevant passages or text excerpts given an input query. + """ + + QueryId: str + """The ID of the query.""" + ResultItems: List[RetrieveResultItem] + """The result items.""" + + +KENDRA_CONFIDENCE_MAPPING = { + "NOT_AVAILABLE": 0.0, + "LOW": 0.25, + "MEDIUM": 0.50, + "HIGH": 0.75, + "VERY_HIGH": 1.0, +} + + +class AmazonKendraRetriever(BaseRetriever): + """`Amazon Kendra Index` retriever. + + Args: + index_id: Kendra index id + + region_name: The aws region e.g., `us-west-2`. + Fallsback to AWS_DEFAULT_REGION env variable + or region specified in ~/.aws/config. + + credentials_profile_name: The name of the profile in the ~/.aws/credentials + or ~/.aws/config files, which has either access keys or role information + specified. If not specified, the default credential profile or, if on an + EC2 instance, credentials from IMDS will be used. + + top_k: No of results to return + + attribute_filter: Additional filtering of results based on metadata + See: https://docs.aws.amazon.com/kendra/latest/APIReference + + page_content_formatter: generates the Document page_content + allowing access to all result item attributes. By default, it uses + the item's title and excerpt. + + client: boto3 client for Kendra + + user_context: Provides information about the user context + See: https://docs.aws.amazon.com/kendra/latest/APIReference + + Example: + .. code-block:: python + + retriever = AmazonKendraRetriever( + index_id="c0806df7-e76b-4bce-9b5c-d5582f6b1a03" + ) + + """ + + index_id: str + region_name: Optional[str] = None + credentials_profile_name: Optional[str] = None + top_k: int = 3 + attribute_filter: Optional[Dict] = None + page_content_formatter: Callable[[ResultItem], str] = combined_text + client: Any + user_context: Optional[Dict] = None + min_score_confidence: Annotated[Optional[float], Field(ge=0.0, le=1.0)] + + @validator("top_k") + def validate_top_k(cls, value: int) -> int: + if value < 0: + raise ValueError(f"top_k ({value}) cannot be negative.") + return value + + @root_validator(pre=True) + def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if values.get("client") is not None: + return values + + try: + import boto3 + + if values.get("credentials_profile_name"): + session = boto3.Session(profile_name=values["credentials_profile_name"]) + else: + # use default credentials + session = boto3.Session() + + client_params = {} + if values.get("region_name"): + client_params["region_name"] = values["region_name"] + + values["client"] = session.client("kendra", **client_params) + + return values + except ImportError: + raise ModuleNotFoundError( + "Could not import boto3 python package. " + "Please install it with `pip install boto3`." + ) + except Exception as e: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + "profile name are valid." + ) from e + + def _kendra_query(self, query: str) -> Sequence[ResultItem]: + kendra_kwargs = { + "IndexId": self.index_id, + # truncate the query to ensure that + # there is no validation exception from Kendra. + "QueryText": query.strip()[0:999], + "PageSize": self.top_k, + } + if self.attribute_filter is not None: + kendra_kwargs["AttributeFilter"] = self.attribute_filter + if self.user_context is not None: + kendra_kwargs["UserContext"] = self.user_context + + response = self.client.retrieve(**kendra_kwargs) + r_result = RetrieveResult.parse_obj(response) + if r_result.ResultItems: + return r_result.ResultItems + + # Retrieve API returned 0 results, fall back to Query API + response = self.client.query(**kendra_kwargs) + q_result = QueryResult.parse_obj(response) + return q_result.ResultItems + + def _get_top_k_docs(self, result_items: Sequence[ResultItem]) -> List[Document]: + top_docs = [ + item.to_doc(self.page_content_formatter) + for item in result_items[: self.top_k] + ] + return top_docs + + def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]: + """ + Filter out the records that have a score confidence + greater than the required threshold. + """ + if not self.min_score_confidence: + return docs + filtered_docs = [ + item + for item in docs + if ( + item.metadata.get("score") is not None + and isinstance(item.metadata["score"], str) + and KENDRA_CONFIDENCE_MAPPING.get(item.metadata["score"], 0.0) + >= self.min_score_confidence + ) + ] + return filtered_docs + + def _get_relevant_documents( + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, + ) -> List[Document]: + """Run search on Kendra index and get top k documents + + Example: + .. code-block:: python + + docs = retriever.get_relevant_documents('This is my query') + + """ + result_items = self._kendra_query(query) + top_k_docs = self._get_top_k_docs(result_items) + return self._filter_by_score_confidence(top_k_docs) diff --git a/libs/aws/tests/unit_tests/test_imports.py b/libs/aws/tests/unit_tests/test_imports.py index 27ce128d..592cde2c 100644 --- a/libs/aws/tests/unit_tests/test_imports.py +++ b/libs/aws/tests/unit_tests/test_imports.py @@ -1,9 +1,15 @@ -from langchain_aws import llms -from tests.unit_tests import assert_all_importable +import glob +import importlib +from pathlib import Path -EXPECTED_ALL_LLMS = ["SagemakerEndpoint"] - -def test_imports() -> None: - assert sorted(llms.__all__) == sorted(EXPECTED_ALL_LLMS) - assert_all_importable(llms) +def test_importable_all() -> None: + for path in glob.glob("../langchain_aws/*"): + relative_path = Path(path).parts[-1] + if relative_path.endswith(".typed"): + continue + module_name = relative_path.split(".")[0] + module = importlib.import_module("langchain_aws." + module_name) + all_ = getattr(module, "__all__", []) + for cls_ in all_: + getattr(module, cls_) From 23a712a84d068b6ed8c9fbcd636b01e9f14316c8 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Tue, 2 Apr 2024 13:49:33 -0700 Subject: [PATCH 2/5] Fix linting --- libs/aws/langchain_aws/__init__.py | 5 +---- libs/aws/langchain_aws/retrievers/__init__.py | 8 ++++---- libs/aws/langchain_aws/retrievers/bedrock.py | 7 +++---- libs/aws/pyproject.toml | 5 +---- 4 files changed, 9 insertions(+), 16 deletions(-) diff --git a/libs/aws/langchain_aws/__init__.py b/libs/aws/langchain_aws/__init__.py index c98359f9..64110eb5 100644 --- a/libs/aws/langchain_aws/__init__.py +++ b/libs/aws/langchain_aws/__init__.py @@ -1,7 +1,4 @@ from langchain_aws.llms import SagemakerEndpoint from langchain_aws.retrievers import AmazonKendraRetriever -__all__ = [ - "SagemakerEndpoint", - "AmazonKendraRetriever" -] +__all__ = ["SagemakerEndpoint", "AmazonKendraRetriever"] diff --git a/libs/aws/langchain_aws/retrievers/__init__.py b/libs/aws/langchain_aws/retrievers/__init__.py index 2affadec..d5517c16 100644 --- a/libs/aws/langchain_aws/retrievers/__init__.py +++ b/libs/aws/langchain_aws/retrievers/__init__.py @@ -1,8 +1,8 @@ -from langchain_aws.retrievers.kendra import AmazonKendraRetriever from langchain_aws.retrievers.bedrock import AmazonKnowledgeBasesRetriever +from langchain_aws.retrievers.kendra import AmazonKendraRetriever __all__ = [ - "AmazonKendraRetriever" - "AmazonKnowledgeBasesRetriever" + "AmazonKendraRetriever", + "AmazonKendraRetriever", + "AmazonKnowledgeBasesRetriever", ] - diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index 0c3d1d66..068b904c 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -1,5 +1,8 @@ from typing import Any, Dict, List, Optional +import boto3 +from botocore.client import Config +from botocore.exceptions import UnknownServiceError from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document from langchain_core.pydantic_v1 import BaseModel, root_validator @@ -63,10 +66,6 @@ def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values try: - import boto3 - from botocore.client import Config - from botocore.exceptions import UnknownServiceError - if values.get("credentials_profile_name"): session = boto3.Session(profile_name=values["credentials_profile_name"]) else: diff --git a/libs/aws/pyproject.toml b/libs/aws/pyproject.toml index 33588913..d6b523b0 100644 --- a/libs/aws/pyproject.toml +++ b/libs/aws/pyproject.toml @@ -57,12 +57,9 @@ select = [ ] [tool.mypy] +ignore_missing_imports = "True" disallow_untyped_defs = "True" -[[tool.mypy.overrides]] -module = "boto3" -ignore_missing_imports = true - [tool.coverage.run] omit = ["tests/*"] From 9e1d84c98373471ecbaf821ede5d22b189620626 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Tue, 2 Apr 2024 18:03:23 -0700 Subject: [PATCH 3/5] Added tests for knowlegebases retriever --- libs/aws/langchain_aws/__init__.py | 11 +++- .../integration_tests/retrievers/__init__.py | 0 .../test_amazon_knowledgebases_retriever.py | 66 +++++++++++++++++++ 3 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 libs/aws/tests/integration_tests/retrievers/__init__.py create mode 100644 libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py diff --git a/libs/aws/langchain_aws/__init__.py b/libs/aws/langchain_aws/__init__.py index 64110eb5..00e29d09 100644 --- a/libs/aws/langchain_aws/__init__.py +++ b/libs/aws/langchain_aws/__init__.py @@ -1,4 +1,11 @@ from langchain_aws.llms import SagemakerEndpoint -from langchain_aws.retrievers import AmazonKendraRetriever +from langchain_aws.retrievers import ( + AmazonKendraRetriever, + AmazonKnowledgeBasesRetriever, +) -__all__ = ["SagemakerEndpoint", "AmazonKendraRetriever"] +__all__ = [ + "SagemakerEndpoint", + "AmazonKendraRetriever", + "AmazonKnowledgeBasesRetriever", +] diff --git a/libs/aws/tests/integration_tests/retrievers/__init__.py b/libs/aws/tests/integration_tests/retrievers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py b/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py new file mode 100644 index 00000000..f66a158f --- /dev/null +++ b/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py @@ -0,0 +1,66 @@ +from unittest.mock import Mock + +import pytest +from langchain_core.documents import Document + +from langchain_aws import AmazonKnowledgeBasesRetriever + + +@pytest.fixture +def mock_client() -> Mock: + return Mock() + + +@pytest.fixture +def retriever(mock_client: Mock) -> AmazonKnowledgeBasesRetriever: + return AmazonKnowledgeBasesRetriever( + knowledge_base_id="test-knowledge-base", + client=mock_client, + retrieval_config={"vectorSearchConfiguration": {"numberOfResults": 4}}, # type: ignore[arg-type] + ) + + +def test_get_relevant_documents(retriever, mock_client) -> None: # type: ignore[no-untyped-def] + response = { + "retrievalResults": [ + { + "content": {"text": "This is the first result."}, + "location": "location1", + "score": 0.9, + }, + { + "content": {"text": "This is the second result."}, + "location": "location2", + "score": 0.8, + }, + {"content": {"text": "This is the third result."}, "location": "location3"}, + ] + } + mock_client.retrieve.return_value = response + + query = "test query" + + expected_documents = [ + Document( + page_content="This is the first result.", + metadata={"location": "location1", "score": 0.9}, + ), + Document( + page_content="This is the second result.", + metadata={"location": "location2", "score": 0.8}, + ), + Document( + page_content="This is the third result.", + metadata={"location": "location3", "score": 0.0}, + ), + ] + + documents = retriever.get_relevant_documents(query) + + assert documents == expected_documents + + mock_client.retrieve.assert_called_once_with( + retrievalQuery={"text": "test query"}, + knowledgeBaseId="test-knowledge-base", + retrievalConfiguration={"vectorSearchConfiguration": {"numberOfResults": 4}}, + ) From bdcaf0c446ef5f52472b42c25c51bb10ccf6f6eb Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Tue, 2 Apr 2024 18:35:17 -0700 Subject: [PATCH 4/5] Added tests for kendra retriever. --- .../test_amazon_kendra_retriever.py | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 libs/aws/tests/integration_tests/retrievers/test_amazon_kendra_retriever.py diff --git a/libs/aws/tests/integration_tests/retrievers/test_amazon_kendra_retriever.py b/libs/aws/tests/integration_tests/retrievers/test_amazon_kendra_retriever.py new file mode 100644 index 00000000..24ecb9dc --- /dev/null +++ b/libs/aws/tests/integration_tests/retrievers/test_amazon_kendra_retriever.py @@ -0,0 +1,76 @@ +from typing import Any +from unittest.mock import Mock + +import pytest + +from langchain_aws import AmazonKendraRetriever +from langchain_aws.retrievers.kendra import RetrieveResultItem + + +@pytest.fixture +def mock_client() -> Mock: + mock_client = Mock() + return mock_client + + +@pytest.fixture +def retriever(mock_client: Any) -> AmazonKendraRetriever: + return AmazonKendraRetriever( + index_id="test_index_id", client=mock_client, top_k=3, min_score_confidence=0.6 + ) + + +def test_get_relevant_documents(retriever, mock_client) -> None: # type: ignore[no-untyped-def] + # Mock data for Kendra response + mock_retrieve_result = { + "QueryId": "test_query_id", + "ResultItems": [ + RetrieveResultItem( + Id="doc1", + DocumentId="doc1", + DocumentURI="https://example.com/doc1", + DocumentTitle="Document 1", + Content="This is the content of Document 1.", + ScoreAttributes={"ScoreConfidence": "HIGH"}, + ), + RetrieveResultItem( + Id="doc2", + DocumentId="doc2", + DocumentURI="https://example.com/doc2", + DocumentTitle="Document 2", + Content="This is the content of Document 2.", + ScoreAttributes={"ScoreConfidence": "MEDIUM"}, + ), + RetrieveResultItem( + Id="doc3", + DocumentId="doc3", + DocumentURI="https://example.com/doc3", + DocumentTitle="Document 3", + Content="This is the content of Document 3.", + ScoreAttributes={"ScoreConfidence": "HIGH"}, + ), + ], + } + + mock_client.retrieve.return_value = mock_retrieve_result + + query = "test query" + + docs = retriever.get_relevant_documents(query) + + # Only documents with confidence score of HIGH are returned + assert len(docs) == 2 + assert docs[0].page_content == ( + "Document Title: Document 1\nDocument Excerpt: \n" + "This is the content of Document 1.\n" + ) + assert docs[1].page_content == ( + "Document Title: Document 3\nDocument Excerpt: \n" + "This is the content of Document 3.\n" + ) + + # Assert that the mock methods were called with the expected arguments + mock_client.retrieve.assert_called_with( + IndexId="test_index_id", QueryText="test query", PageSize=3 + ) + mock_client.query.assert_not_called() From a30a4946ae5d8ed805ecadd0667df8ef1a516599 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Tue, 2 Apr 2024 18:44:14 -0700 Subject: [PATCH 5/5] Added coverage report. --- libs/aws/poetry.lock | 87 ++++++++++++++++++++++++++++++++++++++++- libs/aws/pyproject.toml | 3 +- 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/libs/aws/poetry.lock b/libs/aws/poetry.lock index 6f5ab5d1..3b6bd439 100644 --- a/libs/aws/poetry.lock +++ b/libs/aws/poetry.lock @@ -215,6 +215,73 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "coverage" +version = "7.4.4" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "coverage-7.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0be5efd5127542ef31f165de269f77560d6cdef525fffa446de6f7e9186cfb2"}, + {file = "coverage-7.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ccd341521be3d1b3daeb41960ae94a5e87abe2f46f17224ba5d6f2b8398016cf"}, + {file = "coverage-7.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09fa497a8ab37784fbb20ab699c246053ac294d13fc7eb40ec007a5043ec91f8"}, + {file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1a93009cb80730c9bca5d6d4665494b725b6e8e157c1cb7f2db5b4b122ea562"}, + {file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:690db6517f09336559dc0b5f55342df62370a48f5469fabf502db2c6d1cffcd2"}, + {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:09c3255458533cb76ef55da8cc49ffab9e33f083739c8bd4f58e79fecfe288f7"}, + {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8ce1415194b4a6bd0cdcc3a1dfbf58b63f910dcb7330fe15bdff542c56949f87"}, + {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b91cbc4b195444e7e258ba27ac33769c41b94967919f10037e6355e998af255c"}, + {file = "coverage-7.4.4-cp310-cp310-win32.whl", hash = "sha256:598825b51b81c808cb6f078dcb972f96af96b078faa47af7dfcdf282835baa8d"}, + {file = "coverage-7.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:09ef9199ed6653989ebbcaacc9b62b514bb63ea2f90256e71fea3ed74bd8ff6f"}, + {file = "coverage-7.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0f9f50e7ef2a71e2fae92774c99170eb8304e3fdf9c8c3c7ae9bab3e7229c5cf"}, + {file = "coverage-7.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:623512f8ba53c422fcfb2ce68362c97945095b864cda94a92edbaf5994201083"}, + {file = "coverage-7.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0513b9508b93da4e1716744ef6ebc507aff016ba115ffe8ecff744d1322a7b63"}, + {file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40209e141059b9370a2657c9b15607815359ab3ef9918f0196b6fccce8d3230f"}, + {file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a2b2b78c78293782fd3767d53e6474582f62443d0504b1554370bde86cc8227"}, + {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:73bfb9c09951125d06ee473bed216e2c3742f530fc5acc1383883125de76d9cd"}, + {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1f384c3cc76aeedce208643697fb3e8437604b512255de6d18dae3f27655a384"}, + {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:54eb8d1bf7cacfbf2a3186019bcf01d11c666bd495ed18717162f7eb1e9dd00b"}, + {file = "coverage-7.4.4-cp311-cp311-win32.whl", hash = "sha256:cac99918c7bba15302a2d81f0312c08054a3359eaa1929c7e4b26ebe41e9b286"}, + {file = "coverage-7.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:b14706df8b2de49869ae03a5ccbc211f4041750cd4a66f698df89d44f4bd30ec"}, + {file = "coverage-7.4.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:201bef2eea65e0e9c56343115ba3814e896afe6d36ffd37bab783261db430f76"}, + {file = "coverage-7.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:41c9c5f3de16b903b610d09650e5e27adbfa7f500302718c9ffd1c12cf9d6818"}, + {file = "coverage-7.4.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d898fe162d26929b5960e4e138651f7427048e72c853607f2b200909794ed978"}, + {file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ea79bb50e805cd6ac058dfa3b5c8f6c040cb87fe83de10845857f5535d1db70"}, + {file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce4b94265ca988c3f8e479e741693d143026632672e3ff924f25fab50518dd51"}, + {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:00838a35b882694afda09f85e469c96367daa3f3f2b097d846a7216993d37f4c"}, + {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fdfafb32984684eb03c2d83e1e51f64f0906b11e64482df3c5db936ce3839d48"}, + {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:69eb372f7e2ece89f14751fbcbe470295d73ed41ecd37ca36ed2eb47512a6ab9"}, + {file = "coverage-7.4.4-cp312-cp312-win32.whl", hash = "sha256:137eb07173141545e07403cca94ab625cc1cc6bc4c1e97b6e3846270e7e1fea0"}, + {file = "coverage-7.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:d71eec7d83298f1af3326ce0ff1d0ea83c7cb98f72b577097f9083b20bdaf05e"}, + {file = "coverage-7.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d5ae728ff3b5401cc320d792866987e7e7e880e6ebd24433b70a33b643bb0384"}, + {file = "coverage-7.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc4f1358cb0c78edef3ed237ef2c86056206bb8d9140e73b6b89fbcfcbdd40e1"}, + {file = "coverage-7.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8130a2aa2acb8788e0b56938786c33c7c98562697bf9f4c7d6e8e5e3a0501e4a"}, + {file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf271892d13e43bc2b51e6908ec9a6a5094a4df1d8af0bfc360088ee6c684409"}, + {file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4cdc86d54b5da0df6d3d3a2f0b710949286094c3a6700c21e9015932b81447e"}, + {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ae71e7ddb7a413dd60052e90528f2f65270aad4b509563af6d03d53e979feafd"}, + {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:38dd60d7bf242c4ed5b38e094baf6401faa114fc09e9e6632374388a404f98e7"}, + {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa5b1c1bfc28384f1f53b69a023d789f72b2e0ab1b3787aae16992a7ca21056c"}, + {file = "coverage-7.4.4-cp38-cp38-win32.whl", hash = "sha256:dfa8fe35a0bb90382837b238fff375de15f0dcdb9ae68ff85f7a63649c98527e"}, + {file = "coverage-7.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:b2991665420a803495e0b90a79233c1433d6ed77ef282e8e152a324bbbc5e0c8"}, + {file = "coverage-7.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3b799445b9f7ee8bf299cfaed6f5b226c0037b74886a4e11515e569b36fe310d"}, + {file = "coverage-7.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b4d33f418f46362995f1e9d4f3a35a1b6322cb959c31d88ae56b0298e1c22357"}, + {file = "coverage-7.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aadacf9a2f407a4688d700e4ebab33a7e2e408f2ca04dbf4aef17585389eff3e"}, + {file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c95949560050d04d46b919301826525597f07b33beba6187d04fa64d47ac82e"}, + {file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff7687ca3d7028d8a5f0ebae95a6e4827c5616b31a4ee1192bdfde697db110d4"}, + {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5fc1de20b2d4a061b3df27ab9b7c7111e9a710f10dc2b84d33a4ab25065994ec"}, + {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c74880fc64d4958159fbd537a091d2a585448a8f8508bf248d72112723974cbd"}, + {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:742a76a12aa45b44d236815d282b03cfb1de3b4323f3e4ec933acfae08e54ade"}, + {file = "coverage-7.4.4-cp39-cp39-win32.whl", hash = "sha256:d89d7b2974cae412400e88f35d86af72208e1ede1a541954af5d944a8ba46c57"}, + {file = "coverage-7.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:9ca28a302acb19b6af89e90f33ee3e1906961f94b54ea37de6737b7ca9d8827c"}, + {file = "coverage-7.4.4-pp38.pp39.pp310-none-any.whl", hash = "sha256:b2c5edc4ac10a7ef6605a966c58929ec6c1bd0917fb8c15cb3363f65aa40e677"}, + {file = "coverage-7.4.4.tar.gz", hash = "sha256:c901df83d097649e257e803be22592aedfd5182f07b3cc87d640bbb9afd50f49"}, +] + +[package.dependencies] +tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} + +[package.extras] +toml = ["tomli"] + [[package]] name = "exceptiongroup" version = "1.2.0" @@ -620,6 +687,24 @@ pytest = ">=7.0.0,<9" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] +[[package]] +name = "pytest-cov" +version = "4.1.0" +description = "Pytest plugin for measuring coverage." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"}, + {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"}, +] + +[package.dependencies] +coverage = {version = ">=5.2.1", extras = ["toml"]} +pytest = ">=4.6" + +[package.extras] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -865,4 +950,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "f5abc531b0f400758b685366c860fdf4fa6622d2e19487898221c2f02e35e9b4" +content-hash = "7269f017eec10dcbeae1b84c8734c9b33b902253142a169c747468a1cab4f80d" diff --git a/libs/aws/pyproject.toml b/libs/aws/pyproject.toml index d6b523b0..8f6e4763 100644 --- a/libs/aws/pyproject.toml +++ b/libs/aws/pyproject.toml @@ -20,6 +20,7 @@ optional = true [tool.poetry.group.test.dependencies] pytest = "^7.4.3" +pytest-cov = "^4.1.0" syrupy = "^4.0.2" pytest-asyncio = "^0.23.2" @@ -77,7 +78,7 @@ build-backend = "poetry.core.masonry.api" # # https://github.com/tophat/syrupy # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. -addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5" +addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5 --cov=langchain_aws" # Registering custom markers. # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers markers = [