Skip to content

Commit

Permalink
✨ feat: support user config model
Browse files Browse the repository at this point in the history
  • Loading branch information
arvinxx committed Apr 10, 2024
1 parent d865ca1 commit 72fd873
Show file tree
Hide file tree
Showing 11 changed files with 149 additions and 43 deletions.
1 change: 0 additions & 1 deletion src/app/settings/llm/Ollama/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ const OllamaProvider = memo(() => {
}}
provider={ModelProvider.Ollama}
showApiKey={false}
showCustomModelName
showEndpoint
title={
<Ollama.Combine color={theme.isDarkMode ? theme.colorText : theme.colorPrimary} size={24} />
Expand Down
19 changes: 15 additions & 4 deletions src/app/settings/llm/components/CustomModelList/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import { memo } from 'react';

import { filterEnabledModels } from '@/config/modelProviders';
import { useGlobalStore } from '@/store/global';
import { modelConfigSelectors } from '@/store/global/selectors';
import { modelConfigSelectors, modelProviderSelectors } from '@/store/global/selectors';
import { GlobalLLMProviderKey } from '@/types/settings';

import { OptionRender } from './Option';

Expand All @@ -18,22 +19,31 @@ const popup = css`
`;

interface CustomModelSelectProps {
onChange?: (value: string[]) => void;
placeholder?: string;
provider: string;
value?: string[];
}

const CustomModelSelect = memo<CustomModelSelectProps>(({ provider, placeholder }) => {
const CustomModelSelect = memo<CustomModelSelectProps>(({ provider, placeholder, onChange }) => {
const providerCard = useGlobalStore(
(s) => modelConfigSelectors.modelSelectList(s).find((s) => s.id === provider),
(s) => modelProviderSelectors.providerModelList(s).find((s) => s.id === provider),
isEqual,
);
const providerConfig = useGlobalStore((s) =>
modelConfigSelectors.providerConfig(provider as GlobalLLMProviderKey)(s),
);

const defaultEnableModel = providerCard ? filterEnabledModels(providerCard) : [];

return (
<Select
<Select<string[]>
allowClear
defaultValue={defaultEnableModel}
mode="tags"
onChange={(value) => {
onChange?.(value.filter(Boolean));
}}
optionFilterProp="label"
optionRender={({ label, value }) => (
<OptionRender displayName={label as string} id={value as string} />
Expand All @@ -45,6 +55,7 @@ const CustomModelSelect = memo<CustomModelSelectProps>(({ provider, placeholder
placeholder={placeholder}
popupClassName={cx(popup)}
popupMatchSelectWidth={false}
value={providerConfig?.models.filter(Boolean)}
/>
);
});
Expand Down
2 changes: 1 addition & 1 deletion src/components/ModelSelect/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ interface ModelInfoTagsProps extends ChatModelCard {
}
export const ModelInfoTags = memo<ModelInfoTagsProps>(
({ directionReverse, placement = 'right', ...model }) => {
const { t } = useTranslation('common');
const { t } = useTranslation('components');
const { styles, cx } = useStyles();

return (
Expand Down
16 changes: 6 additions & 10 deletions src/features/AgentSetting/AgentConfig/ModelSelect.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,16 @@ interface ModelOption {

const ModelSelect = memo(() => {
const [model, updateConfig] = useStore((s) => [s.config.model, s.setAgentConfig]);
const select = useGlobalStore(modelConfigSelectors.modelSelectList, isEqual);
const enabledList = useGlobalStore(modelConfigSelectors.enabledModelProviderList, isEqual);
const { styles } = useStyles();

const enabledList = select.filter((s) => s.enabled);

const options = useMemo<SelectProps['options']>(() => {
const getChatModels = (provider: ModelProviderCard) =>
provider.chatModels
.filter((c) => !c.hidden)
.map((model) => ({
label: <ModelItemRender {...model} />,
provider: provider.id,
value: model.id,
}));
provider.chatModels.map((model) => ({
label: <ModelItemRender {...model} />,
provider: provider.id,
value: model.id,
}));

if (enabledList.length === 1) {
const provider = enabledList[0];
Expand Down
51 changes: 39 additions & 12 deletions src/features/ModelSwitchPanel/index.tsx
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import { Icon } from '@lobehub/ui';
import { Dropdown } from 'antd';
import { createStyles } from 'antd-style';
import isEqual from 'fast-deep-equal';
import { LucideArrowRight } from 'lucide-react';
import { useRouter } from 'next/navigation';
import { PropsWithChildren, memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { Flexbox } from 'react-layout-kit';

import { ModelItemRender, ProviderItemRender } from '@/components/ModelSelect';
import { useGlobalStore } from '@/store/global';
import { modelConfigSelectors } from '@/store/global/selectors';
import { useSessionStore } from '@/store/session';
import { agentSelectors } from '@/store/session/selectors';
import { ModelProviderCard } from '@/types/llm';
import { withBasePath } from '@/utils/basePath';

const useStyles = createStyles(({ css, prefixCls }) => ({
menu: css`
Expand All @@ -32,30 +38,51 @@ const useStyles = createStyles(({ css, prefixCls }) => ({
}));

const ModelSwitchPanel = memo<PropsWithChildren>(({ children }) => {
const { styles } = useStyles();
const { t } = useTranslation('components');
const { styles, theme } = useStyles();
const model = useSessionStore(agentSelectors.currentAgentModel);
const updateAgentConfig = useSessionStore((s) => s.updateAgentConfig);

const select = useGlobalStore(modelConfigSelectors.modelSelectList, isEqual);
const enabledList = select.filter((s) => s.enabled);
const router = useRouter();
const enabledList = useGlobalStore(modelConfigSelectors.enabledModelProviderList, isEqual);

const items = useMemo(() => {
const getModelItems = (provider: ModelProviderCard) =>
provider.chatModels
.filter((c) => !c.hidden)
.map((model) => ({
key: model.id,
label: <ModelItemRender {...model} />,
onClick: () => {
updateAgentConfig({ model: model.id, provider: provider.id });
const getModelItems = (provider: ModelProviderCard) => {
const items = provider.chatModels.map((model) => ({
key: model.id,
label: <ModelItemRender {...model} />,
onClick: () => {
updateAgentConfig({ model: model.id, provider: provider.id });
},
}));

// if there is empty items, add a placeholder guide
if (items.length === 0)
return [
{
key: 'empty',
label: (
<Flexbox gap={8} horizontal style={{ color: theme.colorTextTertiary }}>
{t('ModelSwitchPanel.emptyModel')}
<Icon icon={LucideArrowRight} />
</Flexbox>
),
onClick: () => {
router.push(withBasePath('/settings/llm'));
},
},
}));
];

return items;
};

// If there is only one provider, just remove the group, show model directly
if (enabledList.length === 1) {
const provider = enabledList[0];
return getModelItems(provider);
}

// otherwise show with provider group
return enabledList.map((provider) => ({
children: getModelItems(provider),
key: provider.id,
Expand Down
11 changes: 0 additions & 11 deletions src/locales/default/common.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,7 @@
export default {
ModelSelect: {
featureTag: {
custom: '自定义模型,默认设定同时支持函数调用与视觉识别,请根据实际情况验证上述能力的可用性',
file: '该模型支持上传文件读取与识别',
functionCall: '该模型支持函数调用(Function Call)',
tokens: '该模型单个会话最多支持 {{tokens}} Tokens',
vision: '该模型支持视觉识别',
},
},
about: '关于',
advanceSettings: '高级设置',

appInitializing: 'LobeChat 启动中,请耐心等待...',

autoGenerate: '自动补全',
autoGenerateTooltip: '基于提示词自动补全助手描述',
cancel: '取消',
Expand Down
15 changes: 15 additions & 0 deletions src/locales/default/components.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
export default {
ModelSelect: {
featureTag: {
custom: '自定义模型,默认设定同时支持函数调用与视觉识别,请根据实际情况验证上述能力的可用性',
file: '该模型支持上传文件读取与识别',
functionCall: '该模型支持函数调用(Function Call)',
tokens: '该模型单个会话最多支持 {{tokens}} Tokens',
vision: '该模型支持视觉识别',
},
},
ModelSwitchPanel: {
emptyModel: '没有启用的模型,请前往设置开启',
provider: '提供商',
},
};
2 changes: 2 additions & 0 deletions src/locales/default/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tool from '../default/tool';
import chat from './chat';
import common from './common';
import components from './components';
import error from './error';
import market from './market';
import migration from './migration';
Expand All @@ -11,6 +12,7 @@ import welcome from './welcome';
const resources = {
chat,
common,
components,
error,
market,
migration,
Expand Down
30 changes: 30 additions & 0 deletions src/store/global/slices/settings/selectors/modelConfig.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { describe, expect, it } from 'vitest';

import { DEFAULT_SETTINGS } from '@/const/settings';
import { modelProviderSelectors } from '@/store/global/slices/settings/selectors/modelProvider';
import { agentSelectors } from '@/store/session/slices/agent';
import { merge } from '@/utils/merge';

import { GlobalStore, useGlobalStore } from '../../../store';
import { GlobalSettingsState, initialSettingsState } from '../initialState';
import { modelConfigSelectors } from './modelConfig';

describe('modelConfigSelectors', () => {
describe('modelSelectList', () => {
it('visible', () => {
const s = merge(initialSettingsState, {
settings: {
languageModel: {
ollama: {
models: ['llava'],
},
},
},
} as GlobalSettingsState) as unknown as GlobalStore;

const ollamaList = modelConfigSelectors.modelSelectList(s).find((r) => r.id === 'ollama');

expect(ollamaList?.chatModels).toEqual([]);
});
});
});
43 changes: 40 additions & 3 deletions src/store/global/slices/settings/selectors/modelConfig.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
import { ModelProviderCard } from '@/types/llm';
import { GlobalLLMProviderKey } from '@/types/settings';
import { GeneralModelProviderConfig, GlobalLLMProviderKey } from '@/types/settings';

import { GlobalStore } from '../../../store';
import { modelProviderSelectors } from './modelProvider';
import { currentSettings } from './settings';

const modelProvider = (s: GlobalStore) => currentSettings(s).languageModel;
const providerEnabled = (provider: GlobalLLMProviderKey) => (s: GlobalStore) =>
currentSettings(s).languageModel[provider]?.enabled || false;

const providerConfig = (provider: string) => (s: GlobalStore) =>
currentSettings(s).languageModel[provider as GlobalLLMProviderKey] as
| GeneralModelProviderConfig
| undefined;

const providerEnabled = (provider: GlobalLLMProviderKey) => (s: GlobalStore) => {
// TODO: we need to migrate the 'openAI' key to 'openai'
// @ts-ignore
if (provider === 'openai') return true;

return currentSettings(s).languageModel[provider]?.enabled || false;
};

const providerEnableModels =
(provider: string) =>
(s: GlobalStore): string[] | undefined => {
return providerConfig(provider)(s)?.models;
};

const openAIConfig = (s: GlobalStore) => modelProvider(s).openAI;

Expand Down Expand Up @@ -67,14 +84,34 @@ const zerooneAPIKey = (s: GlobalStore) => modelProvider(s).zeroone.apiKey;
const modelSelectList = (s: GlobalStore): ModelProviderCard[] => {
return modelProviderSelectors.providerModelList(s).map((list) => ({
...list,
chatModels: list.chatModels.map((model) => {
const models = providerEnableModels(list.id)(s);

if (!models) return model;

return {
...model,
hidden: !models?.some((m) => m === model.id),
};
}),
enabled: providerEnabled(list.id as any)(s),
}));
};

const enabledModelProviderList = (s: GlobalStore): ModelProviderCard[] =>
modelSelectList(s)
.filter((s) => s.enabled)
.map((provider) => ({
...provider,
chatModels: provider.chatModels.filter((model) => !model.hidden),
}));

/* eslint-disable sort-keys-fix/sort-keys-fix, */
export const modelConfigSelectors = {
providerEnabled,
providerConfig,
modelSelectList,
enabledModelProviderList,

// OpenAI
openAIConfig,
Expand Down
2 changes: 1 addition & 1 deletion src/types/settings/modelProvider.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export type CustomModels = { displayName: string; id: string }[];

interface GeneralModelProviderConfig {
export interface GeneralModelProviderConfig {
apiKey?: string;
enabled: boolean;
endpoint?: string;
Expand Down

0 comments on commit 72fd873

Please sign in to comment.