Skip to content

Commit

Permalink
✅ test: add test
Browse files Browse the repository at this point in the history
  • Loading branch information
arvinxx committed Apr 10, 2024
1 parent 32cd78f commit ecd1d25
Show file tree
Hide file tree
Showing 7 changed files with 487 additions and 11 deletions.
166 changes: 166 additions & 0 deletions src/libs/agent-runtime/azureOpenai/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
// @vitest-environment node
import { AzureKeyCredential, OpenAIClient } from '@azure/openai';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import * as debugStreamModule from '../utils/debugStream';
import { LobeAzureOpenAI } from './index';

// Mock the console.error to avoid polluting test output
vi.spyOn(console, 'error').mockImplementation(() => {});

describe('LobeAzureOpenAI', () => {
let instance: LobeAzureOpenAI;

beforeEach(() => {
instance = new LobeAzureOpenAI(
'https://test.openai.azure.com/',
'test_key',
'2023-03-15-preview',
);

// 使用 vi.spyOn 来模拟 streamChatCompletions 方法
vi.spyOn(instance['client'], 'streamChatCompletions').mockResolvedValue(
new ReadableStream() as any,
);
});

afterEach(() => {
vi.clearAllMocks();
});

describe('constructor', () => {
it('should throw InvalidAzureAPIKey error when apikey or endpoint is missing', () => {
try {
new LobeAzureOpenAI();
} catch (e) {
expect(e).toEqual({ errorType: 'InvalidAzureAPIKey' });
}
});

it('should create an instance of OpenAIClient with correct parameters', () => {
const endpoint = 'https://test.openai.azure.com/';
const apikey = 'test_key';
const apiVersion = '2023-03-15-preview';

const instance = new LobeAzureOpenAI(endpoint, apikey, apiVersion);

expect(instance.client).toBeInstanceOf(OpenAIClient);
expect(instance.baseURL).toBe(endpoint);
});
});

describe('chat', () => {
it('should return a StreamingTextResponse on successful API call', async () => {
// Arrange
const mockStream = new ReadableStream();
const mockResponse = Promise.resolve(mockStream);

(instance['client'].streamChatCompletions as Mock).mockResolvedValue(mockResponse);

// Act
const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 0,
});

// Assert
expect(result).toBeInstanceOf(Response);
});

describe('Error', () => {
it('should return AzureBizError with DeploymentNotFound error', async () => {
// Arrange
const error = {
code: 'DeploymentNotFound',
message: 'Deployment not found',
};

(instance['client'].streamChatCompletions as Mock).mockRejectedValue(error);

// Act
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 0,
});
} catch (e) {
// Assert
expect(e).toEqual({
endpoint: 'https://test.openai.azure.com/',
error: {
code: 'DeploymentNotFound',
message: 'Deployment not found',
deployId: 'text-davinci-003',
},
errorType: 'AzureBizError',
provider: 'azure',
});
}
});

it('should return AgentRuntimeError for non-Azure errors', async () => {
// Arrange
const genericError = new Error('Generic Error');

(instance['client'].streamChatCompletions as Mock).mockRejectedValue(genericError);

// Act
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 0,
});
} catch (e) {
// Assert
expect(e).toEqual({
endpoint: 'https://test.openai.azure.com/',
errorType: 'AgentRuntimeError',
provider: 'azure',
error: {
name: genericError.name,
cause: genericError.cause,
message: genericError.message,
},
});
}
});
});

describe('DEBUG', () => {
it('should call debugStream when DEBUG_CHAT_COMPLETION is 1', async () => {
// Arrange
const mockProdStream = new ReadableStream() as any;
const mockDebugStream = new ReadableStream({
start(controller) {
controller.enqueue('Debug stream content');
controller.close();
},
}) as any;
mockDebugStream.toReadableStream = () => mockDebugStream;

(instance['client'].streamChatCompletions as Mock).mockResolvedValue({
tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }],
});

process.env.DEBUG_AZURE_CHAT_COMPLETION = '1';
vi.spyOn(debugStreamModule, 'debugStream').mockImplementation(() => Promise.resolve());

// Act
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 0,
});

// Assert
expect(debugStreamModule.debugStream).toHaveBeenCalled();

// Restore
delete process.env.DEBUG_AZURE_CHAT_COMPLETION;
});
});
});
});
24 changes: 16 additions & 8 deletions src/libs/agent-runtime/azureOpenai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@ import { AgentRuntimeErrorType } from '../error';
import { ChatStreamPayload, ModelProvider } from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { DEBUG_CHAT_COMPLETION } from '../utils/env';

export class LobeAzureOpenAI implements LobeRuntimeAI {
private _client: OpenAIClient;
client: OpenAIClient;

constructor(endpoint?: string, apikey?: string, apiVersion?: string) {
if (!apikey || !endpoint)
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidAzureAPIKey);

this._client = new OpenAIClient(endpoint, new AzureKeyCredential(apikey), { apiVersion });
this.client = new OpenAIClient(endpoint, new AzureKeyCredential(apikey), { apiVersion });

this.baseURL = endpoint;
}
Expand All @@ -34,7 +33,7 @@ export class LobeAzureOpenAI implements LobeRuntimeAI {
// ============ 2. send api ============ //

try {
const response = await this._client.streamChatCompletions(
const response = await this.client.streamChatCompletions(
model,
messages as ChatRequestMessage[],
params as GetChatCompletionsOptions,
Expand All @@ -45,25 +44,34 @@ export class LobeAzureOpenAI implements LobeRuntimeAI {

const [debug, prod] = stream.tee();

if (DEBUG_CHAT_COMPLETION) {
if (process.env.DEBUG_AZURE_CHAT_COMPLETION === '1') {
debugStream(debug).catch(console.error);
}

return new StreamingTextResponse(prod);
} catch (e) {
let error = e as { [key: string]: any; code: string; message: string };

switch (error.code) {
case 'DeploymentNotFound': {
error = { ...error, deployId: model };
if (error.code) {
switch (error.code) {
case 'DeploymentNotFound': {
error = { ...error, deployId: model };
}
}
} else {
error = {
cause: error.cause,
message: error.message,
name: error.name,
} as any;
}

const errorType = error.code
? AgentRuntimeErrorType.AzureBizError
: AgentRuntimeErrorType.AgentRuntimeError;

throw AgentRuntimeError.chat({
endpoint: this.baseURL,
error,
errorType,
provider: ModelProvider.Azure,
Expand Down
1 change: 0 additions & 1 deletion src/libs/agent-runtime/utils/env.ts

This file was deleted.

113 changes: 113 additions & 0 deletions src/store/chat/slices/share/action.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { act, renderHook } from '@testing-library/react';
import { DEFAULT_USER_AVATAR_URL } from '@/const/meta';
import { shareGPTService } from '@/services/share';
import { useChatStore } from '@/store/chat';
import { ChatMessage } from '@/types/message';

describe('shareSlice actions', () => {
let shareGPTServiceSpy: any;
Expand Down Expand Up @@ -82,5 +83,117 @@ describe('shareSlice actions', () => {
expect(result.current.shareLoading).toBe(false);
// 注意:这里的验证可能需要你根据实际的状态管理逻辑进行调整
});

it('should include plugin information when withPluginInfo is true', async () => {
// 模拟带有插件信息的消息
const pluginMessage = {
role: 'function',
content: 'plugin content',
plugin: {
type: 'default',
arguments: '{}',
apiName: 'test-api',
identifier: 'test-identifier',
},
id: 'abc',
} as ChatMessage;

act(() => {
useChatStore.setState({ messages: [pluginMessage] });
});

const { result } = renderHook(() => useChatStore());
await act(async () => {
result.current.shareToShareGPT({ withPluginInfo: true });
});
expect(shareGPTServiceSpy).toHaveBeenCalledWith(
expect.objectContaining({
items: expect.arrayContaining([
expect.objectContaining({
from: 'gpt',
value: expect.stringContaining('Function Calling Plugin'),
}),
]),
}),
);
});

it('should not include plugin information when withPluginInfo is false', async () => {
const pluginMessage = {
role: 'function',
content: 'plugin content',
plugin: {
type: 'default',
arguments: '{}',
apiName: 'test-api',
identifier: 'test-identifier',
},
id: 'abc',
} as ChatMessage;

act(() => {
useChatStore.setState({ messages: [pluginMessage] });
});

const { result } = renderHook(() => useChatStore());
await act(async () => {
result.current.shareToShareGPT({ withPluginInfo: false });
});
expect(shareGPTServiceSpy).toHaveBeenCalledWith(
expect.objectContaining({
items: expect.not.arrayContaining([
expect.objectContaining({
from: 'gpt',
value: expect.stringContaining('Function Calling Plugin'),
}),
]),
}),
);
});

it('should handle messages from different roles correctly', async () => {
const messages = [
{ role: 'user', content: 'user message', id: '1' },
{ role: 'assistant', content: 'assistant message', id: '2' },
{
role: 'function',
content: 'plugin content',
plugin: {
type: 'default',
arguments: '{}',
apiName: 'test-api',
identifier: 'test-identifier',
},
id: '3',
},
] as ChatMessage[];

act(() => {
useChatStore.setState({ messages });
});

const { result } = renderHook(() => useChatStore());
await act(async () => {
await result.current.shareToShareGPT({
withPluginInfo: true,
withSystemRole: true,
});
});

expect(shareGPTServiceSpy).toHaveBeenCalledWith(
expect.objectContaining({
items: [
expect.objectContaining({ from: 'gpt' }), // Agent meta info
expect.objectContaining({ from: 'human', value: 'user message' }),
expect.objectContaining({ from: 'gpt', value: 'assistant message' }),
expect.objectContaining({
from: 'gpt',
value: expect.stringContaining('Function Calling Plugin'),
}),
expect.objectContaining({ from: 'gpt', value: expect.stringContaining('Share from') }), // Footer
],
}),
);
});
});
});
2 changes: 1 addition & 1 deletion src/store/chat/slices/share/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ export interface ShareAction {
avatar?: string;
withPluginInfo?: boolean;
withSystemRole?: boolean;
}) => void;
}) => Promise<void>;
}

export const chatShare: StateCreator<ChatStore, [['zustand/devtools', never]], [], ShareAction> = (
Expand Down
Loading

0 comments on commit ecd1d25

Please sign in to comment.