From 9f211e26be2bbe1f01506c61d371c7ce44c17150 Mon Sep 17 00:00:00 2001 From: Arvin Xu Date: Thu, 11 Jul 2024 22:32:02 +0800 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor:=20improve=20agen?= =?UTF-8?q?t=20runtime=20code=20(#3199)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ♻️ refactor: improve agent runtime code * ♻️ refactor: clean code --- src/app/api/chat/agentRuntime.ts | 9 +- src/libs/agent-runtime/AgentRuntime.ts | 40 ++--- src/libs/agent-runtime/anthropic/index.ts | 2 +- src/libs/agent-runtime/bedrock/index.ts | 2 +- src/libs/agent-runtime/google/index.ts | 2 +- src/libs/agent-runtime/minimax/index.ts | 2 +- src/libs/agent-runtime/ollama/index.ts | 2 +- .../utils/openaiCompatibleFactory/index.ts | 151 +++++++----------- src/libs/agent-runtime/zhipu/index.ts | 2 +- 9 files changed, 94 insertions(+), 118 deletions(-) diff --git a/src/app/api/chat/agentRuntime.ts b/src/app/api/chat/agentRuntime.ts index 3496cc674a39..6dbd8f9c73df 100644 --- a/src/app/api/chat/agentRuntime.ts +++ b/src/app/api/chat/agentRuntime.ts @@ -193,11 +193,16 @@ const getLlmOptionsFromPayload = (provider: string, payload: JWTPayload) => { * Initializes the agent runtime with the user payload in backend * @param provider - The provider name. * @param payload - The JWT payload. + * @param params * @returns A promise that resolves when the agent runtime is initialized. */ -export const initAgentRuntimeWithUserPayload = (provider: string, payload: JWTPayload) => { +export const initAgentRuntimeWithUserPayload = ( + provider: string, + payload: JWTPayload, + params: any = {}, +) => { return AgentRuntime.initializeWithProviderOptions(provider, { - [provider]: getLlmOptionsFromPayload(provider, payload), + [provider]: { ...getLlmOptionsFromPayload(provider, payload), ...params }, }); }; diff --git a/src/libs/agent-runtime/AgentRuntime.ts b/src/libs/agent-runtime/AgentRuntime.ts index e8b986f4678c..63ded027518e 100644 --- a/src/libs/agent-runtime/AgentRuntime.ts +++ b/src/libs/agent-runtime/AgentRuntime.ts @@ -145,93 +145,93 @@ class AgentRuntime { } case ModelProvider.ZhiPu: { - runtimeModel = await LobeZhipuAI.fromAPIKey(params.zhipu ?? {}); + runtimeModel = await LobeZhipuAI.fromAPIKey(params.zhipu); break; } case ModelProvider.Google: { - runtimeModel = new LobeGoogleAI(params.google ?? {}); + runtimeModel = new LobeGoogleAI(params.google); break; } case ModelProvider.Moonshot: { - runtimeModel = new LobeMoonshotAI(params.moonshot ?? {}); + runtimeModel = new LobeMoonshotAI(params.moonshot); break; } case ModelProvider.Bedrock: { - runtimeModel = new LobeBedrockAI(params.bedrock ?? {}); + runtimeModel = new LobeBedrockAI(params.bedrock); break; } case ModelProvider.Ollama: { - runtimeModel = new LobeOllamaAI(params.ollama ?? {}); + runtimeModel = new LobeOllamaAI(params.ollama); break; } case ModelProvider.Perplexity: { - runtimeModel = new LobePerplexityAI(params.perplexity ?? {}); + runtimeModel = new LobePerplexityAI(params.perplexity); break; } case ModelProvider.Anthropic: { - runtimeModel = new LobeAnthropicAI(params.anthropic ?? {}); + runtimeModel = new LobeAnthropicAI(params.anthropic); break; } case ModelProvider.DeepSeek: { - runtimeModel = new LobeDeepSeekAI(params.deepseek ?? {}); + runtimeModel = new LobeDeepSeekAI(params.deepseek); break; } case ModelProvider.Minimax: { - runtimeModel = new LobeMinimaxAI(params.minimax ?? {}); + runtimeModel = new LobeMinimaxAI(params.minimax); break; } case ModelProvider.Mistral: { - runtimeModel = new LobeMistralAI(params.mistral ?? {}); + runtimeModel = new LobeMistralAI(params.mistral); break; } case ModelProvider.Groq: { - runtimeModel = new LobeGroq(params.groq ?? {}); + runtimeModel = new LobeGroq(params.groq); break; } case ModelProvider.OpenRouter: { - runtimeModel = new LobeOpenRouterAI(params.openrouter ?? {}); + runtimeModel = new LobeOpenRouterAI(params.openrouter); break; } case ModelProvider.TogetherAI: { - runtimeModel = new LobeTogetherAI(params.togetherai ?? {}); + runtimeModel = new LobeTogetherAI(params.togetherai); break; } case ModelProvider.ZeroOne: { - runtimeModel = new LobeZeroOneAI(params.zeroone ?? {}); + runtimeModel = new LobeZeroOneAI(params.zeroone); break; } case ModelProvider.Qwen: { - runtimeModel = new LobeQwenAI(params.qwen ?? {}); + runtimeModel = new LobeQwenAI(params.qwen); break; } case ModelProvider.Stepfun: { - runtimeModel = new LobeStepfunAI(params.stepfun ?? {}); + runtimeModel = new LobeStepfunAI(params.stepfun); break; } case ModelProvider.Baichuan: { - runtimeModel = new LobeBaichuanAI(params.baichuan ?? {}); - break + runtimeModel = new LobeBaichuanAI(params.baichuan); + break; } case ModelProvider.Taichu: { - runtimeModel = new LobeTaichuAI(params.taichu ?? {}); - break + runtimeModel = new LobeTaichuAI(params.taichu); + break; } } diff --git a/src/libs/agent-runtime/anthropic/index.ts b/src/libs/agent-runtime/anthropic/index.ts index 35e9af365d3b..0f02bf9700aa 100644 --- a/src/libs/agent-runtime/anthropic/index.ts +++ b/src/libs/agent-runtime/anthropic/index.ts @@ -20,7 +20,7 @@ export class LobeAnthropicAI implements LobeRuntimeAI { baseURL: string; - constructor({ apiKey, baseURL = DEFAULT_BASE_URL }: ClientOptions) { + constructor({ apiKey, baseURL = DEFAULT_BASE_URL }: ClientOptions = {}) { if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey); this.client = new Anthropic({ apiKey, baseURL }); diff --git a/src/libs/agent-runtime/bedrock/index.ts b/src/libs/agent-runtime/bedrock/index.ts index 6a8cb4df13d8..1e1b7e6db729 100644 --- a/src/libs/agent-runtime/bedrock/index.ts +++ b/src/libs/agent-runtime/bedrock/index.ts @@ -28,7 +28,7 @@ export class LobeBedrockAI implements LobeRuntimeAI { region: string; - constructor({ region, accessKeyId, accessKeySecret }: LobeBedrockAIParams) { + constructor({ region, accessKeyId, accessKeySecret }: LobeBedrockAIParams = {}) { if (!(accessKeyId && accessKeySecret)) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidBedrockCredentials); diff --git a/src/libs/agent-runtime/google/index.ts b/src/libs/agent-runtime/google/index.ts index e15c925f3346..5659a1a09632 100644 --- a/src/libs/agent-runtime/google/index.ts +++ b/src/libs/agent-runtime/google/index.ts @@ -41,7 +41,7 @@ export class LobeGoogleAI implements LobeRuntimeAI { private client: GoogleGenerativeAI; baseURL?: string; - constructor({ apiKey, baseURL }: { apiKey?: string; baseURL?: string }) { + constructor({ apiKey, baseURL }: { apiKey?: string; baseURL?: string } = {}) { if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey); this.client = new GoogleGenerativeAI(apiKey); diff --git a/src/libs/agent-runtime/minimax/index.ts b/src/libs/agent-runtime/minimax/index.ts index 659b8cf71c37..1f8964f96f11 100644 --- a/src/libs/agent-runtime/minimax/index.ts +++ b/src/libs/agent-runtime/minimax/index.ts @@ -63,7 +63,7 @@ function parseMinimaxResponse(chunk: string): MinimaxResponse | undefined { export class LobeMinimaxAI implements LobeRuntimeAI { apiKey: string; - constructor({ apiKey }: { apiKey?: string }) { + constructor({ apiKey }: { apiKey?: string } = {}) { if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey); this.apiKey = apiKey; diff --git a/src/libs/agent-runtime/ollama/index.ts b/src/libs/agent-runtime/ollama/index.ts index 80a47ad48423..cd6804cefda7 100644 --- a/src/libs/agent-runtime/ollama/index.ts +++ b/src/libs/agent-runtime/ollama/index.ts @@ -18,7 +18,7 @@ export class LobeOllamaAI implements LobeRuntimeAI { baseURL?: string; - constructor({ baseURL }: ClientOptions) { + constructor({ baseURL }: ClientOptions = {}) { try { if (baseURL) new URL(baseURL); } catch { diff --git a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts index e3c1d19c2196..96d952974b81 100644 --- a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts +++ b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts @@ -27,13 +27,18 @@ const CHAT_MODELS_BLOCK_LIST = [ 'dall-e', ]; -interface OpenAICompatibleFactoryOptions { +type ConstructorOptions = any> = ClientOptions & T; + +interface OpenAICompatibleFactoryOptions = any> { baseURL?: string; chatCompletion?: { - handleError?: (error: any) => Omit | undefined; + handleError?: ( + error: any, + options: ConstructorOptions, + ) => Omit | undefined; handlePayload?: (payload: ChatStreamPayload) => OpenAI.ChatCompletionCreateParamsStreaming; }; - constructorOptions?: ClientOptions; + constructorOptions?: ConstructorOptions; debug?: { chatCompletion: () => boolean; }; @@ -49,7 +54,7 @@ interface OpenAICompatibleFactoryOptions { provider: string; } -export const LobeOpenAICompatibleFactory = ({ +export const LobeOpenAICompatibleFactory = = any>({ provider, baseURL: DEFAULT_BASE_URL, errorType, @@ -57,7 +62,7 @@ export const LobeOpenAICompatibleFactory = ({ constructorOptions, chatCompletion, models, -}: OpenAICompatibleFactoryOptions) => { +}: OpenAICompatibleFactoryOptions) => { const ErrorType = { bizError: errorType?.bizError || AgentRuntimeErrorType.ProviderBizError, invalidAPIKey: errorType?.invalidAPIKey || AgentRuntimeErrorType.InvalidProviderAPIKey, @@ -67,8 +72,11 @@ export const LobeOpenAICompatibleFactory = ({ client: OpenAI; baseURL: string; + private _options: ConstructorOptions; - constructor({ apiKey, baseURL = DEFAULT_BASE_URL, ...res }: ClientOptions) { + constructor(options: ClientOptions & Record = {}) { + const { apiKey, baseURL = DEFAULT_BASE_URL, ...res } = options; + this._options = options as ConstructorOptions; if (!apiKey) throw AgentRuntimeError.createError(ErrorType?.invalidAPIKey); this.client = new OpenAI({ apiKey, baseURL, ...constructorOptions, ...res }); @@ -115,48 +123,7 @@ export const LobeOpenAICompatibleFactory = ({ headers: options?.headers, }); } catch (error) { - let desensitizedEndpoint = this.baseURL; - - // refs: https://github.com/lobehub/lobe-chat/issues/842 - if (this.baseURL !== DEFAULT_BASE_URL) { - desensitizedEndpoint = desensitizeUrl(this.baseURL); - } - - if ('status' in (error as any)) { - switch ((error as Response).status) { - case 401: { - throw AgentRuntimeError.chat({ - endpoint: desensitizedEndpoint, - error: error as any, - errorType: ErrorType.invalidAPIKey, - provider: provider as any, - }); - } - - default: { - break; - } - } - } - - if (chatCompletion?.handleError) { - const errorResult = chatCompletion.handleError(error); - - if (errorResult) - throw AgentRuntimeError.chat({ - ...errorResult, - provider, - } as ChatCompletionErrorPayload); - } - - const { errorResult, RuntimeError } = handleOpenAIError(error); - - throw AgentRuntimeError.chat({ - endpoint: desensitizedEndpoint, - error: errorResult, - errorType: RuntimeError || ErrorType.bizError, - provider: provider as any, - }); + throw this.handleError(error); } } @@ -191,48 +158,7 @@ export const LobeOpenAICompatibleFactory = ({ const res = await this.client.images.generate(payload); return res.data.map((o) => o.url) as string[]; } catch (error) { - let desensitizedEndpoint = this.baseURL; - - // refs: https://github.com/lobehub/lobe-chat/issues/842 - if (this.baseURL !== DEFAULT_BASE_URL) { - desensitizedEndpoint = desensitizeUrl(this.baseURL); - } - - if ('status' in (error as any)) { - switch ((error as Response).status) { - case 401: { - throw AgentRuntimeError.chat({ - endpoint: desensitizedEndpoint, - error: error as any, - errorType: ErrorType.invalidAPIKey, - provider: provider as any, - }); - } - - default: { - break; - } - } - } - - if (chatCompletion?.handleError) { - const errorResult = chatCompletion.handleError(error); - - if (errorResult) - throw AgentRuntimeError.chat({ - ...errorResult, - provider, - } as ChatCompletionErrorPayload); - } - - const { errorResult, RuntimeError } = handleOpenAIError(error); - - throw AgentRuntimeError.chat({ - endpoint: desensitizedEndpoint, - error: errorResult, - errorType: RuntimeError || ErrorType.bizError, - provider: provider as any, - }); + throw this.handleError(error); } } @@ -289,5 +215,50 @@ export const LobeOpenAICompatibleFactory = ({ }, }); } + + private handleError(error: any): ChatCompletionErrorPayload { + let desensitizedEndpoint = this.baseURL; + + // refs: https://github.com/lobehub/lobe-chat/issues/842 + if (this.baseURL !== DEFAULT_BASE_URL) { + desensitizedEndpoint = desensitizeUrl(this.baseURL); + } + + if (chatCompletion?.handleError) { + const errorResult = chatCompletion.handleError(error, this._options); + + if (errorResult) + return AgentRuntimeError.chat({ + ...errorResult, + provider, + } as ChatCompletionErrorPayload); + } + + if ('status' in (error as any)) { + switch ((error as Response).status) { + case 401: { + return AgentRuntimeError.chat({ + endpoint: desensitizedEndpoint, + error: error as any, + errorType: ErrorType.invalidAPIKey, + provider: provider as any, + }); + } + + default: { + break; + } + } + } + + const { errorResult, RuntimeError } = handleOpenAIError(error); + + return AgentRuntimeError.chat({ + endpoint: desensitizedEndpoint, + error: errorResult, + errorType: RuntimeError || ErrorType.bizError, + provider: provider as any, + }); + } }; }; diff --git a/src/libs/agent-runtime/zhipu/index.ts b/src/libs/agent-runtime/zhipu/index.ts index dd4aa653cbd6..5d61c120a49e 100644 --- a/src/libs/agent-runtime/zhipu/index.ts +++ b/src/libs/agent-runtime/zhipu/index.ts @@ -29,7 +29,7 @@ export class LobeZhipuAI implements LobeRuntimeAI { this.baseURL = this.client.baseURL; } - static async fromAPIKey({ apiKey, baseURL = DEFAULT_BASE_URL, ...res }: ClientOptions) { + static async fromAPIKey({ apiKey, baseURL = DEFAULT_BASE_URL, ...res }: ClientOptions = {}) { const invalidZhipuAPIKey = AgentRuntimeError.createError( AgentRuntimeErrorType.InvalidProviderAPIKey, );