From 56032e625b5d57199a231fa3e274cd086ac6415a Mon Sep 17 00:00:00 2001 From: arvinxx Date: Wed, 10 Apr 2024 00:02:22 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20support=20openai=20model=20?= =?UTF-8?q?fetcher?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/app/api/chat/agentRuntime.ts | 4 + src/app/api/chat/models/[provider]/route.ts | 37 ++++ src/app/api/config/route.test.ts | 2 +- src/app/api/config/route.ts | 2 +- src/app/settings/llm/Ollama/index.tsx | 1 + src/app/settings/llm/OpenAI/index.tsx | 7 +- src/app/settings/llm/OpenRouter/index.tsx | 1 + src/app/settings/llm/TogetherAI/index.tsx | 1 + .../llm/components/ProviderConfig/index.tsx | 2 + .../ProviderModelList/ModelFetcher.tsx | 71 ++++++++ .../components/ProviderModelList/index.tsx | 137 +++++++------- src/app/settings/llm/index.tsx | 10 +- src/config/modelProviders/openai.ts | 6 + src/libs/agent-runtime/BaseAI.ts | 11 +- .../openai/__snapshots__/index.test.ts.snap | 93 ++++++++++ .../openai/fixtures/openai-models.json | 170 ++++++++++++++++++ src/libs/agent-runtime/openai/index.test.ts | 13 ++ .../utils/openaiCompatibleFactory/index.ts | 19 ++ src/migrations/FromV3ToV4/index.ts | 7 +- src/services/_url.ts | 1 + src/services/global.ts | 19 +- src/store/global/slices/common/action.test.ts | 2 +- src/store/global/slices/common/action.ts | 3 +- .../global/slices/settings/actions/llm.ts | 28 +++ .../global/slices/settings/initialState.ts | 4 +- .../slices/settings/selectors/modelConfig.ts | 7 + .../settings/selectors/modelProvider.ts | 56 ++++-- src/types/serverConfig.ts | 22 +++ src/types/settings/index.ts | 11 -- src/types/settings/modelProvider.ts | 16 +- src/utils/fetch.ts | 4 +- src/utils/parseModels.ts | 4 +- 32 files changed, 653 insertions(+), 118 deletions(-) create mode 100644 src/app/api/chat/models/[provider]/route.ts create mode 100644 src/app/settings/llm/components/ProviderModelList/ModelFetcher.tsx create mode 100644 src/libs/agent-runtime/openai/__snapshots__/index.test.ts.snap create mode 100644 src/libs/agent-runtime/openai/fixtures/openai-models.json create mode 100644 src/types/serverConfig.ts diff --git a/src/app/api/chat/agentRuntime.ts b/src/app/api/chat/agentRuntime.ts index 0bce3098cb34..11d32e3324a9 100644 --- a/src/app/api/chat/agentRuntime.ts +++ b/src/app/api/chat/agentRuntime.ts @@ -106,6 +106,10 @@ class AgentRuntime { }); } + async models() { + return this._runtime.models?.(); + } + static async initializeWithUserPayload(provider: string, payload: JWTPayload) { let runtimeModel: LobeRuntimeAI; diff --git a/src/app/api/chat/models/[provider]/route.ts b/src/app/api/chat/models/[provider]/route.ts new file mode 100644 index 000000000000..ba553b139f87 --- /dev/null +++ b/src/app/api/chat/models/[provider]/route.ts @@ -0,0 +1,37 @@ +import { NextResponse } from 'next/server'; + +import { getPreferredRegion } from '@/app/api/config'; +import { createErrorResponse } from '@/app/api/errorResponse'; +import { ChatCompletionErrorPayload } from '@/libs/agent-runtime'; +import { ChatErrorType } from '@/types/fetch'; + +import AgentRuntime from '../../agentRuntime'; +import { checkAuth } from '../../auth'; + +export const runtime = 'edge'; + +export const preferredRegion = getPreferredRegion(); + +export const GET = checkAuth(async (req, { params, jwtPayload }) => { + const { provider } = params; + + try { + const agentRuntime = await AgentRuntime.initializeWithUserPayload(provider, jwtPayload); + + const list = await agentRuntime.models(); + + return NextResponse.json(list); + } catch (e) { + const { + errorType = ChatErrorType.InternalServerError, + error: errorContent, + ...res + } = e as ChatCompletionErrorPayload; + + const error = errorContent || e; + // track the error at server side + console.error(`Route: [${provider}] ${errorType}:`, error); + + return createErrorResponse(errorType, { error, ...res, provider }); + } +}); diff --git a/src/app/api/config/route.test.ts b/src/app/api/config/route.test.ts index 6835ff2920b2..717557d7bb79 100644 --- a/src/app/api/config/route.test.ts +++ b/src/app/api/config/route.test.ts @@ -1,6 +1,6 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'; -import { GlobalServerConfig } from '@/types/settings'; +import { GlobalServerConfig } from '@/types/serverConfig'; import { GET } from './route'; diff --git a/src/app/api/config/route.ts b/src/app/api/config/route.ts index 96ee7ed0506a..91fe56af12b7 100644 --- a/src/app/api/config/route.ts +++ b/src/app/api/config/route.ts @@ -4,7 +4,7 @@ import { TogetherAIProviderCard, } from '@/config/modelProviders'; import { getServerConfig } from '@/config/server'; -import { GlobalServerConfig } from '@/types/settings'; +import { GlobalServerConfig } from '@/types/serverConfig'; import { transformToChatModelCards } from '@/utils/parseModels'; import { parseAgentConfig } from './parseDefaultAgent'; diff --git a/src/app/settings/llm/Ollama/index.tsx b/src/app/settings/llm/Ollama/index.tsx index 4752f09e77e7..5d43e0f65a88 100644 --- a/src/app/settings/llm/Ollama/index.tsx +++ b/src/app/settings/llm/Ollama/index.tsx @@ -20,6 +20,7 @@ const OllamaProvider = memo(() => { label: t('llm.checker.title'), minWidth: undefined, }} + modelList={{ showModelFetcher: true }} provider={ModelProvider.Ollama} showApiKey={false} showEndpoint diff --git a/src/app/settings/llm/OpenAI/index.tsx b/src/app/settings/llm/OpenAI/index.tsx index b67ec5d96f3c..ddbb5089d734 100644 --- a/src/app/settings/llm/OpenAI/index.tsx +++ b/src/app/settings/llm/OpenAI/index.tsx @@ -4,7 +4,12 @@ import { memo } from 'react'; import ProviderConfig from '../components/ProviderConfig'; const OpenAIProvider = memo(() => ( - } /> + } + /> )); export default OpenAIProvider; diff --git a/src/app/settings/llm/OpenRouter/index.tsx b/src/app/settings/llm/OpenRouter/index.tsx index 37584a7b0678..d2f282debc72 100644 --- a/src/app/settings/llm/OpenRouter/index.tsx +++ b/src/app/settings/llm/OpenRouter/index.tsx @@ -12,6 +12,7 @@ const OpenRouterProvider = memo(() => { return ( { return ( ( placeholder={modelList?.placeholder ?? t('llm.modelList.placeholder')} provider={provider} showAzureDeployName={modelList?.azureDeployName} + showModelFetcher={modelList?.showModelFetcher} /> ), desc: t('llm.modelList.desc'), diff --git a/src/app/settings/llm/components/ProviderModelList/ModelFetcher.tsx b/src/app/settings/llm/components/ProviderModelList/ModelFetcher.tsx new file mode 100644 index 000000000000..5d055f94405e --- /dev/null +++ b/src/app/settings/llm/components/ProviderModelList/ModelFetcher.tsx @@ -0,0 +1,71 @@ +import { Icon, Tooltip } from '@lobehub/ui'; +import { Typography } from 'antd'; +import { createStyles } from 'antd-style'; +import dayjs from 'dayjs'; +import { LucideLoaderCircle, LucideRefreshCcwDot } from 'lucide-react'; +import { memo } from 'react'; +import { Flexbox } from 'react-layout-kit'; + +import { useGlobalStore } from '@/store/global'; +import { modelConfigSelectors } from '@/store/global/selectors'; +import { GlobalLLMProviderKey } from '@/types/settings'; + +const useStyles = createStyles(({ css, token }) => ({ + hover: css` + cursor: pointer; + padding: 4px 8px; + border-radius: ${token.borderRadius}px; + transition: all 0.2s ease-in-out; + + &:hover { + color: ${token.colorText}; + background-color: ${token.colorFillSecondary}; + } + `, +})); + +interface ModelFetcherProps { + provider: GlobalLLMProviderKey; +} + +const ModelFetcher = memo(({ provider }) => { + const { styles } = useStyles(); + const [useFetchProviderModelList] = useGlobalStore((s) => [ + s.useFetchProviderModelList, + s.setModelProviderConfig, + ]); + const enabledAutoFetch = useGlobalStore(modelConfigSelectors.enabledAutoFetchModels(provider)); + const latestFetchTime = useGlobalStore( + (s) => modelConfigSelectors.providerConfig(provider)(s)?.latestFetchTime, + ); + const totalModels = useGlobalStore( + (s) => modelConfigSelectors.providerModelCards(provider)(s).length, + ); + + const { mutate, isValidating } = useFetchProviderModelList(provider, enabledAutoFetch); + + return ( + + +
共 {totalModels} 个模型可用
+ + mutate()} + > + +
{isValidating ? '正在获取模型列表...' : '获取模型列表'}
+
+
+
+
+ ); +}); +export default ModelFetcher; diff --git a/src/app/settings/llm/components/ProviderModelList/index.tsx b/src/app/settings/llm/components/ProviderModelList/index.tsx index 041c47eba776..39c7924253cb 100644 --- a/src/app/settings/llm/components/ProviderModelList/index.tsx +++ b/src/app/settings/llm/components/ProviderModelList/index.tsx @@ -12,6 +12,7 @@ import { modelConfigSelectors, modelProviderSelectors } from '@/store/global/sel import { GlobalLLMProviderKey } from '@/types/settings'; import ModelConfigModal from './ModelConfigModal'; +import ModelFetcher from './ModelFetcher'; import OptionRender from './Option'; const styles = { @@ -36,20 +37,24 @@ interface CustomModelSelectProps { placeholder?: string; provider: GlobalLLMProviderKey; showAzureDeployName?: boolean; + showModelFetcher?: boolean; } const ProviderModelListSelect = memo( - ({ provider, showAzureDeployName, notFoundContent, placeholder }) => { + ({ showModelFetcher = false, provider, showAzureDeployName, notFoundContent, placeholder }) => { const { t } = useTranslation('common'); const { t: transSetting } = useTranslation('setting'); - const chatModelCards = useGlobalStore( - modelConfigSelectors.providerModelCards(provider), - isEqual, - ); const [setModelProviderConfig, dispatchCustomModelCards] = useGlobalStore((s) => [ s.setModelProviderConfig, s.dispatchCustomModelCards, + s.useFetchProviderModelList, ]); + + const chatModelCards = useGlobalStore( + modelConfigSelectors.providerModelCards(provider), + isEqual, + ); + const defaultEnableModel = useGlobalStore( modelProviderSelectors.defaultEnabledProviderModels(provider), isEqual, @@ -58,72 +63,78 @@ const ProviderModelListSelect = memo( modelConfigSelectors.providerEnableModels(provider), isEqual, ); + const showReset = !!enabledModels && !isEqual(defaultEnableModel, enabledModels); return ( -
-
- {showReset && ( - { - setModelProviderConfig(provider, { enabledModels: null }); - }} - size={'small'} - title={t('reset')} - /> - )} -
- - allowClear - mode="tags" - notFoundContent={notFoundContent} - onChange={(value, options) => { - setModelProviderConfig(provider, { enabledModels: value.filter(Boolean) }); + <> + +
+
+ {showReset && ( + { + setModelProviderConfig(provider, { enabledModels: null }); + }} + size={'small'} + title={t('reset')} + /> + )} +
+ + allowClear + mode="tags" + notFoundContent={notFoundContent} + onChange={(value, options) => { + setModelProviderConfig(provider, { enabledModels: value.filter(Boolean) }); - // if there is a new model, add it to `customModelCards` - options.forEach((option: { label?: string; value?: string }, index: number) => { - // if is a known model, it should have value - // if is an unknown model, the option will be {} - if (option.value) return; + // if there is a new model, add it to `customModelCards` + options.forEach((option: { label?: string; value?: string }, index: number) => { + // if is a known model, it should have value + // if is an unknown model, the option will be {} + if (option.value) return; - const modelId = value[index]; + const modelId = value[index]; - dispatchCustomModelCards(provider, { - modelCard: { id: modelId }, - type: 'add', - }); - }); - }} - optionFilterProp="label" - optionRender={({ label, value }) => { - // model is in the chatModels - if (chatModelCards.some((c) => c.id === value)) - return ( - - ); + dispatchCustomModelCards(provider, { + modelCard: { id: modelId }, + type: 'add', + }); + }); + }} + optionFilterProp="label" + optionRender={({ label, value }) => { + // model is in the chatModels + if (chatModelCards.some((c) => c.id === value)) + return ( + + ); - // model is defined by user in client - return ( - - {transSetting('llm.customModelCards.addNew', { id: value })} - - ); - }} - options={chatModelCards.map((model) => ({ - label: model.displayName || model.id, - value: model.id, - }))} - placeholder={placeholder} - popupClassName={cx(styles.popup)} - value={enabledModels ?? defaultEnableModel} - /> + // model is defined by user in client + return ( + + {transSetting('llm.customModelCards.addNew', { id: value })} + + ); + }} + options={chatModelCards.map((model) => ({ + label: model.displayName || model.id, + value: model.id, + }))} + placeholder={placeholder} + popupClassName={cx(styles.popup)} + value={enabledModels ?? defaultEnableModel} + /> +
+ {showModelFetcher && } +
-
+ ); }, ); diff --git a/src/app/settings/llm/index.tsx b/src/app/settings/llm/index.tsx index 599296e62918..a2da21366b53 100644 --- a/src/app/settings/llm/index.tsx +++ b/src/app/settings/llm/index.tsx @@ -32,17 +32,17 @@ export default memo<{ showOllama: boolean }>(({ showOllama }) => { {showOllama && } - - + + + + - - - +