Skip to content

Commit

Permalink
♻️ refactor: refactor agent runtime with openai compatible factory
Browse files Browse the repository at this point in the history
  • Loading branch information
arvinxx committed Apr 10, 2024
1 parent 78a1aac commit 89adf9d
Show file tree
Hide file tree
Showing 14 changed files with 922 additions and 500 deletions.
13 changes: 12 additions & 1 deletion src/libs/agent-runtime/BaseAI.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
import { StreamingTextResponse } from 'ai';
import OpenAI from 'openai';

import { ChatCompetitionOptions, ChatStreamPayload } from './types';

export interface LobeRuntimeAI {
baseURL?: string;

chat(
payload: ChatStreamPayload,
options?: ChatCompetitionOptions,
): Promise<StreamingTextResponse>;
}

export abstract class LobeOpenAICompatibleRuntime {
abstract chat(
payload: ChatStreamPayload,
options?: ChatCompetitionOptions,
): Promise<StreamingTextResponse>;

abstract client: OpenAI;

abstract baseURL: string;
}
347 changes: 347 additions & 0 deletions src/libs/agent-runtime/groq/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,347 @@
// @vitest-environment node
import OpenAI from 'openai';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import { ChatStreamCallbacks, LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime';

import * as debugStreamModule from '../utils/debugStream';
import { LobeGroq } from './index';

const provider = 'groq';
const defaultBaseURL = 'https://api.groq.com/openai/v1';
const bizErrorType = 'GroqBizError';
const invalidErrorType = 'InvalidGroqAPIKey';

// Mock the console.error to avoid polluting test output
vi.spyOn(console, 'error').mockImplementation(() => {});

let instance: LobeOpenAICompatibleRuntime;

beforeEach(() => {
instance = new LobeGroq({ apiKey: 'test' });

// 使用 vi.spyOn 来模拟 chat.completions.create 方法
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
new ReadableStream() as any,
);
});

afterEach(() => {
vi.clearAllMocks();
});

describe('LobeGroqAI', () => {
describe('init', () => {
it('should correctly initialize with an API key', async () => {
const instance = new LobeGroq({ apiKey: 'test_api_key' });
expect(instance).toBeInstanceOf(LobeGroq);
expect(instance.baseURL).toEqual(defaultBaseURL);
});
});

describe('chat', () => {
it('should return a StreamingTextResponse on successful API call', async () => {
// Arrange
const mockStream = new ReadableStream();
const mockResponse = Promise.resolve(mockStream);

(instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse);

// Act
const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0,
});

// Assert
expect(result).toBeInstanceOf(Response);
});

it('should call OpenRouter API with corresponding options', async () => {
// Arrange
const mockStream = new ReadableStream();
const mockResponse = Promise.resolve(mockStream);

(instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse);

// Act
const result = await instance.chat({
max_tokens: 1024,
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0.7,
top_p: 1,
});

// Assert
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith({
max_tokens: 1024,
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0.7,
top_p: 1,
});
expect(result).toBeInstanceOf(Response);
});

describe('Error', () => {
it('should return OpenRouterBizError with an openai error response when OpenAI.APIError is thrown', async () => {
// Arrange
const apiError = new OpenAI.APIError(
400,
{
status: 400,
error: {
message: 'Bad Request',
},
},
'Error message',
{},
);

vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);

// Act
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0,
});
} catch (e) {
expect(e).toEqual({
endpoint: defaultBaseURL,
error: {
error: { message: 'Bad Request' },
status: 400,
},
errorType: bizErrorType,
provider,
});
}
});

it('should throw AgentRuntimeError with InvalidOpenRouterAPIKey if no apiKey is provided', async () => {
try {
new LobeGroq({});
} catch (e) {
expect(e).toEqual({ errorType: invalidErrorType });
}
});

it('should return OpenRouterBizError with the cause when OpenAI.APIError is thrown with cause', async () => {
// Arrange
const errorInfo = {
stack: 'abc',
cause: {
message: 'api is undefined',
},
};
const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {});

vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);

// Act
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0,
});
} catch (e) {
expect(e).toEqual({
endpoint: defaultBaseURL,
error: {
cause: { message: 'api is undefined' },
stack: 'abc',
},
errorType: bizErrorType,
provider,
});
}
});

it('should return OpenRouterBizError with an cause response with desensitize Url', async () => {
// Arrange
const errorInfo = {
stack: 'abc',
cause: { message: 'api is undefined' },
};
const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {});

instance = new LobeGroq({
apiKey: 'test',

baseURL: 'https://api.abc.com/v1',
});

vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);

// Act
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0,
});
} catch (e) {
expect(e).toEqual({
endpoint: 'https://api.***.com/v1',
error: {
cause: { message: 'api is undefined' },
stack: 'abc',
},
errorType: bizErrorType,
provider,
});
}
});

it('should throw an InvalidOpenRouterAPIKey error type on 401 status code', async () => {
// Mock the API call to simulate a 401 error
const error = new Error('Unauthorized') as any;
error.status = 401;
vi.mocked(instance['client'].chat.completions.create).mockRejectedValue(error);

try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0,
});
} catch (e) {
// Expect the chat method to throw an error with InvalidMoonshotAPIKey
expect(e).toEqual({
endpoint: defaultBaseURL,
error: new Error('Unauthorized'),
errorType: invalidErrorType,
provider,
});
}
});

it('should return AgentRuntimeError for non-OpenAI errors', async () => {
// Arrange
const genericError = new Error('Generic Error');

vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(genericError);

// Act
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0,
});
} catch (e) {
expect(e).toEqual({
endpoint: defaultBaseURL,
errorType: 'AgentRuntimeError',
provider,
error: {
name: genericError.name,
cause: genericError.cause,
message: genericError.message,
stack: genericError.stack,
},
});
}
});
});

describe('LobeGroqAI chat with callback and headers', () => {
it('should handle callback and headers correctly', async () => {
// 模拟 chat.completions.create 方法返回一个可读流
const mockCreateMethod = vi
.spyOn(instance['client'].chat.completions, 'create')
.mockResolvedValue(
new ReadableStream({
start(controller) {
controller.enqueue({
id: 'chatcmpl-8xDx5AETP8mESQN7UB30GxTN2H1SO',
object: 'chat.completion.chunk',
created: 1709125675,
model: 'mistralai/mistral-7b-instruct:free',
system_fingerprint: 'fp_86156a94a0',
choices: [
{ index: 0, delta: { content: 'hello' }, logprobs: null, finish_reason: null },
],
});
controller.close();
},
}) as any,
);

// 准备 callback 和 headers
const mockCallback: ChatStreamCallbacks = {
onStart: vi.fn(),
onToken: vi.fn(),
};
const mockHeaders = { 'Custom-Header': 'TestValue' };

// 执行测试
const result = await instance.chat(
{
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0,
},
{ callback: mockCallback, headers: mockHeaders },
);

// 验证 callback 被调用
await result.text(); // 确保流被消费
expect(mockCallback.onStart).toHaveBeenCalled();
expect(mockCallback.onToken).toHaveBeenCalledWith('hello');

// 验证 headers 被正确传递
expect(result.headers.get('Custom-Header')).toEqual('TestValue');

// 清理
mockCreateMethod.mockRestore();
});
});

describe('DEBUG', () => {
it('should call debugStream and return StreamingTextResponse when DEBUG_OPENROUTER_CHAT_COMPLETION is 1', async () => {
// Arrange
const mockProdStream = new ReadableStream() as any; // 模拟的 prod 流
const mockDebugStream = new ReadableStream({
start(controller) {
controller.enqueue('Debug stream content');
controller.close();
},
}) as any;
mockDebugStream.toReadableStream = () => mockDebugStream; // 添加 toReadableStream 方法

// 模拟 chat.completions.create 返回值,包括模拟的 tee 方法
(instance['client'].chat.completions.create as Mock).mockResolvedValue({
tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }],
});

// 保存原始环境变量值
const originalDebugValue = process.env.DEBUG_GROQ_CHAT_COMPLETION;

// 模拟环境变量
process.env.DEBUG_GROQ_CHAT_COMPLETION = '1';
vi.spyOn(debugStreamModule, 'debugStream').mockImplementation(() => Promise.resolve());

// 执行测试
// 运行你的测试函数,确保它会在条件满足时调用 debugStream
// 假设的测试函数调用,你可能需要根据实际情况调整
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0,
});

// 验证 debugStream 被调用
expect(debugStreamModule.debugStream).toHaveBeenCalled();

// 恢复原始环境变量值
process.env.DEBUG_GROQ_CHAT_COMPLETION = originalDebugValue;
});
});
});
});
Loading

0 comments on commit 89adf9d

Please sign in to comment.