Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat: Support Cloudflare Workers AI #2966

Closed
wants to merge 48 commits into from
Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
f2996eb
Delete .nvmrc
sxjeru May 31, 2024
6de6000
Merge branch 'lobehub:main' into cf
sxjeru Jun 21, 2024
4b1d4c6
feat: Add Cloudflare as a model provider
sxjeru Jun 21, 2024
b774549
fix
sxjeru Jun 21, 2024
1336aae
fix
sxjeru Jun 21, 2024
ff5361e
fix
sxjeru Jun 21, 2024
6d658bd
fix
sxjeru Jun 21, 2024
5a0a4da
fix
sxjeru Jun 21, 2024
3608659
fix
sxjeru Jun 21, 2024
161af46
fix
sxjeru Jun 21, 2024
aa609af
fix icon
sxjeru Jun 21, 2024
0972e11
fix
sxjeru Jun 21, 2024
8ad1100
Create .nvmrc
sxjeru Jun 21, 2024
ed2f3c0
Delete src/config/modelProviders/.nvmrc
sxjeru Jun 21, 2024
e47aee5
CF -> CLOUDFLARE
sxjeru Jun 21, 2024
1909a89
Merge branch 'cf' of https://github.com/sxjeru/lobe-chat into cf
sxjeru Jun 21, 2024
5a1180c
revert
sxjeru Jun 21, 2024
7648bde
chore: Update agentRuntime.ts and auth.ts to support Cloudflare accou…
sxjeru Jun 21, 2024
9d036ee
Add provider setting
sxjeru Jun 21, 2024
7fe9401
fix
sxjeru Jun 21, 2024
fa23ba4
Update cloudflare.ts
sxjeru Jun 21, 2024
4414320
fix
sxjeru Jun 24, 2024
8d1f973
Update cloudflare.ts
sxjeru Jun 24, 2024
3b57709
Merge branch 'main' into cf
sxjeru Jun 24, 2024
7efaab9
accountID
sxjeru Jul 1, 2024
87f0721
fix
sxjeru Jul 1, 2024
7844a5b
Merge branch 'main' into cf
sxjeru Jul 1, 2024
26de0f1
i18n
sxjeru Jul 1, 2024
65463e0
Merge branch 'main' into cf
sxjeru Jul 10, 2024
7fe207a
Merge branch 'main' into cf
sxjeru Jul 25, 2024
e0f541a
Update index.ts
sxjeru Jul 27, 2024
bc26fd8
Update baichuan.ts
sxjeru Jul 27, 2024
0f5462f
Merge branch 'main' into cf
sxjeru Jul 27, 2024
bb02954
Update cloudflare.ts
sxjeru Jul 27, 2024
85021aa
save changes
BrandonStudio Jul 31, 2024
cb7dd1c
commit check
BrandonStudio Jul 31, 2024
ac8d4f2
disable function calling for now
BrandonStudio Jul 31, 2024
eefacf5
does not catch errors when fetching models
BrandonStudio Jul 31, 2024
5fc4c81
ready to add base url
BrandonStudio Jul 31, 2024
52ff9d1
commit check
BrandonStudio Jul 31, 2024
b8492e2
revert change
BrandonStudio Aug 1, 2024
b452d30
revert string boolean check
BrandonStudio Aug 1, 2024
b46c642
fix type error on Vercel.
BrandonStudio Aug 1, 2024
2dca07d
i18n by groq/llama-3.1-8b-instant
BrandonStudio Aug 1, 2024
0f40d15
rename env var
BrandonStudio Aug 1, 2024
8469931
Merge branch 'cf' into pr/BrandonStudio/38
sxjeru Aug 1, 2024
b3351d8
Merge branch 'main' into cf
sxjeru Aug 1, 2024
65c0bd2
Merge branch 'main' into cf
sxjeru Aug 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/app/(main)/settings/llm/ProviderList/providers.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import {
Anthropic,
Claude,
Cloudflare,
DeepSeek,
Gemini,
Google,
Expand All @@ -13,6 +14,7 @@ import {
Stepfun,
Together,
Tongyi,
WorkersAI,
ZeroOne,
Zhipu,
} from '@lobehub/icons';
Expand All @@ -24,6 +26,7 @@ import urlJoin from 'url-join';

import {
AnthropicProviderCard,
CloudflareProviderCard,
DeepSeekProviderCard,
GoogleProviderCard,
GroqProviderCard,
Expand Down Expand Up @@ -76,6 +79,14 @@ const GoogleBrand = () => (
</Flexbox>
);

const CloudflareBrand = () => (
<Flexbox align={'center'} gap={8} horizontal>
<Cloudflare.Combine size={22} type={'color'} />
<Divider style={{ margin: '0 4px' }} type={'vertical'} />
<WorkersAI.Combine size={22} type={'color'} />
</Flexbox>
);

export const useProviderList = (): ProviderItem[] => {
const azureProvider = useAzureProvider();
const ollamaProvider = useOllamaProvider();
Expand Down Expand Up @@ -170,7 +181,12 @@ export const useProviderList = (): ProviderItem[] => {
docUrl: urlJoin(BASE_DOC_URL, 'stepfun'),
title: <Stepfun.Combine size={20} type={'color'} />,
},
{
...CloudflareProviderCard,
docUrl: urlJoin(BASE_DOC_URL, 'cloudflare'),
title: <CloudflareBrand />,
},
],
[azureProvider, ollamaProvider, ollamaProvider, bedrockProvider],
);
};
};
10 changes: 9 additions & 1 deletion src/app/api/chat/agentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,19 @@ const getLlmOptionsFromPayload = (provider: string, payload: JWTPayload) => {
}
case ModelProvider.Stepfun: {
const { STEPFUN_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || STEPFUN_API_KEY);

return { apiKey };
}
case ModelProvider.Cloudflare: {
const { CLOUDFLARE_API_KEY, CLOUDFLARE_ACCOUNT_ID } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || CLOUDFLARE_API_KEY);
const accountID = CLOUDFLARE_ACCOUNT_ID;

return { accountID, apiKey };
}
}
};

Expand Down
5 changes: 5 additions & 0 deletions src/components/ModelProviderIcon/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import {
Anthropic,
Azure,
Bedrock,
Cloudflare,
DeepSeek,
Google,
Groq,
Expand Down Expand Up @@ -114,6 +115,10 @@ const ModelProviderIcon = memo<ModelProviderIconProps>(({ provider }) => {
return <Stepfun size={20} />;
}

case ModelProvider.Cloudflare: {
return <Cloudflare size={20} />;
}

default: {
return null;
}
Expand Down
8 changes: 8 additions & 0 deletions src/config/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ export const getLLMConfig = () => {

ENABLED_STEPFUN: z.boolean(),
STEPFUN_API_KEY: z.string().optional(),

ENABLED_CLOUDFLARE: z.boolean(),
CLOUDFLARE_API_KEY: z.string().optional(),
CLOUDFLARE_ACCOUNT_ID: z.string().optional(),
},
runtimeEnv: {
API_KEY_SELECT_MODE: process.env.API_KEY_SELECT_MODE,
Expand Down Expand Up @@ -194,6 +198,10 @@ export const getLLMConfig = () => {

ENABLED_STEPFUN: !!process.env.STEPFUN_API_KEY,
STEPFUN_API_KEY: process.env.STEPFUN_API_KEY,

ENABLED_CLOUDFLARE: !!process.env.CLOUDFLARE_API_KEY && !!process.env.CLOUDFLARE_ACCOUNT_ID,
CLOUDFLARE_API_KEY: process.env.CLOUDFLARE_API_KEY,
CLOUDFLARE_ACCOUNT_ID: process.env.CLOUDFLARE_ACCOUNT_ID,
},
});
};
Expand Down
20 changes: 20 additions & 0 deletions src/config/modelProviders/cloudflare.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { ModelProviderCard } from '@/types/llm';

// ref https://developers.cloudflare.com/workers-ai/models/#text-generation
// api https://developers.cloudflare.com/workers-ai/configuration/open-ai-compatibility
const Cloudflare: ModelProviderCard = {
chatModels: [
{
displayName: 'LLaMA2-7B-chat',
sxjeru marked this conversation as resolved.
Show resolved Hide resolved
enabled: true,
// functionCall: true,
id: '@cf/meta/llama-2-7b-chat-fp16',
tokens: 3072,
},
],
checkModel: '@cf/meta/llama-2-7b-chat-fp16',
id: 'cloudflare',
name: 'Cloudflare Workers AI',
};

export default Cloudflare;
4 changes: 4 additions & 0 deletions src/config/modelProviders/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { ChatModelCard, ModelProviderCard } from '@/types/llm';
import AnthropicProvider from './anthropic';
import AzureProvider from './azure';
import BedrockProvider from './bedrock';
import CloudflareProvider from './cloudflare';
import DeepSeekProvider from './deepseek';
import GoogleProvider from './google';
import GroqProvider from './groq';
Expand Down Expand Up @@ -37,6 +38,7 @@ export const LOBE_DEFAULT_MODEL_LIST: ChatModelCard[] = [
AnthropicProvider.chatModels,
ZeroOneProvider.chatModels,
StepfunProvider.chatModels,
CloudflareProvider.chatModels,
].flat();

export const DEFAULT_MODEL_PROVIDER_LIST = [
Expand All @@ -58,6 +60,7 @@ export const DEFAULT_MODEL_PROVIDER_LIST = [
ZeroOneProvider,
ZhiPuProvider,
StepfunProvider,
CloudflareProvider,
];

export const filterEnabledModels = (provider: ModelProviderCard) => {
Expand All @@ -67,6 +70,7 @@ export const filterEnabledModels = (provider: ModelProviderCard) => {
export { default as AnthropicProviderCard } from './anthropic';
export { default as AzureProviderCard } from './azure';
export { default as BedrockProviderCard } from './bedrock';
export { default as CloudflareProviderCard } from './cloudflare';
export { default as DeepSeekProviderCard } from './deepseek';
export { default as GoogleProviderCard } from './google';
export { default as GroqProviderCard } from './groq';
Expand Down
5 changes: 5 additions & 0 deletions src/const/settings/llm.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import {
AnthropicProviderCard,
BedrockProviderCard,
CloudflareProviderCard,
DeepSeekProviderCard,
GoogleProviderCard,
GroqProviderCard,
Expand Down Expand Up @@ -33,6 +34,10 @@ export const DEFAULT_LLM_CONFIG: UserModelProviderConfig = {
enabled: false,
enabledModels: filterEnabledModels(BedrockProviderCard),
},
cloudflare: {
enabled: false,
enabledModels: filterEnabledModels(CloudflareProviderCard),
},
deepseek: {
enabled: false,
enabledModels: filterEnabledModels(DeepSeekProviderCard),
Expand Down
8 changes: 7 additions & 1 deletion src/libs/agent-runtime/AgentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { LobeRuntimeAI } from './BaseAI';
import { LobeAnthropicAI } from './anthropic';
import { LobeAzureOpenAI } from './azureOpenai';
import { LobeBedrockAI, LobeBedrockAIParams } from './bedrock';
import { LobeCloudflareAI, LobeCloudflareParams } from './cloudflare';
import { LobeDeepSeekAI } from './deepseek';
import { LobeGoogleAI } from './google';
import { LobeGroq } from './groq';
Expand Down Expand Up @@ -104,6 +105,7 @@ class AgentRuntime {
anthropic: Partial<ClientOptions>;
azure: { apiVersion?: string; apikey?: string; endpoint?: string };
bedrock: Partial<LobeBedrockAIParams>;
cloudflare: Partial<LobeCloudflareParams>;
deepseek: Partial<ClientOptions>;
google: { apiKey?: string; baseURL?: string };
groq: Partial<ClientOptions>;
Expand Down Expand Up @@ -219,8 +221,12 @@ class AgentRuntime {
runtimeModel = new LobeStepfunAI(params.stepfun ?? {});
break;
}
}

case ModelProvider.Cloudflare: {
runtimeModel = new LobeCloudflareAI(params.cloudflare ?? {});
Copy link
Contributor

@arvinxx arvinxx Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

考虑下在这里做 accountId 的拼接传入。或者在 LobeOpenAICompatibleFactory 方法中加个 baseURL 允许是个入参函数的方法。 baseURL: (params)=> `https://api.cloudflare.com/client/v4/accounts/${params.accountId}/ai/v1`

个人感觉后者会更加泛化一些,未来有别的特殊 baseURL 也可以兼容

break;
}
}
return new AgentRuntime(runtimeModel);
}
}
Expand Down
15 changes: 15 additions & 0 deletions src/libs/agent-runtime/cloudflare/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import { ModelProvider } from '../types';
import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory';

export interface LobeCloudflareParams {
accountID?: string;
apiKey?: string;
}

export const LobeCloudflareAI = LobeOpenAICompatibleFactory({
baseURL: `https://api.cloudflare.com/client/v4/accounts/${process.env.CLOUDFLARE_ACCOUNT_ID}/ai/v1`,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里要用 llmEnv.XXX

Copy link
Contributor Author

@sxjeru sxjeru Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

换成 llmEnv.CLOUDFLARE_ACCOUNT_ID 之后报下述错误:
`Error: ❌ Attempted to access a server-side environment variable on the client

请教一下有什么办法能让设置页的 Cloudflare Account ID 同步到此处 baseURL?
配置项写在 src/app/api/chat/agentRuntime.ts

case ModelProvider.Cloudflare: {
      const { CLOUDFLARE_API_KEY, CLOUDFLARE_ACCOUNT_ID } = getLLMConfig();
      const apiKey = apiKeyManager.pick(payload?.apiKey || CLOUDFLARE_API_KEY);
      const accountID = payload.apiKey && payload.cloudflareAccountID ? payload.cloudflareAccountID : CLOUDFLARE_ACCOUNT_ID;
      return { accountID, apiKey };

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accountId 作为参数传给 server端。然后 server 端的 LobeCloudflareAI 初始化时候 baseURL 的逻辑变成 accountId 拼接的?

debug: {
chatCompletion: () => process.env.DEBUG_CLOUDFLARE_CHAT_COMPLETION === '1',
},
provider: ModelProvider.Cloudflare,
});
1 change: 1 addition & 0 deletions src/libs/agent-runtime/types/type.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export enum ModelProvider {
Anthropic = 'anthropic',
Azure = 'azure',
Bedrock = 'bedrock',
Cloudflare = 'cloudflare',
DeepSeek = 'deepseek',
Google = 'google',
Groq = 'groq',
Expand Down
2 changes: 2 additions & 0 deletions src/server/globalConfig/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export const getServerGlobalConfig = () => {
ENABLED_MISTRAL,
ENABLED_QWEN,
ENABLED_STEPFUN,
ENABLED_CLOUDFLARE,

ENABLED_AZURE_OPENAI,
AZURE_MODEL_LIST,
Expand Down Expand Up @@ -71,6 +72,7 @@ export const getServerGlobalConfig = () => {
}),
},
bedrock: { enabled: ENABLED_AWS_BEDROCK },
cloudflare: { enabled: ENABLED_CLOUDFLARE },
deepseek: { enabled: ENABLED_DEEPSEEK },
google: { enabled: ENABLED_GOOGLE },
groq: { enabled: ENABLED_GROQ },
Expand Down
6 changes: 6 additions & 0 deletions src/types/user/settings/keyVaults.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@ export interface AWSBedrockKeyVault {
secretAccessKey?: string;
}

export interface CloudflareKeyVault {
accountID?: string;
apiKey?: string;
}

export interface UserKeyVaults {
anthropic?: OpenAICompatibleKeyVault;
azure?: AzureOpenAIKeyVault;
bedrock?: AWSBedrockKeyVault;
cloudflare?: CloudflareKeyVault;
deepseek?: OpenAICompatibleKeyVault;
google?: OpenAICompatibleKeyVault;
groq?: OpenAICompatibleKeyVault;
Expand Down
Loading