Skip to content

Commit

Permalink
🔥 refactor: clean openai azure code
Browse files Browse the repository at this point in the history
  • Loading branch information
arvinxx committed Apr 10, 2024
1 parent 6ceb818 commit be4bcca
Show file tree
Hide file tree
Showing 14 changed files with 108 additions and 259 deletions.
24 changes: 12 additions & 12 deletions src/app/api/chat/[provider]/agentRuntime.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
LobeOpenAI,
LobeOpenRouterAI,
LobePerplexityAI,
LobeRuntimeAI,
LobeTogetherAI,
LobeZhipuAI,
ModelProvider,
Expand Down Expand Up @@ -70,33 +71,32 @@ describe('AgentRuntime', () => {
const jwtPayload: JWTPayload = {
apiKey: 'user-azure-key',
endpoint: 'user-azure-endpoint',
useAzure: true,
azureApiVersion: '2024-02-01',
};
const runtime = await AgentRuntime.initializeWithUserPayload(
ModelProvider.OpenAI,
ModelProvider.Azure,
jwtPayload,
);

expect(runtime).toBeInstanceOf(AgentRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobeOpenAI);
expect(runtime['_runtime']).toBeInstanceOf(LobeAzureOpenAI);
expect(runtime['_runtime'].baseURL).toBe('user-azure-endpoint');
});
it('should initialize with azureOpenAIParams correctly', async () => {
const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', endpoint: 'user-endpoint' };
const azureOpenAIParams = {
apiVersion: 'custom-version',
model: 'custom-model',
useAzure: true,
const jwtPayload: JWTPayload = {
apiKey: 'user-openai-key',
endpoint: 'user-endpoint',
azureApiVersion: 'custom-version',
};

const runtime = await AgentRuntime.initializeWithUserPayload(
ModelProvider.OpenAI,
ModelProvider.Azure,
jwtPayload,
azureOpenAIParams,
);

expect(runtime).toBeInstanceOf(AgentRuntime);
const openAIRuntime = runtime['_runtime'] as LobeOpenAI;
expect(openAIRuntime).toBeInstanceOf(LobeOpenAI);
const openAIRuntime = runtime['_runtime'] as LobeRuntimeAI;
expect(openAIRuntime).toBeInstanceOf(LobeAzureOpenAI);
});

it('should initialize with AzureAI correctly', async () => {
Expand Down
35 changes: 6 additions & 29 deletions src/app/api/chat/[provider]/agentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@ import { TraceClient } from '@/libs/traces';

import apiKeyManager from '../apiKeyManager';

interface AzureOpenAIParams {
apiVersion?: string;
model: string;
useAzure?: boolean;
}

export interface AgentChatOptions {
enableTrace?: boolean;
provider: string;
Expand Down Expand Up @@ -112,18 +106,14 @@ class AgentRuntime {
});
}

static async initializeWithUserPayload(
provider: string,
payload: JWTPayload,
azureOpenAI?: AzureOpenAIParams,
) {
static async initializeWithUserPayload(provider: string, payload: JWTPayload) {
let runtimeModel: LobeRuntimeAI;

switch (provider) {
default:
case 'oneapi':
case ModelProvider.OpenAI: {
runtimeModel = this.initOpenAI(payload, azureOpenAI);
runtimeModel = this.initOpenAI(payload);
break;
}

Expand Down Expand Up @@ -196,27 +186,14 @@ class AgentRuntime {
return new AgentRuntime(runtimeModel);
}

private static initOpenAI(payload: JWTPayload, azureOpenAI?: AzureOpenAIParams) {
const { OPENAI_API_KEY, OPENAI_PROXY_URL, AZURE_API_VERSION, AZURE_API_KEY, USE_AZURE_OPENAI } =
getServerConfig();
private static initOpenAI(payload: JWTPayload) {
const { OPENAI_API_KEY, OPENAI_PROXY_URL } = getServerConfig();
const openaiApiKey = payload?.apiKey || OPENAI_API_KEY;
const baseURL = payload?.endpoint || OPENAI_PROXY_URL;

const azureApiKey = payload.apiKey || AZURE_API_KEY;
const useAzure = azureOpenAI?.useAzure || USE_AZURE_OPENAI;
const apiVersion = azureOpenAI?.apiVersion || AZURE_API_VERSION;
const apiKey = apiKeyManager.pick(openaiApiKey);

const apiKey = apiKeyManager.pick(useAzure ? azureApiKey : openaiApiKey);

return new LobeOpenAI({
apiKey,
azureOptions: {
apiVersion,
model: azureOpenAI?.model,
},
baseURL,
useAzure,
});
return new LobeOpenAI({ apiKey, baseURL });
}

private static initAzureOpenAI(payload: JWTPayload) {
Expand Down
7 changes: 1 addition & 6 deletions src/app/api/chat/[provider]/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ describe('POST handler', () => {
accessCode: 'test-access-code',
apiKey: 'test-api-key',
azureApiVersion: 'v1',
useAzure: true,
});

const mockRuntime: LobeRuntimeAI = { baseURL: 'abc', chat: vi.fn() };
Expand All @@ -56,11 +55,7 @@ describe('POST handler', () => {

// 验证是否正确调用了模拟函数
expect(getJWTPayload).toHaveBeenCalledWith('Bearer some-valid-token');
expect(spy).toHaveBeenCalledWith('test-provider', expect.anything(), {
apiVersion: 'v1',
model: 'test-model',
useAzure: true,
});
expect(spy).toHaveBeenCalledWith('test-provider', expect.anything());
});

it('should return Unauthorized error when LOBE_CHAT_AUTH_HEADER is missing', async () => {
Expand Down
7 changes: 1 addition & 6 deletions src/app/api/chat/[provider]/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@ export const POST = async (req: Request, { params }: { params: { provider: strin
const jwtPayload = await getJWTPayload(authorization);
checkAuthMethod(jwtPayload.accessCode, jwtPayload.apiKey, oauthAuthorized);

const body = await req.clone().json();
const agentRuntime = await AgentRuntime.initializeWithUserPayload(provider, jwtPayload, {
apiVersion: jwtPayload.azureApiVersion,
model: body.model,
useAzure: jwtPayload.useAzure,
});
const agentRuntime = await AgentRuntime.initializeWithUserPayload(provider, jwtPayload);

// ============ 2. create chat completion ============ //

Expand Down
1 change: 0 additions & 1 deletion src/const/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ export interface JWTPayload {
endpoint?: string;

azureApiVersion?: string;
useAzure?: boolean;

awsAccessKeyId?: string;
awsRegion?: string;
Expand Down
17 changes: 10 additions & 7 deletions src/libs/agent-runtime/groq/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,16 @@ describe('LobeGroqAI', () => {
});

// Assert
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith({
max_tokens: 1024,
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0.7,
top_p: 1,
});
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
{
max_tokens: 1024,
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0.7,
top_p: 1,
},
{ headers: { Accept: '*/*' } },
);
expect(result).toBeInstanceOf(Response);
});

Expand Down
38 changes: 22 additions & 16 deletions src/libs/agent-runtime/mistral/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,17 @@ describe('LobeMistralAI', () => {
});

// Assert
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith({
max_tokens: 1024,
messages: [{ content: 'Hello', role: 'user' }],
model: 'open-mistral-7b',
stream: true,
temperature: 0.7,
top_p: 1,
});
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
{
max_tokens: 1024,
messages: [{ content: 'Hello', role: 'user' }],
model: 'open-mistral-7b',
stream: true,
temperature: 0.7,
top_p: 1,
},
{ headers: { Accept: '*/*' } },
);
expect(result).toBeInstanceOf(Response);
});

Expand All @@ -105,14 +108,17 @@ describe('LobeMistralAI', () => {
});

// Assert
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith({
max_tokens: 1024,
messages: [{ content: 'Hello', role: 'user' }],
model: 'open-mistral-7b',
stream: true,
temperature: 0.7,
top_p: 1,
});
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
{
max_tokens: 1024,
messages: [{ content: 'Hello', role: 'user' }],
model: 'open-mistral-7b',
stream: true,
temperature: 0.7,
top_p: 1,
},
{ headers: { Accept: '*/*' } },
);
expect(result).toBeInstanceOf(Response);
});

Expand Down
52 changes: 2 additions & 50 deletions src/libs/agent-runtime/openai/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import OpenAI from 'openai';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

// 引入模块以便于对函数进行spy
import { ChatStreamCallbacks } from '@/libs/agent-runtime';
import { ChatStreamCallbacks, LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime';

import * as debugStreamModule from '../utils/debugStream';
import { LobeOpenAI } from './index';
Expand All @@ -12,7 +12,7 @@ import { LobeOpenAI } from './index';
vi.spyOn(console, 'error').mockImplementation(() => {});

describe('LobeOpenAI', () => {
let instance: LobeOpenAI;
let instance: LobeOpenAICompatibleRuntime;

beforeEach(() => {
instance = new LobeOpenAI({ apiKey: 'test' });
Expand All @@ -27,54 +27,6 @@ describe('LobeOpenAI', () => {
vi.clearAllMocks();
});

describe('init', () => {
it('should correctly initialize with Azure options', () => {
const baseURL = 'https://abc.com';
const modelName = 'abc';
const client = new LobeOpenAI({
apiKey: 'test',
useAzure: true,
baseURL,
azureOptions: {
apiVersion: '2023-08-01-preview',
model: 'abc',
},
});

expect(client.baseURL).toEqual(baseURL + '/openai/deployments/' + modelName);
});

describe('initWithAzureOpenAI', () => {
it('should correctly initialize with Azure options', () => {
const baseURL = 'https://abc.com';
const modelName = 'abc';
const client = LobeOpenAI.initWithAzureOpenAI({
apiKey: 'test',
useAzure: true,
baseURL,
azureOptions: {
apiVersion: '2023-08-01-preview',
model: 'abc',
},
});

expect(client.baseURL).toEqual(baseURL + '/openai/deployments/' + modelName);
});

it('should use default Azure options when not explicitly provided', () => {
const baseURL = 'https://abc.com';

const client = LobeOpenAI.initWithAzureOpenAI({
apiKey: 'test',
useAzure: true,
baseURL,
});

expect(client.baseURL).toEqual(baseURL + '/openai/deployments/');
});
});
});

describe('chat', () => {
it('should return a StreamingTextResponse on successful API call', async () => {
// Arrange
Expand Down
Loading

0 comments on commit be4bcca

Please sign in to comment.