From 13090a5638037d13acc8662420c47584473c7398 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Thu, 28 Mar 2024 14:29:26 -0700 Subject: [PATCH 1/3] Added Sagemaker Endpoint import tests. --- libs/aws/.gitignore | 181 ++++++++- libs/aws/langchain_aws/__init__.py | 12 +- .../aws/langchain_aws/chat_models/__init__.py | 0 libs/aws/langchain_aws/graphs/__init__.py | 0 libs/aws/langchain_aws/llms/__init__.py | 3 + .../langchain_aws/llms/sagemaker_endpoint.py | 375 ++++++++++++++++++ libs/aws/poetry.lock | 129 +++++- libs/aws/pyproject.toml | 2 + .../integration_tests/test_chat_models.py | 63 --- .../tests/integration_tests/test_compile.py | 7 - .../integration_tests/test_embeddings.py | 19 - libs/aws/tests/integration_tests/test_llms.py | 63 --- libs/aws/tests/unit_tests/__init__.py | 7 + libs/aws/tests/unit_tests/test_chat_models.py | 9 - libs/aws/tests/unit_tests/test_embeddings.py | 9 - libs/aws/tests/unit_tests/test_imports.py | 15 +- libs/aws/tests/unit_tests/test_llms.py | 7 - .../aws/tests/unit_tests/test_vectorstores.py | 6 - 18 files changed, 701 insertions(+), 206 deletions(-) create mode 100644 libs/aws/langchain_aws/chat_models/__init__.py create mode 100644 libs/aws/langchain_aws/graphs/__init__.py create mode 100644 libs/aws/langchain_aws/llms/__init__.py create mode 100644 libs/aws/langchain_aws/llms/sagemaker_endpoint.py delete mode 100644 libs/aws/tests/integration_tests/test_chat_models.py delete mode 100644 libs/aws/tests/integration_tests/test_compile.py delete mode 100644 libs/aws/tests/integration_tests/test_embeddings.py delete mode 100644 libs/aws/tests/integration_tests/test_llms.py delete mode 100644 libs/aws/tests/unit_tests/test_chat_models.py delete mode 100644 libs/aws/tests/unit_tests/test_embeddings.py delete mode 100644 libs/aws/tests/unit_tests/test_llms.py delete mode 100644 libs/aws/tests/unit_tests/test_vectorstores.py diff --git a/libs/aws/.gitignore b/libs/aws/.gitignore index bee8a64b..aed12c91 100644 --- a/libs/aws/.gitignore +++ b/libs/aws/.gitignore @@ -1 +1,180 @@ -__pycache__ +.vs/ +.vscode/ +.idea/ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Google GitHub Actions credentials files created by: +# https://github.com/google-github-actions/auth +# +# That action recommends adding this gitignore to prevent accidentally committing keys. +gha-creds-*.json + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv* +venv* +env/ +ENV/ +env.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# macOS display setting files +.DS_Store + +# Wandb directory +wandb/ + +# asdf tool versions +.tool-versions +/.ruff_cache/ + +*.pkl +*.bin + +# integration test artifacts +data_map* +\[('_type', 'fake'), ('stop', None)] + +# Replit files +*replit* + +node_modules +docs/.yarn/ +docs/node_modules/ +docs/.docusaurus/ +docs/.cache-loader/ +docs/_dist +docs/api_reference/*api_reference.rst +docs/api_reference/_build +docs/api_reference/*/ +!docs/api_reference/_static/ +!docs/api_reference/templates/ +!docs/api_reference/themes/ +docs/docs/build +docs/docs/node_modules +docs/docs/yarn.lock +_dist +docs/docs/templates + +prof diff --git a/libs/aws/langchain_aws/__init__.py b/libs/aws/langchain_aws/__init__.py index f5982912..4c8cc796 100644 --- a/libs/aws/langchain_aws/__init__.py +++ b/libs/aws/langchain_aws/__init__.py @@ -1,11 +1,3 @@ -from langchain_aws.chat_models import ChatBedrock -from langchain_aws.embeddings import BedrockEmbeddings -from langchain_aws.llms import BedrockLLM -from langchain_aws.vectorstores import BedrockVectorStore +from langchain_aws.llms import SagemakerEndpoint -__all__ = [ - "BedrockLLM", - "ChatBedrock", - "BedrockVectorStore", - "BedrockEmbeddings", -] +__all__ = ["SagemakerEndpoint"] diff --git a/libs/aws/langchain_aws/chat_models/__init__.py b/libs/aws/langchain_aws/chat_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/aws/langchain_aws/graphs/__init__.py b/libs/aws/langchain_aws/graphs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/aws/langchain_aws/llms/__init__.py b/libs/aws/langchain_aws/llms/__init__.py new file mode 100644 index 00000000..1c1157b6 --- /dev/null +++ b/libs/aws/langchain_aws/llms/__init__.py @@ -0,0 +1,3 @@ +from langchain_aws.llms.sagemaker_endpoint import SagemakerEndpoint + +__all__ = ["SagemakerEndpoint"] diff --git a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py new file mode 100644 index 00000000..b0994730 --- /dev/null +++ b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py @@ -0,0 +1,375 @@ +"""Sagemaker InvokeEndpoint API.""" +import io +import json +import re +from abc import abstractmethod +from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.llms import LLM +from langchain_core.pydantic_v1 import Extra, root_validator + +INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]]) +OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]], Iterator]) + + +def enforce_stop_tokens(text: str, stop: List[str]) -> str: + """Cut off the text as soon as any stop words occur.""" + return re.split("|".join(stop), text, maxsplit=1)[0] + + +class LineIterator: + """ + A helper class for parsing the byte stream input. + + The output of the model will be in the following format: + + b'{"outputs": [" a"]}\n' + b'{"outputs": [" challenging"]}\n' + b'{"outputs": [" problem"]}\n' + ... + + While usually each PayloadPart event from the event stream will + contain a byte array with a full json, this is not guaranteed + and some of the json objects may be split acrossPayloadPart events. + + For example: + + {'PayloadPart': {'Bytes': b'{"outputs": '}} + {'PayloadPart': {'Bytes': b'[" problem"]}\n'}} + + + This class accounts for this by concatenating bytes written via the 'write' function + and then exposing a method which will return lines (ending with a '\n' character) + within the buffer via the 'scan_lines' function. + It maintains the position of the last read position to ensure + that previous bytes are not exposed again. + + For more details see: + https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/ + """ + + def __init__(self, stream: Any) -> None: + self.byte_iterator = iter(stream) + self.buffer = io.BytesIO() + self.read_pos = 0 + + def __iter__(self) -> "LineIterator": + return self + + def __next__(self) -> Any: + while True: + self.buffer.seek(self.read_pos) + line = self.buffer.readline() + if line and line[-1] == ord("\n"): + self.read_pos += len(line) + return line[:-1] + try: + chunk = next(self.byte_iterator) + except StopIteration: + if self.read_pos < self.buffer.getbuffer().nbytes: + continue + raise + if "PayloadPart" not in chunk: + # Unknown Event Type + continue + self.buffer.seek(0, io.SEEK_END) + self.buffer.write(chunk["PayloadPart"]["Bytes"]) + + +class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]): + """A handler class to transform input from LLM to a + format that SageMaker endpoint expects. + + Similarly, the class handles transforming output from the + SageMaker endpoint to a format that LLM class expects. + """ + + """ + Example: + .. code-block:: python + + class ContentHandler(ContentHandlerBase): + content_type = "application/json" + accepts = "application/json" + + def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: + input_str = json.dumps({prompt: prompt, **model_kwargs}) + return input_str.encode('utf-8') + + def transform_output(self, output: bytes) -> str: + response_json = json.loads(output.read().decode("utf-8")) + return response_json[0]["generated_text"] + """ + + content_type: Optional[str] = "text/plain" + """The MIME type of the input data passed to endpoint""" + + accepts: Optional[str] = "text/plain" + """The MIME type of the response data returned from endpoint""" + + @abstractmethod + def transform_input(self, prompt: INPUT_TYPE, model_kwargs: Dict) -> bytes: + """Transforms the input to a format that model can accept + as the request Body. Should return bytes or seekable file + like object in the format specified in the content_type + request header. + """ + + @abstractmethod + def transform_output(self, output: bytes) -> OUTPUT_TYPE: + """Transforms the output from the model to string that + the LLM class expects. + """ + + +class LLMContentHandler(ContentHandlerBase[str, str]): + """Content handler for LLM class.""" + + +class SagemakerEndpoint(LLM): + """Sagemaker Inference Endpoint models. + + To use, you must supply the endpoint name from your deployed + Sagemaker model & the region where it is deployed. + + To authenticate, the AWS client uses the following methods to + automatically load credentials: + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + + If a specific credential profile should be used, you must pass + the name of the profile from the ~/.aws/credentials file that is to be used. + + Make sure the credentials / roles used have the required policies to + access the Sagemaker endpoint. + See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html + """ + + """ + Args: + + region_name: The aws region e.g., `us-west-2`. + Fallsback to AWS_DEFAULT_REGION env variable + or region specified in ~/.aws/config. + + credentials_profile_name: The name of the profile in the ~/.aws/credentials + or ~/.aws/config files, which has either access keys or role information + specified. If not specified, the default credential profile or, if on an + EC2 instance, credentials from IMDS will be used. + + client: boto3 client for Sagemaker Endpoint + + content_handler: Implementation for model specific LLMContentHandler + + + Example: + .. code-block:: python + + from langchain_community.llms import SagemakerEndpoint + endpoint_name = ( + "my-endpoint-name" + ) + region_name = ( + "us-west-2" + ) + credentials_profile_name = ( + "default" + ) + se = SagemakerEndpoint( + endpoint_name=endpoint_name, + region_name=region_name, + credentials_profile_name=credentials_profile_name + ) + + #Use with boto3 client + client = boto3.client( + "sagemaker-runtime", + region_name=region_name + ) + + se = SagemakerEndpoint( + endpoint_name=endpoint_name, + client=client + ) + + """ + client: Any = None + """Boto3 client for sagemaker runtime""" + + endpoint_name: str = "" + """The name of the endpoint from the deployed Sagemaker model. + Must be unique within an AWS Region.""" + + region_name: str = "" + """The aws region where the Sagemaker model is deployed, eg. `us-west-2`.""" + + credentials_profile_name: Optional[str] = None + """The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which + has either access keys or role information specified. + If not specified, the default credential profile or, if on an EC2 instance, + credentials from IMDS will be used. + See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + """ + + content_handler: LLMContentHandler + """The content handler class that provides an input and + output transform functions to handle formats between LLM + and the endpoint. + """ + + streaming: bool = False + """Whether to stream the results.""" + + """ + Example: + .. code-block:: python + + from langchain_community.llms.sagemaker_endpoint import LLMContentHandler + + class ContentHandler(LLMContentHandler): + content_type = "application/json" + accepts = "application/json" + + def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: + input_str = json.dumps({prompt: prompt, **model_kwargs}) + return input_str.encode('utf-8') + + def transform_output(self, output: bytes) -> str: + response_json = json.loads(output.read().decode("utf-8")) + return response_json[0]["generated_text"] + """ + + model_kwargs: Optional[Dict] = None + """Keyword arguments to pass to the model.""" + + endpoint_kwargs: Optional[Dict] = None + """Optional attributes passed to the invoke_endpoint + function. See `boto3`_. docs for more info. + .. _boto3: + """ + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Dont do anything if client provided externally""" + if values.get("client") is not None: + return values + + """Validate that AWS credentials to and python package exists in environment.""" + try: + import boto3 + + try: + if values["credentials_profile_name"] is not None: + session = boto3.Session( + profile_name=values["credentials_profile_name"] + ) + else: + # use default credentials + session = boto3.Session() + + values["client"] = session.client( + "sagemaker-runtime", region_name=values["region_name"] + ) + + except Exception as e: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + "profile name are valid." + ) from e + + except ImportError: + raise ImportError( + "Could not import boto3 python package. " + "Please install it with `pip install boto3`." + ) + return values + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + _model_kwargs = self.model_kwargs or {} + return { + **{"endpoint_name": self.endpoint_name}, + **{"model_kwargs": _model_kwargs}, + } + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "sagemaker_endpoint" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call out to Sagemaker inference endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = se("Tell me a joke.") + """ + _model_kwargs = self.model_kwargs or {} + _model_kwargs = {**_model_kwargs, **kwargs} + _endpoint_kwargs = self.endpoint_kwargs or {} + + body = self.content_handler.transform_input(prompt, _model_kwargs) + content_type = self.content_handler.content_type + accepts = self.content_handler.accepts + + if self.streaming and run_manager: + try: + resp = self.client.invoke_endpoint_with_response_stream( + EndpointName=self.endpoint_name, + Body=body, + ContentType=self.content_handler.content_type, + **_endpoint_kwargs, + ) + iterator = LineIterator(resp["Body"]) + current_completion: str = "" + for line in iterator: + resp = json.loads(line) + resp_output = resp.get("outputs")[0] + if stop is not None: + # Uses same approach as below + resp_output = enforce_stop_tokens(resp_output, stop) + current_completion += resp_output + run_manager.on_llm_new_token(resp_output) + return current_completion + except Exception as e: + raise ValueError(f"Error raised by streaming inference endpoint: {e}") + else: + try: + response = self.client.invoke_endpoint( + EndpointName=self.endpoint_name, + Body=body, + ContentType=content_type, + Accept=accepts, + **_endpoint_kwargs, + ) + except Exception as e: + raise ValueError(f"Error raised by inference endpoint: {e}") + + text = self.content_handler.transform_output(response["Body"]) + if stop is not None: + # This is a bit hacky, but I can't figure out a better way to enforce + # stop tokens when making calls to the sagemaker endpoint. + text = enforce_stop_tokens(text, stop) + + return text diff --git a/libs/aws/poetry.lock b/libs/aws/poetry.lock index 4b16314e..6f5ab5d1 100644 --- a/libs/aws/poetry.lock +++ b/libs/aws/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -36,6 +36,47 @@ doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphin test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] trio = ["trio (>=0.23)"] +[[package]] +name = "boto3" +version = "1.34.72" +description = "The AWS SDK for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "boto3-1.34.72-py3-none-any.whl", hash = "sha256:a33585ef0d811ee0dffd92a96108344997a3059262c57349be0761d7885f6ae7"}, + {file = "boto3-1.34.72.tar.gz", hash = "sha256:cbfabd99c113bbb1708c2892e864b6dd739593b97a76fbb2e090a7d965b63b82"}, +] + +[package.dependencies] +botocore = ">=1.34.72,<1.35.0" +jmespath = ">=0.7.1,<2.0.0" +s3transfer = ">=0.10.0,<0.11.0" + +[package.extras] +crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] + +[[package]] +name = "botocore" +version = "1.34.72" +description = "Low-level, data-driven core of boto 3." +optional = false +python-versions = ">=3.8" +files = [ + {file = "botocore-1.34.72-py3-none-any.whl", hash = "sha256:a6b92735a73c19a7e540d77320420da3af3f32c91fa661c738c0b8c9f912d782"}, + {file = "botocore-1.34.72.tar.gz", hash = "sha256:342edb6f91d5839e790411822fc39f9c712c87cdaa7f3b1999f50b1ca16c4a14"}, +] + +[package.dependencies] +jmespath = ">=0.7.1,<2.0.0" +python-dateutil = ">=2.1,<3.0.0" +urllib3 = [ + {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, + {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}, +] + +[package.extras] +crt = ["awscrt (==0.19.19)"] + [[package]] name = "certifi" version = "2024.2.2" @@ -210,6 +251,17 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "jmespath" +version = "1.0.1" +description = "JSON Matching Expressions" +optional = false +python-versions = ">=3.7" +files = [ + {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, + {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, +] + [[package]] name = "jsonpatch" version = "1.33" @@ -568,6 +620,20 @@ pytest = ">=7.0.0,<9" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +description = "Extensions to the standard Python datetime module" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, +] + +[package.dependencies] +six = ">=1.5" + [[package]] name = "pyyaml" version = "6.0.1" @@ -593,7 +659,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -675,6 +740,34 @@ files = [ {file = "ruff-0.1.15.tar.gz", hash = "sha256:f6dfa8c1b21c913c326919056c390966648b680966febcb796cc9d1aaab8564e"}, ] +[[package]] +name = "s3transfer" +version = "0.10.1" +description = "An Amazon S3 Transfer Manager" +optional = false +python-versions = ">= 3.8" +files = [ + {file = "s3transfer-0.10.1-py3-none-any.whl", hash = "sha256:ceb252b11bcf87080fb7850a224fb6e05c8a776bab8f2b64b7f25b969464839d"}, + {file = "s3transfer-0.10.1.tar.gz", hash = "sha256:5683916b4c724f799e600f41dd9e10a9ff19871bf87623cc8f491cb4f5fa0a19"}, +] + +[package.dependencies] +botocore = ">=1.33.2,<2.0a.0" + +[package.extras] +crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] + +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + [[package]] name = "sniffio" version = "1.3.1" @@ -686,6 +779,20 @@ files = [ {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] +[[package]] +name = "syrupy" +version = "4.6.1" +description = "Pytest Snapshot Test Utility" +optional = false +python-versions = ">=3.8.1,<4" +files = [ + {file = "syrupy-4.6.1-py3-none-any.whl", hash = "sha256:203e52f9cb9fa749cf683f29bd68f02c16c3bc7e7e5fe8f2fc59bdfe488ce133"}, + {file = "syrupy-4.6.1.tar.gz", hash = "sha256:37a835c9ce7857eeef86d62145885e10b3cb9615bc6abeb4ce404b3f18e1bb36"}, +] + +[package.dependencies] +pytest = ">=7.0.0,<9.0.0" + [[package]] name = "tenacity" version = "8.2.3" @@ -722,6 +829,22 @@ files = [ {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, ] +[[package]] +name = "urllib3" +version = "1.26.18" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +files = [ + {file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"}, + {file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"}, +] + +[package.extras] +brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] +secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] +socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] + [[package]] name = "urllib3" version = "2.2.1" @@ -742,4 +865,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "cb37517ca81743a3ed239231d386a90a6d380ad58227f13508aaedaa56a0538f" +content-hash = "f5abc531b0f400758b685366c860fdf4fa6622d2e19487898221c2f02e35e9b4" diff --git a/libs/aws/pyproject.toml b/libs/aws/pyproject.toml index bc120073..40d34f6d 100644 --- a/libs/aws/pyproject.toml +++ b/libs/aws/pyproject.toml @@ -13,12 +13,14 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" langchain-core = "^0.1" +boto3 = "^1.34.72" [tool.poetry.group.test] optional = true [tool.poetry.group.test.dependencies] pytest = "^7.4.3" +syrupy = "^4.0.2" pytest-asyncio = "^0.23.2" [tool.poetry.group.codespell] diff --git a/libs/aws/tests/integration_tests/test_chat_models.py b/libs/aws/tests/integration_tests/test_chat_models.py deleted file mode 100644 index 0ec212d0..00000000 --- a/libs/aws/tests/integration_tests/test_chat_models.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Test ChatBedrock chat model.""" -from langchain_aws.chat_models import ChatBedrock - - -def test_stream() -> None: - """Test streaming tokens from OpenAI.""" - llm = ChatBedrock() - - for token in llm.stream("I'm Pickle Rick"): - assert isinstance(token.content, str) - - -async def test_astream() -> None: - """Test streaming tokens from OpenAI.""" - llm = ChatBedrock() - - async for token in llm.astream("I'm Pickle Rick"): - assert isinstance(token.content, str) - - -async def test_abatch() -> None: - """Test streaming tokens from ChatBedrock.""" - llm = ChatBedrock() - - result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) - for token in result: - assert isinstance(token.content, str) - - -async def test_abatch_tags() -> None: - """Test batch tokens from ChatBedrock.""" - llm = ChatBedrock() - - result = await llm.abatch( - ["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]} - ) - for token in result: - assert isinstance(token.content, str) - - -def test_batch() -> None: - """Test batch tokens from ChatBedrock.""" - llm = ChatBedrock() - - result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) - for token in result: - assert isinstance(token.content, str) - - -async def test_ainvoke() -> None: - """Test invoke tokens from ChatBedrock.""" - llm = ChatBedrock() - - result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) - assert isinstance(result.content, str) - - -def test_invoke() -> None: - """Test invoke tokens from ChatBedrock.""" - llm = ChatBedrock() - - result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) - assert isinstance(result.content, str) diff --git a/libs/aws/tests/integration_tests/test_compile.py b/libs/aws/tests/integration_tests/test_compile.py deleted file mode 100644 index 33ecccdf..00000000 --- a/libs/aws/tests/integration_tests/test_compile.py +++ /dev/null @@ -1,7 +0,0 @@ -import pytest - - -@pytest.mark.compile -def test_placeholder() -> None: - """Used for compiling integration tests without running any real tests.""" - pass diff --git a/libs/aws/tests/integration_tests/test_embeddings.py b/libs/aws/tests/integration_tests/test_embeddings.py deleted file mode 100644 index 406b053a..00000000 --- a/libs/aws/tests/integration_tests/test_embeddings.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Test Bedrock embeddings.""" -from langchain_aws.embeddings import BedrockEmbeddings - - -def test_langchain_aws_embedding_documents() -> None: - """Test cohere embeddings.""" - documents = ["foo bar"] - embedding = BedrockEmbeddings() - output = embedding.embed_documents(documents) - assert len(output) == 1 - assert len(output[0]) > 0 - - -def test_langchain_aws_embedding_query() -> None: - """Test cohere embeddings.""" - document = "foo bar" - embedding = BedrockEmbeddings() - output = embedding.embed_query(document) - assert len(output) > 0 diff --git a/libs/aws/tests/integration_tests/test_llms.py b/libs/aws/tests/integration_tests/test_llms.py deleted file mode 100644 index bbfea02c..00000000 --- a/libs/aws/tests/integration_tests/test_llms.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Test BedrockLLM llm.""" -from langchain_aws.llms import BedrockLLM - - -def test_stream() -> None: - """Test streaming tokens from OpenAI.""" - llm = BedrockLLM() - - for token in llm.stream("I'm Pickle Rick"): - assert isinstance(token, str) - - -async def test_astream() -> None: - """Test streaming tokens from OpenAI.""" - llm = BedrockLLM() - - async for token in llm.astream("I'm Pickle Rick"): - assert isinstance(token, str) - - -async def test_abatch() -> None: - """Test streaming tokens from BedrockLLM.""" - llm = BedrockLLM() - - result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) - for token in result: - assert isinstance(token, str) - - -async def test_abatch_tags() -> None: - """Test batch tokens from BedrockLLM.""" - llm = BedrockLLM() - - result = await llm.abatch( - ["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]} - ) - for token in result: - assert isinstance(token, str) - - -def test_batch() -> None: - """Test batch tokens from BedrockLLM.""" - llm = BedrockLLM() - - result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) - for token in result: - assert isinstance(token, str) - - -async def test_ainvoke() -> None: - """Test invoke tokens from BedrockLLM.""" - llm = BedrockLLM() - - result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) - assert isinstance(result, str) - - -def test_invoke() -> None: - """Test invoke tokens from BedrockLLM.""" - llm = BedrockLLM() - - result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) - assert isinstance(result, str) diff --git a/libs/aws/tests/unit_tests/__init__.py b/libs/aws/tests/unit_tests/__init__.py index e69de29b..800bc7f3 100644 --- a/libs/aws/tests/unit_tests/__init__.py +++ b/libs/aws/tests/unit_tests/__init__.py @@ -0,0 +1,7 @@ +"""All unit tests (lightweight tests).""" +from typing import Any + + +def assert_all_importable(module: Any) -> None: + for attr in module.__all__: + getattr(module, attr) diff --git a/libs/aws/tests/unit_tests/test_chat_models.py b/libs/aws/tests/unit_tests/test_chat_models.py deleted file mode 100644 index da522d6a..00000000 --- a/libs/aws/tests/unit_tests/test_chat_models.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Test chat model integration.""" - - -from langchain_aws.chat_models import ChatBedrock - - -def test_initialization() -> None: - """Test chat model initialization.""" - ChatBedrock() diff --git a/libs/aws/tests/unit_tests/test_embeddings.py b/libs/aws/tests/unit_tests/test_embeddings.py deleted file mode 100644 index c04511b2..00000000 --- a/libs/aws/tests/unit_tests/test_embeddings.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Test embedding model integration.""" - - -from langchain_aws.embeddings import BedrockEmbeddings - - -def test_initialization() -> None: - """Test embedding model initialization.""" - BedrockEmbeddings() diff --git a/libs/aws/tests/unit_tests/test_imports.py b/libs/aws/tests/unit_tests/test_imports.py index 0e706797..27ce128d 100644 --- a/libs/aws/tests/unit_tests/test_imports.py +++ b/libs/aws/tests/unit_tests/test_imports.py @@ -1,12 +1,9 @@ -from langchain_aws import __all__ +from langchain_aws import llms +from tests.unit_tests import assert_all_importable -EXPECTED_ALL = [ - "BedrockLLM", - "ChatBedrock", - "BedrockVectorStore", - "BedrockEmbeddings", -] +EXPECTED_ALL_LLMS = ["SagemakerEndpoint"] -def test_all_imports() -> None: - assert sorted(EXPECTED_ALL) == sorted(__all__) +def test_imports() -> None: + assert sorted(llms.__all__) == sorted(EXPECTED_ALL_LLMS) + assert_all_importable(llms) diff --git a/libs/aws/tests/unit_tests/test_llms.py b/libs/aws/tests/unit_tests/test_llms.py deleted file mode 100644 index d9f19b30..00000000 --- a/libs/aws/tests/unit_tests/test_llms.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Test Bedrock Chat API wrapper.""" -from langchain_aws import BedrockLLM - - -def test_initialization() -> None: - """Test integration initialization.""" - BedrockLLM() diff --git a/libs/aws/tests/unit_tests/test_vectorstores.py b/libs/aws/tests/unit_tests/test_vectorstores.py deleted file mode 100644 index 9077fa15..00000000 --- a/libs/aws/tests/unit_tests/test_vectorstores.py +++ /dev/null @@ -1,6 +0,0 @@ -from langchain_aws.vectorstores import BedrockVectorStore - - -def test_initialization() -> None: - """Test integration vectorstore initialization.""" - BedrockVectorStore() From f1cccf84530ccd7625b6c7b8c197c32c322f257e Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Thu, 28 Mar 2024 14:56:35 -0700 Subject: [PATCH 2/3] Fixed lint error. --- libs/aws/pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/libs/aws/pyproject.toml b/libs/aws/pyproject.toml index 40d34f6d..3d60dab8 100644 --- a/libs/aws/pyproject.toml +++ b/libs/aws/pyproject.toml @@ -59,6 +59,10 @@ select = [ [tool.mypy] disallow_untyped_defs = "True" +[[tool.mypy.overrides]] +module = "boto3" +ignore_missing_imports = true + [tool.coverage.run] omit = ["tests/*"] From 8f58ed3d70fefdf2c8442137c91de16287344514 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Thu, 28 Mar 2024 15:33:59 -0700 Subject: [PATCH 3/3] Added integration tests. --- libs/aws/pyproject.toml | 3 +++ .../integration_tests/chat_models/__init__.py | 0 .../integration_tests/graphs/__init__.py | 0 .../tests/integration_tests/llms/__init__.py | 0 .../llms/test_sagemaker_endpoint.py | 21 +++++++++++++++++++ .../tests/integration_tests/test_compile.py | 7 +++++++ 6 files changed, 31 insertions(+) create mode 100644 libs/aws/tests/integration_tests/chat_models/__init__.py create mode 100644 libs/aws/tests/integration_tests/graphs/__init__.py create mode 100644 libs/aws/tests/integration_tests/llms/__init__.py create mode 100644 libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py create mode 100644 libs/aws/tests/integration_tests/test_compile.py diff --git a/libs/aws/pyproject.toml b/libs/aws/pyproject.toml index 3d60dab8..33588913 100644 --- a/libs/aws/pyproject.toml +++ b/libs/aws/pyproject.toml @@ -84,6 +84,9 @@ addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5 # Registering custom markers. # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers markers = [ + "requires: mark tests as requiring a specific library", + "asyncio: mark tests as requiring asyncio", "compile: mark placeholder test used to compile integration tests without running them", + "scheduled: mark tests to run in scheduled testing", ] asyncio_mode = "auto" diff --git a/libs/aws/tests/integration_tests/chat_models/__init__.py b/libs/aws/tests/integration_tests/chat_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/aws/tests/integration_tests/graphs/__init__.py b/libs/aws/tests/integration_tests/graphs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/aws/tests/integration_tests/llms/__init__.py b/libs/aws/tests/integration_tests/llms/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py b/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py new file mode 100644 index 00000000..cb5e32ef --- /dev/null +++ b/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py @@ -0,0 +1,21 @@ +from typing import Dict + +from langchain_aws.llms import SagemakerEndpoint +from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler + + +class ContentHandler(LLMContentHandler): + def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: + return b"" + + def transform_output(self, output: bytes) -> str: + return "" + + +def test_sagemaker_endpoint_name_param() -> None: + llm = SagemakerEndpoint( + endpoint_name="foo", + content_handler=ContentHandler(), + region_name="us-west-2", + ) + assert llm.endpoint_name == "foo" diff --git a/libs/aws/tests/integration_tests/test_compile.py b/libs/aws/tests/integration_tests/test_compile.py new file mode 100644 index 00000000..33ecccdf --- /dev/null +++ b/libs/aws/tests/integration_tests/test_compile.py @@ -0,0 +1,7 @@ +import pytest + + +@pytest.mark.compile +def test_placeholder() -> None: + """Used for compiling integration tests without running any real tests.""" + pass