diff --git a/src/libs/agent-runtime/cohere/index.ts b/src/libs/agent-runtime/cohere/index.ts index 6488801894a57..626f428606dcd 100644 --- a/src/libs/agent-runtime/cohere/index.ts +++ b/src/libs/agent-runtime/cohere/index.ts @@ -1,19 +1,15 @@ -// sort-imports-ignore -// TODO: FOR COHERE See if this shims is needed for Cohere -// import '@anthropic-ai/sdk/shims/web'; import { CohereClient, CohereError, CohereTimeoutError } from "cohere-ai"; import { ClientOptions } from 'openai'; import { LobeRuntimeAI } from '../BaseAI'; import { AgentRuntimeErrorType } from '../error'; -// TODO: FOR COHERE Add cohere to types import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider, UserMessageContentPart } from '../types'; import { AgentRuntimeError } from '../utils/createError'; import { debugStream } from '../utils/debugStream'; import { desensitizeUrl } from '../utils/desensitizeUrl'; // TODO: FOR COHERE create cohereHelpers -import { buildAnthropicMessages, buildAnthropicTools } from '../utils/anthropicHelpers'; +import { buildCohereMessages, buildCohereTools } from '../utils/cohereHelpers'; import { StreamingResponse } from '../utils/response'; // TODO: FOR COHERE create stream util for cohere @@ -127,10 +123,10 @@ export class LobeCohereAI implements LobeRuntimeAI { max_tokens, model, temperature, - tools: buildAnthropicTools(tools), + tools: buildCohereTools(tools), p, message: typeof message === 'string' ? message : message.join(' '), - chatHistory: chatHistory.map((m) => ({ role: m.role, message: m.content })), + chat_history: chatHistory.map((m) => ({ role: m.role.toUpperCase(), message: m.content })), }; } diff --git a/src/libs/agent-runtime/utils/cohereHelpers.ts b/src/libs/agent-runtime/utils/cohereHelpers.ts new file mode 100644 index 0000000000000..67a4c5c7d48a4 --- /dev/null +++ b/src/libs/agent-runtime/utils/cohereHelpers.ts @@ -0,0 +1,60 @@ +import { parseDataUri } from './uriParser'; +import { OpenAIChatMessage, UserMessageContentPart } from '../types'; +import Cohere from 'cohere-ai'; + +export const buildCohereBlock = ( + content: UserMessageContentPart, +): { type: string; content?: string; data?: string; mime_type?: string } => { + switch (content.type) { + case 'text': { + return { type: 'text', content: content.text }; + } + + case 'image_url': { + const { mimeType, base64 } = parseDataUri(content.image_url.url); + return { + type: 'image', + data: base64, + mime_type: mimeType, + }; + } + + default: { + throw new Error(`Unsupported content type: ${content.type}`); + } + } +}; + +export const buildCohereMessage = ( + message: OpenAIChatMessage, +): { role: string; content: string | object } => { + const content = message.content as string | UserMessageContentPart[]; + + switch (message.role) { + case 'system': + case 'user': + case 'assistant': { + return { + role: message.role, + content: typeof content === 'string' ? content : content.map(buildCohereBlock), + }; + } + + default: { + throw new Error(`Unsupported message role: ${message.role}`); + } + } +}; + +export const buildCohereMessages = ( + oaiMessages: OpenAIChatMessage[], +): { role: string; content: string | object }[] => { + return oaiMessages.map(buildCohereMessage); +}; + +export const buildCohereTools = (tools?: OpenAI.ChatCompletionTool[]) => + tools?.map(tool => ({ + name: tool.function.name, + description: tool.function.description, + input_schema: tool.function.parameters, + }));