Skip to content

Commit

Permalink
♻️ refactor: improve agent runtime code (#3199)
Browse files Browse the repository at this point in the history
* ♻️ refactor: improve agent runtime code

* ♻️ refactor: clean code
  • Loading branch information
arvinxx authored Jul 11, 2024
1 parent 5e8780d commit 9f211e2
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 118 deletions.
9 changes: 7 additions & 2 deletions src/app/api/chat/agentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,16 @@ const getLlmOptionsFromPayload = (provider: string, payload: JWTPayload) => {
* Initializes the agent runtime with the user payload in backend
* @param provider - The provider name.
* @param payload - The JWT payload.
* @param params
* @returns A promise that resolves when the agent runtime is initialized.
*/
export const initAgentRuntimeWithUserPayload = (provider: string, payload: JWTPayload) => {
export const initAgentRuntimeWithUserPayload = (
provider: string,
payload: JWTPayload,
params: any = {},
) => {
return AgentRuntime.initializeWithProviderOptions(provider, {
[provider]: getLlmOptionsFromPayload(provider, payload),
[provider]: { ...getLlmOptionsFromPayload(provider, payload), ...params },
});
};

Expand Down
40 changes: 20 additions & 20 deletions src/libs/agent-runtime/AgentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,93 +145,93 @@ class AgentRuntime {
}

case ModelProvider.ZhiPu: {
runtimeModel = await LobeZhipuAI.fromAPIKey(params.zhipu ?? {});
runtimeModel = await LobeZhipuAI.fromAPIKey(params.zhipu);
break;
}

case ModelProvider.Google: {
runtimeModel = new LobeGoogleAI(params.google ?? {});
runtimeModel = new LobeGoogleAI(params.google);
break;
}

case ModelProvider.Moonshot: {
runtimeModel = new LobeMoonshotAI(params.moonshot ?? {});
runtimeModel = new LobeMoonshotAI(params.moonshot);
break;
}

case ModelProvider.Bedrock: {
runtimeModel = new LobeBedrockAI(params.bedrock ?? {});
runtimeModel = new LobeBedrockAI(params.bedrock);
break;
}

case ModelProvider.Ollama: {
runtimeModel = new LobeOllamaAI(params.ollama ?? {});
runtimeModel = new LobeOllamaAI(params.ollama);
break;
}

case ModelProvider.Perplexity: {
runtimeModel = new LobePerplexityAI(params.perplexity ?? {});
runtimeModel = new LobePerplexityAI(params.perplexity);
break;
}

case ModelProvider.Anthropic: {
runtimeModel = new LobeAnthropicAI(params.anthropic ?? {});
runtimeModel = new LobeAnthropicAI(params.anthropic);
break;
}

case ModelProvider.DeepSeek: {
runtimeModel = new LobeDeepSeekAI(params.deepseek ?? {});
runtimeModel = new LobeDeepSeekAI(params.deepseek);
break;
}

case ModelProvider.Minimax: {
runtimeModel = new LobeMinimaxAI(params.minimax ?? {});
runtimeModel = new LobeMinimaxAI(params.minimax);
break;
}

case ModelProvider.Mistral: {
runtimeModel = new LobeMistralAI(params.mistral ?? {});
runtimeModel = new LobeMistralAI(params.mistral);
break;
}

case ModelProvider.Groq: {
runtimeModel = new LobeGroq(params.groq ?? {});
runtimeModel = new LobeGroq(params.groq);
break;
}

case ModelProvider.OpenRouter: {
runtimeModel = new LobeOpenRouterAI(params.openrouter ?? {});
runtimeModel = new LobeOpenRouterAI(params.openrouter);
break;
}

case ModelProvider.TogetherAI: {
runtimeModel = new LobeTogetherAI(params.togetherai ?? {});
runtimeModel = new LobeTogetherAI(params.togetherai);
break;
}

case ModelProvider.ZeroOne: {
runtimeModel = new LobeZeroOneAI(params.zeroone ?? {});
runtimeModel = new LobeZeroOneAI(params.zeroone);
break;
}

case ModelProvider.Qwen: {
runtimeModel = new LobeQwenAI(params.qwen ?? {});
runtimeModel = new LobeQwenAI(params.qwen);
break;
}

case ModelProvider.Stepfun: {
runtimeModel = new LobeStepfunAI(params.stepfun ?? {});
runtimeModel = new LobeStepfunAI(params.stepfun);
break;
}

case ModelProvider.Baichuan: {
runtimeModel = new LobeBaichuanAI(params.baichuan ?? {});
break
runtimeModel = new LobeBaichuanAI(params.baichuan);
break;
}

case ModelProvider.Taichu: {
runtimeModel = new LobeTaichuAI(params.taichu ?? {});
break
runtimeModel = new LobeTaichuAI(params.taichu);
break;
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/libs/agent-runtime/anthropic/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export class LobeAnthropicAI implements LobeRuntimeAI {

baseURL: string;

constructor({ apiKey, baseURL = DEFAULT_BASE_URL }: ClientOptions) {
constructor({ apiKey, baseURL = DEFAULT_BASE_URL }: ClientOptions = {}) {
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);

this.client = new Anthropic({ apiKey, baseURL });
Expand Down
2 changes: 1 addition & 1 deletion src/libs/agent-runtime/bedrock/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ export class LobeBedrockAI implements LobeRuntimeAI {

region: string;

constructor({ region, accessKeyId, accessKeySecret }: LobeBedrockAIParams) {
constructor({ region, accessKeyId, accessKeySecret }: LobeBedrockAIParams = {}) {
if (!(accessKeyId && accessKeySecret))
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidBedrockCredentials);

Expand Down
2 changes: 1 addition & 1 deletion src/libs/agent-runtime/google/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
private client: GoogleGenerativeAI;
baseURL?: string;

constructor({ apiKey, baseURL }: { apiKey?: string; baseURL?: string }) {
constructor({ apiKey, baseURL }: { apiKey?: string; baseURL?: string } = {}) {
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);

this.client = new GoogleGenerativeAI(apiKey);
Expand Down
2 changes: 1 addition & 1 deletion src/libs/agent-runtime/minimax/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ function parseMinimaxResponse(chunk: string): MinimaxResponse | undefined {
export class LobeMinimaxAI implements LobeRuntimeAI {
apiKey: string;

constructor({ apiKey }: { apiKey?: string }) {
constructor({ apiKey }: { apiKey?: string } = {}) {
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);

this.apiKey = apiKey;
Expand Down
2 changes: 1 addition & 1 deletion src/libs/agent-runtime/ollama/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export class LobeOllamaAI implements LobeRuntimeAI {

baseURL?: string;

constructor({ baseURL }: ClientOptions) {
constructor({ baseURL }: ClientOptions = {}) {
try {
if (baseURL) new URL(baseURL);
} catch {
Expand Down
151 changes: 61 additions & 90 deletions src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,18 @@ const CHAT_MODELS_BLOCK_LIST = [
'dall-e',
];

interface OpenAICompatibleFactoryOptions {
type ConstructorOptions<T extends Record<string, any> = any> = ClientOptions & T;

interface OpenAICompatibleFactoryOptions<T extends Record<string, any> = any> {
baseURL?: string;
chatCompletion?: {
handleError?: (error: any) => Omit<ChatCompletionErrorPayload, 'provider'> | undefined;
handleError?: (
error: any,
options: ConstructorOptions<T>,
) => Omit<ChatCompletionErrorPayload, 'provider'> | undefined;
handlePayload?: (payload: ChatStreamPayload) => OpenAI.ChatCompletionCreateParamsStreaming;
};
constructorOptions?: ClientOptions;
constructorOptions?: ConstructorOptions<T>;
debug?: {
chatCompletion: () => boolean;
};
Expand All @@ -49,15 +54,15 @@ interface OpenAICompatibleFactoryOptions {
provider: string;
}

export const LobeOpenAICompatibleFactory = ({
export const LobeOpenAICompatibleFactory = <T extends Record<string, any> = any>({
provider,
baseURL: DEFAULT_BASE_URL,
errorType,
debug,
constructorOptions,
chatCompletion,
models,
}: OpenAICompatibleFactoryOptions) => {
}: OpenAICompatibleFactoryOptions<T>) => {
const ErrorType = {
bizError: errorType?.bizError || AgentRuntimeErrorType.ProviderBizError,
invalidAPIKey: errorType?.invalidAPIKey || AgentRuntimeErrorType.InvalidProviderAPIKey,
Expand All @@ -67,8 +72,11 @@ export const LobeOpenAICompatibleFactory = ({
client: OpenAI;

baseURL: string;
private _options: ConstructorOptions<T>;

constructor({ apiKey, baseURL = DEFAULT_BASE_URL, ...res }: ClientOptions) {
constructor(options: ClientOptions & Record<string, any> = {}) {
const { apiKey, baseURL = DEFAULT_BASE_URL, ...res } = options;
this._options = options as ConstructorOptions<T>;
if (!apiKey) throw AgentRuntimeError.createError(ErrorType?.invalidAPIKey);

this.client = new OpenAI({ apiKey, baseURL, ...constructorOptions, ...res });
Expand Down Expand Up @@ -115,48 +123,7 @@ export const LobeOpenAICompatibleFactory = ({
headers: options?.headers,
});
} catch (error) {
let desensitizedEndpoint = this.baseURL;

// refs: https://github.com/lobehub/lobe-chat/issues/842
if (this.baseURL !== DEFAULT_BASE_URL) {
desensitizedEndpoint = desensitizeUrl(this.baseURL);
}

if ('status' in (error as any)) {
switch ((error as Response).status) {
case 401: {
throw AgentRuntimeError.chat({
endpoint: desensitizedEndpoint,
error: error as any,
errorType: ErrorType.invalidAPIKey,
provider: provider as any,
});
}

default: {
break;
}
}
}

if (chatCompletion?.handleError) {
const errorResult = chatCompletion.handleError(error);

if (errorResult)
throw AgentRuntimeError.chat({
...errorResult,
provider,
} as ChatCompletionErrorPayload);
}

const { errorResult, RuntimeError } = handleOpenAIError(error);

throw AgentRuntimeError.chat({
endpoint: desensitizedEndpoint,
error: errorResult,
errorType: RuntimeError || ErrorType.bizError,
provider: provider as any,
});
throw this.handleError(error);
}
}

Expand Down Expand Up @@ -191,48 +158,7 @@ export const LobeOpenAICompatibleFactory = ({
const res = await this.client.images.generate(payload);
return res.data.map((o) => o.url) as string[];
} catch (error) {
let desensitizedEndpoint = this.baseURL;

// refs: https://github.com/lobehub/lobe-chat/issues/842
if (this.baseURL !== DEFAULT_BASE_URL) {
desensitizedEndpoint = desensitizeUrl(this.baseURL);
}

if ('status' in (error as any)) {
switch ((error as Response).status) {
case 401: {
throw AgentRuntimeError.chat({
endpoint: desensitizedEndpoint,
error: error as any,
errorType: ErrorType.invalidAPIKey,
provider: provider as any,
});
}

default: {
break;
}
}
}

if (chatCompletion?.handleError) {
const errorResult = chatCompletion.handleError(error);

if (errorResult)
throw AgentRuntimeError.chat({
...errorResult,
provider,
} as ChatCompletionErrorPayload);
}

const { errorResult, RuntimeError } = handleOpenAIError(error);

throw AgentRuntimeError.chat({
endpoint: desensitizedEndpoint,
error: errorResult,
errorType: RuntimeError || ErrorType.bizError,
provider: provider as any,
});
throw this.handleError(error);
}
}

Expand Down Expand Up @@ -289,5 +215,50 @@ export const LobeOpenAICompatibleFactory = ({
},
});
}

private handleError(error: any): ChatCompletionErrorPayload {
let desensitizedEndpoint = this.baseURL;

// refs: https://github.com/lobehub/lobe-chat/issues/842
if (this.baseURL !== DEFAULT_BASE_URL) {
desensitizedEndpoint = desensitizeUrl(this.baseURL);
}

if (chatCompletion?.handleError) {
const errorResult = chatCompletion.handleError(error, this._options);

if (errorResult)
return AgentRuntimeError.chat({
...errorResult,
provider,
} as ChatCompletionErrorPayload);
}

if ('status' in (error as any)) {
switch ((error as Response).status) {
case 401: {
return AgentRuntimeError.chat({
endpoint: desensitizedEndpoint,
error: error as any,
errorType: ErrorType.invalidAPIKey,
provider: provider as any,
});
}

default: {
break;
}
}
}

const { errorResult, RuntimeError } = handleOpenAIError(error);

return AgentRuntimeError.chat({
endpoint: desensitizedEndpoint,
error: errorResult,
errorType: RuntimeError || ErrorType.bizError,
provider: provider as any,
});
}
};
};
2 changes: 1 addition & 1 deletion src/libs/agent-runtime/zhipu/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export class LobeZhipuAI implements LobeRuntimeAI {
this.baseURL = this.client.baseURL;
}

static async fromAPIKey({ apiKey, baseURL = DEFAULT_BASE_URL, ...res }: ClientOptions) {
static async fromAPIKey({ apiKey, baseURL = DEFAULT_BASE_URL, ...res }: ClientOptions = {}) {
const invalidZhipuAPIKey = AgentRuntimeError.createError(
AgentRuntimeErrorType.InvalidProviderAPIKey,
);
Expand Down

0 comments on commit 9f211e2

Please sign in to comment.