Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aws[patch]: support ChatBedrock(max_tokens, temperature) #226

Merged
merged 2 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,9 @@ def _get_ls_params(
ls_model_name=self.model_id,
ls_model_type="chat",
)
if ls_temperature := params.get("temperature"):
if ls_temperature := params.get("temperature", self.temperature):
ls_params["ls_temperature"] = ls_temperature
if ls_max_tokens := params.get("max_tokens"):
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
ls_params["ls_max_tokens"] = ls_max_tokens
if ls_stop := stop or params.get("stop", None):
ls_params["ls_stop"] = ls_stop
Expand Down Expand Up @@ -818,6 +818,10 @@ def _as_converse(self) -> ChatBedrockConverse:
for k, v in (self.model_kwargs or {}).items()
if k in ("stop", "stop_sequences", "max_tokens", "temperature", "top_p")
}
if self.max_tokens:
kwargs["max_tokens"] = self.max_tokens
if self.temperature is not None:
kwargs["temperature"] = self.temperature
return ChatBedrockConverse(
model=self.model_id,
region_name=self.region_name,
Expand Down
95 changes: 75 additions & 20 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ def prepare_input(
system: Optional[str] = None,
messages: Optional[List[Dict]] = None,
tools: Optional[List[AnthropicTool]] = None,
*,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> Dict[str, Any]:
input_body = {**model_kwargs}
if provider == "anthropic":
Expand All @@ -273,18 +276,44 @@ def prepare_input(
input_body["messages"] = messages
if system:
input_body["system"] = system
if "max_tokens" not in input_body:
if max_tokens:
input_body["max_tokens"] = max_tokens
elif "max_tokens" not in input_body:
input_body["max_tokens"] = 1024

if prompt:
input_body["prompt"] = _human_assistant_format(prompt)
if "max_tokens_to_sample" not in input_body:
if max_tokens:
input_body["max_tokens_to_sample"] = max_tokens
elif "max_tokens_to_sample" not in input_body:
input_body["max_tokens_to_sample"] = 1024

if temperature is not None:
input_body["temperature"] = temperature

elif provider in ("ai21", "cohere", "meta", "mistral"):
input_body["prompt"] = prompt
if max_tokens:
if provider == "cohere":
input_body["max_tokens"] = max_tokens
elif provider == "meta":
input_body["max_gen_len"] = max_tokens
elif provider == "mistral":
input_body["max_tokens"] = max_tokens
else:
# TODO: Add AI21 support, param depends on specific model.
pass
if temperature is not None:
input_body["temperature"] = temperature

elif provider == "amazon":
input_body = dict()
input_body["inputText"] = prompt
input_body["textGenerationConfig"] = {**model_kwargs}
if max_tokens:
input_body["textGenerationConfig"]["maxTokenCount"] = max_tokens
if temperature is not None:
input_body["textGenerationConfig"]["temperature"] = temperature
else:
input_body["inputText"] = prompt

Expand Down Expand Up @@ -453,7 +482,7 @@ class BedrockBase(BaseLanguageModel, ABC):

client: Any = Field(default=None, exclude=True) #: :meta private:

region_name: Optional[str] = None
region_name: Optional[str] = Field(default=None, alias="region")
"""The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable
or region specified in ~/.aws/config in case it is not provided here.
"""
Expand Down Expand Up @@ -591,6 +620,9 @@ async def on_llm_error(
...Logic to handle guardrail intervention...
""" # noqa: E501

temperature: Optional[float] = None
max_tokens: Optional[int] = None

@property
def lc_secrets(self) -> Dict[str, str]:
return {
Expand All @@ -603,6 +635,17 @@ def lc_secrets(self) -> Dict[str, str]:
def validate_environment(self) -> Self:
"""Validate that AWS credentials to and python package exists in environment."""

if self.model_kwargs:
if "temperature" in self.model_kwargs:
if self.temperature is None:
self.temperature = self.model_kwargs["temperature"]
self.model_kwargs.pop("temperature")

if "max_tokens" in self.model_kwargs:
if not self.max_tokens:
self.max_tokens = self.model_kwargs["max_tokens"]
self.model_kwargs.pop("max_tokens")

# Skip creating new client if passed in constructor
if self.client is not None:
return self
Expand Down Expand Up @@ -739,23 +782,27 @@ def _prepare_input_and_invoke(

provider = self._get_provider()
params = {**_model_kwargs, **kwargs}
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
prompt=prompt,
system=system,
messages=messages,
)
if "claude-3" in self._get_model():
if _tools_in_params(params):
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
prompt=prompt,
system=system,
messages=messages,
tools=params["tools"],
)
if "claude-3" in self._get_model() and _tools_in_params(params):
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
prompt=prompt,
system=system,
messages=messages,
tools=params["tools"],
max_tokens=self.max_tokens,
temperature=self.temperature,
)
else:
input_body = LLMInputOutputAdapter.prepare_input(
provider=provider,
model_kwargs=params,
prompt=prompt,
system=system,
messages=messages,
max_tokens=self.max_tokens,
temperature=self.temperature,
)
body = json.dumps(input_body)
accept = "application/json"
contentType = "application/json"
Expand Down Expand Up @@ -876,6 +923,8 @@ def _prepare_input_and_invoke_stream(
system=system,
messages=messages,
model_kwargs=params,
max_tokens=self.max_tokens,
temperature=self.temperature,
)
coerce_content_to_string = True
if "claude-3" in self._get_model():
Expand All @@ -888,6 +937,8 @@ def _prepare_input_and_invoke_stream(
system=system,
messages=messages,
tools=params["tools"],
max_tokens=self.max_tokens,
temperature=self.temperature,
)
body = json.dumps(input_body)

Expand Down Expand Up @@ -961,6 +1012,8 @@ async def _aprepare_input_and_invoke_stream(
system=system,
messages=messages,
tools=params["tools"],
max_tokens=self.max_tokens,
temperature=self.temperature,
)
else:
input_body = LLMInputOutputAdapter.prepare_input(
Expand All @@ -969,6 +1022,8 @@ async def _aprepare_input_and_invoke_stream(
system=system,
messages=messages,
model_kwargs=params,
max_tokens=self.max_tokens,
temperature=self.temperature,
)
body = json.dumps(input_body)

Expand Down
2 changes: 1 addition & 1 deletion libs/aws/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-aws"
version = "0.2.1"
version = "0.2.2"
description = "An integration package connecting AWS and LangChain"
authors = []
readme = "README.md"
Expand Down
8 changes: 4 additions & 4 deletions libs/aws/tests/integration_tests/chat_models/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def chat_model_params(self) -> dict:

@property
def standard_chat_model_params(self) -> dict:
return {}
return {"temperature": 0, "max_tokens": 100}

@pytest.mark.xfail(reason="Not implemented.")
def test_stop_sequence(self, model: BaseChatModel) -> None:
Expand Down Expand Up @@ -57,11 +57,11 @@ def chat_model_params(self) -> dict:
@property
def standard_chat_model_params(self) -> dict:
return {
"temperature": 0,
"max_tokens": 100,
"model_kwargs": {
"temperature": 0,
"max_tokens": 100,
"stop": [],
}
},
}

@property
Expand Down
4 changes: 2 additions & 2 deletions libs/aws/tests/unit_tests/__snapshots__/test_standard.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
'guardrailVersion': None,
'trace': None,
}),
'max_tokens': 100,
'model_id': 'anthropic.claude-3-sonnet-20240229-v1:0',
'model_kwargs': dict({
'max_tokens': 100,
'stop': list([
]),
'temperature': 0,
}),
'provider_stop_reason_key_map': dict({
'ai21': 'finishReason',
Expand All @@ -36,6 +35,7 @@
'mistral': 'stop_sequences',
}),
'region_name': 'us-east-1',
'temperature': 0,
}),
'lc': 1,
'name': 'ChatBedrock',
Expand Down
Loading