Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Cloudflare official API, including model list. #38

Merged
merged 11 commits into from
Aug 1, 2024
11 changes: 5 additions & 6 deletions src/app/(main)/settings/llm/ProviderList/Cloudflare/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@ export const useCloudflareProvider = (): ProviderItem => {
},
{
children: (
<Input.Password
autoComplete={'new-password'}
placeholder={t(`${providerKey}.accountID.placeholder`)}
<Input
placeholder={t(`${providerKey}.baseURLOrAccountID.placeholder`)}
/>
),
desc: t(`${providerKey}.accountID.desc`),
label: t(`${providerKey}.accountID.title`),
name: [KeyVaultsConfigKey, providerKey, 'accountID'],
desc: t(`${providerKey}.baseURLOrAccountID.desc`),
label: t(`${providerKey}.baseURLOrAccountID.title`),
name: [KeyVaultsConfigKey, providerKey, 'baseURLOrAccountID'],
},
],
title: <CloudflareBrand />,
Expand Down
8 changes: 4 additions & 4 deletions src/app/api/chat/agentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,12 @@ const getLlmOptionsFromPayload = (provider: string, payload: JWTPayload) => {
const { CLOUDFLARE_API_KEY, CLOUDFLARE_ACCOUNT_ID } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || CLOUDFLARE_API_KEY);
const accountID =
payload.apiKey && payload.cloudflareAccountID
? payload.cloudflareAccountID
const baseURLOrAccountID =
payload.apiKey && payload.cloudflareBaseURLOrAccountID
? payload.cloudflareBaseURLOrAccountID
: CLOUDFLARE_ACCOUNT_ID;

return { accountID, apiKey };
return { apiKey, baseURLOrAccountID };
}
}
};
Expand Down
13 changes: 12 additions & 1 deletion src/config/modelProviders/cloudflare.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,20 @@ const Cloudflare: ModelProviderCard = {
id: '@hf/thebloke/zephyr-7b-beta-awq',
tokens: 32_768,
},
{
description:
'Generation over generation, Meta Llama 3 demonstrates state-of-the-art performance on a wide range of industry benchmarks and offers new capabilities, including improved reasoning.\t',
displayName: 'meta-llama-3-8b-instruct',
enabled: true,
functionCall: false,
id: '@hf/meta-llama/meta-llama-3-8b-instruct',
BrandonStudio marked this conversation as resolved.
Show resolved Hide resolved
},
],
checkModel: '@hf/thebloke/deepseek-coder-6.7b-instruct-awq',
checkModel: '@hf/meta-llama/meta-llama-3-8b-instruct',
id: 'cloudflare',
modelList: {
showModelFetcher: true,
},
name: 'Cloudflare Workers AI',
};

Expand Down
2 changes: 1 addition & 1 deletion src/const/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ export interface JWTPayload {
awsRegion?: string;
awsSecretAccessKey?: string;

cloudflareAccountID?: string;
cloudflareBaseURLOrAccountID?: string;
/**
* user id
* in client db mode it's a uuid
Expand Down
8 changes: 2 additions & 6 deletions src/libs/agent-runtime/AgentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,11 @@ class AgentRuntime {

case ModelProvider.Taichu: {
runtimeModel = new LobeTaichuAI(params.taichu ?? {});
break
break;
}

case ModelProvider.Cloudflare: {
const cloudflareParams = params.cloudflare ?? {};
runtimeModel = new LobeCloudflareAI({
...cloudflareParams,
baseURL: `https://api.cloudflare.com/client/v4/accounts/${params.cloudflare?.accountID}/ai/v1`,
});
runtimeModel = new LobeCloudflareAI(params.cloudflare ?? {});
break;
}
}
Expand Down
226 changes: 216 additions & 10 deletions src/libs/agent-runtime/cloudflare/index.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,222 @@
import { ChatModelCard } from '@/types/llm';

import { ModelProvider } from '../types';
import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory';
import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { desensitizeUrl } from '../utils/desensitizeUrl';
import { StreamingResponse } from '../utils/response';

const DEFAULT_BASE_URL_PREFIX = 'https://api.cloudflare.com';

export interface LobeCloudflareParams {
accountID?: string;
apiKey?: string;
baseURLOrAccountID?: string;
}

function fillUrl(accountID: string): string {
return `${DEFAULT_BASE_URL_PREFIX}/client/v4/accounts/${accountID}/ai/run/`;
}

function desensitizeAccountId(path: string): string {
return path.replace(/\/[\dA-Fa-f]{32}\//, '/****/');
}

function desensitizeCloudflareUrl(url: string): string {
const urlObj = new URL(url);
let { protocol, hostname, port, pathname, search } = urlObj;
if (url.startsWith(DEFAULT_BASE_URL_PREFIX)) {
return `${protocol}//${hostname}${port ? `:${port}` : ''}${desensitizeAccountId(pathname)}${search}`;
} else {
const desensitizedUrl = desensitizeUrl(`${protocol}//${hostname}${port ? `:${port}` : ''}`);
return `${desensitizedUrl}${desensitizeAccountId(pathname)}${search}`;
}
}

const CF_PROPERTY_NAME = 'property_id';

function getModelBeta(model: any): boolean {
try {
const betaProperty = model['properties'].filter(
(property: any) => property[CF_PROPERTY_NAME] === 'beta',
);
if (betaProperty.length === 1) {
// eslint-disable-next-line eqeqeq
return betaProperty[0]['value'] == true; // This is a string now.
}
return false;
} catch {
return false;
}
}

function getModelDisplayName(model: any, beta: boolean): string {
const modelId = model['name'];
let name = modelId.split('/').at(-1)!;
if (beta) {
name += ' (Beta)';
}
return name;
}

function getModelFunctionCalling(model: any): boolean {
return false;
// eslint-disable-next-line no-unreachable
try {
const fcProperty = model['properties'].filter(
(property: any) => property[CF_PROPERTY_NAME] === 'function_calling',
);
if (fcProperty.length === 1) {
// eslint-disable-next-line eqeqeq
return fcProperty[0]['value'] == true;
}
return false;
} catch {
return false;
}
}

function getModelTokens(model: any): number | undefined {
try {
const tokensProperty = model['properties'].filter(
(property: any) => property[CF_PROPERTY_NAME] === 'max_total_tokens',
);
if (tokensProperty.length === 1) {
return parseInt(tokensProperty[0]['value']);
}
return undefined;
} catch {
return undefined;
}
}

class CloudflareStreamTransformer {
private textDecoder = new TextDecoder();
private buffer: string = '';

private parseChunk(chunk: string, controller: TransformStreamDefaultController) {
const dataPrefix = /^data: /;
const json = chunk.replace(dataPrefix, '');
const parsedChunk = JSON.parse(json);
controller.enqueue(`event: text\n`);
controller.enqueue(`data: ${JSON.stringify(parsedChunk.response)}\n\n`);
}

public async transform(chunk: Uint8Array, controller: TransformStreamDefaultController) {
let textChunk = this.textDecoder.decode(chunk);
if (this.buffer.trim() !== '') {
textChunk = this.buffer + textChunk;
this.buffer = '';
}
const splits = textChunk.split('\n\n');
for (let i = 0; i < splits.length - 1; i++) {
if (/\[DONE]/.test(splits[i].trim())) {
return;
}
this.parseChunk(splits[i], controller);
}
const lastChunk = splits.at(-1)!;
if (lastChunk.trim() !== '') {
this.buffer += lastChunk; // does not need to be trimmed.
} // else drop.
}
}

export const LobeCloudflareAI = LobeOpenAICompatibleFactory({
baseURL: `https://api.cloudflare.com/client/v4/accounts/${process.env.CLOUDFLARE_ACCOUNT_ID}/ai/v1`,
debug: {
chatCompletion: () => process.env.DEBUG_CLOUDFLARE_CHAT_COMPLETION === '1',
},
provider: ModelProvider.Cloudflare,
});
export class LobeCloudflareAI implements LobeRuntimeAI {
baseURL: string;
accountID: string;
apiKey?: string;

constructor({ apiKey, baseURLOrAccountID }: LobeCloudflareParams) {
if (!baseURLOrAccountID) {
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
}
if (baseURLOrAccountID.startsWith('http')) {
this.baseURL = baseURLOrAccountID.endsWith('/')
? baseURLOrAccountID
: baseURLOrAccountID + '/';
// Try get accountID from baseURL
this.accountID = baseURLOrAccountID.replaceAll(/^.*\/([\dA-Fa-f]{32})\/.*$/g, '$1');
} else {
this.accountID = baseURLOrAccountID;
this.baseURL = fillUrl(baseURLOrAccountID);
}
this.apiKey = apiKey;
}

async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions): Promise<Response> {
// Implement your logic here
// This method should handle the chat functionality using the provided payload and options
// It should return a Promise that resolves to a Response object
// You can make API calls, perform computations, or any other necessary operations

// Example implementation:
try {
const { model, tools, ...restPayload } = payload;
const functions = tools?.map((tool) => tool.function);
const headers = options?.headers || {};
if (this.apiKey) {
headers['Authorization'] = `Bearer ${this.apiKey}`;
}
const url = new URL(model, this.baseURL);
const response = await fetch(url, {
body: JSON.stringify({ tools: functions, ...restPayload }),
headers: { 'Content-Type': 'application/json', ...headers },
method: 'POST',
});

const desensitizedEndpoint = desensitizeCloudflareUrl(this.baseURL);

switch (response.status) {
case 400: {
throw AgentRuntimeError.chat({
endpoint: desensitizedEndpoint,
error: response,
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: ModelProvider.Cloudflare,
});
}
}

return StreamingResponse(
response.body!.pipeThrough(new TransformStream(new CloudflareStreamTransformer())),
);
} catch (error) {
const desensitizedEndpoint = desensitizeCloudflareUrl(this.baseURL);

throw AgentRuntimeError.chat({
endpoint: desensitizedEndpoint,
error: error as any,
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: ModelProvider.Cloudflare,
});
}
}

async models(): Promise<ChatModelCard[]> {
const url = `${DEFAULT_BASE_URL_PREFIX}/client/v4/accounts/${this.accountID}/ai/models/search`;
const response = await fetch(url, {
headers: {
'Authorization': `Bearer ${this.apiKey}`,
'Content-Type': 'application/json',
},
method: 'GET',
});
const j = await response.json();
const models: any[] = j['result'].filter(
(model: any) => model['task']['name'] === 'Text Generation',
);
const chatModels: ChatModelCard[] = models.map((model) => {
const modelBeta = getModelBeta(model);
return {
description: model['description'],
displayName: getModelDisplayName(model, modelBeta),
enabled: !modelBeta,
functionCall: getModelFunctionCalling(model),
id: model['name'],
tokens: getModelTokens(model),
};
});
return chatModels;
}
}
10 changes: 5 additions & 5 deletions src/locales/default/modelProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ export default {
},
},
cloudflare: {
accountID: {
desc: '请填写 Cloudflare 账户 ID',
placeholder: 'Cloudflare Account ID',
title: 'Cloudflare Account ID',
},
apiKey: {
desc: '请填写 Cloudflare API Key',
placeholder: 'Cloudflare API Key',
title: 'Cloudflare API Key',
},
baseURLOrAccountID: {
desc: '填入 Cloudflare 账户 ID 或 自定义 API 地址',
placeholder: 'Cloudflare Account ID / custom API URL',
title: 'Cloudflare 账户 ID / API 地址',
},
},
ollama: {
checker: {
Expand Down
6 changes: 3 additions & 3 deletions src/services/_auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ export const getProviderAuthPayload = (provider: string) => {
case ModelProvider.Cloudflare: {
const config = keyVaultsConfigSelectors.cloudflareConfig(useUserStore.getState());

return {
apiKey: config?.apiKey,
cloudflareAccountID: config?.accountID,
return {
apiKey: config?.apiKey,
cloudflareBaseURLOrAccountID: config?.baseURLOrAccountID,
};
}

Expand Down
2 changes: 1 addition & 1 deletion src/services/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ export function initializeWithClientStore(provider: string, payload: any) {
}
case ModelProvider.Cloudflare: {
providerOptions = {
accountID: providerAuthPayload?.cloudflareAccountID,
apikey: providerAuthPayload?.apiKey,
baseURLOrAccountID: providerAuthPayload?.cloudflareBaseURLOrAccountID,
};
break;
}
Expand Down
4 changes: 4 additions & 0 deletions src/store/user/slices/modelList/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ export const createModelListSlice: StateCreator<

const togetherai = draft.find((d) => d.id === ModelProvider.TogetherAI);
if (togetherai) togetherai.chatModels = mergeModels('togetherai', togetherai.chatModels);

const cloudflare = draft.find((d) => d.id === ModelProvider.Cloudflare);
if (cloudflare)
cloudflare.chatModels = mergeModels('cloudflare', cloudflare.chatModels);
});

set({ defaultModelProviderList }, false, `refreshDefaultModelList - ${params?.trigger}`);
Expand Down
2 changes: 1 addition & 1 deletion src/types/user/settings/keyVaults.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ export interface AWSBedrockKeyVault {
}

export interface CloudflareKeyVault {
accountID?: string;
apiKey?: string;
baseURLOrAccountID?: string;
}

export interface UserKeyVaults {
Expand Down