From ecd1d259b3b75c01f2dbc63bb69a6aa804cf2688 Mon Sep 17 00:00:00 2001 From: Arvin Xu Date: Wed, 10 Apr 2024 07:26:03 +0000 Subject: [PATCH] =?UTF-8?q?=20=E2=9C=85=20test:=20add=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../agent-runtime/azureOpenai/index.test.ts | 166 ++++++++++++++++++ src/libs/agent-runtime/azureOpenai/index.ts | 24 ++- src/libs/agent-runtime/utils/env.ts | 1 - src/store/chat/slices/share/action.test.ts | 113 ++++++++++++ src/store/chat/slices/share/action.ts | 2 +- src/store/global/slices/common/action.test.ts | 165 +++++++++++++++++ .../slices/settings/actions/llm.test.ts | 27 ++- 7 files changed, 487 insertions(+), 11 deletions(-) create mode 100644 src/libs/agent-runtime/azureOpenai/index.test.ts delete mode 100644 src/libs/agent-runtime/utils/env.ts diff --git a/src/libs/agent-runtime/azureOpenai/index.test.ts b/src/libs/agent-runtime/azureOpenai/index.test.ts new file mode 100644 index 000000000000..06e0ba55b1ec --- /dev/null +++ b/src/libs/agent-runtime/azureOpenai/index.test.ts @@ -0,0 +1,166 @@ +// @vitest-environment node +import { AzureKeyCredential, OpenAIClient } from '@azure/openai'; +import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +import * as debugStreamModule from '../utils/debugStream'; +import { LobeAzureOpenAI } from './index'; + +// Mock the console.error to avoid polluting test output +vi.spyOn(console, 'error').mockImplementation(() => {}); + +describe('LobeAzureOpenAI', () => { + let instance: LobeAzureOpenAI; + + beforeEach(() => { + instance = new LobeAzureOpenAI( + 'https://test.openai.azure.com/', + 'test_key', + '2023-03-15-preview', + ); + + // 使用 vi.spyOn 来模拟 streamChatCompletions 方法 + vi.spyOn(instance['client'], 'streamChatCompletions').mockResolvedValue( + new ReadableStream() as any, + ); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + describe('constructor', () => { + it('should throw InvalidAzureAPIKey error when apikey or endpoint is missing', () => { + try { + new LobeAzureOpenAI(); + } catch (e) { + expect(e).toEqual({ errorType: 'InvalidAzureAPIKey' }); + } + }); + + it('should create an instance of OpenAIClient with correct parameters', () => { + const endpoint = 'https://test.openai.azure.com/'; + const apikey = 'test_key'; + const apiVersion = '2023-03-15-preview'; + + const instance = new LobeAzureOpenAI(endpoint, apikey, apiVersion); + + expect(instance.client).toBeInstanceOf(OpenAIClient); + expect(instance.baseURL).toBe(endpoint); + }); + }); + + describe('chat', () => { + it('should return a StreamingTextResponse on successful API call', async () => { + // Arrange + const mockStream = new ReadableStream(); + const mockResponse = Promise.resolve(mockStream); + + (instance['client'].streamChatCompletions as Mock).mockResolvedValue(mockResponse); + + // Act + const result = await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + + // Assert + expect(result).toBeInstanceOf(Response); + }); + + describe('Error', () => { + it('should return AzureBizError with DeploymentNotFound error', async () => { + // Arrange + const error = { + code: 'DeploymentNotFound', + message: 'Deployment not found', + }; + + (instance['client'].streamChatCompletions as Mock).mockRejectedValue(error); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + } catch (e) { + // Assert + expect(e).toEqual({ + endpoint: 'https://test.openai.azure.com/', + error: { + code: 'DeploymentNotFound', + message: 'Deployment not found', + deployId: 'text-davinci-003', + }, + errorType: 'AzureBizError', + provider: 'azure', + }); + } + }); + + it('should return AgentRuntimeError for non-Azure errors', async () => { + // Arrange + const genericError = new Error('Generic Error'); + + (instance['client'].streamChatCompletions as Mock).mockRejectedValue(genericError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + } catch (e) { + // Assert + expect(e).toEqual({ + endpoint: 'https://test.openai.azure.com/', + errorType: 'AgentRuntimeError', + provider: 'azure', + error: { + name: genericError.name, + cause: genericError.cause, + message: genericError.message, + }, + }); + } + }); + }); + + describe('DEBUG', () => { + it('should call debugStream when DEBUG_CHAT_COMPLETION is 1', async () => { + // Arrange + const mockProdStream = new ReadableStream() as any; + const mockDebugStream = new ReadableStream({ + start(controller) { + controller.enqueue('Debug stream content'); + controller.close(); + }, + }) as any; + mockDebugStream.toReadableStream = () => mockDebugStream; + + (instance['client'].streamChatCompletions as Mock).mockResolvedValue({ + tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }], + }); + + process.env.DEBUG_AZURE_CHAT_COMPLETION = '1'; + vi.spyOn(debugStreamModule, 'debugStream').mockImplementation(() => Promise.resolve()); + + // Act + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'text-davinci-003', + temperature: 0, + }); + + // Assert + expect(debugStreamModule.debugStream).toHaveBeenCalled(); + + // Restore + delete process.env.DEBUG_AZURE_CHAT_COMPLETION; + }); + }); + }); +}); diff --git a/src/libs/agent-runtime/azureOpenai/index.ts b/src/libs/agent-runtime/azureOpenai/index.ts index 3ace27daae6d..3ffba0d1970e 100644 --- a/src/libs/agent-runtime/azureOpenai/index.ts +++ b/src/libs/agent-runtime/azureOpenai/index.ts @@ -11,16 +11,15 @@ import { AgentRuntimeErrorType } from '../error'; import { ChatStreamPayload, ModelProvider } from '../types'; import { AgentRuntimeError } from '../utils/createError'; import { debugStream } from '../utils/debugStream'; -import { DEBUG_CHAT_COMPLETION } from '../utils/env'; export class LobeAzureOpenAI implements LobeRuntimeAI { - private _client: OpenAIClient; + client: OpenAIClient; constructor(endpoint?: string, apikey?: string, apiVersion?: string) { if (!apikey || !endpoint) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidAzureAPIKey); - this._client = new OpenAIClient(endpoint, new AzureKeyCredential(apikey), { apiVersion }); + this.client = new OpenAIClient(endpoint, new AzureKeyCredential(apikey), { apiVersion }); this.baseURL = endpoint; } @@ -34,7 +33,7 @@ export class LobeAzureOpenAI implements LobeRuntimeAI { // ============ 2. send api ============ // try { - const response = await this._client.streamChatCompletions( + const response = await this.client.streamChatCompletions( model, messages as ChatRequestMessage[], params as GetChatCompletionsOptions, @@ -45,7 +44,7 @@ export class LobeAzureOpenAI implements LobeRuntimeAI { const [debug, prod] = stream.tee(); - if (DEBUG_CHAT_COMPLETION) { + if (process.env.DEBUG_AZURE_CHAT_COMPLETION === '1') { debugStream(debug).catch(console.error); } @@ -53,10 +52,18 @@ export class LobeAzureOpenAI implements LobeRuntimeAI { } catch (e) { let error = e as { [key: string]: any; code: string; message: string }; - switch (error.code) { - case 'DeploymentNotFound': { - error = { ...error, deployId: model }; + if (error.code) { + switch (error.code) { + case 'DeploymentNotFound': { + error = { ...error, deployId: model }; + } } + } else { + error = { + cause: error.cause, + message: error.message, + name: error.name, + } as any; } const errorType = error.code @@ -64,6 +71,7 @@ export class LobeAzureOpenAI implements LobeRuntimeAI { : AgentRuntimeErrorType.AgentRuntimeError; throw AgentRuntimeError.chat({ + endpoint: this.baseURL, error, errorType, provider: ModelProvider.Azure, diff --git a/src/libs/agent-runtime/utils/env.ts b/src/libs/agent-runtime/utils/env.ts deleted file mode 100644 index 38c7b59a85e1..000000000000 --- a/src/libs/agent-runtime/utils/env.ts +++ /dev/null @@ -1 +0,0 @@ -export const DEBUG_CHAT_COMPLETION = process.env.DEBUG_CHAT_COMPLETION === '1'; diff --git a/src/store/chat/slices/share/action.test.ts b/src/store/chat/slices/share/action.test.ts index 9e6ce50fc7b2..32573782ed1e 100644 --- a/src/store/chat/slices/share/action.test.ts +++ b/src/store/chat/slices/share/action.test.ts @@ -3,6 +3,7 @@ import { act, renderHook } from '@testing-library/react'; import { DEFAULT_USER_AVATAR_URL } from '@/const/meta'; import { shareGPTService } from '@/services/share'; import { useChatStore } from '@/store/chat'; +import { ChatMessage } from '@/types/message'; describe('shareSlice actions', () => { let shareGPTServiceSpy: any; @@ -82,5 +83,117 @@ describe('shareSlice actions', () => { expect(result.current.shareLoading).toBe(false); // 注意:这里的验证可能需要你根据实际的状态管理逻辑进行调整 }); + + it('should include plugin information when withPluginInfo is true', async () => { + // 模拟带有插件信息的消息 + const pluginMessage = { + role: 'function', + content: 'plugin content', + plugin: { + type: 'default', + arguments: '{}', + apiName: 'test-api', + identifier: 'test-identifier', + }, + id: 'abc', + } as ChatMessage; + + act(() => { + useChatStore.setState({ messages: [pluginMessage] }); + }); + + const { result } = renderHook(() => useChatStore()); + await act(async () => { + result.current.shareToShareGPT({ withPluginInfo: true }); + }); + expect(shareGPTServiceSpy).toHaveBeenCalledWith( + expect.objectContaining({ + items: expect.arrayContaining([ + expect.objectContaining({ + from: 'gpt', + value: expect.stringContaining('Function Calling Plugin'), + }), + ]), + }), + ); + }); + + it('should not include plugin information when withPluginInfo is false', async () => { + const pluginMessage = { + role: 'function', + content: 'plugin content', + plugin: { + type: 'default', + arguments: '{}', + apiName: 'test-api', + identifier: 'test-identifier', + }, + id: 'abc', + } as ChatMessage; + + act(() => { + useChatStore.setState({ messages: [pluginMessage] }); + }); + + const { result } = renderHook(() => useChatStore()); + await act(async () => { + result.current.shareToShareGPT({ withPluginInfo: false }); + }); + expect(shareGPTServiceSpy).toHaveBeenCalledWith( + expect.objectContaining({ + items: expect.not.arrayContaining([ + expect.objectContaining({ + from: 'gpt', + value: expect.stringContaining('Function Calling Plugin'), + }), + ]), + }), + ); + }); + + it('should handle messages from different roles correctly', async () => { + const messages = [ + { role: 'user', content: 'user message', id: '1' }, + { role: 'assistant', content: 'assistant message', id: '2' }, + { + role: 'function', + content: 'plugin content', + plugin: { + type: 'default', + arguments: '{}', + apiName: 'test-api', + identifier: 'test-identifier', + }, + id: '3', + }, + ] as ChatMessage[]; + + act(() => { + useChatStore.setState({ messages }); + }); + + const { result } = renderHook(() => useChatStore()); + await act(async () => { + await result.current.shareToShareGPT({ + withPluginInfo: true, + withSystemRole: true, + }); + }); + + expect(shareGPTServiceSpy).toHaveBeenCalledWith( + expect.objectContaining({ + items: [ + expect.objectContaining({ from: 'gpt' }), // Agent meta info + expect.objectContaining({ from: 'human', value: 'user message' }), + expect.objectContaining({ from: 'gpt', value: 'assistant message' }), + expect.objectContaining({ + from: 'gpt', + value: expect.stringContaining('Function Calling Plugin'), + }), + expect.objectContaining({ from: 'gpt', value: expect.stringContaining('Share from') }), // Footer + ], + }), + ); + }); }); }); diff --git a/src/store/chat/slices/share/action.ts b/src/store/chat/slices/share/action.ts index 8e6c10d69bfb..b680556a6e52 100644 --- a/src/store/chat/slices/share/action.ts +++ b/src/store/chat/slices/share/action.ts @@ -47,7 +47,7 @@ export interface ShareAction { avatar?: string; withPluginInfo?: boolean; withSystemRole?: boolean; - }) => void; + }) => Promise; } export const chatShare: StateCreator = ( diff --git a/src/store/global/slices/common/action.test.ts b/src/store/global/slices/common/action.test.ts index 361f4503176f..6d912fc1dc75 100644 --- a/src/store/global/slices/common/action.test.ts +++ b/src/store/global/slices/common/action.test.ts @@ -4,8 +4,12 @@ import { afterEach, describe, expect, it, vi } from 'vitest'; import { withSWR } from '~test-utils'; import { globalService } from '@/services/global'; +import { messageService } from '@/services/message'; import { userService } from '@/services/user'; import { useGlobalStore } from '@/store/global'; +import { commonSelectors } from '@/store/global/slices/common/selectors'; +import { preferenceSelectors } from '@/store/global/slices/preference/selectors'; +import { syncSettingsSelectors } from '@/store/global/slices/settings/selectors'; import { GlobalServerConfig } from '@/types/serverConfig'; import { switchLang } from '@/utils/client/switchLang'; @@ -200,4 +204,165 @@ describe('createCommonSlice', () => { expect(useGlobalStore.getState().settings).toEqual({}); }); }); + + describe('refreshConnection', () => { + it('should not call triggerEnableSync when userId is empty', async () => { + const { result } = renderHook(() => useGlobalStore()); + const onEvent = vi.fn(); + + vi.spyOn(commonSelectors, 'userId').mockReturnValueOnce(undefined); + const triggerEnableSyncSpy = vi.spyOn(result.current, 'triggerEnableSync'); + + await act(async () => { + await result.current.refreshConnection(onEvent); + }); + + expect(triggerEnableSyncSpy).not.toHaveBeenCalled(); + }); + + it('should call triggerEnableSync when userId exists', async () => { + const { result } = renderHook(() => useGlobalStore()); + const onEvent = vi.fn(); + const userId = 'user-id'; + + vi.spyOn(commonSelectors, 'userId').mockReturnValueOnce(userId); + const triggerEnableSyncSpy = vi.spyOn(result.current, 'triggerEnableSync'); + + await act(async () => { + await result.current.refreshConnection(onEvent); + }); + + expect(triggerEnableSyncSpy).toHaveBeenCalledWith(userId, onEvent); + }); + }); + + describe('triggerEnableSync', () => { + it('should return false when sync.channelName is empty', async () => { + const { result } = renderHook(() => useGlobalStore()); + const userId = 'user-id'; + const onEvent = vi.fn(); + + vi.spyOn(syncSettingsSelectors, 'webrtcConfig').mockReturnValueOnce({ + channelName: '', + enabled: true, + }); + + const data = await act(async () => { + return result.current.triggerEnableSync(userId, onEvent); + }); + + expect(data).toBe(false); + }); + + it('should call globalService.enabledSync when sync.channelName exists', async () => { + const userId = 'user-id'; + const onEvent = vi.fn(); + const channelName = 'channel-name'; + const channelPassword = 'channel-password'; + const deviceName = 'device-name'; + const signaling = 'signaling'; + + vi.spyOn(syncSettingsSelectors, 'webrtcConfig').mockReturnValueOnce({ + channelName, + channelPassword, + signaling, + enabled: true, + }); + vi.spyOn(syncSettingsSelectors, 'deviceName').mockReturnValueOnce(deviceName); + const enabledSyncSpy = vi.spyOn(globalService, 'enabledSync').mockResolvedValueOnce(true); + const { result } = renderHook(() => useGlobalStore()); + + const data = await act(async () => { + return result.current.triggerEnableSync(userId, onEvent); + }); + + expect(enabledSyncSpy).toHaveBeenCalledWith({ + channel: { name: channelName, password: channelPassword }, + onAwarenessChange: expect.any(Function), + onSyncEvent: onEvent, + onSyncStatusChange: expect.any(Function), + signaling, + user: expect.objectContaining({ id: userId, name: deviceName }), + }); + expect(data).toBe(true); + }); + }); + + describe('useCheckTrace', () => { + it('should return false when shouldFetch is false', async () => { + const { result } = renderHook(() => useGlobalStore().useCheckTrace(false), { + wrapper: withSWR, + }); + + await waitFor(() => expect(result.current.data).toBe(false)); + }); + + it('should return false when userAllowTrace is already set', async () => { + vi.spyOn(preferenceSelectors, 'userAllowTrace').mockReturnValueOnce(true); + + const { result } = renderHook(() => useGlobalStore().useCheckTrace(true), { + wrapper: withSWR, + }); + + await waitFor(() => expect(result.current.data).toBe(false)); + }); + + it('should call messageService.messageCountToCheckTrace when needed', async () => { + vi.spyOn(preferenceSelectors, 'userAllowTrace').mockReturnValueOnce(null); + const messageCountToCheckTraceSpy = vi + .spyOn(messageService, 'messageCountToCheckTrace') + .mockResolvedValueOnce(true); + + const { result } = renderHook(() => useGlobalStore().useCheckTrace(true), { + wrapper: withSWR, + }); + + await waitFor(() => expect(result.current.data).toBe(true)); + expect(messageCountToCheckTraceSpy).toHaveBeenCalled(); + }); + }); + + describe('useEnabledSync', () => { + it('should return false when userId is empty', async () => { + const { result } = renderHook( + () => useGlobalStore().useEnabledSync(true, undefined, vi.fn()), + { wrapper: withSWR }, + ); + + await waitFor(() => expect(result.current.data).toBe(false)); + }); + + it('should call globalService.disableSync when userEnableSync is false', async () => { + const disableSyncSpy = vi.spyOn(globalService, 'disableSync').mockResolvedValueOnce(false); + + const { result } = renderHook( + () => useGlobalStore().useEnabledSync(false, 'user-id', vi.fn()), + { wrapper: withSWR }, + ); + + await waitFor(() => expect(result.current.data).toBeUndefined()); + expect(disableSyncSpy).toHaveBeenCalled(); + }); + + it('should call triggerEnableSync when userEnableSync and userId exist', async () => { + const userId = 'user-id'; + const onEvent = vi.fn(); + const triggerEnableSyncSpy = vi.fn().mockResolvedValueOnce(true); + + const { result } = renderHook(() => useGlobalStore()); + + // replace triggerEnableSync as a mock + result.current.triggerEnableSync = triggerEnableSyncSpy; + + const { result: swrResult } = renderHook( + () => result.current.useEnabledSync(true, userId, onEvent), + { + wrapper: withSWR, + }, + ); + + await waitFor(() => expect(swrResult.current.data).toBe(true)); + expect(triggerEnableSyncSpy).toHaveBeenCalledWith(userId, onEvent); + }); + }); }); diff --git a/src/store/global/slices/settings/actions/llm.test.ts b/src/store/global/slices/settings/actions/llm.test.ts index 3196fe3d0d07..0afa9575895a 100644 --- a/src/store/global/slices/settings/actions/llm.test.ts +++ b/src/store/global/slices/settings/actions/llm.test.ts @@ -3,8 +3,11 @@ import { describe, expect, it, vi } from 'vitest'; import { userService } from '@/services/user'; import { useGlobalStore } from '@/store/global'; +import { modelConfigSelectors } from '@/store/global/slices/settings/selectors'; import { GeneralModelProviderConfig } from '@/types/settings'; +import { CustomModelCardDispatch, customModelCardsReducer } from '../reducers/customModelCard'; + // Mock userService vi.mock('@/services/user', () => ({ userService: { @@ -12,8 +15,11 @@ vi.mock('@/services/user', () => ({ resetUserSettings: vi.fn(), }, })); +vi.mock('../reducers/customModelCard', () => ({ + customModelCardsReducer: vi.fn().mockReturnValue([]), +})); -describe('SettingsAction', () => { +describe('LLMSettingsSliceAction', () => { describe('setModelProviderConfig', () => { it('should set OpenAI configuration', async () => { const { result } = renderHook(() => useGlobalStore()); @@ -32,4 +38,23 @@ describe('SettingsAction', () => { }); }); }); + + describe('dispatchCustomModelCards', () => { + it('should return early when prevState does not exist', async () => { + const { result } = renderHook(() => useGlobalStore()); + const provider = 'openai'; + const payload: CustomModelCardDispatch = { type: 'add', modelCard: { id: 'test-id' } }; + + // Mock the selector to return undefined + vi.spyOn(modelConfigSelectors, 'providerConfig').mockReturnValue(() => undefined); + vi.spyOn(result.current, 'setModelProviderConfig'); + + await act(async () => { + await result.current.dispatchCustomModelCards(provider, payload); + }); + + // Assert that setModelProviderConfig was not called + expect(result.current.setModelProviderConfig).not.toHaveBeenCalled(); + }); + }); });