diff --git a/docs/plugin-authoring.md b/docs/plugin-authoring.md index b7e7a725e..97de1533b 100644 --- a/docs/plugin-authoring.md +++ b/docs/plugin-authoring.md @@ -128,6 +128,7 @@ export const myPlugin = genkitPlugin('my-plugin', async (options: {apiKey?: stri multiturn: true, // true if your model supports conversations media: true, // true if your model supports multimodal input tools: true, // true if your model supports tool/function calling + systemRole: true, // true if your model supports the system role output: ['text', 'media', 'json'], // types of output your model supports }, // Zod schema for your model's custom configuration diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 2f3b06df6..30d5d3e4d 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -124,6 +124,8 @@ export const ModelInfoSchema = z.object({ media: z.boolean().optional(), /** Model can perform tool calls. */ tools: z.boolean().optional(), + /** Model can accept messages with role "system". */ + systemRole: z.boolean().optional(), /** Model can output this type of data. */ output: z.array(OutputFormatSchema).optional(), }) diff --git a/js/plugins/googleai/src/gemini.ts b/js/plugins/googleai/src/gemini.ts index d18428bbe..bd0f0aefa 100644 --- a/js/plugins/googleai/src/gemini.ts +++ b/js/plugins/googleai/src/gemini.ts @@ -22,6 +22,7 @@ import { MediaPart, MessageData, ModelAction, + ModelMiddleware, modelRef, ModelReference, Part, @@ -72,6 +73,7 @@ export const geminiPro = modelRef({ multiturn: true, media: false, tools: false, + systemRole: true, }, }, configSchema: GeminiConfigSchema, @@ -87,6 +89,7 @@ export const geminiProVision = modelRef({ multiturn: true, media: true, tools: false, + systemRole: false, }, }, configSchema: GeminiConfigSchema, @@ -100,6 +103,7 @@ export const gemini15Pro = modelRef({ multiturn: true, media: true, tools: true, + systemRole: true, }, }, configSchema: GeminiConfigSchema, @@ -114,20 +118,22 @@ export const geminiUltra = modelRef({ multiturn: true, media: false, tools: false, + systemRole: true, }, }, configSchema: GeminiConfigSchema, }); -export const V1_SUPPORTED_MODELS: Record< +export const SUPPORTED_V1_MODELS: Record< string, ModelReference > = { 'gemini-pro': geminiPro, 'gemini-pro-vision': geminiProVision, + // 'gemini-ultra': geminiUltra, }; -export const V1_BETA_SUPPORTED_MODELS: Record< +export const SUPPORTED_V15_MODELS: Record< string, ModelReference > = { @@ -135,18 +141,29 @@ export const V1_BETA_SUPPORTED_MODELS: Record< }; const SUPPORTED_MODELS = { - ...V1_SUPPORTED_MODELS, - ...V1_BETA_SUPPORTED_MODELS, + ...SUPPORTED_V1_MODELS, + ...SUPPORTED_V15_MODELS, }; -function toGeminiRole(role: MessageData['role']): string { +function toGeminiRole( + role: MessageData['role'], + model?: ModelReference +): string { switch (role) { case 'user': return 'user'; case 'model': return 'model'; case 'system': - throw new Error('system role is not supported'); + if (model && SUPPORTED_V15_MODELS[model.name]) { + // We should have already pulled out the supported system messages, + // anything remaining is unsupported; throw an error. + throw new Error( + 'system role is only supported for a single message in the first position' + ); + } else { + throw new Error('system role is not supported'); + } case 'tool': return 'function'; default: @@ -195,9 +212,19 @@ function fromGeminiPart(part: GeminiPart): Part { throw new Error('Only support text for the moment.'); } -function toGeminiMessage(message: MessageData): GeminiMessage { +function toGeminiMessage( + message: MessageData, + model?: ModelReference +): GeminiMessage { + return { + role: toGeminiRole(message.role, model), + parts: message.content.map(toGeminiPart), + }; +} + +function toGeminiSystemInstruction(message: MessageData): GeminiMessage { return { - role: toGeminiRole(message.role), + role: 'user', parts: message.content.map(toGeminiPart), }; } @@ -245,25 +272,34 @@ export function googleAIModel( baseUrl?: string ): ModelAction { const modelName = `googleai/${name}`; - if (!apiKey) + + if (!apiKey) { apiKey = process.env.GOOGLE_GENAI_API_KEY || process.env.GOOGLE_API_KEY; - if (!apiKey) + } + if (!apiKey) { throw new Error( 'please pass in the API key or set the GOOGLE_GENAI_API_KEY or GOOGLE_API_KEY environment variable' ); + } + const model: ModelReference = SUPPORTED_MODELS[name]; if (!model) throw new Error(`Unsupported model: ${name}`); + + const middleware: ModelMiddleware[] = []; + if (SUPPORTED_V1_MODELS[name]) { + middleware.push(simulateSystemPrompt()); + } + if (model?.info?.supports?.media) { + // the gemini api doesn't support downloading media from http(s) + middleware.push(downloadRequestMedia({ maxBytes: 1024 * 1024 * 10 })); + } + return defineModel( { name: modelName, ...model.info, configSchema: model.configSchema, - use: [ - // simulate a system prompt since no native one is supported - simulateSystemPrompt(), - // since gemini api doesn't support downloading media from http(s) - downloadRequestMedia({ maxBytes: 1024 * 1024 * 10 }), - ], + use: middleware, }, async (request, streamingCallback) => { const options: RequestOptions = { apiClient: GENKIT_CLIENT_HEADER }; @@ -279,10 +315,27 @@ export function googleAIModel( }, options ); - const messages = request.messages.map(toGeminiMessage); + + // make a copy so that modifying the request will not produce side-effects + const messages = [...request.messages]; if (messages.length === 0) throw new Error('No messages provided.'); + + // Gemini does not support messages with role system and instead expects + // systemInstructions to be provided as a separate input. The first + // message detected with role=system will be used for systemInstructions. + // Any additional system messages may be considered to be "exceptional". + let systemInstruction: GeminiMessage | undefined = undefined; + const systemMessage = messages.find((m) => m.role === 'system'); + if (systemMessage) { + messages.splice(messages.indexOf(systemMessage), 1); + systemInstruction = toGeminiSystemInstruction(systemMessage); + } + const chatRequest = { - history: messages.slice(0, messages.length - 1), + systemInstruction, + history: messages + .slice(0, -1) + .map((message) => toGeminiMessage(message, model)), generationConfig: { candidateCount: request.candidates || undefined, temperature: request.config?.temperature, @@ -293,10 +346,11 @@ export function googleAIModel( }, safetySettings: request.config?.safetySettings, } as StartChatParams; + const msg = toGeminiMessage(messages[messages.length - 1], model); if (streamingCallback) { const result = await client .startChat(chatRequest) - .sendMessageStream(messages[messages.length - 1].parts); + .sendMessageStream(msg.parts); for await (const item of result.stream) { (item as GenerateContentResponse).candidates?.forEach((candidate) => { const c = fromGeminiCandidate(candidate); @@ -317,7 +371,7 @@ export function googleAIModel( } else { const result = await client .startChat(chatRequest) - .sendMessage(messages[messages.length - 1].parts); + .sendMessage(msg.parts); if (!result.response.candidates?.length) throw new Error('No valid candidates returned.'); const responseCandidates = diff --git a/js/plugins/googleai/src/index.ts b/js/plugins/googleai/src/index.ts index 7bdfbb863..bbc24c36b 100644 --- a/js/plugins/googleai/src/index.ts +++ b/js/plugins/googleai/src/index.ts @@ -24,8 +24,8 @@ import { geminiPro, geminiProVision, googleAIModel, - V1_BETA_SUPPORTED_MODELS, - V1_SUPPORTED_MODELS, + SUPPORTED_V15_MODELS, + SUPPORTED_V1_MODELS, } from './gemini.js'; export { gemini15Pro, geminiPro, geminiProVision }; @@ -52,14 +52,14 @@ export const googleAI: Plugin<[PluginOptions] | []> = genkitPlugin( if (apiVersions.includes('v1beta')) { (embedders = []), (models = [ - ...Object.keys(V1_BETA_SUPPORTED_MODELS).map((name) => + ...Object.keys(SUPPORTED_V15_MODELS).map((name) => googleAIModel(name, options?.apiKey, 'v1beta', options?.baseUrl) ), ]); } if (apiVersions.includes('v1')) { models = [ - ...Object.keys(V1_SUPPORTED_MODELS).map((name) => + ...Object.keys(SUPPORTED_V1_MODELS).map((name) => googleAIModel(name, options?.apiKey, undefined, options?.baseUrl) ), ]; diff --git a/js/plugins/vertexai/src/anthropic.ts b/js/plugins/vertexai/src/anthropic.ts index f7119b1eb..d6bf2ed08 100644 --- a/js/plugins/vertexai/src/anthropic.ts +++ b/js/plugins/vertexai/src/anthropic.ts @@ -44,6 +44,7 @@ export const claude3Sonnet = modelRef({ multiturn: true, media: true, tools: false, + systemRole: true, output: ['text'], }, }, @@ -59,6 +60,7 @@ export const claude3Haiku = modelRef({ multiturn: true, media: true, tools: false, + systemRole: true, output: ['text'], }, }, @@ -74,6 +76,7 @@ export const claude3Opus = modelRef({ multiturn: true, media: true, tools: false, + systemRole: true, output: ['text'], }, }, diff --git a/js/plugins/vertexai/src/gemini.ts b/js/plugins/vertexai/src/gemini.ts index ee5485865..fc87b2208 100644 --- a/js/plugins/vertexai/src/gemini.ts +++ b/js/plugins/vertexai/src/gemini.ts @@ -37,6 +37,7 @@ import { Content, FunctionDeclaration, FunctionDeclarationSchemaType, + Part as GeminiPart, GenerateContentCandidate, GenerateContentResponse, GenerateContentResult, @@ -44,7 +45,6 @@ import { HarmCategory, StartChatParams, VertexAI, - Part as VertexPart, } from '@google-cloud/vertexai'; import { z } from 'zod'; @@ -66,6 +66,7 @@ export const geminiPro = modelRef({ multiturn: true, media: false, tools: true, + systemRole: true, }, }, configSchema: GeminiConfigSchema, @@ -80,6 +81,7 @@ export const geminiProVision = modelRef({ multiturn: true, media: true, tools: false, + systemRole: false, }, }, configSchema: GeminiConfigSchema, @@ -94,27 +96,47 @@ export const gemini15ProPreview = modelRef({ multiturn: true, media: true, tools: true, + systemRole: true, }, }, configSchema: GeminiConfigSchema, version: 'gemini-1.5-pro-preview-0409', }); -export const SUPPORTED_GEMINI_MODELS = { - 'gemini-1.0-pro': geminiPro, - 'gemini-1.5-pro-preview': gemini15ProPreview, - 'gemini-1.0-pro-vision': geminiProVision, +export const SUPPORTED_V1_MODELS = { + 'gemini-pro': geminiPro, + 'gemini-pro-vision': geminiProVision, // 'gemini-ultra': geminiUltra, }; -function toGeminiRole(role: MessageData['role']): string { +export const SUPPORTED_V15_MODELS = { + 'gemini-1.5-pro-preview': gemini15ProPreview, +}; + +export const SUPPORTED_GEMINI_MODELS = { + ...SUPPORTED_V1_MODELS, + ...SUPPORTED_V15_MODELS, +}; + +function toGeminiRole( + role: MessageData['role'], + model?: ModelReference +): string { switch (role) { case 'user': return 'user'; case 'model': return 'model'; case 'system': - throw new Error('system role is not supported'); + if (model && SUPPORTED_V15_MODELS[model.name]) { + // We should have already pulled out the supported system messages, + // anything remaining is unsupported; throw an error. + throw new Error( + 'system role is only supported for a single message in the first position' + ); + } else { + throw new Error('system role is not supported'); + } case 'tool': return 'function'; default: @@ -133,7 +155,7 @@ const toGeminiTool = ( return declaration; }; -const toGeminiFileDataPart = (part: MediaPart): VertexPart => { +const toGeminiFileDataPart = (part: MediaPart): GeminiPart => { const media = part.media; if (media.url.startsWith('gs://')) { if (!media.contentType) @@ -160,7 +182,7 @@ const toGeminiFileDataPart = (part: MediaPart): VertexPart => { ); }; -const toGeminiToolRequestPart = (part: Part): VertexPart => { +const toGeminiToolRequestPart = (part: Part): GeminiPart => { if (!part?.toolRequest?.input) { throw Error( 'Could not convert genkit part to gemini tool response part: missing tool request data' @@ -174,7 +196,7 @@ const toGeminiToolRequestPart = (part: Part): VertexPart => { }; }; -const toGeminiToolResponsePart = (part: Part): VertexPart => { +const toGeminiToolResponsePart = (part: Part): GeminiPart => { if (!part?.toolResponse?.output) { throw Error( 'Could not convert genkit part to gemini tool response part: missing tool response data' @@ -191,30 +213,22 @@ const toGeminiToolResponsePart = (part: Part): VertexPart => { }; }; -export const toGeminiMessage = (message: MessageData): Content => { - const vertexRole = toGeminiRole(message.role); - const vertexAiMessage: any = { - role: vertexRole, - parts: [], +export function toGeminiSystemInstruction(message: MessageData): Content { + return { + role: 'user', + parts: message.content.map(toGeminiPart), }; +} - const parts = message.content; - parts.forEach((part) => { - if (part.text) { - vertexAiMessage.parts.push({ text: part.text }); - } - if (part.media) { - vertexAiMessage.parts.push(toGeminiFileDataPart(part)); - } - if (part.toolRequest) { - vertexAiMessage.parts.push(toGeminiToolRequestPart(part)); - } - if (part.toolResponse) { - vertexAiMessage.parts.push(toGeminiToolResponsePart(part)); - } - }); - return vertexAiMessage; -}; +export function toGeminiMessage( + message: MessageData, + model?: ModelReference +): Content { + return { + role: toGeminiRole(message.role, model), + parts: message.content.map(toGeminiPart), + }; +} function fromGeminiFinishReason( reason: GenerateContentCandidate['finishReason'] @@ -233,7 +247,21 @@ function fromGeminiFinishReason( } } -function fromGeminiInlineDataPart(part: VertexPart): MediaPart { +function toGeminiPart(part: Part): GeminiPart { + if (part.text) { + return { text: part.text }; + } else if (part.media) { + return toGeminiFileDataPart(part); + } else if (part.toolRequest) { + return toGeminiToolRequestPart(part); + } else if (part.toolResponse) { + return toGeminiToolResponsePart(part); + } else { + throw new Error('unsupported type'); + } +} + +function fromGeminiInlineDataPart(part: GeminiPart): MediaPart { // Check if the required properties exist if ( !part.inlineData || @@ -253,7 +281,7 @@ function fromGeminiInlineDataPart(part: VertexPart): MediaPart { }; } -function fromGeminiFileDataPart(part: VertexPart): MediaPart { +function fromGeminiFileDataPart(part: GeminiPart): MediaPart { if ( !part.fileData || !part.fileData.hasOwnProperty('mimeType') || @@ -272,7 +300,7 @@ function fromGeminiFileDataPart(part: VertexPart): MediaPart { }; } -function fromGeminiFunctionCallPart(part: VertexPart): Part { +function fromGeminiFunctionCallPart(part: GeminiPart): Part { if (!part.functionCall) { throw new Error( 'Invalid Gemini Function Call Part: missing function call data' @@ -286,7 +314,7 @@ function fromGeminiFunctionCallPart(part: VertexPart): Part { }; } -function fromGeminiFunctionResponsePart(part: VertexPart): Part { +function fromGeminiFunctionResponsePart(part: GeminiPart): Part { if (!part.functionResponse) { throw new Error( 'Invalid Gemini Function Call Part: missing function call data' @@ -301,7 +329,7 @@ function fromGeminiFunctionResponsePart(part: VertexPart): Part { } // Converts vertex part to genkit part -function fromGeminiPart(part: VertexPart): Part { +function fromGeminiPart(part: GeminiPart): Part { if (part.text !== undefined) return { text: part.text }; if (part.functionCall) return fromGeminiFunctionCallPart(part); if (part.functionResponse) return fromGeminiFunctionResponsePart(part); @@ -370,8 +398,12 @@ export function geminiModel(name: string, vertex: VertexAI): ModelAction { const model: ModelReference = SUPPORTED_GEMINI_MODELS[name]; if (!model) throw new Error(`Unsupported model: ${name}`); - const middlewares: ModelMiddleware[] = [simulateSystemPrompt()]; + const middlewares: ModelMiddleware[] = []; + if (SUPPORTED_V1_MODELS[name]) { + middlewares.push(simulateSystemPrompt()); + } if (model?.info?.supports?.media) { + // the gemini api doesn't support downloading media from http(s) middlewares.push(downloadRequestMedia({ maxBytes: 1024 * 1024 * 20 })); } @@ -392,16 +424,28 @@ export function geminiModel(name: string, vertex: VertexAI): ModelAction { } ); - const messages = request.messages; + // make a copy so that modifying the request will not produce side-effects + const messages = [...request.messages]; if (messages.length === 0) throw new Error('No messages provided.'); + // Gemini does not support messages with role system and instead expects + // systemInstructions to be provided as a separate input. The first + // message detected with role=system will be used for systemInstructions. + // Any additional system messages may be considered to be "exceptional". + let systemInstruction: Content | undefined = undefined; + const systemMessage = messages.find((m) => m.role === 'system'); + if (systemMessage) { + messages.splice(messages.indexOf(systemMessage), 1); + systemInstruction = toGeminiSystemInstruction(systemMessage); + } const chatRequest: StartChatParams = { + systemInstruction, tools: request.tools?.length ? [{ functionDeclarations: request.tools?.map(toGeminiTool) }] : [], history: messages .slice(0, -1) - .map((message) => toGeminiMessage(message)), + .map((message) => toGeminiMessage(message, model)), generationConfig: { candidateCount: request.candidates || undefined, temperature: request.config?.temperature, @@ -412,7 +456,7 @@ export function geminiModel(name: string, vertex: VertexAI): ModelAction { }, safetySettings: request.config?.safetySettings, }; - const msg = toGeminiMessage(messages[messages.length - 1]); + const msg = toGeminiMessage(messages[messages.length - 1], model); if (streamingCallback) { const result = await client .startChat(chatRequest) diff --git a/js/plugins/vertexai/src/imagen.ts b/js/plugins/vertexai/src/imagen.ts index 735fd50d6..54aff6418 100644 --- a/js/plugins/vertexai/src/imagen.ts +++ b/js/plugins/vertexai/src/imagen.ts @@ -49,6 +49,7 @@ export const imagen2 = modelRef({ media: false, multiturn: false, tools: false, + systemRole: false, output: ['media'], }, }, diff --git a/js/plugins/vertexai/tests/gemini_test.ts b/js/plugins/vertexai/tests/gemini_test.ts index 8e5620ea0..dbb39e8f1 100644 --- a/js/plugins/vertexai/tests/gemini_test.ts +++ b/js/plugins/vertexai/tests/gemini_test.ts @@ -14,10 +14,15 @@ * limitations under the License. */ +import { MessageData } from '@genkit-ai/ai/model'; import { GenerateContentCandidate } from '@google-cloud/vertexai'; import assert from 'node:assert'; import { describe, it } from 'node:test'; -import { fromGeminiCandidate, toGeminiMessage } from '../src/gemini.js'; +import { + fromGeminiCandidate, + toGeminiMessage, + toGeminiSystemInstruction, +} from '../src/gemini.js'; describe('toGeminiMessages', () => { const testCases = [ @@ -108,7 +113,51 @@ describe('toGeminiMessages', () => { ]; for (const test of testCases) { it(test.should, () => { - assert.deepEqual(toGeminiMessage(test.inputMessage), test.expectedOutput); + assert.deepEqual( + toGeminiMessage(test.inputMessage as MessageData), + test.expectedOutput + ); + }); + } +}); + +describe('toGeminiSystemInstruction', () => { + const testCases = [ + { + should: 'should transform from system to user', + inputMessage: { + role: 'system', + content: [{ text: 'You are an expert in all things cats.' }], + }, + expectedOutput: { + role: 'user', + parts: [{ text: 'You are an expert in all things cats.' }], + }, + }, + { + should: 'should transform from system to user with multiple parts', + inputMessage: { + role: 'system', + content: [ + { text: 'You are an expert in all things animals.' }, + { text: 'You love cats.' }, + ], + }, + expectedOutput: { + role: 'user', + parts: [ + { text: 'You are an expert in all things animals.' }, + { text: 'You love cats.' }, + ], + }, + }, + ]; + for (const test of testCases) { + it(test.should, () => { + assert.deepEqual( + toGeminiSystemInstruction(test.inputMessage as MessageData), + test.expectedOutput + ); }); } });