From fc85c20b1ccdbeb9ddad7e10b1caf8c436c616ef Mon Sep 17 00:00:00 2001 From: Zhijie He Date: Wed, 11 Sep 2024 00:14:05 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20Add=20Spark=20model=20provi?= =?UTF-8?q?der=20(#3098)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ✨ feat: Add Spark model provider * 🔨 chore: split Spark API Key & Spark Secret * 💄 style: update Spark icon * 💄 style: update Spark icon size * 💄 style: update Spark icon in ProviderAvatar * 🔨 chore: update Spark models * 🔨 chore: update Spark models * 💄 style: fixed Spark 4.0 Ultra model icon display * 🔨 chore: update Spark models info * 🔨 chore: update Spark models tokens info * 🔨 chore: update Spark models info * 🐛 fix: fixed "'$.header.uid' length must be less or equal than 32" with Spark Lite * 💄 style: fix model tag icon missing * 🐛 fix: fix typo in ModelIcon * 🔨 chore: add unit test for noUserId * 🔨 chore: disable stream mode * Revert "🔨 chore: disable stream mode" (#25) This reverts commit 302e01d181a8f949053e04df85d23f69d2149039. * 💄 style: add Spark Pro-128K new model * ✨ feat: Add Spark ENV * 🐛 fix: fixed Pro-128k model id, wrong id from official document ![image](https://github.com/user-attachments/assets/7fc3fc73-b460-448c-ad78-4a56d3cae34e) * 💄style: improve APIKeyForm for Spark * 💄 style: improve custom Spark API missing form * 🔨 chore: cleanup code * 🐛 fix: fix CI issue after merge * 👷 build: add ENV * ♻️ refactor: support latest Spark HTTP SDK * ♻️ refactor: cleanup * 🔨 chore: fix rebase conflicts --------- Co-authored-by: Arvin Xu --- Dockerfile | 2 + Dockerfile.database | 2 + .../settings/llm/ProviderList/providers.tsx | 2 + src/app/api/chat/agentRuntime.ts | 7 + src/config/llm.ts | 6 + src/config/modelProviders/index.ts | 4 + src/config/modelProviders/spark.ts | 59 ++++ src/const/settings/llm.ts | 5 + src/libs/agent-runtime/AgentRuntime.ts | 7 + src/libs/agent-runtime/spark/index.test.ts | 255 ++++++++++++++++++ src/libs/agent-runtime/spark/index.ts | 13 + src/libs/agent-runtime/types/type.ts | 1 + src/server/globalConfig/index.ts | 5 +- src/types/user/settings/keyVaults.ts | 1 + 14 files changed, 368 insertions(+), 1 deletion(-) create mode 100644 src/config/modelProviders/spark.ts create mode 100644 src/libs/agent-runtime/spark/index.test.ts create mode 100644 src/libs/agent-runtime/spark/index.ts diff --git a/Dockerfile b/Dockerfile index 8142403b6e1b..055dbccacf05 100644 --- a/Dockerfile +++ b/Dockerfile @@ -141,6 +141,8 @@ ENV \ QWEN_API_KEY="" QWEN_MODEL_LIST="" \ # SiliconCloud SILICONCLOUD_API_KEY="" SILICONCLOUD_MODEL_LIST="" SILICONCLOUD_PROXY_URL="" \ + # Spark + SPARK_API_KEY="" \ # Stepfun STEPFUN_API_KEY="" \ # Taichu diff --git a/Dockerfile.database b/Dockerfile.database index 33ee0ea08f3e..a86e6aa0826a 100644 --- a/Dockerfile.database +++ b/Dockerfile.database @@ -173,6 +173,8 @@ ENV \ QWEN_API_KEY="" QWEN_MODEL_LIST="" \ # SiliconCloud SILICONCLOUD_API_KEY="" SILICONCLOUD_MODEL_LIST="" SILICONCLOUD_PROXY_URL="" \ + # Spark + SPARK_API_KEY="" \ # Stepfun STEPFUN_API_KEY="" \ # Taichu diff --git a/src/app/(main)/settings/llm/ProviderList/providers.tsx b/src/app/(main)/settings/llm/ProviderList/providers.tsx index 645c7d336796..bd1f1b93d5ac 100644 --- a/src/app/(main)/settings/llm/ProviderList/providers.tsx +++ b/src/app/(main)/settings/llm/ProviderList/providers.tsx @@ -15,6 +15,7 @@ import { PerplexityProviderCard, QwenProviderCard, SiliconCloudProviderCard, + SparkProviderCard, StepfunProviderCard, TaichuProviderCard, TogetherAIProviderCard, @@ -61,6 +62,7 @@ export const useProviderList = (): ProviderItem[] => { Ai360ProviderCard, SiliconCloudProviderCard, UpstageProviderCard, + SparkProviderCard, ], [AzureProvider, OllamaProvider, OpenAIProvider, BedrockProvider], ); diff --git a/src/app/api/chat/agentRuntime.ts b/src/app/api/chat/agentRuntime.ts index 05571c96ef0e..2bcc089f6898 100644 --- a/src/app/api/chat/agentRuntime.ts +++ b/src/app/api/chat/agentRuntime.ts @@ -213,6 +213,13 @@ const getLlmOptionsFromPayload = (provider: string, payload: JWTPayload) => { const apiKey = apiKeyManager.pick(payload?.apiKey || UPSTAGE_API_KEY); + return { apiKey }; + } + case ModelProvider.Spark: { + const { SPARK_API_KEY } = getLLMConfig(); + + const apiKey = apiKeyManager.pick(payload?.apiKey || SPARK_API_KEY); + return { apiKey }; } } diff --git a/src/config/llm.ts b/src/config/llm.ts index 43f54ef2827a..20ae1da6bb20 100644 --- a/src/config/llm.ts +++ b/src/config/llm.ts @@ -101,6 +101,9 @@ export const getLLMConfig = () => { ENABLED_UPSTAGE: z.boolean(), UPSTAGE_API_KEY: z.string().optional(), + + ENABLED_SPARK: z.boolean(), + SPARK_API_KEY: z.string().optional(), }, runtimeEnv: { API_KEY_SELECT_MODE: process.env.API_KEY_SELECT_MODE, @@ -199,6 +202,9 @@ export const getLLMConfig = () => { ENABLED_UPSTAGE: !!process.env.UPSTAGE_API_KEY, UPSTAGE_API_KEY: process.env.UPSTAGE_API_KEY, + + ENABLED_SPARK: !!process.env.SPARK_API_KEY, + SPARK_API_KEY: process.env.SPARK_API_KEY, }, }); }; diff --git a/src/config/modelProviders/index.ts b/src/config/modelProviders/index.ts index f2f382f17b9b..3840cc6cee68 100644 --- a/src/config/modelProviders/index.ts +++ b/src/config/modelProviders/index.ts @@ -18,6 +18,7 @@ import OpenRouterProvider from './openrouter'; import PerplexityProvider from './perplexity'; import QwenProvider from './qwen'; import SiliconCloudProvider from './siliconcloud'; +import SparkProvider from './spark'; import StepfunProvider from './stepfun'; import TaichuProvider from './taichu'; import TogetherAIProvider from './togetherai'; @@ -49,6 +50,7 @@ export const LOBE_DEFAULT_MODEL_LIST: ChatModelCard[] = [ Ai360Provider.chatModels, SiliconCloudProvider.chatModels, UpstageProvider.chatModels, + SparkProvider.chatModels, ].flat(); export const DEFAULT_MODEL_PROVIDER_LIST = [ @@ -76,6 +78,7 @@ export const DEFAULT_MODEL_PROVIDER_LIST = [ Ai360Provider, SiliconCloudProvider, UpstageProvider, + SparkProvider, ]; export const filterEnabledModels = (provider: ModelProviderCard) => { @@ -105,6 +108,7 @@ export { default as OpenRouterProviderCard } from './openrouter'; export { default as PerplexityProviderCard } from './perplexity'; export { default as QwenProviderCard } from './qwen'; export { default as SiliconCloudProviderCard } from './siliconcloud'; +export { default as SparkProviderCard } from './spark'; export { default as StepfunProviderCard } from './stepfun'; export { default as TaichuProviderCard } from './taichu'; export { default as TogetherAIProviderCard } from './togetherai'; diff --git a/src/config/modelProviders/spark.ts b/src/config/modelProviders/spark.ts new file mode 100644 index 000000000000..848219521363 --- /dev/null +++ b/src/config/modelProviders/spark.ts @@ -0,0 +1,59 @@ +import { ModelProviderCard } from '@/types/llm'; + +// ref https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E +// ref https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E +const Spark: ModelProviderCard = { + chatModels: [ + { + description: '轻量级大语言模型,低延迟,全免费 支持在线联网搜索功能 响应快速、便捷,全面免费开放 适用于低算力推理与模型精调等定制化场景', + displayName: 'Spark Lite', + enabled: true, + functionCall: false, + id: 'general', + maxOutput: 4096, + tokens: 8192, + }, + { + description: '专业级大语言模型,兼顾模型效果与性能 数学、代码、医疗、教育等场景专项优化 支持联网搜索、天气、日期等多个内置插件 覆盖大部分知识问答、语言理解、文本创作等多个场景', + displayName: 'Spark Pro', + enabled: true, + functionCall: false, + id: 'generalv3', + maxOutput: 8192, + tokens: 8192, + }, + { + description: '支持最长上下文的星火大模型,长文无忧 128K星火大模型强势来袭 通读全文,旁征博引 沟通无界,逻辑连贯', + displayName: 'Spark Pro-128K', + enabled: true, + functionCall: false, + id: 'Pro-128k', + maxOutput: 4096, + tokens: 128_000, + }, + { + description: '最全面的星火大模型版本,功能丰富 支持联网搜索、天气、日期等多个内置插件 核心能力全面升级,各场景应用效果普遍提升 支持System角色人设与FunctionCall函数调用', + displayName: 'Spark3.5 Max', + enabled: true, + functionCall: false, + id: 'generalv3.5', + maxOutput: 8192, + tokens: 8192, + }, + { + description: '最强大的星火大模型版本,效果极佳 全方位提升效果,引领智能巅峰 优化联网搜索链路,提供精准回答 强化文本总结能力,提升办公生产力', + displayName: 'Spark4.0 Ultra', + enabled: true, + functionCall: false, + id: '4.0Ultra', + maxOutput: 8192, + tokens: 8192, + }, + ], + checkModel: 'generalv3', + id: 'spark', + modelList: { showModelFetcher: true }, + name: 'Spark', +}; + +export default Spark; diff --git a/src/const/settings/llm.ts b/src/const/settings/llm.ts index 6056265a074f..fc875f669f86 100644 --- a/src/const/settings/llm.ts +++ b/src/const/settings/llm.ts @@ -16,6 +16,7 @@ import { PerplexityProviderCard, QwenProviderCard, SiliconCloudProviderCard, + SparkProviderCard, StepfunProviderCard, TaichuProviderCard, TogetherAIProviderCard, @@ -100,6 +101,10 @@ export const DEFAULT_LLM_CONFIG: UserModelProviderConfig = { enabled: false, enabledModels: filterEnabledModels(SiliconCloudProviderCard), }, + spark: { + enabled: false, + enabledModels: filterEnabledModels(SparkProviderCard), + }, stepfun: { enabled: false, enabledModels: filterEnabledModels(StepfunProviderCard), diff --git a/src/libs/agent-runtime/AgentRuntime.ts b/src/libs/agent-runtime/AgentRuntime.ts index 0454e3bbde0a..41e99c9ca0e6 100644 --- a/src/libs/agent-runtime/AgentRuntime.ts +++ b/src/libs/agent-runtime/AgentRuntime.ts @@ -21,6 +21,7 @@ import { LobeOpenRouterAI } from './openrouter'; import { LobePerplexityAI } from './perplexity'; import { LobeQwenAI } from './qwen'; import { LobeSiliconCloudAI } from './siliconcloud'; +import { LobeSparkAI } from './spark'; import { LobeStepfunAI } from './stepfun'; import { LobeTaichuAI } from './taichu'; import { LobeTogetherAI } from './togetherai'; @@ -132,6 +133,7 @@ class AgentRuntime { perplexity: Partial; qwen: Partial; siliconcloud: Partial; + spark: Partial; stepfun: Partial; taichu: Partial; togetherai: Partial; @@ -268,6 +270,11 @@ class AgentRuntime { runtimeModel = new LobeUpstageAI(params.upstage); break } + + case ModelProvider.Spark: { + runtimeModel = new LobeSparkAI(params.spark); + break + } } return new AgentRuntime(runtimeModel); diff --git a/src/libs/agent-runtime/spark/index.test.ts b/src/libs/agent-runtime/spark/index.test.ts new file mode 100644 index 000000000000..7b6b1a2b1a06 --- /dev/null +++ b/src/libs/agent-runtime/spark/index.test.ts @@ -0,0 +1,255 @@ +// @vitest-environment node +import OpenAI from 'openai'; +import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +import { + ChatStreamCallbacks, + LobeOpenAICompatibleRuntime, + ModelProvider, +} from '@/libs/agent-runtime'; + +import * as debugStreamModule from '../utils/debugStream'; +import { LobeSparkAI } from './index'; + +const provider = ModelProvider.Spark; +const defaultBaseURL = 'https://spark-api-open.xf-yun.com/v1'; + +const bizErrorType = 'ProviderBizError'; +const invalidErrorType = 'InvalidProviderAPIKey'; + +// Mock the console.error to avoid polluting test output +vi.spyOn(console, 'error').mockImplementation(() => {}); + +let instance: LobeOpenAICompatibleRuntime; + +beforeEach(() => { + instance = new LobeSparkAI({ apiKey: 'test' }); + + // 使用 vi.spyOn 来模拟 chat.completions.create 方法 + vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue( + new ReadableStream() as any, + ); +}); + +afterEach(() => { + vi.clearAllMocks(); +}); + +describe('LobeSparkAI', () => { + describe('init', () => { + it('should correctly initialize with an API key', async () => { + const instance = new LobeSparkAI({ apiKey: 'test_api_key' }); + expect(instance).toBeInstanceOf(LobeSparkAI); + expect(instance.baseURL).toEqual(defaultBaseURL); + }); + }); + + describe('chat', () => { + describe('Error', () => { + it('should return OpenAIBizError with an openai error response when OpenAI.APIError is thrown', async () => { + // Arrange + const apiError = new OpenAI.APIError( + 400, + { + status: 400, + error: { + message: 'Bad Request', + }, + }, + 'Error message', + {}, + ); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'general', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: { + error: { message: 'Bad Request' }, + status: 400, + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should throw AgentRuntimeError with NoOpenAIAPIKey if no apiKey is provided', async () => { + try { + new LobeSparkAI({}); + } catch (e) { + expect(e).toEqual({ errorType: invalidErrorType }); + } + }); + + it('should return OpenAIBizError with the cause when OpenAI.APIError is thrown with cause', async () => { + // Arrange + const errorInfo = { + stack: 'abc', + cause: { + message: 'api is undefined', + }, + }; + const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'general', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: { + cause: { message: 'api is undefined' }, + stack: 'abc', + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should return OpenAIBizError with an cause response with desensitize Url', async () => { + // Arrange + const errorInfo = { + stack: 'abc', + cause: { message: 'api is undefined' }, + }; + const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); + + instance = new LobeSparkAI({ + apiKey: 'test', + + baseURL: 'https://api.abc.com/v1', + }); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'general', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: 'https://api.***.com/v1', + error: { + cause: { message: 'api is undefined' }, + stack: 'abc', + }, + errorType: bizErrorType, + provider, + }); + } + }); + + it('should throw an InvalidSparkAPIKey error type on 401 status code', async () => { + // Mock the API call to simulate a 401 error + const error = new Error('Unauthorized') as any; + error.status = 401; + vi.mocked(instance['client'].chat.completions.create).mockRejectedValue(error); + + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'general', + temperature: 0, + }); + } catch (e) { + // Expect the chat method to throw an error with InvalidSparkAPIKey + expect(e).toEqual({ + endpoint: defaultBaseURL, + error: new Error('Unauthorized'), + errorType: invalidErrorType, + provider, + }); + } + }); + + it('should return AgentRuntimeError for non-OpenAI errors', async () => { + // Arrange + const genericError = new Error('Generic Error'); + + vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(genericError); + + // Act + try { + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'general', + temperature: 0, + }); + } catch (e) { + expect(e).toEqual({ + endpoint: defaultBaseURL, + errorType: 'AgentRuntimeError', + provider, + error: { + name: genericError.name, + cause: genericError.cause, + message: genericError.message, + stack: genericError.stack, + }, + }); + } + }); + }); + + describe('DEBUG', () => { + it('should call debugStream and return StreamingTextResponse when DEBUG_SPARK_CHAT_COMPLETION is 1', async () => { + // Arrange + const mockProdStream = new ReadableStream() as any; // 模拟的 prod 流 + const mockDebugStream = new ReadableStream({ + start(controller) { + controller.enqueue('Debug stream content'); + controller.close(); + }, + }) as any; + mockDebugStream.toReadableStream = () => mockDebugStream; // 添加 toReadableStream 方法 + + // 模拟 chat.completions.create 返回值,包括模拟的 tee 方法 + (instance['client'].chat.completions.create as Mock).mockResolvedValue({ + tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }], + }); + + // 保存原始环境变量值 + const originalDebugValue = process.env.DEBUG_SPARK_CHAT_COMPLETION; + + // 模拟环境变量 + process.env.DEBUG_SPARK_CHAT_COMPLETION = '1'; + vi.spyOn(debugStreamModule, 'debugStream').mockImplementation(() => Promise.resolve()); + + // 执行测试 + // 运行你的测试函数,确保它会在条件满足时调用 debugStream + // 假设的测试函数调用,你可能需要根据实际情况调整 + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: 'general', + stream: true, + temperature: 0, + }); + + // 验证 debugStream 被调用 + expect(debugStreamModule.debugStream).toHaveBeenCalled(); + + // 恢复原始环境变量值 + process.env.DEBUG_SPARK_CHAT_COMPLETION = originalDebugValue; + }); + }); + }); +}); diff --git a/src/libs/agent-runtime/spark/index.ts b/src/libs/agent-runtime/spark/index.ts new file mode 100644 index 000000000000..8cc8dfe1e28e --- /dev/null +++ b/src/libs/agent-runtime/spark/index.ts @@ -0,0 +1,13 @@ +import { ModelProvider } from '../types'; +import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory'; + +export const LobeSparkAI = LobeOpenAICompatibleFactory({ + baseURL: 'https://spark-api-open.xf-yun.com/v1', + chatCompletion: { + noUserId: true, + }, + debug: { + chatCompletion: () => process.env.DEBUG_SPARK_CHAT_COMPLETION === '1', + }, + provider: ModelProvider.Spark, +}); diff --git a/src/libs/agent-runtime/types/type.ts b/src/libs/agent-runtime/types/type.ts index 8c0999f9c120..2a6c6a09f938 100644 --- a/src/libs/agent-runtime/types/type.ts +++ b/src/libs/agent-runtime/types/type.ts @@ -40,6 +40,7 @@ export enum ModelProvider { Perplexity = 'perplexity', Qwen = 'qwen', SiliconCloud = 'siliconcloud', + Spark = 'spark', Stepfun = 'stepfun', Taichu = 'taichu', TogetherAI = 'togetherai', diff --git a/src/server/globalConfig/index.ts b/src/server/globalConfig/index.ts index baa5f9d04982..77f3c70b2b28 100644 --- a/src/server/globalConfig/index.ts +++ b/src/server/globalConfig/index.ts @@ -63,7 +63,9 @@ export const getServerGlobalConfig = () => { SILICONCLOUD_MODEL_LIST, ENABLED_UPSTAGE, - + + ENABLED_SPARK, + ENABLED_AZURE_OPENAI, AZURE_MODEL_LIST, @@ -174,6 +176,7 @@ export const getServerGlobalConfig = () => { modelString: SILICONCLOUD_MODEL_LIST, }), }, + spark: { enabled: ENABLED_SPARK }, stepfun: { enabled: ENABLED_STEPFUN }, taichu: { enabled: ENABLED_TAICHU }, diff --git a/src/types/user/settings/keyVaults.ts b/src/types/user/settings/keyVaults.ts index 523a1e8aa0c5..6b250fcc60fa 100644 --- a/src/types/user/settings/keyVaults.ts +++ b/src/types/user/settings/keyVaults.ts @@ -36,6 +36,7 @@ export interface UserKeyVaults { perplexity?: OpenAICompatibleKeyVault; qwen?: OpenAICompatibleKeyVault; siliconcloud?: OpenAICompatibleKeyVault; + spark?: OpenAICompatibleKeyVault; stepfun?: OpenAICompatibleKeyVault; taichu?: OpenAICompatibleKeyVault; togetherai?: OpenAICompatibleKeyVault;