diff --git a/src/app/api/chat/[provider]/agentRuntime.test.ts b/src/app/api/chat/[provider]/agentRuntime.test.ts index e2b182659fe0..c77f5c581e91 100644 --- a/src/app/api/chat/[provider]/agentRuntime.test.ts +++ b/src/app/api/chat/[provider]/agentRuntime.test.ts @@ -18,6 +18,7 @@ import { LobeOpenAI, LobeOpenRouterAI, LobePerplexityAI, + LobeRuntimeAI, LobeTogetherAI, LobeZhipuAI, ModelProvider, @@ -70,33 +71,32 @@ describe('AgentRuntime', () => { const jwtPayload: JWTPayload = { apiKey: 'user-azure-key', endpoint: 'user-azure-endpoint', - useAzure: true, + azureApiVersion: '2024-02-01', }; const runtime = await AgentRuntime.initializeWithUserPayload( - ModelProvider.OpenAI, + ModelProvider.Azure, jwtPayload, ); expect(runtime).toBeInstanceOf(AgentRuntime); - expect(runtime['_runtime']).toBeInstanceOf(LobeOpenAI); + expect(runtime['_runtime']).toBeInstanceOf(LobeAzureOpenAI); expect(runtime['_runtime'].baseURL).toBe('user-azure-endpoint'); }); it('should initialize with azureOpenAIParams correctly', async () => { - const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', endpoint: 'user-endpoint' }; - const azureOpenAIParams = { - apiVersion: 'custom-version', - model: 'custom-model', - useAzure: true, + const jwtPayload: JWTPayload = { + apiKey: 'user-openai-key', + endpoint: 'user-endpoint', + azureApiVersion: 'custom-version', }; + const runtime = await AgentRuntime.initializeWithUserPayload( - ModelProvider.OpenAI, + ModelProvider.Azure, jwtPayload, - azureOpenAIParams, ); expect(runtime).toBeInstanceOf(AgentRuntime); - const openAIRuntime = runtime['_runtime'] as LobeOpenAI; - expect(openAIRuntime).toBeInstanceOf(LobeOpenAI); + const openAIRuntime = runtime['_runtime'] as LobeRuntimeAI; + expect(openAIRuntime).toBeInstanceOf(LobeAzureOpenAI); }); it('should initialize with AzureAI correctly', async () => { diff --git a/src/app/api/chat/[provider]/agentRuntime.ts b/src/app/api/chat/[provider]/agentRuntime.ts index 07793f9c1bc1..1ef3a2793690 100644 --- a/src/app/api/chat/[provider]/agentRuntime.ts +++ b/src/app/api/chat/[provider]/agentRuntime.ts @@ -30,12 +30,6 @@ import { TraceClient } from '@/libs/traces'; import apiKeyManager from '../apiKeyManager'; -interface AzureOpenAIParams { - apiVersion?: string; - model: string; - useAzure?: boolean; -} - export interface AgentChatOptions { enableTrace?: boolean; provider: string; @@ -112,18 +106,14 @@ class AgentRuntime { }); } - static async initializeWithUserPayload( - provider: string, - payload: JWTPayload, - azureOpenAI?: AzureOpenAIParams, - ) { + static async initializeWithUserPayload(provider: string, payload: JWTPayload) { let runtimeModel: LobeRuntimeAI; switch (provider) { default: case 'oneapi': case ModelProvider.OpenAI: { - runtimeModel = this.initOpenAI(payload, azureOpenAI); + runtimeModel = this.initOpenAI(payload); break; } @@ -196,27 +186,14 @@ class AgentRuntime { return new AgentRuntime(runtimeModel); } - private static initOpenAI(payload: JWTPayload, azureOpenAI?: AzureOpenAIParams) { - const { OPENAI_API_KEY, OPENAI_PROXY_URL, AZURE_API_VERSION, AZURE_API_KEY, USE_AZURE_OPENAI } = - getServerConfig(); + private static initOpenAI(payload: JWTPayload) { + const { OPENAI_API_KEY, OPENAI_PROXY_URL } = getServerConfig(); const openaiApiKey = payload?.apiKey || OPENAI_API_KEY; const baseURL = payload?.endpoint || OPENAI_PROXY_URL; - const azureApiKey = payload.apiKey || AZURE_API_KEY; - const useAzure = azureOpenAI?.useAzure || USE_AZURE_OPENAI; - const apiVersion = azureOpenAI?.apiVersion || AZURE_API_VERSION; + const apiKey = apiKeyManager.pick(openaiApiKey); - const apiKey = apiKeyManager.pick(useAzure ? azureApiKey : openaiApiKey); - - return new LobeOpenAI({ - apiKey, - azureOptions: { - apiVersion, - model: azureOpenAI?.model, - }, - baseURL, - useAzure, - }); + return new LobeOpenAI({ apiKey, baseURL }); } private static initAzureOpenAI(payload: JWTPayload) { diff --git a/src/app/api/chat/[provider]/route.test.ts b/src/app/api/chat/[provider]/route.test.ts index 57df4c470b59..9ba00d88ba0d 100644 --- a/src/app/api/chat/[provider]/route.test.ts +++ b/src/app/api/chat/[provider]/route.test.ts @@ -42,7 +42,6 @@ describe('POST handler', () => { accessCode: 'test-access-code', apiKey: 'test-api-key', azureApiVersion: 'v1', - useAzure: true, }); const mockRuntime: LobeRuntimeAI = { baseURL: 'abc', chat: vi.fn() }; @@ -56,11 +55,7 @@ describe('POST handler', () => { // 验证是否正确调用了模拟函数 expect(getJWTPayload).toHaveBeenCalledWith('Bearer some-valid-token'); - expect(spy).toHaveBeenCalledWith('test-provider', expect.anything(), { - apiVersion: 'v1', - model: 'test-model', - useAzure: true, - }); + expect(spy).toHaveBeenCalledWith('test-provider', expect.anything()); }); it('should return Unauthorized error when LOBE_CHAT_AUTH_HEADER is missing', async () => { diff --git a/src/app/api/chat/[provider]/route.ts b/src/app/api/chat/[provider]/route.ts index 9f5d719dba4d..72b622a29e2e 100644 --- a/src/app/api/chat/[provider]/route.ts +++ b/src/app/api/chat/[provider]/route.ts @@ -29,12 +29,7 @@ export const POST = async (req: Request, { params }: { params: { provider: strin const jwtPayload = await getJWTPayload(authorization); checkAuthMethod(jwtPayload.accessCode, jwtPayload.apiKey, oauthAuthorized); - const body = await req.clone().json(); - const agentRuntime = await AgentRuntime.initializeWithUserPayload(provider, jwtPayload, { - apiVersion: jwtPayload.azureApiVersion, - model: body.model, - useAzure: jwtPayload.useAzure, - }); + const agentRuntime = await AgentRuntime.initializeWithUserPayload(provider, jwtPayload); // ============ 2. create chat completion ============ // diff --git a/src/const/auth.ts b/src/const/auth.ts index 5d05a6d1db56..5c148dfe347e 100644 --- a/src/const/auth.ts +++ b/src/const/auth.ts @@ -24,7 +24,6 @@ export interface JWTPayload { endpoint?: string; azureApiVersion?: string; - useAzure?: boolean; awsAccessKeyId?: string; awsRegion?: string; diff --git a/src/libs/agent-runtime/groq/index.test.ts b/src/libs/agent-runtime/groq/index.test.ts index e74652aee5cf..67779be2eb26 100644 --- a/src/libs/agent-runtime/groq/index.test.ts +++ b/src/libs/agent-runtime/groq/index.test.ts @@ -75,13 +75,16 @@ describe('LobeGroqAI', () => { }); // Assert - expect(instance['client'].chat.completions.create).toHaveBeenCalledWith({ - max_tokens: 1024, - messages: [{ content: 'Hello', role: 'user' }], - model: 'mistralai/mistral-7b-instruct:free', - temperature: 0.7, - top_p: 1, - }); + expect(instance['client'].chat.completions.create).toHaveBeenCalledWith( + { + max_tokens: 1024, + messages: [{ content: 'Hello', role: 'user' }], + model: 'mistralai/mistral-7b-instruct:free', + temperature: 0.7, + top_p: 1, + }, + { headers: { Accept: '*/*' } }, + ); expect(result).toBeInstanceOf(Response); }); diff --git a/src/libs/agent-runtime/mistral/index.test.ts b/src/libs/agent-runtime/mistral/index.test.ts index 208307599f88..1250d69ddc8b 100644 --- a/src/libs/agent-runtime/mistral/index.test.ts +++ b/src/libs/agent-runtime/mistral/index.test.ts @@ -75,14 +75,17 @@ describe('LobeMistralAI', () => { }); // Assert - expect(instance['client'].chat.completions.create).toHaveBeenCalledWith({ - max_tokens: 1024, - messages: [{ content: 'Hello', role: 'user' }], - model: 'open-mistral-7b', - stream: true, - temperature: 0.7, - top_p: 1, - }); + expect(instance['client'].chat.completions.create).toHaveBeenCalledWith( + { + max_tokens: 1024, + messages: [{ content: 'Hello', role: 'user' }], + model: 'open-mistral-7b', + stream: true, + temperature: 0.7, + top_p: 1, + }, + { headers: { Accept: '*/*' } }, + ); expect(result).toBeInstanceOf(Response); }); @@ -105,14 +108,17 @@ describe('LobeMistralAI', () => { }); // Assert - expect(instance['client'].chat.completions.create).toHaveBeenCalledWith({ - max_tokens: 1024, - messages: [{ content: 'Hello', role: 'user' }], - model: 'open-mistral-7b', - stream: true, - temperature: 0.7, - top_p: 1, - }); + expect(instance['client'].chat.completions.create).toHaveBeenCalledWith( + { + max_tokens: 1024, + messages: [{ content: 'Hello', role: 'user' }], + model: 'open-mistral-7b', + stream: true, + temperature: 0.7, + top_p: 1, + }, + { headers: { Accept: '*/*' } }, + ); expect(result).toBeInstanceOf(Response); }); diff --git a/src/libs/agent-runtime/openai/index.test.ts b/src/libs/agent-runtime/openai/index.test.ts index 5eb9138bfce5..f32830bdb749 100644 --- a/src/libs/agent-runtime/openai/index.test.ts +++ b/src/libs/agent-runtime/openai/index.test.ts @@ -3,7 +3,7 @@ import OpenAI from 'openai'; import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; // 引入模块以便于对函数进行spy -import { ChatStreamCallbacks } from '@/libs/agent-runtime'; +import { ChatStreamCallbacks, LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime'; import * as debugStreamModule from '../utils/debugStream'; import { LobeOpenAI } from './index'; @@ -12,7 +12,7 @@ import { LobeOpenAI } from './index'; vi.spyOn(console, 'error').mockImplementation(() => {}); describe('LobeOpenAI', () => { - let instance: LobeOpenAI; + let instance: LobeOpenAICompatibleRuntime; beforeEach(() => { instance = new LobeOpenAI({ apiKey: 'test' }); @@ -27,54 +27,6 @@ describe('LobeOpenAI', () => { vi.clearAllMocks(); }); - describe('init', () => { - it('should correctly initialize with Azure options', () => { - const baseURL = 'https://abc.com'; - const modelName = 'abc'; - const client = new LobeOpenAI({ - apiKey: 'test', - useAzure: true, - baseURL, - azureOptions: { - apiVersion: '2023-08-01-preview', - model: 'abc', - }, - }); - - expect(client.baseURL).toEqual(baseURL + '/openai/deployments/' + modelName); - }); - - describe('initWithAzureOpenAI', () => { - it('should correctly initialize with Azure options', () => { - const baseURL = 'https://abc.com'; - const modelName = 'abc'; - const client = LobeOpenAI.initWithAzureOpenAI({ - apiKey: 'test', - useAzure: true, - baseURL, - azureOptions: { - apiVersion: '2023-08-01-preview', - model: 'abc', - }, - }); - - expect(client.baseURL).toEqual(baseURL + '/openai/deployments/' + modelName); - }); - - it('should use default Azure options when not explicitly provided', () => { - const baseURL = 'https://abc.com'; - - const client = LobeOpenAI.initWithAzureOpenAI({ - apiKey: 'test', - useAzure: true, - baseURL, - }); - - expect(client.baseURL).toEqual(baseURL + '/openai/deployments/'); - }); - }); - }); - describe('chat', () => { it('should return a StreamingTextResponse on successful API call', async () => { // Arrange diff --git a/src/libs/agent-runtime/openai/index.ts b/src/libs/agent-runtime/openai/index.ts index f3d6137b6fa9..288982b5e52f 100644 --- a/src/libs/agent-runtime/openai/index.ts +++ b/src/libs/agent-runtime/openai/index.ts @@ -1,108 +1,16 @@ -import { OpenAIStream, StreamingTextResponse } from 'ai'; -import OpenAI, { ClientOptions } from 'openai'; -import urlJoin from 'url-join'; - -import { ChatStreamPayload } from '@/types/openai/chat'; - -import { LobeRuntimeAI } from '../BaseAI'; import { AgentRuntimeErrorType } from '../error'; -import { ChatCompetitionOptions, ModelProvider } from '../types'; -import { AgentRuntimeError } from '../utils/createError'; -import { debugStream } from '../utils/debugStream'; -import { desensitizeUrl } from '../utils/desensitizeUrl'; -import { handleOpenAIError } from '../utils/handleOpenAIError'; - -const DEFAULT_BASE_URL = 'https://api.openai.com/v1'; - -interface LobeOpenAIOptions extends ClientOptions { - azureOptions?: { - apiVersion?: string; - model?: string; - }; - useAzure?: boolean; -} - -export class LobeOpenAI implements LobeRuntimeAI { - private client: OpenAI; - - constructor(options: LobeOpenAIOptions) { - if (!options.apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.NoOpenAIAPIKey); - - if (options.useAzure) { - this.client = LobeOpenAI.initWithAzureOpenAI(options); - } else { - this.client = new OpenAI(options); - } - - this.baseURL = this.client.baseURL; - } - - baseURL: string; - - async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) { - // ============ 1. preprocess messages ============ // - const { messages, ...params } = payload; - - // ============ 2. send api ============ // - - try { - const response = await this.client.chat.completions.create( - { - messages, - ...params, - stream: true, - } as unknown as OpenAI.ChatCompletionCreateParamsStreaming, - { headers: { Accept: '*/*' } }, - ); - - const [prod, debug] = response.tee(); - - if (process.env.DEBUG_OPENAI_CHAT_COMPLETION === '1') { - debugStream(debug.toReadableStream()).catch(console.error); - } - - return new StreamingTextResponse(OpenAIStream(prod, options?.callback), { - headers: options?.headers, - }); - } catch (error) { - const { errorResult, RuntimeError } = handleOpenAIError(error); - - const errorType = RuntimeError || AgentRuntimeErrorType.OpenAIBizError; - - let desensitizedEndpoint = this.baseURL; - - // refs: https://github.com/lobehub/lobe-chat/issues/842 - if (this.baseURL !== DEFAULT_BASE_URL) { - desensitizedEndpoint = desensitizeUrl(this.baseURL); - } - - throw AgentRuntimeError.chat({ - endpoint: desensitizedEndpoint, - error: errorResult, - errorType, - provider: ModelProvider.OpenAI, - }); - } - } - - static initWithAzureOpenAI(options: LobeOpenAIOptions) { - const endpoint = options.baseURL!; - const model = options.azureOptions?.model || ''; - - // refs: https://test-001.openai.azure.com/openai/deployments/gpt-35-turbo - const baseURL = urlJoin(endpoint, `/openai/deployments/${model.replace('.', '')}`); - - const apiVersion = options.azureOptions?.apiVersion || '2023-08-01-preview'; - const apiKey = options.apiKey!; - - const config: ClientOptions = { - ...options, - apiKey, - baseURL, - defaultHeaders: { 'api-key': apiKey }, - defaultQuery: { 'api-version': apiVersion }, - }; - - return new OpenAI(config); - } -} +import { ModelProvider } from '../types'; +import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory'; + +export const LobeOpenAI = LobeOpenAICompatibleFactory({ + baseURL: 'https://api.openai.com/v1', + debug: { + chatCompletion: () => process.env.DEBUG_OPENAI_CHAT_COMPLETION === '1', + }, + errorType: { + bizError: AgentRuntimeErrorType.OpenAIBizError, + invalidAPIKey: AgentRuntimeErrorType.NoOpenAIAPIKey, + }, + + provider: ModelProvider.OpenAI, +}); diff --git a/src/libs/agent-runtime/openrouter/index.test.ts b/src/libs/agent-runtime/openrouter/index.test.ts index ac2c72818d04..b4cc9bd718e8 100644 --- a/src/libs/agent-runtime/openrouter/index.test.ts +++ b/src/libs/agent-runtime/openrouter/index.test.ts @@ -75,13 +75,16 @@ describe('LobeOpenRouterAI', () => { }); // Assert - expect(instance['client'].chat.completions.create).toHaveBeenCalledWith({ - max_tokens: 1024, - messages: [{ content: 'Hello', role: 'user' }], - model: 'mistralai/mistral-7b-instruct:free', - temperature: 0.7, - top_p: 1, - }); + expect(instance['client'].chat.completions.create).toHaveBeenCalledWith( + { + max_tokens: 1024, + messages: [{ content: 'Hello', role: 'user' }], + model: 'mistralai/mistral-7b-instruct:free', + temperature: 0.7, + top_p: 1, + }, + { headers: { Accept: '*/*' } }, + ); expect(result).toBeInstanceOf(Response); }); diff --git a/src/libs/agent-runtime/togetherai/index.test.ts b/src/libs/agent-runtime/togetherai/index.test.ts index 31b3c9b32097..80e22bc13487 100644 --- a/src/libs/agent-runtime/togetherai/index.test.ts +++ b/src/libs/agent-runtime/togetherai/index.test.ts @@ -75,13 +75,16 @@ describe('LobeTogetherAI', () => { }); // Assert - expect(instance['client'].chat.completions.create).toHaveBeenCalledWith({ - max_tokens: 1024, - messages: [{ content: 'Hello', role: 'user' }], - model: 'mistralai/mistral-7b-instruct:free', - temperature: 0.7, - top_p: 1, - }); + expect(instance['client'].chat.completions.create).toHaveBeenCalledWith( + { + max_tokens: 1024, + messages: [{ content: 'Hello', role: 'user' }], + model: 'mistralai/mistral-7b-instruct:free', + temperature: 0.7, + top_p: 1, + }, + { headers: { Accept: '*/*' } }, + ); expect(result).toBeInstanceOf(Response); }); diff --git a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts index 4304734741f2..1ee56be41dc0 100644 --- a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts +++ b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts @@ -51,7 +51,11 @@ export const LobeOpenAICompatibleFactory = ({ ? chatCompletion.handlePayload(payload) : (payload as unknown as OpenAI.ChatCompletionCreateParamsStreaming); - const response = await this.client.chat.completions.create(postPayload); + const response = await this.client.chat.completions.create(postPayload, { + // https://github.com/lobehub/lobe-chat/pull/318 + headers: { Accept: '*/*' }, + }); + const [prod, useForDebug] = response.tee(); if (debug?.chatCompletion?.()) { @@ -64,6 +68,7 @@ export const LobeOpenAICompatibleFactory = ({ } catch (error) { let desensitizedEndpoint = this.baseURL; + // refs: https://github.com/lobehub/lobe-chat/issues/842 if (this.baseURL !== DEFAULT_BASE_URL) { desensitizedEndpoint = desensitizeUrl(this.baseURL); } diff --git a/src/libs/agent-runtime/zeroone/index.test.ts b/src/libs/agent-runtime/zeroone/index.test.ts index ec5291428a36..b949c4412c0d 100644 --- a/src/libs/agent-runtime/zeroone/index.test.ts +++ b/src/libs/agent-runtime/zeroone/index.test.ts @@ -58,7 +58,7 @@ describe('LobeZeroOneAI', () => { expect(result).toBeInstanceOf(Response); }); - it('should call OpenRouter API with corresponding options', async () => { + it('should call ZeroOne API with corresponding options', async () => { // Arrange const mockStream = new ReadableStream(); const mockResponse = Promise.resolve(mockStream); @@ -75,13 +75,16 @@ describe('LobeZeroOneAI', () => { }); // Assert - expect(instance['client'].chat.completions.create).toHaveBeenCalledWith({ - max_tokens: 1024, - messages: [{ content: 'Hello', role: 'user' }], - model: 'mistralai/mistral-7b-instruct:free', - temperature: 0.7, - top_p: 1, - }); + expect(instance['client'].chat.completions.create).toHaveBeenCalledWith( + { + max_tokens: 1024, + messages: [{ content: 'Hello', role: 'user' }], + model: 'mistralai/mistral-7b-instruct:free', + temperature: 0.7, + top_p: 1, + }, + { headers: { Accept: '*/*' } }, + ); expect(result).toBeInstanceOf(Response); }); diff --git a/src/services/chat.ts b/src/services/chat.ts index f4548716be49..6c638b0766eb 100644 --- a/src/services/chat.ts +++ b/src/services/chat.ts @@ -146,8 +146,8 @@ class ChatService { } const payload = merge( - { stream: true, ...DEFAULT_AGENT_CONFIG.params }, - { ...res, model: res.model }, + { model: DEFAULT_AGENT_CONFIG.model, stream: true, ...DEFAULT_AGENT_CONFIG.params }, + { ...res, model }, ); const traceHeader = createTraceHeader({ ...options?.trace });