Skip to content

Commit

Permalink
aws[patch]: support ChatBedrock(max_tokens, temperature) (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Oct 4, 2024
1 parent e2c2f7c commit ba30daa
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 29 deletions.
8 changes: 6 additions & 2 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,9 @@ def _get_ls_params(
ls_model_name=self.model_id,
ls_model_type="chat",
)
if ls_temperature := params.get("temperature"):
if ls_temperature := params.get("temperature", self.temperature):
ls_params["ls_temperature"] = ls_temperature
if ls_max_tokens := params.get("max_tokens"):
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
ls_params["ls_max_tokens"] = ls_max_tokens
if ls_stop := stop or params.get("stop", None):
ls_params["ls_stop"] = ls_stop
Expand Down Expand Up @@ -818,6 +818,10 @@ def _as_converse(self) -> ChatBedrockConverse:
for k, v in (self.model_kwargs or {}).items()
if k in ("stop", "stop_sequences", "max_tokens", "temperature", "top_p")
}
if self.max_tokens:
kwargs["max_tokens"] = self.max_tokens
if self.temperature is not None:
kwargs["temperature"] = self.temperature
return ChatBedrockConverse(
model=self.model_id,
region_name=self.region_name,
Expand Down
95 changes: 75 additions & 20 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ def prepare_input(
system: Optional[str] = None,
messages: Optional[List[Dict]] = None,
tools: Optional[List[AnthropicTool]] = None,
*,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> Dict[str, Any]:
input_body = {**model_kwargs}
if provider == "anthropic":
Expand All @@ -273,18 +276,44 @@ def prepare_input(
input_body["messages"] = messages
if system:
input_body["system"] = system
if "max_tokens" not in input_body:
if max_tokens:
input_body["max_tokens"] = max_tokens
elif "max_tokens" not in input_body:
input_body["max_tokens"] = 1024

if prompt:
input_body["prompt"] = _human_assistant_format(prompt)
if "max_tokens_to_sample" not in input_body:
if max_tokens:
input_body["max_tokens_to_sample"] = max_tokens
elif "max_tokens_to_sample" not in input_body:
input_body["max_tokens_to_sample"] = 1024

if temperature is not None:
input_body["temperature"] = temperature

elif provider in ("ai21", "cohere", "meta", "mistral"):
input_body["prompt"] = prompt
if max_tokens:
if provider == "cohere":
input_body["max_tokens"] = max_tokens
elif provider == "meta":
input_body["max_gen_len"] = max_tokens
elif provider == "mistral":
input_body["max_tokens"] = max_tokens
else:
# TODO: Add AI21 support, param depends on specific model.
pass
if temperature is not None:
input_body["temperature"] = temperature

elif provider == "amazon":
input_body = dict()
input_body["inputText"] = prompt
input_body["textGenerationConfig"] = {**model_kwargs}
if max_tokens:
input_body["textGenerationConfig"]["maxTokenCount"] = max_tokens
if temperature is not None:
input_body["textGenerationConfig"]["temperature"] = temperature
else:
input_body["inputText"] = prompt

Expand Down Expand Up @@ -451,7 +480,7 @@ class BedrockBase(BaseLanguageModel, ABC):

client: Any = Field(default=None, exclude=True) #: :meta private:

region_name: Optional[str] = None
region_name: Optional[str] = Field(default=None, alias="region")
"""The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable
or region specified in ~/.aws/config in case it is not provided here.
"""
Expand Down Expand Up @@ -589,6 +618,9 @@ async def on_llm_error(
...Logic to handle guardrail intervention...
""" # noqa: E501

temperature: Optional[float] = None
max_tokens: Optional[int] = None

@property
def lc_secrets(self) -> Dict[str, str]:
return {
Expand All @@ -601,6 +633,17 @@ def lc_secrets(self) -> Dict[str, str]:
def validate_environment(self) -> Self:
"""Validate that AWS credentials to and python package exists in environment."""

if self.model_kwargs:
if "temperature" in self.model_kwargs:
if self.temperature is None:
self.temperature = self.model_kwargs["temperature"]
self.model_kwargs.pop("temperature")

if "max_tokens" in self.model_kwargs:
if not self.max_tokens:
self.max_tokens = self.model_kwargs["max_tokens"]
self.model_kwargs.pop("max_tokens")

# Skip creating new client if passed in constructor
if self.client is not None:
return self
Expand Down Expand Up @@ -737,23 +780,27 @@ def _prepare_input_and_invoke(

provider = self._get_provider()
params = {**_model_kwargs, **kwargs}
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
prompt=prompt,
system=system,
messages=messages,
)
if "claude-3" in self._get_model():
if _tools_in_params(params):
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
prompt=prompt,
system=system,
messages=messages,
tools=params["tools"],
)
if "claude-3" in self._get_model() and _tools_in_params(params):
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
prompt=prompt,
system=system,
messages=messages,
tools=params["tools"],
max_tokens=self.max_tokens,
temperature=self.temperature,
)
else:
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
prompt=prompt,
system=system,
messages=messages,
max_tokens=self.max_tokens,
temperature=self.temperature,
)
body = json.dumps(input_body)
accept = "application/json"
contentType = "application/json"
Expand Down Expand Up @@ -874,6 +921,8 @@ def _prepare_input_and_invoke_stream(
system=system,
messages=messages,
model_kwargs=params,
max_tokens=self.max_tokens,
temperature=self.temperature,
)
coerce_content_to_string = True
if "claude-3" in self._get_model():
Expand All @@ -886,6 +935,8 @@ def _prepare_input_and_invoke_stream(
system=system,
messages=messages,
tools=params["tools"],
max_tokens=self.max_tokens,
temperature=self.temperature,
)
body = json.dumps(input_body)

Expand Down Expand Up @@ -959,6 +1010,8 @@ async def _aprepare_input_and_invoke_stream(
system=system,
messages=messages,
tools=params["tools"],
max_tokens=self.max_tokens,
temperature=self.temperature,
)
else:
input_body = LLMInputOutputAdapter.prepare_input(
Expand All @@ -967,6 +1020,8 @@ async def _aprepare_input_and_invoke_stream(
system=system,
messages=messages,
model_kwargs=params,
max_tokens=self.max_tokens,
temperature=self.temperature,
)
body = json.dumps(input_body)

Expand Down
2 changes: 1 addition & 1 deletion libs/aws/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-aws"
version = "0.2.1"
version = "0.2.2"
description = "An integration package connecting AWS and LangChain"
authors = []
readme = "README.md"
Expand Down
8 changes: 4 additions & 4 deletions libs/aws/tests/integration_tests/chat_models/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def chat_model_params(self) -> dict:

@property
def standard_chat_model_params(self) -> dict:
return {}
return {"temperature": 0, "max_tokens": 100}

@pytest.mark.xfail(reason="Not implemented.")
def test_stop_sequence(self, model: BaseChatModel) -> None:
Expand Down Expand Up @@ -57,11 +57,11 @@ def chat_model_params(self) -> dict:
@property
def standard_chat_model_params(self) -> dict:
return {
"temperature": 0,
"max_tokens": 100,
"model_kwargs": {
"temperature": 0,
"max_tokens": 100,
"stop": [],
}
},
}

@property
Expand Down
4 changes: 2 additions & 2 deletions libs/aws/tests/unit_tests/__snapshots__/test_standard.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
'guardrailVersion': None,
'trace': None,
}),
'max_tokens': 100,
'model_id': 'anthropic.claude-3-sonnet-20240229-v1:0',
'model_kwargs': dict({
'max_tokens': 100,
'stop': list([
]),
'temperature': 0,
}),
'provider_stop_reason_key_map': dict({
'ai21': 'finishReason',
Expand All @@ -36,6 +35,7 @@
'mistral': 'stop_sequences',
}),
'region_name': 'us-east-1',
'temperature': 0,
}),
'lc': 1,
'name': 'ChatBedrock',
Expand Down

0 comments on commit ba30daa

Please sign in to comment.