From 017382c10b313578c6246a77614f0cadcbaa49b2 Mon Sep 17 00:00:00 2001 From: arvinxx Date: Fri, 25 Oct 2024 21:42:22 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=84=20style:=20add=20lm=20studio?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- package.json | 2 +- .../settings/llm/ProviderList/providers.tsx | 4 +- src/config/modelProviders/index.ts | 4 + src/config/modelProviders/lmstudio.ts | 35 +++ src/const/settings/llm.ts | 5 + src/libs/agent-runtime/AgentRuntime.ts | 7 + src/libs/agent-runtime/lmstudio/index.test.ts | 255 ++++++++++++++++++ src/libs/agent-runtime/lmstudio/index.ts | 11 + src/libs/agent-runtime/types/type.ts | 2 + .../utils/openaiCompatibleFactory/index.ts | 8 +- src/types/user/settings/keyVaults.ts | 1 + 11 files changed, 331 insertions(+), 3 deletions(-) create mode 100644 src/config/modelProviders/lmstudio.ts create mode 100644 src/libs/agent-runtime/lmstudio/index.test.ts create mode 100644 src/libs/agent-runtime/lmstudio/index.ts diff --git a/package.json b/package.json index b87e94786c4f..112f9870202e 100644 --- a/package.json +++ b/package.json @@ -123,7 +123,7 @@ "@langchain/community": "^0.3.0", "@lobehub/chat-plugin-sdk": "^1.32.4", "@lobehub/chat-plugins-gateway": "^1.9.0", - "@lobehub/icons": "^1.35.4", + "@lobehub/icons": "^1.37.0", "@lobehub/tts": "^1.25.1", "@lobehub/ui": "^1.152.0", "@neondatabase/serverless": "^0.10.1", diff --git a/src/app/(main)/settings/llm/ProviderList/providers.tsx b/src/app/(main)/settings/llm/ProviderList/providers.tsx index a2e24524f98e..66efe4ce4c31 100644 --- a/src/app/(main)/settings/llm/ProviderList/providers.tsx +++ b/src/app/(main)/settings/llm/ProviderList/providers.tsx @@ -10,6 +10,7 @@ import { GoogleProviderCard, GroqProviderCard, HunyuanProviderCard, + LMStudioProviderCard, MinimaxProviderCard, MistralProviderCard, MoonshotProviderCard, @@ -34,8 +35,8 @@ import { useGithubProvider } from './Github'; import { useHuggingFaceProvider } from './HuggingFace'; import { useOllamaProvider } from './Ollama'; import { useOpenAIProvider } from './OpenAI'; -import { useWenxinProvider } from './Wenxin'; import { useSenseNovaProvider } from './SenseNova'; +import { useWenxinProvider } from './Wenxin'; export const useProviderList = (): ProviderItem[] => { const AzureProvider = useAzureProvider(); @@ -74,6 +75,7 @@ export const useProviderList = (): ProviderItem[] => { ZhiPuProviderCard, ZeroOneProviderCard, SenseNovaProvider, + LMStudioProviderCard, StepfunProviderCard, MoonshotProviderCard, BaichuanProviderCard, diff --git a/src/config/modelProviders/index.ts b/src/config/modelProviders/index.ts index 2237ef877b7c..98ed9b3e5721 100644 --- a/src/config/modelProviders/index.ts +++ b/src/config/modelProviders/index.ts @@ -13,6 +13,7 @@ import GoogleProvider from './google'; import GroqProvider from './groq'; import HuggingFaceProvider from './huggingface'; import HunyuanProvider from './hunyuan'; +import LMStudioProvider from './lmstudio'; import MinimaxProvider from './minimax'; import MistralProvider from './mistral'; import MoonshotProvider from './moonshot'; @@ -65,6 +66,7 @@ export const LOBE_DEFAULT_MODEL_LIST: ChatModelCard[] = [ HunyuanProvider.chatModels, WenxinProvider.chatModels, SenseNovaProvider.chatModels, + LMStudioProvider.chatModels, ].flat(); export const DEFAULT_MODEL_PROVIDER_LIST = [ @@ -100,6 +102,7 @@ export const DEFAULT_MODEL_PROVIDER_LIST = [ Ai360Provider, TaichuProvider, SiliconCloudProvider, + LMStudioProvider, ]; export const filterEnabledModels = (provider: ModelProviderCard) => { @@ -124,6 +127,7 @@ export { default as GoogleProviderCard } from './google'; export { default as GroqProviderCard } from './groq'; export { default as HuggingFaceProviderCard } from './huggingface'; export { default as HunyuanProviderCard } from './hunyuan'; +export { default as LMStudioProviderCard } from './lmstudio'; export { default as MinimaxProviderCard } from './minimax'; export { default as MistralProviderCard } from './mistral'; export { default as MoonshotProviderCard } from './moonshot'; diff --git a/src/config/modelProviders/lmstudio.ts b/src/config/modelProviders/lmstudio.ts new file mode 100644 index 000000000000..cee2d3c26e70 --- /dev/null +++ b/src/config/modelProviders/lmstudio.ts @@ -0,0 +1,35 @@ +import { ModelProviderCard } from '@/types/llm'; + +// ref: https://ollama.com/library +const LMStudio: ModelProviderCard = { + chatModels: [ + { + description: + 'Llama 3.1 是 Meta 推出的领先模型,支持高达 405B 参数,可应用于复杂对话、多语言翻译和数据分析领域。', + displayName: 'Llama 3.1 8B', + enabled: true, + id: 'llama3.1', + tokens: 128_000, + }, + { + description: 'Qwen2.5 是阿里巴巴的新一代大规模语言模型,以优异的性能支持多元化的应用需求。', + displayName: 'Qwen2.5 14B', + enabled: true, + id: 'qwen2.5-14b-instruct', + tokens: 128_000, + }, + ], + defaultShowBrowserRequest: true, + id: 'lmstudio', + modelList: { showModelFetcher: true }, + modelsUrl: 'https://lmstudio.ai/models', + name: 'LM Studio', + showApiKey: false, + smoothing: { + speed: 2, + text: true, + }, + url: 'https://lmstudio.ai', +}; + +export default LMStudio; diff --git a/src/const/settings/llm.ts b/src/const/settings/llm.ts index 9c478db2e95b..c48ccf0a3388 100644 --- a/src/const/settings/llm.ts +++ b/src/const/settings/llm.ts @@ -11,6 +11,7 @@ import { GroqProviderCard, HuggingFaceProviderCard, HunyuanProviderCard, + LMStudioProviderCard, MinimaxProviderCard, MistralProviderCard, MoonshotProviderCard, @@ -87,6 +88,10 @@ export const DEFAULT_LLM_CONFIG: UserModelProviderConfig = { enabled: false, enabledModels: filterEnabledModels(HunyuanProviderCard), }, + lmstudio: { + enabled: false, + enabledModels: filterEnabledModels(LMStudioProviderCard), + }, minimax: { enabled: false, enabledModels: filterEnabledModels(MinimaxProviderCard), diff --git a/src/libs/agent-runtime/AgentRuntime.ts b/src/libs/agent-runtime/AgentRuntime.ts index ed6bc9ae16d3..cad2c06ec569 100644 --- a/src/libs/agent-runtime/AgentRuntime.ts +++ b/src/libs/agent-runtime/AgentRuntime.ts @@ -16,6 +16,7 @@ import { LobeGoogleAI } from './google'; import { LobeGroq } from './groq'; import { LobeHuggingFaceAI } from './huggingface'; import { LobeHunyuanAI } from './hunyuan'; +import { LobeLMStudioAI } from './lmstudio'; import { LobeMinimaxAI } from './minimax'; import { LobeMistralAI } from './mistral'; import { LobeMoonshotAI } from './moonshot'; @@ -138,6 +139,7 @@ class AgentRuntime { groq: Partial; huggingface: { apiKey?: string; baseURL?: string }; hunyuan: Partial; + lmstudio: Partial; minimax: Partial; mistral: Partial; moonshot: Partial; @@ -197,6 +199,11 @@ class AgentRuntime { break; } + case ModelProvider.LMStudio: { + runtimeModel = new LobeLMStudioAI(params.lmstudio); + break; + } + case ModelProvider.Ollama: { runtimeModel = new LobeOllamaAI(params.ollama); break; diff --git a/src/libs/agent-runtime/lmstudio/index.test.ts b/src/libs/agent-runtime/lmstudio/index.test.ts new file mode 100644 index 000000000000..6e5d59d431e2 --- /dev/null +++ b/src/libs/agent-runtime/lmstudio/index.test.ts @@ -0,0 +1,255 @@ +// @vitest-environment node +import OpenAI from 'openai'; +import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +import { + ChatStreamCallbacks, + LobeOpenAICompatibleRuntime, + ModelProvider, +} from '@/libs/agent-runtime'; + +import * as debugStreamModule from '../utils/debugStream'; +import { LobeLMStudioAI } from './index'; + +const provider = ModelProvider.LMStudio; +const defaultBaseURL = 'http://localhost:1234/v1'; + +const bizErrorType = 'ProviderBizError'; +const invalidErrorType = 'InvalidProviderAPIKey'; + +// Mock the console.error to avoid polluting test output +vi.spyOn(console, 'error').mockImplementation(() => {}); + +let instance: LobeOpenAICompatibleRuntime; + +beforeEach(() => { + instance = new LobeLMStudioAI({ apiKey: 'test' }); + + // 使用 vi.spyOn 来模拟 chat.completions.create 方法 + vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue( + new ReadableStream() as any, + ); +}); + +afterEach(() => { + vi.clearAllMocks(); +}); + +describe('LobeLMStudioAI', () => { + describe('init', () => { + it('should correctly initialize with an API key', async () => { + const instance = new LobeLMStudioAI({ apiKey: 'test_api_key' }); + expect(instance).toBeInstanceOf(LobeLMStudioAI); + expect(instance.baseURL).toEqual(defaultBaseURL); + }); + }); + + describe('chat', () => { + describe('Error', () => { + it('should return OpenAIBizError with an openai error response when OpenAI.APIError is thrown', async () => { + // Arrange + const apiError = new OpenAI.APIError( + 400, + { + status: 400, + error: { + message: 'Bad Request', + }, + }, + 'Error message', + {}, + ); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'deepseek-chat', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: { + error: { message: 'Bad Request' }, + status: 400, + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should throw AgentRuntimeError with NoOpenAIAPIKey if no apiKey is provided', async () => { + try { + new LobeLMStudioAI({}); + } catch (e) { + expect(e).toEqual({ errorType: invalidErrorType }); + } + }); + + it('should return OpenAIBizError with the cause when OpenAI.APIError is thrown with cause', async () => { + // Arrange + const errorInfo = { + stack: 'abc', + cause: { + message: 'api is undefined', + }, + }; + const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'deepseek-chat', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: { + cause: { message: 'api is undefined' }, + stack: 'abc', + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should return OpenAIBizError with an cause response with desensitize Url', async () => { + // Arrange + const errorInfo = { + stack: 'abc', + cause: { message: 'api is undefined' }, + }; + const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + + instance = new LobeLMStudioAI({ + apiKey: 'test', + + baseURL: 'https://api.abc.com/v1', + }); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'deepseek-chat', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: 'https://api.***.com/v1', + error: { + cause: { message: 'api is undefined' }, + stack: 'abc', + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should throw an InvalidDeepSeekAPIKey error type on 401 status code', async () => { + // Mock the API call to simulate a 401 error + const error = new Error('Unauthorized') as any; + error.status = 401; + vi.mocked(instance['client'].chat.completions.create).mockRejectedValue(error); + + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'deepseek-chat', + temperature: 0, + }); + } catch (e) { + // Expect the chat method to throw an error with InvalidDeepSeekAPIKey + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: new Error('Unauthorized'), + errorType: invalidErrorType, + provider, + }); + } + }); + + it('should return AgentRuntimeError for non-OpenAI errors', async () => { + // Arrange + const genericError = new Error('Generic Error'); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(genericError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'deepseek-chat', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + errorType: 'AgentRuntimeError', + provider, + error: { + name: genericError.name, + cause: genericError.cause, + message: genericError.message, + stack: genericError.stack, + }, + }); + } + }); + }); + + describe('DEBUG', () => { + it('should call debugStream and return StreamingTextResponse when DEBUG_DEEPSEEK_CHAT_COMPLETION is 1', async () => { + // Arrange + const mockProdStream = new ReadableStream() as any; // 模拟的 prod 流 + const mockDebugStream = new ReadableStream({ + start(controller) { + controller.enqueue('Debug stream content'); + controller.close(); + }, + }) as any; + mockDebugStream.toReadableStream = () => mockDebugStream; // 添加 toReadableStream 方法 + + // 模拟 chat.completions.create 返回值,包括模拟的 tee 方法 + (instance['client'].chat.completions.create as Mock).mockResolvedValue({ + tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }], + }); + + // 保存原始环境变量值 + const originalDebugValue = process.env.DEBUG_DEEPSEEK_CHAT_COMPLETION; + + // 模拟环境变量 + process.env.DEBUG_DEEPSEEK_CHAT_COMPLETION = '1'; + vi.spyOn(debugStreamModule, 'debugStream').mockImplementation(() => Promise.resolve()); + + // 执行测试 + // 运行你的测试函数,确保它会在条件满足时调用 debugStream + // 假设的测试函数调用,你可能需要根据实际情况调整 + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'deepseek-chat', + stream: true, + temperature: 0, + }); + + // 验证 debugStream 被调用 + expect(debugStreamModule.debugStream).toHaveBeenCalled(); + + // 恢复原始环境变量值 + process.env.DEBUG_DEEPSEEK_CHAT_COMPLETION = originalDebugValue; + }); + }); + }); +}); diff --git a/src/libs/agent-runtime/lmstudio/index.ts b/src/libs/agent-runtime/lmstudio/index.ts new file mode 100644 index 000000000000..4927bbf5ea15 --- /dev/null +++ b/src/libs/agent-runtime/lmstudio/index.ts @@ -0,0 +1,11 @@ +import { ModelProvider } from '../types'; +import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory'; + +export const LobeLMStudioAI = LobeOpenAICompatibleFactory({ + apiKey: 'placeholder-to-avoid-error', + baseURL: 'http://localhost:1234/v1', + debug: { + chatCompletion: () => process.env.DEBUG_LMSTUDIO_CHAT_COMPLETION === '1', + }, + provider: ModelProvider.LMStudio, +}); diff --git a/src/libs/agent-runtime/types/type.ts b/src/libs/agent-runtime/types/type.ts index db64c94f23fb..7f1a6b914723 100644 --- a/src/libs/agent-runtime/types/type.ts +++ b/src/libs/agent-runtime/types/type.ts @@ -1,5 +1,6 @@ import OpenAI from 'openai'; + import { ILobeAgentRuntimeErrorType } from '../error'; import { ChatStreamPayload } from './chat'; @@ -35,6 +36,7 @@ export enum ModelProvider { Groq = 'groq', HuggingFace = 'huggingface', Hunyuan = 'hunyuan', + LMStudio = 'lmstudio', Minimax = 'minimax', Mistral = 'mistral', Moonshot = 'moonshot', diff --git a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts index ac28ea2b3e69..74f76d70a8bf 100644 --- a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts +++ b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts @@ -50,6 +50,7 @@ export interface CustomClientOptions = any> { } interface OpenAICompatibleFactoryOptions = any> { + apiKey?: string; baseURL?: string; chatCompletion?: { handleError?: ( @@ -139,6 +140,7 @@ export function transformResponseToStream(data: OpenAI.ChatCompletion) { export const LobeOpenAICompatibleFactory = = any>({ provider, baseURL: DEFAULT_BASE_URL, + apiKey: DEFAULT_API_LEY, errorType, debug, constructorOptions, @@ -158,7 +160,11 @@ export const LobeOpenAICompatibleFactory = = any> private _options: ConstructorOptions; constructor(options: ClientOptions & Record = {}) { - const _options = { ...options, baseURL: options.baseURL?.trim() || DEFAULT_BASE_URL }; + const _options = { + ...options, + apiKey: options.apiKey?.trim() || DEFAULT_API_LEY, + baseURL: options.baseURL?.trim() || DEFAULT_BASE_URL, + }; const { apiKey, baseURL = DEFAULT_BASE_URL, ...res } = _options; this._options = _options as ConstructorOptions; diff --git a/src/types/user/settings/keyVaults.ts b/src/types/user/settings/keyVaults.ts index 8ff980fa055f..b0d8fdad19d6 100644 --- a/src/types/user/settings/keyVaults.ts +++ b/src/types/user/settings/keyVaults.ts @@ -40,6 +40,7 @@ export interface UserKeyVaults { groq?: OpenAICompatibleKeyVault; huggingface?: OpenAICompatibleKeyVault; hunyuan?: OpenAICompatibleKeyVault; + lmstudio?: OpenAICompatibleKeyVault; lobehub?: any; minimax?: OpenAICompatibleKeyVault; mistral?: OpenAICompatibleKeyVault;