Skip to content

Commit

Permalink
community[patch]: Update root_validators ChatModels: ChatBaichuan, Qi…
Browse files Browse the repository at this point in the history
…anfanChatEndpoint, MiniMaxChat, ChatSparkLLM, ChatZhipuAI (#22853)

This PR updates root validators for:

- ChatModels: ChatBaichuan, QianfanChatEndpoint, MiniMaxChat,
ChatSparkLLM, ChatZhipuAI

Issues #22819

---------

Co-authored-by: Eugene Yurtsev <[email protected]>
  • Loading branch information
maang-h and eyurtsev authored Jun 20, 2024
1 parent cb6cf4b commit bc4cd9c
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 32 deletions.
7 changes: 3 additions & 4 deletions libs/community/langchain_community/chat_models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def lc_serializable(self) -> bool:

baichuan_api_base: str = Field(default=DEFAULT_API_BASE)
"""Baichuan custom endpoints"""
baichuan_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
baichuan_api_key: SecretStr = Field(alias="api_key")
"""Baichuan API Key"""
baichuan_secret_key: Optional[SecretStr] = None
"""[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key"""
Expand Down Expand Up @@ -142,7 +142,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["model_kwargs"] = extra
return values

@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
values["baichuan_api_base"] = get_from_dict_or_env(
values,
Expand All @@ -153,11 +153,10 @@ def validate_environment(cls, values: Dict) -> Dict:
values["baichuan_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"baichuan_api_key",
["baichuan_api_key", "api_key"],
"BAICHUAN_API_KEY",
)
)

return values

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class QianfanChatEndpoint(BaseChatModel):

client: Any #: :meta private:

qianfan_ak: Optional[SecretStr] = Field(default=None, alias="api_key")
qianfan_ak: SecretStr = Field(alias="api_key")
"""Qianfan API KEY"""
qianfan_sk: Optional[SecretStr] = Field(default=None, alias="secret_key")
"""Qianfan SECRET KEY"""
Expand Down Expand Up @@ -171,35 +171,43 @@ class Config:

allow_population_by_field_name = True

@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
values["qianfan_ak"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"qianfan_ak",
["qianfan_ak", "api_key"],
"QIANFAN_AK",
default="",
)
)
values["qianfan_sk"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"qianfan_sk",
["qianfan_sk", "secret_key"],
"QIANFAN_SK",
default="",
)
)

default_values = {
name: field.default
for name, field in cls.__fields__.items()
if field.default is not None
}
default_values.update(values)
params = {
**values.get("init_kwargs", {}),
"model": values["model"],
"stream": values["streaming"],
"model": default_values.get("model"),
"stream": default_values.get("streaming"),
}
if values["qianfan_ak"].get_secret_value() != "":
params["ak"] = values["qianfan_ak"].get_secret_value()
if values["qianfan_sk"].get_secret_value() != "":
params["sk"] = values["qianfan_sk"].get_secret_value()
if values["endpoint"] is not None and values["endpoint"] != "":
params["endpoint"] = values["endpoint"]
if (
default_values.get("endpoint") is not None
and default_values["endpoint"] != ""
):
params["endpoint"] = default_values["endpoint"]
try:
import qianfan

Expand Down
12 changes: 8 additions & 4 deletions libs/community/langchain_community/chat_models/minimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _default_params(self) -> Dict[str, Any]:
)
minimax_group_id: Optional[str] = Field(default=None, alias="group_id")
"""[DEPRECATED, keeping it for for backward compatibility] Group Id"""
minimax_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
minimax_api_key: SecretStr = Field(alias="api_key")
"""Minimax API Key"""
streaming: bool = False
"""Whether to stream the results or not."""
Expand All @@ -176,14 +176,18 @@ class Config:

allow_population_by_field_name = True

@root_validator(allow_reuse=True)
@root_validator(pre=True, allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["minimax_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "minimax_api_key", "MINIMAX_API_KEY")
get_from_dict_or_env(
values,
["minimax_api_key", "api_key"],
"MINIMAX_API_KEY",
)
)
values["minimax_group_id"] = get_from_dict_or_env(
values, "minimax_group_id", "MINIMAX_GROUP_ID"
values, ["minimax_group_id", "group_id"], "MINIMAX_GROUP_ID"
)
# Get custom api url from environment.
values["minimax_api_host"] = get_from_dict_or_env(
Expand Down
18 changes: 12 additions & 6 deletions libs/community/langchain_community/chat_models/sparkllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,21 +195,21 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:

return values

@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
values["spark_app_id"] = get_from_dict_or_env(
values,
"spark_app_id",
["spark_app_id", "app_id"],
"IFLYTEK_SPARK_APP_ID",
)
values["spark_api_key"] = get_from_dict_or_env(
values,
"spark_api_key",
["spark_api_key", "api_key"],
"IFLYTEK_SPARK_API_KEY",
)
values["spark_api_secret"] = get_from_dict_or_env(
values,
"spark_api_secret",
["spark_api_secret", "api_secret"],
"IFLYTEK_SPARK_API_SECRET",
)
values["spark_api_url"] = get_from_dict_or_env(
Expand All @@ -224,9 +224,15 @@ def validate_environment(cls, values: Dict) -> Dict:
"IFLYTEK_SPARK_LLM_DOMAIN",
SPARK_LLM_DOMAIN,
)

# put extra params into model_kwargs
values["model_kwargs"]["temperature"] = values["temperature"] or cls.temperature
values["model_kwargs"]["top_k"] = values["top_k"] or cls.top_k
default_values = {
name: field.default
for name, field in cls.__fields__.items()
if field.default is not None
}
values["model_kwargs"]["temperature"] = default_values.get("temperature")
values["model_kwargs"]["top_k"] = default_values.get("top_k")

values["client"] = _SparkLLMClient(
app_id=values["spark_app_id"],
Expand Down
4 changes: 2 additions & 2 deletions libs/community/langchain_community/chat_models/zhipuai.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,10 @@ class Config:

allow_population_by_field_name = True

@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["zhipuai_api_key"] = get_from_dict_or_env(
values, "zhipuai_api_key", "ZHIPUAI_API_KEY"
values, ["zhipuai_api_key", "api_key"], "ZHIPUAI_API_KEY"
)
values["zhipuai_api_base"] = get_from_dict_or_env(
values, "zhipuai_api_base", "ZHIPUAI_API_BASE", default=ZHIPUAI_API_BASE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@


def test_chat_baichuan_default() -> None:
chat = ChatBaichuan(streaming=True)
chat = ChatBaichuan(streaming=True) # type: ignore[call-arg]
message = HumanMessage(content="请完整背诵将进酒,背诵5遍")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)


def test_chat_baichuan_default_non_streaming() -> None:
chat = ChatBaichuan()
chat = ChatBaichuan() # type: ignore[call-arg]
message = HumanMessage(content="请完整背诵将进酒,背诵5遍")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
Expand All @@ -39,15 +39,15 @@ def test_chat_baichuan_turbo_non_streaming() -> None:


def test_chat_baichuan_with_temperature() -> None:
chat = ChatBaichuan(temperature=1.0)
chat = ChatBaichuan(temperature=1.0) # type: ignore[call-arg]
message = HumanMessage(content="Hello")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)


def test_chat_baichuan_with_kwargs() -> None:
chat = ChatBaichuan()
chat = ChatBaichuan() # type: ignore[call-arg]
message = HumanMessage(content="百川192K API是什么时候上线的?")
response = chat.invoke(
[message], temperature=0.88, top_p=0.7, with_search_enhance=True
Expand All @@ -58,7 +58,7 @@ def test_chat_baichuan_with_kwargs() -> None:


def test_extra_kwargs() -> None:
chat = ChatBaichuan(temperature=0.88, top_p=0.7, with_search_enhance=True)
chat = ChatBaichuan(temperature=0.88, top_p=0.7, with_search_enhance=True) # type: ignore[call-arg]
assert chat.temperature == 0.88
assert chat.top_p == 0.7
assert chat.with_search_enhance is True
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_baichuan_key_masked_when_passed_from_env(
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("BAICHUAN_API_KEY", "test-api-key")

chat = ChatBaichuan()
chat = ChatBaichuan() # type: ignore[call-arg]
print(chat.baichuan_api_key, end="") # noqa: T201
captured = capsys.readouterr()
assert captured.out == "**********"
Expand Down

0 comments on commit bc4cd9c

Please sign in to comment.