Skip to content

Commit

Permalink
chore: Extract common functions of the base model in Azure OpenAI Pro…
Browse files Browse the repository at this point in the history
…vider (#9907)
  • Loading branch information
yaoice authored Oct 27, 2024
1 parent 216442d commit 22776f2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ model_credential_schema:
type: select
required: true
options:
- label:
en_US: 2024-10-01-preview
value: 2024-10-01-preview
- label:
en_US: 2024-09-01-preview
value: 2024-09-01-preview
Expand Down
27 changes: 10 additions & 17 deletions api/core/model_runtime/model_providers/azure_openai/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def _invoke(
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError("Base Model Name is required")
base_model_name = self._get_base_model_name(credentials)
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)

if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
Expand Down Expand Up @@ -81,9 +79,7 @@ def get_num_tokens(
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError("Base Model Name is required")
base_model_name = self._get_base_model_name(credentials)
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not model_entity:
raise ValueError(f"Base Model Name {base_model_name} is invalid")
Expand All @@ -108,9 +104,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
if "base_model_name" not in credentials:
raise CredentialsValidateFailedError("Base Model Name is required")

base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise CredentialsValidateFailedError("Base Model Name is required")
base_model_name = self._get_base_model_name(credentials)
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)

if not ai_model_entity:
Expand Down Expand Up @@ -149,9 +143,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
raise CredentialsValidateFailedError(str(ex))

def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError("Base Model Name is required")
base_model_name = self._get_base_model_name(credentials)
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
return ai_model_entity.entity if ai_model_entity else None

Expand Down Expand Up @@ -308,11 +300,6 @@ def _chat_generate(

if tools:
extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
# extra_model_kwargs['functions'] = [{
# "name": tool.name,
# "description": tool.description,
# "parameters": tool.parameters
# } for tool in tools]

if stop:
extra_model_kwargs["stop"] = stop
Expand Down Expand Up @@ -769,3 +756,9 @@ def _get_ai_model_entity(base_model_name: str, model: str):
ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy

def _get_base_model_name(self, credentials: dict) -> str:
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError("Base Model Name is required")
return base_model_name

0 comments on commit 22776f2

Please sign in to comment.