From ba30daa400fe1799aacb44ef66cb4f6e9e95b2e2 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Thu, 3 Oct 2024 17:17:25 -0700 Subject: [PATCH] aws[patch]: support ChatBedrock(max_tokens, temperature) (#226) --- libs/aws/langchain_aws/chat_models/bedrock.py | 8 +- libs/aws/langchain_aws/llms/bedrock.py | 95 +++++++++++++++---- libs/aws/pyproject.toml | 2 +- .../chat_models/test_standard.py | 8 +- .../__snapshots__/test_standard.ambr | 4 +- 5 files changed, 88 insertions(+), 29 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index fab5fbaf..23a842a2 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -431,9 +431,9 @@ def _get_ls_params( ls_model_name=self.model_id, ls_model_type="chat", ) - if ls_temperature := params.get("temperature"): + if ls_temperature := params.get("temperature", self.temperature): ls_params["ls_temperature"] = ls_temperature - if ls_max_tokens := params.get("max_tokens"): + if ls_max_tokens := params.get("max_tokens", self.max_tokens): ls_params["ls_max_tokens"] = ls_max_tokens if ls_stop := stop or params.get("stop", None): ls_params["ls_stop"] = ls_stop @@ -818,6 +818,10 @@ def _as_converse(self) -> ChatBedrockConverse: for k, v in (self.model_kwargs or {}).items() if k in ("stop", "stop_sequences", "max_tokens", "temperature", "top_p") } + if self.max_tokens: + kwargs["max_tokens"] = self.max_tokens + if self.temperature is not None: + kwargs["temperature"] = self.temperature return ChatBedrockConverse( model=self.model_id, region_name=self.region_name, diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 0173d66a..cced42c7 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -263,6 +263,9 @@ def prepare_input( system: Optional[str] = None, messages: Optional[List[Dict]] = None, tools: Optional[List[AnthropicTool]] = None, + *, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, ) -> Dict[str, Any]: input_body = {**model_kwargs} if provider == "anthropic": @@ -273,18 +276,44 @@ def prepare_input( input_body["messages"] = messages if system: input_body["system"] = system - if "max_tokens" not in input_body: + if max_tokens: + input_body["max_tokens"] = max_tokens + elif "max_tokens" not in input_body: input_body["max_tokens"] = 1024 + if prompt: input_body["prompt"] = _human_assistant_format(prompt) - if "max_tokens_to_sample" not in input_body: + if max_tokens: + input_body["max_tokens_to_sample"] = max_tokens + elif "max_tokens_to_sample" not in input_body: input_body["max_tokens_to_sample"] = 1024 + + if temperature is not None: + input_body["temperature"] = temperature + elif provider in ("ai21", "cohere", "meta", "mistral"): input_body["prompt"] = prompt + if max_tokens: + if provider == "cohere": + input_body["max_tokens"] = max_tokens + elif provider == "meta": + input_body["max_gen_len"] = max_tokens + elif provider == "mistral": + input_body["max_tokens"] = max_tokens + else: + # TODO: Add AI21 support, param depends on specific model. + pass + if temperature is not None: + input_body["temperature"] = temperature + elif provider == "amazon": input_body = dict() input_body["inputText"] = prompt input_body["textGenerationConfig"] = {**model_kwargs} + if max_tokens: + input_body["textGenerationConfig"]["maxTokenCount"] = max_tokens + if temperature is not None: + input_body["textGenerationConfig"]["temperature"] = temperature else: input_body["inputText"] = prompt @@ -451,7 +480,7 @@ class BedrockBase(BaseLanguageModel, ABC): client: Any = Field(default=None, exclude=True) #: :meta private: - region_name: Optional[str] = None + region_name: Optional[str] = Field(default=None, alias="region") """The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable or region specified in ~/.aws/config in case it is not provided here. """ @@ -589,6 +618,9 @@ async def on_llm_error( ...Logic to handle guardrail intervention... """ # noqa: E501 + temperature: Optional[float] = None + max_tokens: Optional[int] = None + @property def lc_secrets(self) -> Dict[str, str]: return { @@ -601,6 +633,17 @@ def lc_secrets(self) -> Dict[str, str]: def validate_environment(self) -> Self: """Validate that AWS credentials to and python package exists in environment.""" + if self.model_kwargs: + if "temperature" in self.model_kwargs: + if self.temperature is None: + self.temperature = self.model_kwargs["temperature"] + self.model_kwargs.pop("temperature") + + if "max_tokens" in self.model_kwargs: + if not self.max_tokens: + self.max_tokens = self.model_kwargs["max_tokens"] + self.model_kwargs.pop("max_tokens") + # Skip creating new client if passed in constructor if self.client is not None: return self @@ -737,23 +780,27 @@ def _prepare_input_and_invoke( provider = self._get_provider() params = {**_model_kwargs, **kwargs} - input_body = LLMInputOutputAdapter.prepare_input( - provider=provider, - model_kwargs=params, - prompt=prompt, - system=system, - messages=messages, - ) - if "claude-3" in self._get_model(): - if _tools_in_params(params): - input_body = LLMInputOutputAdapter.prepare_input( - provider=provider, - model_kwargs=params, - prompt=prompt, - system=system, - messages=messages, - tools=params["tools"], - ) + if "claude-3" in self._get_model() and _tools_in_params(params): + input_body = LLMInputOutputAdapter.prepare_input( + provider=provider, + model_kwargs=params, + prompt=prompt, + system=system, + messages=messages, + tools=params["tools"], + max_tokens=self.max_tokens, + temperature=self.temperature, + ) + else: + input_body = LLMInputOutputAdapter.prepare_input( + provider=provider, + model_kwargs=params, + prompt=prompt, + system=system, + messages=messages, + max_tokens=self.max_tokens, + temperature=self.temperature, + ) body = json.dumps(input_body) accept = "application/json" contentType = "application/json" @@ -874,6 +921,8 @@ def _prepare_input_and_invoke_stream( system=system, messages=messages, model_kwargs=params, + max_tokens=self.max_tokens, + temperature=self.temperature, ) coerce_content_to_string = True if "claude-3" in self._get_model(): @@ -886,6 +935,8 @@ def _prepare_input_and_invoke_stream( system=system, messages=messages, tools=params["tools"], + max_tokens=self.max_tokens, + temperature=self.temperature, ) body = json.dumps(input_body) @@ -959,6 +1010,8 @@ async def _aprepare_input_and_invoke_stream( system=system, messages=messages, tools=params["tools"], + max_tokens=self.max_tokens, + temperature=self.temperature, ) else: input_body = LLMInputOutputAdapter.prepare_input( @@ -967,6 +1020,8 @@ async def _aprepare_input_and_invoke_stream( system=system, messages=messages, model_kwargs=params, + max_tokens=self.max_tokens, + temperature=self.temperature, ) body = json.dumps(input_body) diff --git a/libs/aws/pyproject.toml b/libs/aws/pyproject.toml index 6ee35921..81b583c5 100644 --- a/libs/aws/pyproject.toml +++ b/libs/aws/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-aws" -version = "0.2.1" +version = "0.2.2" description = "An integration package connecting AWS and LangChain" authors = [] readme = "README.md" diff --git a/libs/aws/tests/integration_tests/chat_models/test_standard.py b/libs/aws/tests/integration_tests/chat_models/test_standard.py index 43434169..eed7ed51 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_standard.py +++ b/libs/aws/tests/integration_tests/chat_models/test_standard.py @@ -20,7 +20,7 @@ def chat_model_params(self) -> dict: @property def standard_chat_model_params(self) -> dict: - return {} + return {"temperature": 0, "max_tokens": 100} @pytest.mark.xfail(reason="Not implemented.") def test_stop_sequence(self, model: BaseChatModel) -> None: @@ -57,11 +57,11 @@ def chat_model_params(self) -> dict: @property def standard_chat_model_params(self) -> dict: return { + "temperature": 0, + "max_tokens": 100, "model_kwargs": { - "temperature": 0, - "max_tokens": 100, "stop": [], - } + }, } @property diff --git a/libs/aws/tests/unit_tests/__snapshots__/test_standard.ambr b/libs/aws/tests/unit_tests/__snapshots__/test_standard.ambr index 6f98f04b..8c7a7418 100644 --- a/libs/aws/tests/unit_tests/__snapshots__/test_standard.ambr +++ b/libs/aws/tests/unit_tests/__snapshots__/test_standard.ambr @@ -14,12 +14,11 @@ 'guardrailVersion': None, 'trace': None, }), + 'max_tokens': 100, 'model_id': 'anthropic.claude-3-sonnet-20240229-v1:0', 'model_kwargs': dict({ - 'max_tokens': 100, 'stop': list([ ]), - 'temperature': 0, }), 'provider_stop_reason_key_map': dict({ 'ai21': 'finishReason', @@ -36,6 +35,7 @@ 'mistral': 'stop_sequences', }), 'region_name': 'us-east-1', + 'temperature': 0, }), 'lc': 1, 'name': 'ChatBedrock',