diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index 667ba4c78c8fd..50805bce85b60 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -472,7 +472,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: else: raise ValueError(f"Got unknown type {message}") - if message.name is not None: + if message.name: message_dict["user_name"] = message.name return message_dict diff --git a/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml b/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml index b729012c40788..b8ed8ba934bcc 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml @@ -24,7 +24,7 @@ parameter_rules: min: 1 max: 8000 - name: safe_prompt - defulat: false + default: false type: boolean help: en_US: Whether to inject a safety prompt before all conversations. diff --git a/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml b/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml index 6e586b48437ba..bf6f1b2d1d9a0 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml @@ -24,7 +24,7 @@ parameter_rules: min: 1 max: 8000 - name: safe_prompt - defulat: false + default: false type: boolean help: en_US: Whether to inject a safety prompt before all conversations. diff --git a/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml b/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml index 4e7e6147f5543..111cd054576d6 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml @@ -24,7 +24,7 @@ parameter_rules: min: 1 max: 8000 - name: safe_prompt - defulat: false + default: false type: boolean help: en_US: Whether to inject a safety prompt before all conversations. diff --git a/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml b/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml index 30454f7df22ea..4f7264866241d 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml @@ -24,7 +24,7 @@ parameter_rules: min: 1 max: 2048 - name: safe_prompt - defulat: false + default: false type: boolean help: en_US: Whether to inject a safety prompt before all conversations. diff --git a/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml b/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml index a35cf0a9ae4d1..719de29c3a005 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml @@ -24,7 +24,7 @@ parameter_rules: min: 1 max: 8000 - name: safe_prompt - defulat: false + default: false type: boolean help: en_US: Whether to inject a safety prompt before all conversations. diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index d294fcaa9c68e..8cfec0e34b2f3 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -25,6 +25,7 @@ AIModelEntity, DefaultParameterName, FetchFrom, + ModelFeature, ModelPropertyKey, ModelType, ParameterRule, @@ -166,11 +167,23 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode """ generate custom model entities from credentials """ + support_function_call = False + features = [] + function_calling_type = credentials.get('function_calling_type', 'no_call') + if function_calling_type == 'function_call': + features = [ModelFeature.TOOL_CALL] + support_function_call = True + endpoint_url = credentials["endpoint_url"] + # if not endpoint_url.endswith('/'): + # endpoint_url += '/' + # if 'https://api.openai.com/v1/' == endpoint_url: + # features = [ModelFeature.STREAM_TOOL_CALL] entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features if support_function_call else [], model_properties={ ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")), ModelPropertyKey.MODE: credentials.get('mode'), @@ -194,14 +207,6 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode max=1, precision=2 ), - ParameterRule( - name="top_k", - label=I18nObject(en_US="Top K"), - type=ParameterType.INT, - default=int(credentials.get('top_k', 1)), - min=1, - max=100 - ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY.value, label=I18nObject(en_US="Frequency Penalty"), @@ -232,7 +237,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode output=Decimal(credentials.get('output_price', 0)), unit=Decimal(credentials.get('unit', 0)), currency=credentials.get('currency', "USD") - ) + ), ) if credentials['mode'] == 'chat': @@ -292,14 +297,22 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM raise ValueError("Unsupported completion type for model configuration.") # annotate tools with names, descriptions, etc. + function_calling_type = credentials.get('function_calling_type', 'no_call') formatted_tools = [] if tools: - data["tool_choice"] = "auto" + if function_calling_type == 'function_call': + data['functions'] = [{ + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters + } for tool in tools] + elif function_calling_type == 'tool_call': + data["tool_choice"] = "auto" - for tool in tools: - formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool))) + for tool in tools: + formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool))) - data["tools"] = formatted_tools + data["tools"] = formatted_tools if stop: data["stop"] = stop @@ -367,9 +380,9 @@ def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, f for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): if chunk: - #ignore sse comments + # ignore sse comments if chunk.startswith(':'): - continue + continue decoded_chunk = chunk.strip().lstrip('data: ').lstrip() chunk_json = None try: @@ -452,10 +465,13 @@ def _handle_generate_response(self, model: str, credentials: dict, response: req response_content = '' tool_calls = None - + function_calling_type = credentials.get('function_calling_type', 'no_call') if completion_type is LLMMode.CHAT: response_content = output.get('message', {})['content'] - tool_calls = output.get('message', {}).get('tool_calls') + if function_calling_type == 'tool_call': + tool_calls = output.get('message', {}).get('tool_calls') + elif function_calling_type == 'function_call': + tool_calls = output.get('message', {}).get('function_call') elif completion_type is LLMMode.COMPLETION: response_content = output['text'] @@ -463,7 +479,10 @@ def _handle_generate_response(self, model: str, credentials: dict, response: req assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[]) if tool_calls: - assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls) + if function_calling_type == 'tool_call': + assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls) + elif function_calling_type == 'function_call': + assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)] usage = response_json.get("usage") if usage: @@ -522,33 +541,34 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} if message.tool_calls: - message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call - in - message.tool_calls] - # function_call = message.tool_calls[0] - # message_dict["function_call"] = { - # "name": function_call.function.name, - # "arguments": function_call.function.arguments, - # } + # message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call + # in + # message.tool_calls] + + function_call = message.tool_calls[0] + message_dict["function_call"] = { + "name": function_call.function.name, + "arguments": function_call.function.arguments, + } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - message_dict = { - "role": "tool", - "content": message.content, - "tool_call_id": message.tool_call_id - } # message_dict = { - # "role": "function", + # "role": "tool", # "content": message.content, - # "name": message.tool_call_id + # "tool_call_id": message.tool_call_id # } + message_dict = { + "role": "function", + "content": message.content, + "name": message.tool_call_id + } else: raise ValueError(f"Got unknown type {message}") - if message.name is not None: + if message.name: message_dict["name"] = message.name return message_dict @@ -693,3 +713,26 @@ def _extract_response_tool_calls(self, tool_calls.append(tool_call) return tool_calls + + def _extract_response_function_call(self, response_function_call) \ + -> AssistantPromptMessage.ToolCall: + """ + Extract function call from response + + :param response_function_call: response function call + :return: tool call + """ + tool_call = None + if response_function_call: + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_function_call['name'], + arguments=response_function_call['arguments'] + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_function_call['name'], + type="function", + function=function + ) + + return tool_call diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml index 213d334fe89ca..be99f7684ce3a 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml @@ -75,6 +75,28 @@ model_credential_schema: value: llm default: '4096' type: text-input + - variable: function_calling_type + show_on: + - variable: __model_type + value: llm + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: function_call + label: + en_US: Support + zh_Hans: 支持 +# - value: tool_call +# label: +# en_US: Tool Call +# zh_Hans: Tool Call + - value: no_call + label: + en_US: Not Support + zh_Hans: 不支持 - variable: stream_mode_delimiter label: zh_Hans: 流模式返回结果的分隔符