From e8ed84720306d40d2bf31228d90a6214812a9714 Mon Sep 17 00:00:00 2001 From: Arvin Xu Date: Mon, 8 Apr 2024 08:55:13 +0000 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20support=20update=20model=20?= =?UTF-8?q?config?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/app/settings/llm/OpenAI/index.tsx | 2 +- .../ProviderModelList/CustomModelOption.tsx | 94 +++++------ .../ProviderModelList/MaxTokenSlider.tsx | 88 +++++++++++ .../ProviderModelList/ModelConfigModal.tsx | 147 +++++++++++------- .../components/ProviderModelList/index.tsx | 7 +- .../general.test.ts} | 18 --- .../{action.ts => actions/general.ts} | 40 +---- .../global/slices/settings/actions/index.ts | 18 +++ .../slices/settings/actions/llm.test.ts | 35 +++++ .../global/slices/settings/actions/llm.ts | 50 ++++++ .../global/slices/settings/initialState.ts | 2 + .../settings/reducers/customModelCard.test.ts | 6 +- .../settings/reducers/customModelCard.ts | 12 +- .../slices/settings/selectors/modelConfig.ts | 19 +++ src/store/global/store.ts | 2 +- 15 files changed, 365 insertions(+), 175 deletions(-) create mode 100644 src/app/settings/llm/components/ProviderModelList/MaxTokenSlider.tsx rename src/store/global/slices/settings/{action.test.ts => actions/general.test.ts} (85%) rename src/store/global/slices/settings/{action.ts => actions/general.ts} (55%) create mode 100644 src/store/global/slices/settings/actions/index.ts create mode 100644 src/store/global/slices/settings/actions/llm.test.ts create mode 100644 src/store/global/slices/settings/actions/llm.ts diff --git a/src/app/settings/llm/OpenAI/index.tsx b/src/app/settings/llm/OpenAI/index.tsx index ceb1475d0207..028fff9b34bf 100644 --- a/src/app/settings/llm/OpenAI/index.tsx +++ b/src/app/settings/llm/OpenAI/index.tsx @@ -69,7 +69,7 @@ const LLM = memo(() => { children: ( ), desc: t('llm.openai.customModelName.desc'), diff --git a/src/app/settings/llm/components/ProviderModelList/CustomModelOption.tsx b/src/app/settings/llm/components/ProviderModelList/CustomModelOption.tsx index de7d470e2b7c..51fa800f69d7 100644 --- a/src/app/settings/llm/components/ProviderModelList/CustomModelOption.tsx +++ b/src/app/settings/llm/components/ProviderModelList/CustomModelOption.tsx @@ -1,16 +1,17 @@ import { ActionIcon } from '@lobehub/ui'; import { App, Typography } from 'antd'; +import isEqual from 'fast-deep-equal'; import { LucideSettings, LucideTrash2 } from 'lucide-react'; -import { memo, useState } from 'react'; +import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import { Flexbox } from 'react-layout-kit'; import ModelIcon from '@/components/ModelIcon'; +import { ModelInfoTags } from '@/components/ModelSelect'; import { useGlobalStore } from '@/store/global'; +import { modelConfigSelectors } from '@/store/global/slices/settings/selectors'; import { GlobalLLMProviderKey } from '@/types/settings'; -import ModelConfigModal from './ModelConfigModal'; - interface CustomModelOptionProps { id: string; provider: GlobalLLMProviderKey; @@ -21,57 +22,60 @@ const CustomModelOption = memo(({ id, provider }) => { const { t: s } = useTranslation('setting'); const { modal } = App.useApp(); - const [open, setOpen] = useState(true); - const [dispatchCustomModelCards] = useGlobalStore((s) => [s.dispatchCustomModelCards]); + const [dispatchCustomModelCards, toggleEditingCustomModelCard] = useGlobalStore((s) => [ + s.dispatchCustomModelCards, + s.toggleEditingCustomModelCard, + ]); + const modelCard = useGlobalStore( + modelConfigSelectors.getCustomModelCardById({ id, provider }), + isEqual, + ); return ( - <> - + + + - - - - {id} - {/**/} - - - {id} - + + {modelCard?.displayName || id} + + + {id} + + - - { - e.stopPropagation(); - setOpen(true); - }} - title={s('llm.customModelCards.config')} - /> - { - e.stopPropagation(); - e.preventDefault(); + + { + e.stopPropagation(); + toggleEditingCustomModelCard({ id, provider }); + }} + title={s('llm.customModelCards.config')} + /> + { + e.stopPropagation(); + e.preventDefault(); - const isConfirm = await modal.confirm({ - centered: true, - content: s('llm.customModelCards.confirmDelete'), - okButtonProps: { danger: true }, - type: 'warning', - }); + const isConfirm = await modal.confirm({ + centered: true, + content: s('llm.customModelCards.confirmDelete'), + okButtonProps: { danger: true }, + type: 'warning', + }); - if (isConfirm) { - dispatchCustomModelCards(provider, { id, type: 'delete' }); - } - }} - title={t('delete')} - /> - + if (isConfirm) { + dispatchCustomModelCards(provider, { id, type: 'delete' }); + } + }} + title={t('delete')} + /> - - + ); }); diff --git a/src/app/settings/llm/components/ProviderModelList/MaxTokenSlider.tsx b/src/app/settings/llm/components/ProviderModelList/MaxTokenSlider.tsx new file mode 100644 index 000000000000..628ddbc66f88 --- /dev/null +++ b/src/app/settings/llm/components/ProviderModelList/MaxTokenSlider.tsx @@ -0,0 +1,88 @@ +import { InputNumber, Slider, SliderSingleProps } from 'antd'; +import { memo } from 'react'; +import { Flexbox } from 'react-layout-kit'; +import useMergeState from 'use-merge-value'; + +const exponent = (num: number) => Math.log2(num); +const getRealValue = (num: number) => Math.round(Math.pow(2, num)); + +const marks: SliderSingleProps['marks'] = { + [exponent(1)]: '1k', + [exponent(2)]: '2k', + [exponent(4)]: '4k', + [exponent(8)]: '8k', + [exponent(16)]: '16k', + [exponent(32)]: '32k', + [exponent(64)]: '64k', + [exponent(128)]: '128k', + [exponent(200)]: '200k', + [exponent(1000)]: '1M', +}; + +interface MaxTokenSliderProps { + defaultValue?: number; + onChange?: (value: number) => void; + value?: number; +} + +const MaxTokenSlider = memo(({ value, onChange, defaultValue }) => { + const [token, setTokens] = useMergeState(0, { + defaultValue, + onChange, + value: value, + }); + + const [powValue, setPowValue] = useMergeState(0, { + defaultValue: exponent(typeof defaultValue === 'undefined' ? 0 : defaultValue / 1000), + value: exponent(typeof value === 'undefined' ? 0 : value / 1000), + }); + + const updateWithPowValue = (value: number) => { + setPowValue(value); + + setTokens(getRealValue(value) * 1024); + }; + const updateWithRealValue = (value: number) => { + setTokens(value); + + setPowValue(exponent(value / 1024)); + }; + + return ( + + + { + if (typeof x === 'undefined') return; + + const value = getRealValue(x); + + if (value < 1000) return value.toFixed(0) + 'K'; + + return (value / 1000).toFixed(0) + 'M'; + }, + }} + value={powValue} + /> + +
+ { + if (!e) return; + + updateWithRealValue(e); + }} + step={1024} + value={token} + /> +
+
+ ); +}); +export default MaxTokenSlider; diff --git a/src/app/settings/llm/components/ProviderModelList/ModelConfigModal.tsx b/src/app/settings/llm/components/ProviderModelList/ModelConfigModal.tsx index 61538e06d9d2..975c7178002c 100644 --- a/src/app/settings/llm/components/ProviderModelList/ModelConfigModal.tsx +++ b/src/app/settings/llm/components/ProviderModelList/ModelConfigModal.tsx @@ -1,81 +1,108 @@ -import { Modal, SliderWithInput } from '@lobehub/ui'; +import { Modal } from '@lobehub/ui'; import { Checkbox, Form, Input } from 'antd'; +import isEqual from 'fast-deep-equal'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; -import { GlobalLLMProviderKey } from '@/types/settings'; +import { useGlobalStore } from '@/store/global'; +import { modelConfigSelectors } from '@/store/global/slices/settings/selectors'; -interface ModelConfigModalProps { - id: string; - onOpenChange: (open: boolean) => void; - open?: boolean; - provider: GlobalLLMProviderKey; -} -const ModelConfigModal = memo(({ open, id, onOpenChange }) => { +import MaxTokenSlider from './MaxTokenSlider'; + +const ModelConfigModal = memo(() => { const [formInstance] = Form.useForm(); const { t } = useTranslation('setting'); + const [open, id, provider, dispatchCustomModelCards, toggleEditingCustomModelCard] = + useGlobalStore((s) => [ + !!s.editingCustomCardModel, + s.editingCustomCardModel?.id, + s.editingCustomCardModel?.provider, + s.dispatchCustomModelCards, + s.toggleEditingCustomModelCard, + ]); + + const modelCard = useGlobalStore( + modelConfigSelectors.getCustomModelCardById({ id, provider }), + isEqual, + ); + + const closeModal = () => { + toggleEditingCustomModelCard(undefined); + }; + return ( { - onOpenChange(false); + closeModal(); + }} + onOk={() => { + if (!provider || !id) return; + const data = formInstance.getFieldsValue(); + + dispatchCustomModelCards(provider as any, { id, type: 'update', value: data }); + + closeModal(); }} open={open} title={t('llm.customModelCards.modelConfig.modalTitle')} > -
{ + e.stopPropagation(); + }} + onKeyDown={(e) => { + e.stopPropagation(); + }} > - - - - - - - - - - - - - - - - - - -
+ + + + + + + + + + + + + + + + + + + +
); }); diff --git a/src/app/settings/llm/components/ProviderModelList/index.tsx b/src/app/settings/llm/components/ProviderModelList/index.tsx index 07f8ee1d2f96..602aaece2bf4 100644 --- a/src/app/settings/llm/components/ProviderModelList/index.tsx +++ b/src/app/settings/llm/components/ProviderModelList/index.tsx @@ -7,6 +7,7 @@ import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import { Flexbox } from 'react-layout-kit'; +import ModelConfigModal from '@/app/settings/llm/components/ProviderModelList/ModelConfigModal'; import { useGlobalStore } from '@/store/global'; import { modelConfigSelectors, modelProviderSelectors } from '@/store/global/selectors'; import { GlobalLLMProviderKey } from '@/types/settings'; @@ -23,10 +24,10 @@ const styles = { `, reset: css` position: absolute; - top: 50%; - transform: translateY(-50%); z-index: 20; + top: 50%; inset-inline-end: 28px; + transform: translateY(-50%); `, }; @@ -87,7 +88,6 @@ const ProviderModelListSelect = memo(({ provider, placeh }); }); }} - open optionFilterProp="label" optionRender={({ label, value }) => { // model is in the chatModels @@ -115,6 +115,7 @@ const ProviderModelListSelect = memo(({ provider, placeh popupClassName={cx(styles.popup)} value={enabledModels ?? defaultEnableModel} /> + ); }); diff --git a/src/store/global/slices/settings/action.test.ts b/src/store/global/slices/settings/actions/general.test.ts similarity index 85% rename from src/store/global/slices/settings/action.test.ts rename to src/store/global/slices/settings/actions/general.test.ts index 3e5f34808919..c01023be8d29 100644 --- a/src/store/global/slices/settings/action.test.ts +++ b/src/store/global/slices/settings/actions/general.test.ts @@ -66,24 +66,6 @@ describe('SettingsAction', () => { }); }); - describe('setModelProviderConfig', () => { - it('should set OpenAI configuration', async () => { - const { result } = renderHook(() => useGlobalStore()); - const openAIConfig: Partial = { OPENAI_API_KEY: 'test-key' }; - - // Perform the action - await act(async () => { - await result.current.setModelProviderConfig('openAI', openAIConfig); - }); - - // Assert that updateUserSettings was called with the correct OpenAI configuration - expect(userService.updateUserSettings).toHaveBeenCalledWith({ - languageModel: { - openAI: openAIConfig, - }, - }); - }); - }); describe('setSettings', () => { it('should set partial settings', async () => { const { result } = renderHook(() => useGlobalStore()); diff --git a/src/store/global/slices/settings/action.ts b/src/store/global/slices/settings/actions/general.ts similarity index 55% rename from src/store/global/slices/settings/action.ts rename to src/store/global/slices/settings/actions/general.ts index adcc1ef45c1e..fa45a574d08e 100644 --- a/src/store/global/slices/settings/action.ts +++ b/src/store/global/slices/settings/actions/general.ts @@ -6,48 +6,24 @@ import type { StateCreator } from 'zustand/vanilla'; import { userService } from '@/services/user'; import type { GlobalStore } from '@/store/global'; import { LobeAgentSettings } from '@/types/session'; -import { GlobalLLMConfig, GlobalLLMProviderKey, GlobalSettings } from '@/types/settings'; +import { GlobalSettings } from '@/types/settings'; import { difference } from '@/utils/difference'; import { merge } from '@/utils/merge'; -import { CustomModelCardDispatch, customModelCardsReducer } from './reducers/customModelCard'; -import { modelConfigSelectors } from './selectors/modelConfig'; - -/** - * 设置操作 - */ -export interface SettingsAction { - dispatchCustomModelCards: ( - provider: GlobalLLMProviderKey, - payload: CustomModelCardDispatch, - ) => Promise; +export interface GeneralSettingsAction { importAppSettings: (settings: GlobalSettings) => Promise; resetSettings: () => Promise; - setModelProviderConfig: ( - provider: T, - config: Partial, - ) => Promise; setSettings: (settings: DeepPartial) => Promise; switchThemeMode: (themeMode: ThemeMode) => Promise; - toggleProviderEnabled: (provider: GlobalLLMProviderKey, enabled: boolean) => Promise; updateDefaultAgent: (agent: DeepPartial) => Promise; } -export const createSettingsSlice: StateCreator< +export const generalSettingsSlice: StateCreator< GlobalStore, [['zustand/devtools', never]], [], - SettingsAction + GeneralSettingsAction > = (set, get) => ({ - dispatchCustomModelCards: async (provider, payload) => { - const prevState = modelConfigSelectors.providerConfig(provider)(get()); - - if (!prevState) return; - - const nextState = customModelCardsReducer(prevState.customModelCards, payload); - - await get().setModelProviderConfig(provider, { customModelCards: nextState }); - }, importAppSettings: async (importAppSettings) => { const { setSettings } = get(); // eslint-disable-next-line @typescript-eslint/no-unused-vars @@ -59,9 +35,6 @@ export const createSettingsSlice: StateCreator< await userService.resetUserSettings(); await get().refreshUserConfig(); }, - setModelProviderConfig: async (provider, config) => { - await get().setSettings({ languageModel: { [provider]: config } }); - }, setSettings: async (settings) => { const { settings: prevSetting, defaultSettings } = get(); @@ -70,18 +43,13 @@ export const createSettingsSlice: StateCreator< if (isEqual(prevSetting, nextSettings)) return; const diffs = difference(nextSettings, defaultSettings); - console.log(diffs); await userService.updateUserSettings(diffs); await get().refreshUserConfig(); }, - switchThemeMode: async (themeMode) => { await get().setSettings({ themeMode }); }, - toggleProviderEnabled: async (provider, enabled) => { - await get().setSettings({ languageModel: { [provider]: { enabled } } }); - }, updateDefaultAgent: async (defaultAgent) => { await get().setSettings({ defaultAgent }); }, diff --git a/src/store/global/slices/settings/actions/index.ts b/src/store/global/slices/settings/actions/index.ts new file mode 100644 index 000000000000..880ffb65f2ec --- /dev/null +++ b/src/store/global/slices/settings/actions/index.ts @@ -0,0 +1,18 @@ +import type { StateCreator } from 'zustand/vanilla'; + +import type { GlobalStore } from '@/store/global'; + +import { GeneralSettingsAction, generalSettingsSlice } from './general'; +import { LLMSettingsAction, llmSettingsSlice } from './llm'; + +export interface SettingsAction extends LLMSettingsAction, GeneralSettingsAction {} + +export const createSettingsSlice: StateCreator< + GlobalStore, + [['zustand/devtools', never]], + [], + SettingsAction +> = (...params) => ({ + ...llmSettingsSlice(...params), + ...generalSettingsSlice(...params), +}); diff --git a/src/store/global/slices/settings/actions/llm.test.ts b/src/store/global/slices/settings/actions/llm.test.ts new file mode 100644 index 000000000000..2a2a68247975 --- /dev/null +++ b/src/store/global/slices/settings/actions/llm.test.ts @@ -0,0 +1,35 @@ +import { act, renderHook, waitFor } from '@testing-library/react'; +import { describe, expect, it, vi } from 'vitest'; + +import { userService } from '@/services/user'; +import { useGlobalStore } from '@/store/global'; +import { GlobalSettings, OpenAIConfig } from '@/types/settings'; + +// Mock userService +vi.mock('@/services/user', () => ({ + userService: { + updateUserSettings: vi.fn(), + resetUserSettings: vi.fn(), + }, +})); + +describe('SettingsAction', () => { + describe('setModelProviderConfig', () => { + it('should set OpenAI configuration', async () => { + const { result } = renderHook(() => useGlobalStore()); + const openAIConfig: Partial = { OPENAI_API_KEY: 'test-key' }; + + // Perform the action + await act(async () => { + await result.current.setModelProviderConfig('openAI', openAIConfig); + }); + + // Assert that updateUserSettings was called with the correct OpenAI configuration + expect(userService.updateUserSettings).toHaveBeenCalledWith({ + languageModel: { + openAI: openAIConfig, + }, + }); + }); + }); +}); diff --git a/src/store/global/slices/settings/actions/llm.ts b/src/store/global/slices/settings/actions/llm.ts new file mode 100644 index 000000000000..54c0ce225309 --- /dev/null +++ b/src/store/global/slices/settings/actions/llm.ts @@ -0,0 +1,50 @@ +import type { StateCreator } from 'zustand/vanilla'; + +import type { GlobalStore } from '@/store/global'; +import { GlobalLLMConfig, GlobalLLMProviderKey } from '@/types/settings'; + +import { CustomModelCardDispatch, customModelCardsReducer } from '../reducers/customModelCard'; +import { modelConfigSelectors } from '../selectors/modelConfig'; + +/** + * 设置操作 + */ +export interface LLMSettingsAction { + dispatchCustomModelCards: ( + provider: GlobalLLMProviderKey, + payload: CustomModelCardDispatch, + ) => Promise; + setModelProviderConfig: ( + provider: T, + config: Partial, + ) => Promise; + + toggleEditingCustomModelCard: (params?: { id: string; provider: GlobalLLMProviderKey }) => void; + toggleProviderEnabled: (provider: GlobalLLMProviderKey, enabled: boolean) => Promise; +} + +export const llmSettingsSlice: StateCreator< + GlobalStore, + [['zustand/devtools', never]], + [], + LLMSettingsAction +> = (set, get) => ({ + dispatchCustomModelCards: async (provider, payload) => { + const prevState = modelConfigSelectors.providerConfig(provider)(get()); + + if (!prevState) return; + + const nextState = customModelCardsReducer(prevState.customModelCards, payload); + + await get().setModelProviderConfig(provider, { customModelCards: nextState }); + }, + setModelProviderConfig: async (provider, config) => { + await get().setSettings({ languageModel: { [provider]: config } }); + }, + toggleEditingCustomModelCard: (params) => { + set({ editingCustomCardModel: params }, false, 'toggleEditingCustomModelCard'); + }, + toggleProviderEnabled: async (provider, enabled) => { + await get().setSettings({ languageModel: { [provider]: { enabled } } }); + }, +}); diff --git a/src/store/global/slices/settings/initialState.ts b/src/store/global/slices/settings/initialState.ts index 1377dd5a3457..c4f3c1fccb93 100644 --- a/src/store/global/slices/settings/initialState.ts +++ b/src/store/global/slices/settings/initialState.ts @@ -6,6 +6,8 @@ import { GlobalServerConfig, GlobalSettings } from '@/types/settings'; export interface GlobalSettingsState { avatar?: string; defaultSettings: GlobalSettings; + editingCustomCardModel?: { id: string; provider: string } | undefined; + serverConfig: GlobalServerConfig; settings: DeepPartial; userId?: string; diff --git a/src/store/global/slices/settings/reducers/customModelCard.test.ts b/src/store/global/slices/settings/reducers/customModelCard.test.ts index 1feec1093350..9c43c0b723b3 100644 --- a/src/store/global/slices/settings/reducers/customModelCard.test.ts +++ b/src/store/global/slices/settings/reducers/customModelCard.test.ts @@ -106,8 +106,7 @@ describe('customModelCardsReducer', () => { const action: UpdateCustomModelCard = { type: 'update', id: 'model1', - key: 'displayName', - value: 'Updated Model 1', + value: { displayName: 'Updated Model 1' }, }; const newState = customModelCardsReducer(initialState, action); @@ -130,8 +129,7 @@ describe('customModelCardsReducer', () => { const action: UpdateCustomModelCard = { type: 'update', id: 'nonexistent', - key: 'displayName', - value: 'Updated Nonexistent Model', + value: { displayName: 'Updated Nonexistent Model' }, }; const newState = customModelCardsReducer(initialState, action); diff --git a/src/store/global/slices/settings/reducers/customModelCard.ts b/src/store/global/slices/settings/reducers/customModelCard.ts index 5bfcb92bd45a..306573ea8678 100644 --- a/src/store/global/slices/settings/reducers/customModelCard.ts +++ b/src/store/global/slices/settings/reducers/customModelCard.ts @@ -14,9 +14,8 @@ export interface DeleteCustomModelCard { export interface UpdateCustomModelCard { id: string; - key: keyof ChatModelCard; type: 'update'; - value: ChatModelCard[keyof ChatModelCard]; + value: Partial; } export type CustomModelCardDispatch = @@ -51,11 +50,10 @@ export const customModelCardsReducer = ( case 'update': { return produce(state || [], (draftState) => { const index = draftState.findIndex((card) => card.id === payload.id); - if (index === -1) return; - - const card = draftState[index]; - // @ts-ignore - card[payload.key] = payload.value; + if (index !== -1) { + const card = draftState[index]; + Object.assign(card, payload.value); + } }); } diff --git a/src/store/global/slices/settings/selectors/modelConfig.ts b/src/store/global/slices/settings/selectors/modelConfig.ts index 3a7b02c8e51e..53f056c1e2e1 100644 --- a/src/store/global/slices/settings/selectors/modelConfig.ts +++ b/src/store/global/slices/settings/selectors/modelConfig.ts @@ -124,12 +124,31 @@ const providerModelCards = return uniqBy([...builtinCards, ...userCards], 'id'); }; +const getCustomModelCardById = + ({ id, provider }: { id?: string; provider?: string }) => + (s: GlobalStore) => { + if (!provider) return; + + const config = providerConfig(provider)(s); + + return config?.customModelCards?.find((m) => m.id === id); + }; + +const currentEditingCustomModelCard = (s: GlobalStore) => { + if (!s.editingCustomCardModel) return; + const { id, provider } = s.editingCustomCardModel; + + return getCustomModelCardById({ id, provider }); +}; + /* eslint-disable sort-keys-fix/sort-keys-fix, */ export const modelConfigSelectors = { providerEnabled, providerEnableModels, providerConfig, providerModelCards, + currentEditingCustomModelCard, + getCustomModelCardById, modelSelectList, enabledModelProviderList, diff --git a/src/store/global/store.ts b/src/store/global/store.ts index c281377dc87e..ec6dbcb18305 100644 --- a/src/store/global/store.ts +++ b/src/store/global/store.ts @@ -9,7 +9,7 @@ import { createHyperStorage } from '../middleware/createHyperStorage'; import { type GlobalState, initialState } from './initialState'; import { type CommonAction, createCommonSlice } from './slices/common/action'; import { type PreferenceAction, createPreferenceSlice } from './slices/preference/action'; -import { type SettingsAction, createSettingsSlice } from './slices/settings/action'; +import { type SettingsAction, createSettingsSlice } from './slices/settings/actions'; // =============== 聚合 createStoreFn ============ //