From d737afe87343b55f1822e91f14b04bfb6c70e6bf Mon Sep 17 00:00:00 2001 From: arvinxx Date: Tue, 9 Apr 2024 20:16:57 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20refactor=20to=20support=20a?= =?UTF-8?q?zure=20openai=20provider?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/app/api/config/route.test.ts | 2 - src/app/api/config/route.ts | 15 +++-- src/app/settings/llm/Azure/index.tsx | 16 ++++- .../ProviderModelList/CustomModelOption.tsx | 14 ++++- .../ProviderModelList/ModelConfigModal.tsx | 9 ++- src/config/modelProviders/azure.ts | 43 +------------ src/config/modelProviders/index.ts | 27 ++++---- src/const/settings/index.ts | 53 ++++++++-------- src/locales/default/setting.ts | 5 +- src/migrations/FromV3ToV4/index.ts | 6 +- src/migrations/FromV3ToV4/types/v4.ts | 5 +- src/services/chat.ts | 22 +++++-- .../slices/settings/selectors/modelConfig.ts | 2 +- .../settings/selectors/modelProvider.ts | 63 ++++++++++--------- src/types/settings/modelProvider.ts | 8 +-- src/utils/parseModels.ts | 4 +- 16 files changed, 153 insertions(+), 141 deletions(-) diff --git a/src/app/api/config/route.test.ts b/src/app/api/config/route.test.ts index 12ae8ca28a31..6835ff2920b2 100644 --- a/src/app/api/config/route.test.ts +++ b/src/app/api/config/route.test.ts @@ -1,7 +1,5 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'; -import { OllamaProvider, OpenRouterProvider, TogetherAIProvider } from '@/config/modelProviders'; -import { getServerConfig } from '@/config/server'; import { GlobalServerConfig } from '@/types/settings'; import { GET } from './route'; diff --git a/src/app/api/config/route.ts b/src/app/api/config/route.ts index bea31e50ab04..96ee7ed0506a 100644 --- a/src/app/api/config/route.ts +++ b/src/app/api/config/route.ts @@ -1,4 +1,8 @@ -import { OllamaProvider, OpenRouterProvider, TogetherAIProvider } from '@/config/modelProviders'; +import { + OllamaProviderCard, + OpenRouterProviderCard, + TogetherAIProviderCard, +} from '@/config/modelProviders'; import { getServerConfig } from '@/config/server'; import { GlobalServerConfig } from '@/types/settings'; import { transformToChatModelCards } from '@/utils/parseModels'; @@ -54,7 +58,10 @@ export const GET = async () => { ollama: { enabled: ENABLE_OLLAMA, - serverModelCards: transformToChatModelCards(OLLAMA_MODEL_LIST, OllamaProvider.chatModels), + serverModelCards: transformToChatModelCards( + OLLAMA_MODEL_LIST, + OllamaProviderCard.chatModels, + ), }, openai: { serverModelCards: transformToChatModelCards(OPENAI_MODEL_LIST), @@ -63,7 +70,7 @@ export const GET = async () => { enabled: ENABLED_OPENROUTER, serverModelCards: transformToChatModelCards( OPENROUTER_MODEL_LIST, - OpenRouterProvider.chatModels, + OpenRouterProviderCard.chatModels, ), }, perplexity: { enabled: ENABLED_PERPLEXITY }, @@ -72,7 +79,7 @@ export const GET = async () => { enabled: ENABLED_TOGETHERAI, serverModelCards: transformToChatModelCards( TOGETHERAI_MODEL_LIST, - TogetherAIProvider.chatModels, + TogetherAIProviderCard.chatModels, ), }, diff --git a/src/app/settings/llm/Azure/index.tsx b/src/app/settings/llm/Azure/index.tsx index 8943b6c56b66..f54d04771cfe 100644 --- a/src/app/settings/llm/Azure/index.tsx +++ b/src/app/settings/llm/Azure/index.tsx @@ -7,6 +7,8 @@ import { useTranslation } from 'react-i18next'; import { Flexbox } from 'react-layout-kit'; import { ModelProvider } from '@/libs/agent-runtime'; +import { useGlobalStore } from '@/store/global'; +import { modelConfigSelectors } from '@/store/global/selectors'; import ProviderConfig from '../components/ProviderConfig'; import { LLMProviderApiTokenKey, LLMProviderBaseUrlKey, LLMProviderConfigKey } from '../const'; @@ -30,6 +32,17 @@ const AzureOpenAIProvider = memo(() => { const { styles } = useStyles(); + // Get the first model card's deployment name as the check model + const checkModel = useGlobalStore((s) => { + const chatModelCards = modelConfigSelectors.providerModelCards(providerKey)(s); + + if (chatModelCards.length > 0) { + return chatModelCards[0].deploymentName; + } + + return 'gpt-35-turbo'; + }); + return ( { name: [LLMProviderConfigKey, providerKey, 'apiVersion'], }, ]} - checkModel={'gpt-3.5-turbo'} + checkModel={checkModel} modelList={{ azureDeployName: true, + notFoundContent: t('llm.azure.empty'), placeholder: t('llm.azure.modelListPlaceholder'), }} provider={providerKey} diff --git a/src/app/settings/llm/components/ProviderModelList/CustomModelOption.tsx b/src/app/settings/llm/components/ProviderModelList/CustomModelOption.tsx index 665cd548db78..dac605d0315a 100644 --- a/src/app/settings/llm/components/ProviderModelList/CustomModelOption.tsx +++ b/src/app/settings/llm/components/ProviderModelList/CustomModelOption.tsx @@ -1,7 +1,7 @@ -import { ActionIcon } from '@lobehub/ui'; +import { ActionIcon, Icon } from '@lobehub/ui'; import { App, Typography } from 'antd'; import isEqual from 'fast-deep-equal'; -import { LucideSettings, LucideTrash2 } from 'lucide-react'; +import { LucideArrowRight, LucideSettings, LucideTrash2 } from 'lucide-react'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import { Flexbox } from 'react-layout-kit'; @@ -43,7 +43,15 @@ const CustomModelOption = memo(({ id, provider }) => { - {id} + + {id} + {!!modelCard?.deploymentName && ( + <> + + {modelCard?.deploymentName} + + )} + diff --git a/src/app/settings/llm/components/ProviderModelList/ModelConfigModal.tsx b/src/app/settings/llm/components/ProviderModelList/ModelConfigModal.tsx index cd63f0ae2577..8fd9bb235bd1 100644 --- a/src/app/settings/llm/components/ProviderModelList/ModelConfigModal.tsx +++ b/src/app/settings/llm/components/ProviderModelList/ModelConfigModal.tsx @@ -69,13 +69,18 @@ const ModelConfigModal = memo(({ showAzureDeployName, pro style={{ marginTop: 16 }} wrapperCol={{ offset: 1, span: 18 }} > - + {showAzureDeployName && ( { return provider.chatModels.filter((v) => v.enabled).map((m) => m.id); }; -export { default as AnthropicProvider } from './anthropic'; -export { default as BedrockProvider } from './bedrock'; -export { default as GoogleProvider } from './google'; -export { default as GroqProvider } from './groq'; -export { default as MistralProvider } from './mistral'; -export { default as MoonshotProvider } from './moonshot'; -export { default as OllamaProvider } from './ollama'; -export { default as OpenAIProvider } from './openai'; -export { default as OpenRouterProvider } from './openrouter'; -export { default as PerplexityProvider } from './perplexity'; -export { default as TogetherAIProvider } from './togetherai'; -export { default as ZeroOneProvider } from './zeroone'; -export { default as ZhiPuProvider } from './zhipu'; +export { default as AnthropicProviderCard } from './anthropic'; +export { default as AzureProviderCard } from './azure'; +export { default as BedrockProviderCard } from './bedrock'; +export { default as GoogleProviderCard } from './google'; +export { default as GroqProviderCard } from './groq'; +export { default as MistralProviderCard } from './mistral'; +export { default as MoonshotProviderCard } from './moonshot'; +export { default as OllamaProviderCard } from './ollama'; +export { default as OpenAIProviderCard } from './openai'; +export { default as OpenRouterProviderCard } from './openrouter'; +export { default as PerplexityProviderCard } from './perplexity'; +export { default as TogetherAIProviderCard } from './togetherai'; +export { default as ZeroOneProviderCard } from './zeroone'; +export { default as ZhiPuProviderCard } from './zhipu'; diff --git a/src/const/settings/index.ts b/src/const/settings/index.ts index 3a3791e99a87..2fb27b0512e0 100644 --- a/src/const/settings/index.ts +++ b/src/const/settings/index.ts @@ -1,17 +1,17 @@ import { - AnthropicProvider, - BedrockProvider, - GoogleProvider, - GroqProvider, - MistralProvider, - MoonshotProvider, - OllamaProvider, - OpenAIProvider, - OpenRouterProvider, - PerplexityProvider, - TogetherAIProvider, - ZeroOneProvider, - ZhiPuProvider, + AnthropicProviderCard, + BedrockProviderCard, + GoogleProviderCard, + GroqProviderCard, + MistralProviderCard, + MoonshotProviderCard, + OllamaProviderCard, + OpenAIProviderCard, + OpenRouterProviderCard, + PerplexityProviderCard, + TogetherAIProviderCard, + ZeroOneProviderCard, + ZhiPuProviderCard, filterEnabledModels, } from '@/config/modelProviders'; import { DEFAULT_AGENT_META } from '@/const/meta'; @@ -66,75 +66,74 @@ export const DEFAULT_LLM_CONFIG: GlobalLLMConfig = { anthropic: { apiKey: '', enabled: false, - enabledModels: filterEnabledModels(AnthropicProvider), + enabledModels: filterEnabledModels(AnthropicProviderCard), }, azure: { apiKey: '', - deployments: '', enabled: false, endpoint: '', }, bedrock: { accessKeyId: '', enabled: false, - enabledModels: filterEnabledModels(BedrockProvider), + enabledModels: filterEnabledModels(BedrockProviderCard), region: 'us-east-1', secretAccessKey: '', }, google: { apiKey: '', enabled: false, - enabledModels: filterEnabledModels(GoogleProvider), + enabledModels: filterEnabledModels(GoogleProviderCard), }, groq: { apiKey: '', enabled: false, - enabledModels: filterEnabledModels(GroqProvider), + enabledModels: filterEnabledModels(GroqProviderCard), }, mistral: { apiKey: '', enabled: false, - enabledModels: filterEnabledModels(MistralProvider), + enabledModels: filterEnabledModels(MistralProviderCard), }, moonshot: { apiKey: '', enabled: false, - enabledModels: filterEnabledModels(MoonshotProvider), + enabledModels: filterEnabledModels(MoonshotProviderCard), }, ollama: { enabled: false, - enabledModels: filterEnabledModels(OllamaProvider), + enabledModels: filterEnabledModels(OllamaProviderCard), endpoint: '', }, openai: { apiKey: '', enabled: true, - enabledModels: filterEnabledModels(OpenAIProvider), + enabledModels: filterEnabledModels(OpenAIProviderCard), }, openrouter: { apiKey: '', enabled: false, - enabledModels: filterEnabledModels(OpenRouterProvider), + enabledModels: filterEnabledModels(OpenRouterProviderCard), }, perplexity: { apiKey: '', enabled: false, - enabledModels: filterEnabledModels(PerplexityProvider), + enabledModels: filterEnabledModels(PerplexityProviderCard), }, togetherai: { apiKey: '', enabled: false, - enabledModels: filterEnabledModels(TogetherAIProvider), + enabledModels: filterEnabledModels(TogetherAIProviderCard), }, zeroone: { apiKey: '', enabled: false, - enabledModels: filterEnabledModels(ZeroOneProvider), + enabledModels: filterEnabledModels(ZeroOneProviderCard), }, zhipu: { apiKey: '', enabled: false, - enabledModels: filterEnabledModels(ZhiPuProvider), + enabledModels: filterEnabledModels(ZhiPuProviderCard), }, }; diff --git a/src/locales/default/setting.ts b/src/locales/default/setting.ts index 5e776f28ff88..b8c2942cfe58 100644 --- a/src/locales/default/setting.ts +++ b/src/locales/default/setting.ts @@ -50,12 +50,13 @@ export default { fetch: '获取列表', title: 'Azure Api Version', }, + empty: '请输入模型 ID 添加第一个模型', endpoint: { desc: '从 Azure 门户检查资源时,可在“密钥和终结点”部分中找到此值', placeholder: 'https://docs-test-001.openai.azure.com', title: 'Azure API 地址', }, - modelListPlaceholder: '请选择或添加你的部署模型', + modelListPlaceholder: '请选择或添加你部署的 OpenAI 模型', title: 'Azure OpenAI', token: { desc: '从 Azure 门户检查资源时,可在“密钥和终结点”部分中找到此值。 可以使用 KEY1 或 KEY2', @@ -96,6 +97,7 @@ export default { confirmDelete: '即将删除该自定义模型,删除后将不可恢复,请谨慎操作。', modelConfig: { azureDeployName: { + extra: '在 Azure OpenAI 中实际请求的字段', placeholder: '请输入 Azure 中的模型部署名称', title: '模型部署名称', }, @@ -114,6 +116,7 @@ export default { title: '支持函数调用', }, id: { + extra: '将作为模型标签进行展示', placeholder: '请输入模型id,例如 gpt-4-turbo-preview 或 claude-2.1', title: '模型 ID', }, diff --git a/src/migrations/FromV3ToV4/index.ts b/src/migrations/FromV3ToV4/index.ts index 6ca902501b37..c8cd3a1c2ba0 100644 --- a/src/migrations/FromV3ToV4/index.ts +++ b/src/migrations/FromV3ToV4/index.ts @@ -2,7 +2,7 @@ import type { Migration, MigrationData } from '@/migrations/VersionController'; import { transformToChatModelCards } from '@/utils/parseModels'; import { V3ConfigState, V3LegacyConfig, V3OpenAIConfig, V3Settings } from './types/v3'; -import { V4ConfigState, V4ProviderConfig, V4Settings } from './types/v4'; +import { V4AzureOpenAIConfig, V4ConfigState, V4ProviderConfig, V4Settings } from './types/v4'; export class MigrationV3ToV4 implements Migration { // from this version to start migration @@ -40,13 +40,11 @@ export class MigrationV3ToV4 implements Migration { static migrateOpenAI = ( openai: V3OpenAIConfig, - ): { azure: V4ProviderConfig; openai: V4ProviderConfig } => { + ): { azure: V4AzureOpenAIConfig; openai: V4ProviderConfig } => { if (openai.useAzure) { return { azure: { apiKey: openai.OPENAI_API_KEY, - // TODO: 要确认下 azure 的 api version 是放到 customModelCard 里还是怎么样 - // @ts-ignore apiVersion: openai.azureApiVersion, enabled: true, endpoint: openai.endpoint, diff --git a/src/migrations/FromV3ToV4/types/v4.ts b/src/migrations/FromV3ToV4/types/v4.ts index 90485da102f5..87f9fbc8a59c 100644 --- a/src/migrations/FromV3ToV4/types/v4.ts +++ b/src/migrations/FromV3ToV4/types/v4.ts @@ -12,10 +12,13 @@ export interface V4ProviderConfig { enabledModels?: string[] | null; endpoint?: string; } +export interface V4AzureOpenAIConfig extends V4ProviderConfig { + apiVersion?: string; +} export interface V4lLLMConfig extends Omit { - azure: V4ProviderConfig; + azure: V4AzureOpenAIConfig; ollama: V4ProviderConfig; openai: V4ProviderConfig; openrouter: V4ProviderConfig; diff --git a/src/services/chat.ts b/src/services/chat.ts index ddce4b6a7869..f4548716be49 100644 --- a/src/services/chat.ts +++ b/src/services/chat.ts @@ -9,6 +9,7 @@ import { filesSelectors, useFileStore } from '@/store/file'; import { useGlobalStore } from '@/store/global'; import { commonSelectors, + modelConfigSelectors, modelProviderSelectors, preferenceSelectors, } from '@/store/global/selectors'; @@ -131,13 +132,22 @@ class ChatService { const { signal } = options ?? {}; const { provider = ModelProvider.OpenAI, ...res } = params; + + let model = res.model || DEFAULT_AGENT_CONFIG.model; + + // if the provider is Azure, get the deployment name as the request model + if (provider === ModelProvider.Azure) { + const chatModelCards = modelConfigSelectors.providerModelCards(provider)( + useGlobalStore.getState(), + ); + + const deploymentName = chatModelCards.find((i) => i.id === model)?.deploymentName; + if (deploymentName) model = deploymentName; + } + const payload = merge( - { - model: DEFAULT_AGENT_CONFIG.model, - stream: true, - ...DEFAULT_AGENT_CONFIG.params, - }, - res, + { stream: true, ...DEFAULT_AGENT_CONFIG.params }, + { ...res, model: res.model }, ); const traceHeader = createTraceHeader({ ...options?.trace }); diff --git a/src/store/global/slices/settings/selectors/modelConfig.ts b/src/store/global/slices/settings/selectors/modelConfig.ts index 40b5da5b4068..9a9e2f78e30f 100644 --- a/src/store/global/slices/settings/selectors/modelConfig.ts +++ b/src/store/global/slices/settings/selectors/modelConfig.ts @@ -94,7 +94,7 @@ const providerModelCards = isCustom: true, })); - return uniqBy([...builtinCards, ...userCards], 'id'); + return uniqBy([...userCards, ...builtinCards], 'id'); }; const getCustomModelCardById = diff --git a/src/store/global/slices/settings/selectors/modelProvider.ts b/src/store/global/slices/settings/selectors/modelProvider.ts index e1ed88d9e05c..60e98046962f 100644 --- a/src/store/global/slices/settings/selectors/modelProvider.ts +++ b/src/store/global/slices/settings/selectors/modelProvider.ts @@ -1,17 +1,18 @@ import { - AnthropicProvider, - BedrockProvider, - GoogleProvider, - GroqProvider, - MistralProvider, - MoonshotProvider, - OllamaProvider, - OpenAIProvider, - OpenRouterProvider, - PerplexityProvider, - TogetherAIProvider, - ZeroOneProvider, - ZhiPuProvider, + AnthropicProviderCard, + AzureProviderCard, + BedrockProviderCard, + GoogleProviderCard, + GroqProviderCard, + MistralProviderCard, + MoonshotProviderCard, + OllamaProviderCard, + OpenAIProviderCard, + OpenRouterProviderCard, + PerplexityProviderCard, + TogetherAIProviderCard, + ZeroOneProviderCard, + ZhiPuProviderCard, filterEnabledModels, } from '@/config/modelProviders'; import { ChatModelCard, ModelProviderCard } from '@/types/llm'; @@ -51,22 +52,28 @@ const providerModelList = (s: GlobalStore): ModelProviderCard[] => { return [ { - ...OpenAIProvider, - chatModels: openaiChatModels ?? OpenAIProvider.chatModels, + ...OpenAIProviderCard, + chatModels: openaiChatModels ?? OpenAIProviderCard.chatModels, }, - // { ...azureModelList(s), enabled: enableAzure(s) }, - { ...OllamaProvider, chatModels: ollamaChatModels ?? OllamaProvider.chatModels }, - AnthropicProvider, - GoogleProvider, - { ...OpenRouterProvider, chatModels: openrouterChatModels ?? OpenRouterProvider.chatModels }, - { ...TogetherAIProvider, chatModels: togetheraiChatModels ?? TogetherAIProvider.chatModels }, - BedrockProvider, - PerplexityProvider, - MistralProvider, - GroqProvider, - MoonshotProvider, - ZeroOneProvider, - ZhiPuProvider, + { ...AzureProviderCard, chatModels: [] }, + { ...OllamaProviderCard, chatModels: ollamaChatModels ?? OllamaProviderCard.chatModels }, + AnthropicProviderCard, + GoogleProviderCard, + { + ...OpenRouterProviderCard, + chatModels: openrouterChatModels ?? OpenRouterProviderCard.chatModels, + }, + { + ...TogetherAIProviderCard, + chatModels: togetheraiChatModels ?? TogetherAIProviderCard.chatModels, + }, + BedrockProviderCard, + PerplexityProviderCard, + MistralProviderCard, + GroqProviderCard, + MoonshotProviderCard, + ZeroOneProviderCard, + ZhiPuProviderCard, ]; }; diff --git a/src/types/settings/modelProvider.ts b/src/types/settings/modelProvider.ts index fc92f81e7f5e..0c73fae27033 100644 --- a/src/types/settings/modelProvider.ts +++ b/src/types/settings/modelProvider.ts @@ -13,17 +13,13 @@ export interface GeneralModelProviderConfig { endpoint?: string; /** - * the model cards defined in server config + * the model cards defined in server */ serverModelCards?: ChatModelCard[]; } -export interface AzureOpenAIConfig { - apiKey: string; +export interface AzureOpenAIConfig extends GeneralModelProviderConfig { apiVersion?: string; - deployments: string; - enabled: boolean; - endpoint?: string; } export interface AWSBedrockConfig extends Omit { diff --git a/src/utils/parseModels.ts b/src/utils/parseModels.ts index 21d0bee7f08a..d8bb65d5be38 100644 --- a/src/utils/parseModels.ts +++ b/src/utils/parseModels.ts @@ -1,6 +1,6 @@ import { produce } from 'immer'; -import { LOBE_DEFAULT_MODEL_LIST, OpenAIProvider } from '@/config/modelProviders'; +import { LOBE_DEFAULT_MODEL_LIST, OpenAIProviderCard } from '@/config/modelProviders'; import { ChatModelCard } from '@/types/llm'; import { CustomModels } from '@/types/settings'; @@ -48,7 +48,7 @@ export const parseModelString = (modelString: string = '') => { */ export const transformToChatModelCards = ( modelString: string = '', - defaultChartModels = OpenAIProvider.chatModels, + defaultChartModels = OpenAIProviderCard.chatModels, ): ChatModelCard[] => { const modelConfig = parseModelString(modelString); let chatModels = modelConfig.removeAll ? [] : defaultChartModels;