Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/provider mistralai #2598

Merged
merged 5 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/core/model_runtime/model_providers/_position.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- bedrock
- togetherai
- ollama
- mistralai
- replicate
- huggingface_hub
- zhipuai
Expand Down
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- open-mistral-7b
- open-mixtral-8x7b
- mistral-small-latest
- mistral-medium-latest
- mistral-large-latest
31 changes: 31 additions & 0 deletions api/core/model_runtime/model_providers/mistralai/llm/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from collections.abc import Generator
from typing import Optional, Union

from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel


class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:

self._add_custom_parameters(credentials)

# mistral dose not support user/stop arguments
stop = []
user = None

return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)

def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)

@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials['mode'] = 'chat'
credentials['endpoint_url'] = 'https://api.mistral.ai/v1'
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
model: mistral-large-latest
label:
zh_Hans: mistral-large-latest
en_US: mistral-large-latest
model_type: llm
features:
- agent-thought
model_properties:
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 8000
- name: safe_prompt
defulat: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.008'
output: '0.024'
unit: '0.001'
currency: USD
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
model: mistral-medium-latest
label:
zh_Hans: mistral-medium-latest
en_US: mistral-medium-latest
model_type: llm
features:
- agent-thought
model_properties:
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 8000
- name: safe_prompt
defulat: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.0027'
output: '0.0081'
unit: '0.001'
currency: USD
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
model: mistral-small-latest
label:
zh_Hans: mistral-small-latest
en_US: mistral-small-latest
model_type: llm
features:
- agent-thought
model_properties:
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 8000
- name: safe_prompt
defulat: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.002'
output: '0.006'
unit: '0.001'
currency: USD
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
model: open-mistral-7b
label:
zh_Hans: open-mistral-7b
en_US: open-mistral-7b
model_type: llm
features:
- agent-thought
model_properties:
context_size: 8000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 2048
- name: safe_prompt
defulat: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.00025'
output: '0.00025'
unit: '0.001'
currency: USD
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
model: open-mixtral-8x7b
label:
zh_Hans: open-mixtral-8x7b
en_US: open-mixtral-8x7b
model_type: llm
features:
- agent-thought
model_properties:
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 8000
- name: safe_prompt
defulat: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.0007'
output: '0.0007'
unit: '0.001'
currency: USD
30 changes: 30 additions & 0 deletions api/core/model_runtime/model_providers/mistralai/mistralai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import logging

from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider

logger = logging.getLogger(__name__)


class MistralAIProvider(ModelProvider):

def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception

:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)

model_instance.validate_credentials(
model='open-mistral-7b',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex
31 changes: 31 additions & 0 deletions api/core/model_runtime/model_providers/mistralai/mistralai.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
provider: mistralai
label:
en_US: MistralAI
description:
en_US: Models provided by MistralAI, such as open-mistral-7b and mistral-large-latest.
zh_Hans: MistralAI 提供的模型,例如 open-mistral-7b 和 mistral-large-latest。
icon_small:
en_US: icon_s_en.png
icon_large:
en_US: icon_l_en.png
background: "#FFFFFF"
help:
title:
en_US: Get your API Key from MistralAI
zh_Hans: 从 MistralAI 获取 API Key
url:
en_US: https://console.mistral.ai/api-keys/
supported_model_types:
- llm
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
Loading