From f76bd26dee8aaece6856ef0f516442eda10eb4e4 Mon Sep 17 00:00:00 2001 From: Pravali Uppugunduri Date: Fri, 4 Oct 2024 19:17:35 +0000 Subject: [PATCH 1/4] add inference component support to SageMaker Endpoint --- .../langchain_aws/llms/sagemaker_endpoint.py | 54 +++++++++--- .../llms/test_sagemaker_endpoint.py | 85 +++++++++++++++++++ 2 files changed, 125 insertions(+), 14 deletions(-) diff --git a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py index bc0a480d..4ad03e8b 100644 --- a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py +++ b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py @@ -183,6 +183,14 @@ class SagemakerEndpoint(LLM): region_name=region_name, credentials_profile_name=credentials_profile_name ) + + #Use with SageMaker Endpoint and Inference Component + se = SagemakerEndpoint( + endpoint_name=endpoint_name, + inference_component_name=inference_component_name, + region_name=region_name, + credentials_profile_name=credentials_profile_name + ) #Use with boto3 client client = boto3.client( @@ -203,6 +211,10 @@ class SagemakerEndpoint(LLM): """The name of the endpoint from the deployed Sagemaker model. Must be unique within an AWS Region.""" + inference_component_name: Optional[str] = None + """Optional name of the inference component to invoke + if specified with endpoint name.""" + region_name: str = "" """The aws region where the Sagemaker model is deployed, eg. `us-west-2`.""" @@ -296,6 +308,7 @@ def _identifying_params(self) -> Mapping[str, Any]: _model_kwargs = self.model_kwargs or {} return { **{"endpoint_name": self.endpoint_name}, + **{"inference_component_name": self.inference_component_name}, **{"model_kwargs": _model_kwargs}, } @@ -315,13 +328,19 @@ def _stream( _model_kwargs = {**_model_kwargs, **kwargs} _endpoint_kwargs = self.endpoint_kwargs or {} + invocation_params = { + "EndpointName": self.endpoint_name, + "Body": self.content_handler.transform_input(prompt, _model_kwargs), + "ContentType": self.content_handler.content_type, + **_endpoint_kwargs, + } + + # If inference_compoent_name is specified, append it to invocation_params + if self.inference_component_name: + invocation_params["InferenceComponentName"] = self.inference_component_name + try: - resp = self.client.invoke_endpoint_with_response_stream( - EndpointName=self.endpoint_name, - Body=self.content_handler.transform_input(prompt, _model_kwargs), - ContentType=self.content_handler.content_type, - **_endpoint_kwargs, - ) + resp = self.client.invoke_endpoint_with_response_stream(**invocation_params) iterator = LineIterator(resp["Body"]) for line in iterator: @@ -349,7 +368,8 @@ def _call( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: - """Call out to Sagemaker inference endpoint. + """Call out to SageMaker inference endpoint or inference component + of SageMaker endpoint Args: prompt: The prompt to pass into the model. @@ -371,6 +391,18 @@ def _call( content_type = self.content_handler.content_type accepts = self.content_handler.accepts + invocation_params = { + "EndpointName": self.endpoint_name, + "Body": body, + "ContentType": content_type, + "Accept": accepts, + **_endpoint_kwargs, + } + + # If inference_compoent_name is specified, append it to invocation_params + if self.inference_component_name: + invocation_params["InferenceComponentName"] = self.inference_component_name + if self.streaming and run_manager: completion: str = "" for chunk in self._stream(prompt, stop, run_manager, **kwargs): @@ -378,13 +410,7 @@ def _call( return completion try: - response = self.client.invoke_endpoint( - EndpointName=self.endpoint_name, - Body=body, - ContentType=content_type, - Accept=accepts, - **_endpoint_kwargs, - ) + response = self.client.invoke_endpoint(**invocation_params) except Exception as e: logging.error(f"Error raised by inference endpoint: {e}") if run_manager is not None: diff --git a/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py b/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py index bac022a0..0ea646bf 100644 --- a/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py +++ b/libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py @@ -49,6 +49,39 @@ def test_sagemaker_endpoint_invoke() -> None: ) +def test_sagemaker_endpoint_inference_component_invoke() -> None: + client = Mock() + response = { + "ContentType": "application/json", + "Body": b'[{"generated_text": "SageMaker Endpoint"}]', + } + client.invoke_endpoint.return_value = response + + llm = SagemakerEndpoint( + endpoint_name="my-endpoint", + inference_component_name="my-inference-component", + region_name="us-west-2", + content_handler=DefaultHandler(), + model_kwargs={ + "parameters": { + "max_new_tokens": 50, + } + }, + client=client, + ) + + service_response = llm.invoke("What is Sagemaker endpoints?") + + assert service_response == "SageMaker Endpoint" + client.invoke_endpoint.assert_called_once_with( + EndpointName="my-endpoint", + Body=b"What is Sagemaker endpoints?", + ContentType="application/json", + Accept="application/json", + InferenceComponentName="my-inference-component", + ) + + def test_sagemaker_endpoint_stream() -> None: class ContentHandler(LLMContentHandler): accepts = "application/json" @@ -97,3 +130,55 @@ def transform_output(self, output: bytes) -> str: Body=expected_body, ContentType="application/json", ) + + +def test_sagemaker_endpoint_inference_component_stream() -> None: + class ContentHandler(LLMContentHandler): + accepts = "application/json" + content_type = "application/json" + + def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: + body = json.dumps({"inputs": prompt, **model_kwargs}) + return body.encode() + + def transform_output(self, output: bytes) -> str: + body = json.loads(output) + return body.get("outputs")[0] + + body = ( + {"PayloadPart": {"Bytes": b'{"outputs": ["S"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": ["age"]}\n'}}, + {"PayloadPart": {"Bytes": b'{"outputs": ["Maker"]}\n'}}, + ) + + response = {"ContentType": "application/json", "Body": body} + + client = Mock() + client.invoke_endpoint_with_response_stream.return_value = response + + llm = SagemakerEndpoint( + endpoint_name="my-endpoint", + inference_component_name="my_inference_component", + region_name="us-west-2", + content_handler=ContentHandler(), + client=client, + model_kwargs={"parameters": {"max_new_tokens": 50}}, + ) + + expected_body = json.dumps( + {"inputs": "What is Sagemaker endpoints?", "parameters": {"max_new_tokens": 50}} + ).encode() + + chunks = ["S", "age", "Maker"] + service_chunks = [] + + for chunk in llm.stream("What is Sagemaker endpoints?"): + service_chunks.append(chunk) + + assert service_chunks == chunks + client.invoke_endpoint_with_response_stream.assert_called_once_with( + EndpointName="my-endpoint", + Body=expected_body, + ContentType="application/json", + InferenceComponentName="my_inference_component", + ) From f8ea01f3f7ad0cb0a2d2f857fa35fe6618409e37 Mon Sep 17 00:00:00 2001 From: Pravali Uppugunduri Date: Fri, 4 Oct 2024 19:30:34 +0000 Subject: [PATCH 2/4] correct spellings --- libs/aws/langchain_aws/llms/sagemaker_endpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py index 4ad03e8b..f04ab1f4 100644 --- a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py +++ b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py @@ -335,7 +335,7 @@ def _stream( **_endpoint_kwargs, } - # If inference_compoent_name is specified, append it to invocation_params + # If inference_component_name is specified, append it to invocation_params if self.inference_component_name: invocation_params["InferenceComponentName"] = self.inference_component_name @@ -369,7 +369,7 @@ def _call( **kwargs: Any, ) -> str: """Call out to SageMaker inference endpoint or inference component - of SageMaker endpoint + of SageMaker inference endpoint. Args: prompt: The prompt to pass into the model. From 4c15b2fe271d26d4aff25829517a1cf1d57cafc1 Mon Sep 17 00:00:00 2001 From: Pravali Uppugunduri Date: Fri, 4 Oct 2024 22:54:36 +0000 Subject: [PATCH 3/4] edit documentation based on feedback --- libs/aws/langchain_aws/llms/sagemaker_endpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py index f04ab1f4..bea0d764 100644 --- a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py +++ b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py @@ -184,7 +184,7 @@ class SagemakerEndpoint(LLM): credentials_profile_name=credentials_profile_name ) - #Use with SageMaker Endpoint and Inference Component + # Usage with Inference Component se = SagemakerEndpoint( endpoint_name=endpoint_name, inference_component_name=inference_component_name, From 91b41aba687bbf52365c6ca409b617e5a3e2f6fc Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Fri, 4 Oct 2024 15:57:21 -0700 Subject: [PATCH 4/4] Fixed indent on docs. --- libs/aws/langchain_aws/llms/sagemaker_endpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py index bea0d764..de1a81ed 100644 --- a/libs/aws/langchain_aws/llms/sagemaker_endpoint.py +++ b/libs/aws/langchain_aws/llms/sagemaker_endpoint.py @@ -184,7 +184,7 @@ class SagemakerEndpoint(LLM): credentials_profile_name=credentials_profile_name ) - # Usage with Inference Component + # Usage with Inference Component se = SagemakerEndpoint( endpoint_name=endpoint_name, inference_component_name=inference_component_name,