Skip to content

Commit

Permalink
✨ feat: support openai model fetcher
Browse files Browse the repository at this point in the history
  • Loading branch information
arvinxx committed Apr 10, 2024
1 parent ef5ee2a commit 56032e6
Show file tree
Hide file tree
Showing 32 changed files with 653 additions and 118 deletions.
4 changes: 4 additions & 0 deletions src/app/api/chat/agentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ class AgentRuntime {
});
}

async models() {
return this._runtime.models?.();
}

static async initializeWithUserPayload(provider: string, payload: JWTPayload) {
let runtimeModel: LobeRuntimeAI;

Expand Down
37 changes: 37 additions & 0 deletions src/app/api/chat/models/[provider]/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { NextResponse } from 'next/server';

import { getPreferredRegion } from '@/app/api/config';
import { createErrorResponse } from '@/app/api/errorResponse';
import { ChatCompletionErrorPayload } from '@/libs/agent-runtime';
import { ChatErrorType } from '@/types/fetch';

import AgentRuntime from '../../agentRuntime';
import { checkAuth } from '../../auth';

export const runtime = 'edge';

export const preferredRegion = getPreferredRegion();

export const GET = checkAuth(async (req, { params, jwtPayload }) => {
const { provider } = params;

try {
const agentRuntime = await AgentRuntime.initializeWithUserPayload(provider, jwtPayload);

const list = await agentRuntime.models();

return NextResponse.json(list);
} catch (e) {
const {
errorType = ChatErrorType.InternalServerError,
error: errorContent,
...res
} = e as ChatCompletionErrorPayload;

const error = errorContent || e;
// track the error at server side
console.error(`Route: [${provider}] ${errorType}:`, error);

return createErrorResponse(errorType, { error, ...res, provider });
}
});
2 changes: 1 addition & 1 deletion src/app/api/config/route.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { beforeEach, describe, expect, it, vi } from 'vitest';

import { GlobalServerConfig } from '@/types/settings';
import { GlobalServerConfig } from '@/types/serverConfig';

import { GET } from './route';

Expand Down
2 changes: 1 addition & 1 deletion src/app/api/config/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import {
TogetherAIProviderCard,
} from '@/config/modelProviders';
import { getServerConfig } from '@/config/server';
import { GlobalServerConfig } from '@/types/settings';
import { GlobalServerConfig } from '@/types/serverConfig';
import { transformToChatModelCards } from '@/utils/parseModels';

import { parseAgentConfig } from './parseDefaultAgent';
Expand Down
1 change: 1 addition & 0 deletions src/app/settings/llm/Ollama/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const OllamaProvider = memo(() => {
label: t('llm.checker.title'),
minWidth: undefined,
}}
modelList={{ showModelFetcher: true }}
provider={ModelProvider.Ollama}
showApiKey={false}
showEndpoint
Expand Down
7 changes: 6 additions & 1 deletion src/app/settings/llm/OpenAI/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ import { memo } from 'react';
import ProviderConfig from '../components/ProviderConfig';

const OpenAIProvider = memo(() => (
<ProviderConfig provider={'openai'} showEndpoint title={<OpenAI.Combine size={24} />} />
<ProviderConfig
modelList={{ showModelFetcher: true }}
provider={'openai'}
showEndpoint
title={<OpenAI.Combine size={24} />}
/>
));

export default OpenAIProvider;
1 change: 1 addition & 0 deletions src/app/settings/llm/OpenRouter/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const OpenRouterProvider = memo(() => {
return (
<ProviderConfig
checkModel={'mistralai/mistral-7b-instruct:free'}
modelList={{ showModelFetcher: true }}
provider={ModelProvider.OpenRouter}
title={
<OpenRouter.Combine
Expand Down
1 change: 1 addition & 0 deletions src/app/settings/llm/TogetherAI/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ const TogetherAIProvider = memo(() => {
return (
<ProviderConfig
checkModel={'togethercomputer/alpaca-7b'}
modelList={{ showModelFetcher: true }}
provider={'togetherai'}
title={
<Together.Combine
Expand Down
2 changes: 2 additions & 0 deletions src/app/settings/llm/components/ProviderConfig/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ interface ProviderConfigProps {
azureDeployName?: boolean;
notFoundContent?: ReactNode;
placeholder?: string;
showModelFetcher?: boolean;
};
provider: GlobalLLMProviderKey;
showApiKey?: boolean;
Expand Down Expand Up @@ -90,6 +91,7 @@ const ProviderConfig = memo<ProviderConfigProps>(
placeholder={modelList?.placeholder ?? t('llm.modelList.placeholder')}
provider={provider}
showAzureDeployName={modelList?.azureDeployName}
showModelFetcher={modelList?.showModelFetcher}
/>
),
desc: t('llm.modelList.desc'),
Expand Down
71 changes: 71 additions & 0 deletions src/app/settings/llm/components/ProviderModelList/ModelFetcher.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import { Icon, Tooltip } from '@lobehub/ui';
import { Typography } from 'antd';
import { createStyles } from 'antd-style';
import dayjs from 'dayjs';
import { LucideLoaderCircle, LucideRefreshCcwDot } from 'lucide-react';
import { memo } from 'react';
import { Flexbox } from 'react-layout-kit';

import { useGlobalStore } from '@/store/global';
import { modelConfigSelectors } from '@/store/global/selectors';
import { GlobalLLMProviderKey } from '@/types/settings';

const useStyles = createStyles(({ css, token }) => ({
hover: css`
cursor: pointer;
padding: 4px 8px;
border-radius: ${token.borderRadius}px;
transition: all 0.2s ease-in-out;
&:hover {
color: ${token.colorText};
background-color: ${token.colorFillSecondary};
}
`,
}));

interface ModelFetcherProps {
provider: GlobalLLMProviderKey;
}

const ModelFetcher = memo<ModelFetcherProps>(({ provider }) => {
const { styles } = useStyles();
const [useFetchProviderModelList] = useGlobalStore((s) => [
s.useFetchProviderModelList,
s.setModelProviderConfig,
]);
const enabledAutoFetch = useGlobalStore(modelConfigSelectors.enabledAutoFetchModels(provider));
const latestFetchTime = useGlobalStore(
(s) => modelConfigSelectors.providerConfig(provider)(s)?.latestFetchTime,
);
const totalModels = useGlobalStore(
(s) => modelConfigSelectors.providerModelCards(provider)(s).length,
);

const { mutate, isValidating } = useFetchProviderModelList(provider, enabledAutoFetch);

return (
<Typography.Text style={{ fontSize: 12 }} type={'secondary'}>
<Flexbox align={'center'} gap={0} horizontal justify={'space-between'}>
<div>{totalModels} 个模型可用</div>
<Tooltip title={`上次更新时间:${dayjs(latestFetchTime).format('MM-DD HH:mm:ss')}`}>
<Flexbox
align={'center'}
className={styles.hover}
gap={4}
horizontal
onClick={() => mutate()}
>
<Icon
icon={isValidating ? LucideLoaderCircle : LucideRefreshCcwDot}
size={'small'}
spin={isValidating}
/>
<div>{isValidating ? '正在获取模型列表...' : '获取模型列表'}</div>
</Flexbox>
</Tooltip>
</Flexbox>
</Typography.Text>
);
});
export default ModelFetcher;
137 changes: 74 additions & 63 deletions src/app/settings/llm/components/ProviderModelList/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { modelConfigSelectors, modelProviderSelectors } from '@/store/global/sel
import { GlobalLLMProviderKey } from '@/types/settings';

import ModelConfigModal from './ModelConfigModal';
import ModelFetcher from './ModelFetcher';
import OptionRender from './Option';

const styles = {
Expand All @@ -36,20 +37,24 @@ interface CustomModelSelectProps {
placeholder?: string;
provider: GlobalLLMProviderKey;
showAzureDeployName?: boolean;
showModelFetcher?: boolean;
}

const ProviderModelListSelect = memo<CustomModelSelectProps>(
({ provider, showAzureDeployName, notFoundContent, placeholder }) => {
({ showModelFetcher = false, provider, showAzureDeployName, notFoundContent, placeholder }) => {
const { t } = useTranslation('common');
const { t: transSetting } = useTranslation('setting');
const chatModelCards = useGlobalStore(
modelConfigSelectors.providerModelCards(provider),
isEqual,
);
const [setModelProviderConfig, dispatchCustomModelCards] = useGlobalStore((s) => [
s.setModelProviderConfig,
s.dispatchCustomModelCards,
s.useFetchProviderModelList,
]);

const chatModelCards = useGlobalStore(
modelConfigSelectors.providerModelCards(provider),
isEqual,
);

const defaultEnableModel = useGlobalStore(
modelProviderSelectors.defaultEnabledProviderModels(provider),
isEqual,
Expand All @@ -58,72 +63,78 @@ const ProviderModelListSelect = memo<CustomModelSelectProps>(
modelConfigSelectors.providerEnableModels(provider),
isEqual,
);

const showReset = !!enabledModels && !isEqual(defaultEnableModel, enabledModels);

return (
<div style={{ position: 'relative' }}>
<div className={cx(styles.reset)}>
{showReset && (
<ActionIcon
icon={RotateCwIcon}
onClick={() => {
setModelProviderConfig(provider, { enabledModels: null });
}}
size={'small'}
title={t('reset')}
/>
)}
</div>
<Select<string[]>
allowClear
mode="tags"
notFoundContent={notFoundContent}
onChange={(value, options) => {
setModelProviderConfig(provider, { enabledModels: value.filter(Boolean) });
<>
<Flexbox gap={8}>
<div style={{ position: 'relative' }}>
<div className={cx(styles.reset)}>
{showReset && (
<ActionIcon
icon={RotateCwIcon}
onClick={() => {
setModelProviderConfig(provider, { enabledModels: null });
}}
size={'small'}
title={t('reset')}
/>
)}
</div>
<Select<string[]>
allowClear
mode="tags"
notFoundContent={notFoundContent}
onChange={(value, options) => {
setModelProviderConfig(provider, { enabledModels: value.filter(Boolean) });

// if there is a new model, add it to `customModelCards`
options.forEach((option: { label?: string; value?: string }, index: number) => {
// if is a known model, it should have value
// if is an unknown model, the option will be {}
if (option.value) return;
// if there is a new model, add it to `customModelCards`
options.forEach((option: { label?: string; value?: string }, index: number) => {
// if is a known model, it should have value
// if is an unknown model, the option will be {}
if (option.value) return;

const modelId = value[index];
const modelId = value[index];

dispatchCustomModelCards(provider, {
modelCard: { id: modelId },
type: 'add',
});
});
}}
optionFilterProp="label"
optionRender={({ label, value }) => {
// model is in the chatModels
if (chatModelCards.some((c) => c.id === value))
return (
<OptionRender
displayName={label as string}
id={value as string}
provider={provider}
/>
);
dispatchCustomModelCards(provider, {
modelCard: { id: modelId },
type: 'add',
});
});
}}
optionFilterProp="label"
optionRender={({ label, value }) => {
// model is in the chatModels
if (chatModelCards.some((c) => c.id === value))
return (
<OptionRender
displayName={label as string}
id={value as string}
provider={provider}
/>
);

// model is defined by user in client
return (
<Flexbox align={'center'} gap={8} horizontal>
{transSetting('llm.customModelCards.addNew', { id: value })}
</Flexbox>
);
}}
options={chatModelCards.map((model) => ({
label: model.displayName || model.id,
value: model.id,
}))}
placeholder={placeholder}
popupClassName={cx(styles.popup)}
value={enabledModels ?? defaultEnableModel}
/>
// model is defined by user in client
return (
<Flexbox align={'center'} gap={8} horizontal>
{transSetting('llm.customModelCards.addNew', { id: value })}
</Flexbox>
);
}}
options={chatModelCards.map((model) => ({
label: model.displayName || model.id,
value: model.id,
}))}
placeholder={placeholder}
popupClassName={cx(styles.popup)}
value={enabledModels ?? defaultEnableModel}
/>
</div>
{showModelFetcher && <ModelFetcher provider={provider} />}
</Flexbox>
<ModelConfigModal provider={provider} showAzureDeployName={showAzureDeployName} />
</div>
</>
);
},
);
Expand Down
10 changes: 5 additions & 5 deletions src/app/settings/llm/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@ export default memo<{ showOllama: boolean }>(({ showOllama }) => {
<OpenAI />
<Azure />
{showOllama && <Ollama />}
<Anthropic />
<Google />
<Groq />
<Anthropic />
<Bedrock />
<OpenRouter />
<TogetherAI />
<Groq />
<Perplexity />
<Mistral />
<OpenRouter />
<Moonshot />
<ZeroOne />
<Zhipu />
<TogetherAI />
<ZeroOne />
<Footer>
<Trans i18nKey="llm.waitingForMore" ns={'setting'}>
更多模型正在
Expand Down
Loading

0 comments on commit 56032e6

Please sign in to comment.