From 2254a7b54e69b2dae3124a7186a28048232fbba4 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 7 Aug 2024 15:41:52 -0400 Subject: [PATCH 1/7] refactor: [JS] introduced an generate utility action to make generate veneer logic reusable --- js/ai/src/generate.ts | 333 +++++++++++++++++++++++++--------------- js/core/src/registry.ts | 1 + 2 files changed, 214 insertions(+), 120 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index ebda1c322..821fe2cae 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -17,7 +17,9 @@ import { Action, config as genkitConfig, + defineAction, GenkitError, + getStreamingCallback, runWithStreamingCallback, StreamingCallback, } from '@genkit-ai/core'; @@ -28,21 +30,26 @@ import { validateSchema, } from '@genkit-ai/core/schema'; import { z } from 'zod'; -import { DocumentData } from './document.js'; +import { DocumentData, DocumentDataSchema } from './document.js'; import { extractJson } from './extract.js'; import { CandidateData, GenerateRequest, GenerateResponseChunkData, GenerateResponseData, + GenerateResponseSchema, GenerationCommonConfigSchema, GenerationUsage, MessageData, + MessageSchema, ModelAction, ModelArgument, ModelReference, Part, + PartSchema, Role, + ToolDefinition, + ToolDefinitionSchema, ToolRequestPart, ToolResponsePart, } from './model.js'; @@ -447,38 +454,32 @@ function inferRoleFromParts(parts: Part[]): Role { } export async function toGenerateRequest( - options: GenerateOptions + options: z.infer, + resolvedTools?: ToolAction[], ): Promise { const promptMessage: MessageData = { role: 'user', content: [] }; if (typeof options.prompt === 'string') { promptMessage.content.push({ text: options.prompt }); } else if (Array.isArray(options.prompt)) { promptMessage.role = inferRoleFromParts(options.prompt); - promptMessage.content.push(...options.prompt); + promptMessage.content.push(...(options.prompt as Part[])); } else { promptMessage.role = inferRoleFromParts([options.prompt]); promptMessage.content.push(options.prompt); } const messages: MessageData[] = [...(options.history || []), promptMessage]; - let tools: Action[] | undefined; - if (options.tools) { - tools = await resolveTools(options.tools); - } const out = { messages, candidates: options.candidates, config: options.config, context: options.context, - tools: tools?.map((tool) => toToolDefinition(tool)) || [], + tools: resolvedTools?.map((tool) => toToolDefinition(tool)) || [], output: { format: options.output?.format || - (options.output?.schema || options.output?.jsonSchema - ? 'json' - : 'text'), + (options.output?.jsonSchema ? 'json' : 'text'), schema: toJsonSchema({ - schema: options.output?.schema, jsonSchema: options.output?.jsonSchema, }), }, @@ -593,6 +594,173 @@ export class NoValidCandidatesError extends GenkitError { } } +export const GenerateUtilParamSchema = z.object({ + /** A model name (e.g. `vertexai/gemini-1.0-pro`) or reference. */ + model: z.string(), + /** The prompt for which to generate a response. Can be a string for a simple text prompt or one or more parts for multi-modal prompts. */ + prompt: z.union([z.string(), PartSchema, z.array(PartSchema)]), + /** Retrieved documents to be used as context for this generation. */ + context: z.array(DocumentDataSchema).optional(), + /** Conversation history for multi-turn prompting when supported by the underlying model. */ + history: z.array(MessageSchema).optional(), + /** List of registered tool names or actions to treat as a tool for this generation if supported by the underlying model. */ + tools: z.array(z.union([z.string(), ToolDefinitionSchema])).optional(), + /** Number of candidate messages to generate. */ + candidates: z.number().optional(), + /** Configuration for the generation request. */ + config: z.any().optional(), + /** Configuration for the desired output of the request. Defaults to the model's default output if unspecified. */ + output: z + .object({ + format: z + .union([z.literal('text'), z.literal('json'), z.literal('media')]) + .optional(), + jsonSchema: z.any().optional(), + }) + .optional(), + /** When true, return tool calls for manual processing instead of automatically resolving them. */ + returnToolRequests: z.boolean().optional(), +}); + +const generateAction = defineAction( + { + actionType: 'util', + name: 'generate', + inputSchema: GenerateUtilParamSchema, + outputSchema: GenerateResponseSchema, + }, + async (input) => { + const model = (await lookupAction(`/model/${input.model}`)) as ModelAction; + if (!model) { + throw new Error(`Model ${input.model} not found`); + } + + let tools: ToolAction[] | undefined; + if (input.tools?.length) { + if (!model.__action.metadata?.model.supports?.tools) { + throw new Error( + `Model ${input.model} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.` + ); + } + tools = await Promise.all(input.tools.map(async (toolRef) => { + if (typeof toolRef === 'string') { + const tool = await lookupAction(toolRef) as ToolAction + if (!tool) { + throw new Error(`Tool ${toolRef} not found`); + } + return tool + } + throw '' + })); + } + + const request = await toGenerateRequest(input, tools); + + const accumulatedChunks: GenerateResponseChunkData[] = []; + + const streamingCallback = getStreamingCallback(); + const response = await runWithStreamingCallback( + streamingCallback + ? (chunk: GenerateResponseChunkData) => { + // Store accumulated chunk data + accumulatedChunks.push(chunk); + if (streamingCallback) { + streamingCallback!( + new GenerateResponseChunk(chunk, accumulatedChunks) + ); + } + } + : undefined, + async () => new GenerateResponse(await model(request)) + ); + + // throw NoValidCandidates if all candidates are blocked or + if ( + !response.candidates.some((c) => + ['stop', 'length'].includes(c.finishReason) + ) + ) { + throw new NoValidCandidatesError({ + message: `All candidates returned finishReason issues: ${JSON.stringify(response.candidates.map((c) => c.finishReason))}`, + response, + }); + } + + if (input.output?.jsonSchema && !response.toolRequests()?.length) { + // find a candidate with valid output schema + const candidateErrors = response.candidates.map((c) => { + // don't validate messages that have no text or data + if (c.text() === '' && c.data() === null) return null; + + try { + parseSchema(c.output(), { + jsonSchema: input.output?.jsonSchema, + }); + return null; + } catch (e) { + return e as Error; + } + }); + // if all candidates have a non-null error... + if (candidateErrors.every((c) => !!c)) { + throw new NoValidCandidatesError({ + message: `Generation resulted in no candidates matching provided output schema.${candidateErrors.map((e, i) => `\n\nCandidate[${i}] ${e!.toString()}`)}`, + response, + detail: { + candidateErrors: candidateErrors, + }, + }); + } + } + + // Pick the first valid candidate. + let selected: Candidate | undefined; + for (const candidate of response.candidates) { + if (isValidCandidate(candidate, tools || [])) { + selected = candidate; + break; + } + } + + if (!selected) { + throw new Error('No valid candidates found'); + } + + const toolCalls = selected.message.content.filter( + (part) => !!part.toolRequest + ); + if (input.returnToolRequests || toolCalls.length === 0) { + return response.toJSON(); + } + const toolResponses: ToolResponsePart[] = await Promise.all( + toolCalls.map(async (part) => { + if (!part.toolRequest) { + throw Error( + 'Tool request expected but not provided in tool request part' + ); + } + const tool = tools?.find( + (tool) => tool.__action.name === part.toolRequest?.name + ); + if (!tool) { + throw Error('Tool not found'); + } + return { + toolResponse: { + name: part.toolRequest.name, + ref: part.toolRequest.ref, + output: await tool(part.toolRequest?.input), + }, + }; + }) + ); + input.history = request.messages; + input.history.push(selected.message); + input.prompt = toolResponses; + return await generateAction(input); + } +); + /** * Generate calls a generative model based on the provided prompt and configuration. If * `history` is provided, the generation will include a conversation history in its @@ -604,7 +772,6 @@ export class NoValidCandidatesError extends GenkitError { * @param options The options for this generation request. * @returns The generated response based on the provided parameters. */ - export async function generate< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, @@ -620,120 +787,46 @@ export async function generate< throw new Error(`Model ${JSON.stringify(resolvedOptions.model)} not found`); } - let tools: ToolAction[] | undefined; - if (resolvedOptions.tools?.length) { - if (!model.__action.metadata?.model.supports?.tools) { + // convert tools to action refs (strings). + let tools: (string | ToolDefinition)[] | undefined; + if (resolvedOptions.tools) { + tools = resolvedOptions.tools.map((t) => { + if (typeof t === 'string') { + return `/tool/${t}`; + } else if ((t as Action).__action) { + return `/${(t as Action).__action.metadata?.type}/${(t as Action).__action.name}`; + } else if (t.name) { + return `/tool/${t.name}`; + } throw new Error( - `Model ${JSON.stringify(resolvedOptions.model)} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.` + `Unable to determine type of of tool: ${JSON.stringify(t)}` ); - } - tools = await resolveTools(resolvedOptions.tools); - } - - const request = await toGenerateRequest(resolvedOptions); - - const accumulatedChunks: GenerateResponseChunkData[] = []; - - const response = await runWithStreamingCallback( - resolvedOptions.streamingCallback - ? (chunk: GenerateResponseChunkData) => { - // Store accumulated chunk data - accumulatedChunks.push(chunk); - if (resolvedOptions.streamingCallback) { - resolvedOptions.streamingCallback!( - new GenerateResponseChunk(chunk, accumulatedChunks) - ); - } - } - : undefined, - async () => new GenerateResponse>(await model(request), request) - ); - - // throw NoValidCandidates if all candidates are blocked or - if ( - !response.candidates.some((c) => - ['stop', 'length'].includes(c.finishReason) - ) - ) { - throw new NoValidCandidatesError({ - message: `All candidates returned finishReason issues: ${JSON.stringify(response.candidates.map((c) => c.finishReason))}`, - response, }); } - if (resolvedOptions.output?.schema || resolvedOptions.output?.jsonSchema) { - // find a candidate with valid output schema - const candidateErrors = response.candidates.map((c) => { - // don't validate messages that have no text or data - if (c.text() === '' && c.data() === null) return null; - - try { - parseSchema(c.output(), { - jsonSchema: resolvedOptions.output?.jsonSchema, - schema: resolvedOptions.output?.schema, - }); - return null; - } catch (e) { - return e as Error; - } - }); - // if all candidates have a non-null error... - if (candidateErrors.every((c) => !!c)) { - throw new NoValidCandidatesError({ - message: `Generation resulted in no candidates matching provided output schema.${candidateErrors.map((e, i) => `\n\nCandidate[${i}] ${e!.toString()}`)}`, - response, - detail: { - candidateErrors: candidateErrors, - }, - }); - } - } - - // Pick the first valid candidate. - let selected: Candidate> | undefined; - for (const candidate of response.candidates) { - if (isValidCandidate(candidate, tools || [])) { - selected = candidate; - break; - } - } - - if (!selected) { - throw new Error('No valid candidates found'); - } + const params: z.infer = { + model: model.__action.name, + prompt: resolvedOptions.prompt, + context: resolvedOptions.context, + history: resolvedOptions.history, + tools, + candidates: resolvedOptions.candidates, + config: resolvedOptions.config, + output: resolvedOptions.output && { + format: resolvedOptions.output.format, + jsonSchema: resolvedOptions.output.schema + ? toJsonSchema({ + schema: resolvedOptions.output.schema, + jsonSchema: resolvedOptions.output.jsonSchema, + }) + : resolvedOptions.output.jsonSchema, + }, + returnToolRequests: resolvedOptions.returnToolRequests, + }; - const toolCalls = selected.message.content.filter( - (part) => !!part.toolRequest - ); - if (resolvedOptions.returnToolRequests || toolCalls.length === 0) { - return response; - } - const toolResponses: ToolResponsePart[] = await Promise.all( - toolCalls.map(async (part) => { - if (!part.toolRequest) { - throw Error( - 'Tool request expected but not provided in tool request part' - ); - } - const tool = tools?.find( - (tool) => tool.__action.name === part.toolRequest?.name - ); - if (!tool) { - throw Error('Tool not found'); - } - return { - toolResponse: { - name: part.toolRequest.name, - ref: part.toolRequest.ref, - output: await tool(part.toolRequest?.input), - }, - }; - }) + return await runWithStreamingCallback(resolvedOptions.streamingCallback, async () => + new GenerateResponse(await generateAction(params)) ); - resolvedOptions.history = request.messages; - resolvedOptions.history.push(selected.message); - resolvedOptions.prompt = toolResponses; - return await generate(resolvedOptions); } export type GenerateStreamOptions< diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index 94647c92b..bb93a13bb 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -77,6 +77,7 @@ export type ActionType = | 'flow' | 'model' | 'prompt' + | 'util' | 'tool'; /** From a4bb8668a730349ebe7e27f42822bb9a8f3159fd Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 7 Aug 2024 15:44:02 -0400 Subject: [PATCH 2/7] format --- js/ai/src/generate.ts | 44 +++++++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 821fe2cae..3b603b5dd 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -16,12 +16,12 @@ import { Action, - config as genkitConfig, - defineAction, GenkitError, + StreamingCallback, + defineAction, + config as genkitConfig, getStreamingCallback, runWithStreamingCallback, - StreamingCallback, } from '@genkit-ai/core'; import { lookupAction } from '@genkit-ai/core/registry'; import { @@ -53,12 +53,7 @@ import { ToolRequestPart, ToolResponsePart, } from './model.js'; -import { - resolveTools, - ToolAction, - ToolArgument, - toToolDefinition, -} from './tool.js'; +import { ToolAction, ToolArgument, toToolDefinition } from './tool.js'; /** * Message represents a single role's contribution to a generation. Each message @@ -455,7 +450,7 @@ function inferRoleFromParts(parts: Part[]): Role { export async function toGenerateRequest( options: z.infer, - resolvedTools?: ToolAction[], + resolvedTools?: ToolAction[] ): Promise { const promptMessage: MessageData = { role: 'user', content: [] }; if (typeof options.prompt === 'string') { @@ -595,7 +590,7 @@ export class NoValidCandidatesError extends GenkitError { } export const GenerateUtilParamSchema = z.object({ - /** A model name (e.g. `vertexai/gemini-1.0-pro`) or reference. */ + /** A model name (e.g. `vertexai/gemini-1.0-pro`). */ model: z.string(), /** The prompt for which to generate a response. Can be a string for a simple text prompt or one or more parts for multi-modal prompts. */ prompt: z.union([z.string(), PartSchema, z.array(PartSchema)]), @@ -603,7 +598,7 @@ export const GenerateUtilParamSchema = z.object({ context: z.array(DocumentDataSchema).optional(), /** Conversation history for multi-turn prompting when supported by the underlying model. */ history: z.array(MessageSchema).optional(), - /** List of registered tool names or actions to treat as a tool for this generation if supported by the underlying model. */ + /** List of registered tool names for this generation if supported by the underlying model. */ tools: z.array(z.union([z.string(), ToolDefinitionSchema])).optional(), /** Number of candidate messages to generate. */ candidates: z.number().optional(), @@ -642,16 +637,18 @@ const generateAction = defineAction( `Model ${input.model} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.` ); } - tools = await Promise.all(input.tools.map(async (toolRef) => { - if (typeof toolRef === 'string') { - const tool = await lookupAction(toolRef) as ToolAction - if (!tool) { - throw new Error(`Tool ${toolRef} not found`); + tools = await Promise.all( + input.tools.map(async (toolRef) => { + if (typeof toolRef === 'string') { + const tool = (await lookupAction(toolRef)) as ToolAction; + if (!tool) { + throw new Error(`Tool ${toolRef} not found`); + } + return tool; } - return tool - } - throw '' - })); + throw ''; + }) + ); } const request = await toGenerateRequest(input, tools); @@ -824,8 +821,9 @@ export async function generate< returnToolRequests: resolvedOptions.returnToolRequests, }; - return await runWithStreamingCallback(resolvedOptions.streamingCallback, async () => - new GenerateResponse(await generateAction(params)) + return await runWithStreamingCallback( + resolvedOptions.streamingCallback, + async () => new GenerateResponse(await generateAction(params)) ); } From 05a825447fd8075a2eb2cf3f22ced98faba7ed9f Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 7 Aug 2024 15:53:19 -0400 Subject: [PATCH 3/7] bring back toGenerateRequest --- js/ai/src/generate.ts | 52 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 3 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 3b603b5dd..2a6bf4663 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -53,7 +53,12 @@ import { ToolRequestPart, ToolResponsePart, } from './model.js'; -import { ToolAction, ToolArgument, toToolDefinition } from './tool.js'; +import { + ToolAction, + ToolArgument, + resolveTools, + toToolDefinition, +} from './tool.js'; /** * Message represents a single role's contribution to a generation. Each message @@ -448,7 +453,7 @@ function inferRoleFromParts(parts: Part[]): Role { return Array.from(uniqueRoles)[0]; } -export async function toGenerateRequest( +async function actionToGenerateRequest( options: z.infer, resolvedTools?: ToolAction[] ): Promise { @@ -483,6 +488,47 @@ export async function toGenerateRequest( return out; } +export async function toGenerateRequest( + options: GenerateOptions +): Promise { + const promptMessage: MessageData = { role: 'user', content: [] }; + if (typeof options.prompt === 'string') { + promptMessage.content.push({ text: options.prompt }); + } else if (Array.isArray(options.prompt)) { + promptMessage.role = inferRoleFromParts(options.prompt); + promptMessage.content.push(...options.prompt); + } else { + promptMessage.role = inferRoleFromParts([options.prompt]); + promptMessage.content.push(options.prompt); + } + const messages: MessageData[] = [...(options.history || []), promptMessage]; + let tools: Action[] | undefined; + if (options.tools) { + tools = await resolveTools(options.tools); + } + + const out = { + messages, + candidates: options.candidates, + config: options.config, + context: options.context, + tools: tools?.map((tool) => toToolDefinition(tool)) || [], + output: { + format: + options.output?.format || + (options.output?.schema || options.output?.jsonSchema + ? 'json' + : 'text'), + schema: toJsonSchema({ + schema: options.output?.schema, + jsonSchema: options.output?.jsonSchema, + }), + }, + }; + if (!out.output.schema) delete out.output.schema; + return out; +} + export interface GenerateOptions< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, @@ -651,7 +697,7 @@ const generateAction = defineAction( ); } - const request = await toGenerateRequest(input, tools); + const request = await actionToGenerateRequest(input, tools); const accumulatedChunks: GenerateResponseChunkData[] = []; From 55486d6fd418f5a9d2743c5e15785f2c14749d48 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 7 Aug 2024 16:01:11 -0400 Subject: [PATCH 4/7] fix tests --- js/plugins/google-cloud/tests/logs_test.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/plugins/google-cloud/tests/logs_test.ts b/js/plugins/google-cloud/tests/logs_test.ts index d81b44777..543e94be4 100644 --- a/js/plugins/google-cloud/tests/logs_test.ts +++ b/js/plugins/google-cloud/tests/logs_test.ts @@ -150,19 +150,19 @@ describe('GoogleCloudLogs', () => { const logMessages = await getLogs(); assert.equal( logMessages.includes( - '[info] Config[testFlow > sub1 > sub2 > testModel, testModel]' + '[info] Config[testFlow > sub1 > sub2 > generate > testModel, testModel]' ), true ); assert.equal( logMessages.includes( - '[info] Input[testFlow > sub1 > sub2 > testModel, testModel]' + '[info] Input[testFlow > sub1 > sub2 > generate > testModel, testModel]' ), true ); assert.equal( logMessages.includes( - '[info] Output[testFlow > sub1 > sub2 > testModel, testModel]' + '[info] Output[testFlow > sub1 > sub2 > generate > testModel, testModel]' ), true ); From 4c01b564068999b6299215be213c1335f2a41ae7 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 7 Aug 2024 16:11:54 -0400 Subject: [PATCH 5/7] generate action to separate file --- js/ai/src/generate.ts | 275 +-------------------------------- js/ai/src/generateAction.ts | 299 ++++++++++++++++++++++++++++++++++++ 2 files changed, 307 insertions(+), 267 deletions(-) create mode 100644 js/ai/src/generateAction.ts diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 2a6bf4663..e97ac93cc 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -18,47 +18,36 @@ import { Action, GenkitError, StreamingCallback, - defineAction, config as genkitConfig, - getStreamingCallback, runWithStreamingCallback, } from '@genkit-ai/core'; import { lookupAction } from '@genkit-ai/core/registry'; -import { - parseSchema, - toJsonSchema, - validateSchema, -} from '@genkit-ai/core/schema'; +import { toJsonSchema, validateSchema } from '@genkit-ai/core/schema'; import { z } from 'zod'; -import { DocumentData, DocumentDataSchema } from './document.js'; +import { DocumentData } from './document.js'; import { extractJson } from './extract.js'; +import { + GenerateUtilParamSchema, + generateAction, + inferRoleFromParts, +} from './generateAction.js'; import { CandidateData, GenerateRequest, GenerateResponseChunkData, GenerateResponseData, - GenerateResponseSchema, GenerationCommonConfigSchema, GenerationUsage, MessageData, - MessageSchema, ModelAction, ModelArgument, ModelReference, Part, - PartSchema, - Role, ToolDefinition, - ToolDefinitionSchema, ToolRequestPart, ToolResponsePart, } from './model.js'; -import { - ToolAction, - ToolArgument, - resolveTools, - toToolDefinition, -} from './tool.js'; +import { ToolArgument, resolveTools, toToolDefinition } from './tool.js'; /** * Message represents a single role's contribution to a generation. Each message @@ -432,62 +421,6 @@ export class GenerateResponseChunk } } -function getRoleFromPart(part: Part): Role { - if (part.toolRequest !== undefined) return 'model'; - if (part.toolResponse !== undefined) return 'tool'; - if (part.text !== undefined) return 'user'; - if (part.media !== undefined) return 'user'; - if (part.data !== undefined) return 'user'; - throw new Error('No recognized fields in content'); -} - -function inferRoleFromParts(parts: Part[]): Role { - const uniqueRoles = new Set(); - for (const part of parts) { - const role = getRoleFromPart(part); - uniqueRoles.add(role); - if (uniqueRoles.size > 1) { - throw new Error('Contents contain mixed roles'); - } - } - return Array.from(uniqueRoles)[0]; -} - -async function actionToGenerateRequest( - options: z.infer, - resolvedTools?: ToolAction[] -): Promise { - const promptMessage: MessageData = { role: 'user', content: [] }; - if (typeof options.prompt === 'string') { - promptMessage.content.push({ text: options.prompt }); - } else if (Array.isArray(options.prompt)) { - promptMessage.role = inferRoleFromParts(options.prompt); - promptMessage.content.push(...(options.prompt as Part[])); - } else { - promptMessage.role = inferRoleFromParts([options.prompt]); - promptMessage.content.push(options.prompt); - } - const messages: MessageData[] = [...(options.history || []), promptMessage]; - - const out = { - messages, - candidates: options.candidates, - config: options.config, - context: options.context, - tools: resolvedTools?.map((tool) => toToolDefinition(tool)) || [], - output: { - format: - options.output?.format || - (options.output?.jsonSchema ? 'json' : 'text'), - schema: toJsonSchema({ - jsonSchema: options.output?.jsonSchema, - }), - }, - }; - if (!out.output.schema) delete out.output.schema; - return out; -} - export async function toGenerateRequest( options: GenerateOptions ): Promise { @@ -559,29 +492,6 @@ export interface GenerateOptions< streamingCallback?: StreamingCallback; } -const isValidCandidate = ( - candidate: CandidateData, - tools: Action[] -): boolean => { - // Check if tool calls are vlaid - const toolCalls = candidate.message.content.filter( - (part) => !!part.toolRequest - ); - - // make sure every tool called exists and has valid input - return toolCalls.every((toolCall) => { - const tool = tools?.find( - (tool) => tool.__action.name === toolCall.toolRequest?.name - ); - if (!tool) return false; - const { valid } = validateSchema(toolCall.toolRequest?.input, { - schema: tool.__action.inputSchema, - jsonSchema: tool.__action.inputJsonSchema, - }); - return valid; - }); -}; - async function resolveModel(options: GenerateOptions): Promise { let model = options.model; if (!model) { @@ -635,175 +545,6 @@ export class NoValidCandidatesError extends GenkitError { } } -export const GenerateUtilParamSchema = z.object({ - /** A model name (e.g. `vertexai/gemini-1.0-pro`). */ - model: z.string(), - /** The prompt for which to generate a response. Can be a string for a simple text prompt or one or more parts for multi-modal prompts. */ - prompt: z.union([z.string(), PartSchema, z.array(PartSchema)]), - /** Retrieved documents to be used as context for this generation. */ - context: z.array(DocumentDataSchema).optional(), - /** Conversation history for multi-turn prompting when supported by the underlying model. */ - history: z.array(MessageSchema).optional(), - /** List of registered tool names for this generation if supported by the underlying model. */ - tools: z.array(z.union([z.string(), ToolDefinitionSchema])).optional(), - /** Number of candidate messages to generate. */ - candidates: z.number().optional(), - /** Configuration for the generation request. */ - config: z.any().optional(), - /** Configuration for the desired output of the request. Defaults to the model's default output if unspecified. */ - output: z - .object({ - format: z - .union([z.literal('text'), z.literal('json'), z.literal('media')]) - .optional(), - jsonSchema: z.any().optional(), - }) - .optional(), - /** When true, return tool calls for manual processing instead of automatically resolving them. */ - returnToolRequests: z.boolean().optional(), -}); - -const generateAction = defineAction( - { - actionType: 'util', - name: 'generate', - inputSchema: GenerateUtilParamSchema, - outputSchema: GenerateResponseSchema, - }, - async (input) => { - const model = (await lookupAction(`/model/${input.model}`)) as ModelAction; - if (!model) { - throw new Error(`Model ${input.model} not found`); - } - - let tools: ToolAction[] | undefined; - if (input.tools?.length) { - if (!model.__action.metadata?.model.supports?.tools) { - throw new Error( - `Model ${input.model} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.` - ); - } - tools = await Promise.all( - input.tools.map(async (toolRef) => { - if (typeof toolRef === 'string') { - const tool = (await lookupAction(toolRef)) as ToolAction; - if (!tool) { - throw new Error(`Tool ${toolRef} not found`); - } - return tool; - } - throw ''; - }) - ); - } - - const request = await actionToGenerateRequest(input, tools); - - const accumulatedChunks: GenerateResponseChunkData[] = []; - - const streamingCallback = getStreamingCallback(); - const response = await runWithStreamingCallback( - streamingCallback - ? (chunk: GenerateResponseChunkData) => { - // Store accumulated chunk data - accumulatedChunks.push(chunk); - if (streamingCallback) { - streamingCallback!( - new GenerateResponseChunk(chunk, accumulatedChunks) - ); - } - } - : undefined, - async () => new GenerateResponse(await model(request)) - ); - - // throw NoValidCandidates if all candidates are blocked or - if ( - !response.candidates.some((c) => - ['stop', 'length'].includes(c.finishReason) - ) - ) { - throw new NoValidCandidatesError({ - message: `All candidates returned finishReason issues: ${JSON.stringify(response.candidates.map((c) => c.finishReason))}`, - response, - }); - } - - if (input.output?.jsonSchema && !response.toolRequests()?.length) { - // find a candidate with valid output schema - const candidateErrors = response.candidates.map((c) => { - // don't validate messages that have no text or data - if (c.text() === '' && c.data() === null) return null; - - try { - parseSchema(c.output(), { - jsonSchema: input.output?.jsonSchema, - }); - return null; - } catch (e) { - return e as Error; - } - }); - // if all candidates have a non-null error... - if (candidateErrors.every((c) => !!c)) { - throw new NoValidCandidatesError({ - message: `Generation resulted in no candidates matching provided output schema.${candidateErrors.map((e, i) => `\n\nCandidate[${i}] ${e!.toString()}`)}`, - response, - detail: { - candidateErrors: candidateErrors, - }, - }); - } - } - - // Pick the first valid candidate. - let selected: Candidate | undefined; - for (const candidate of response.candidates) { - if (isValidCandidate(candidate, tools || [])) { - selected = candidate; - break; - } - } - - if (!selected) { - throw new Error('No valid candidates found'); - } - - const toolCalls = selected.message.content.filter( - (part) => !!part.toolRequest - ); - if (input.returnToolRequests || toolCalls.length === 0) { - return response.toJSON(); - } - const toolResponses: ToolResponsePart[] = await Promise.all( - toolCalls.map(async (part) => { - if (!part.toolRequest) { - throw Error( - 'Tool request expected but not provided in tool request part' - ); - } - const tool = tools?.find( - (tool) => tool.__action.name === part.toolRequest?.name - ); - if (!tool) { - throw Error('Tool not found'); - } - return { - toolResponse: { - name: part.toolRequest.name, - ref: part.toolRequest.ref, - output: await tool(part.toolRequest?.input), - }, - }; - }) - ); - input.history = request.messages; - input.history.push(selected.message); - input.prompt = toolResponses; - return await generateAction(input); - } -); - /** * Generate calls a generative model based on the provided prompt and configuration. If * `history` is provided, the generation will include a conversation history in its diff --git a/js/ai/src/generateAction.ts b/js/ai/src/generateAction.ts new file mode 100644 index 000000000..ecdf682dc --- /dev/null +++ b/js/ai/src/generateAction.ts @@ -0,0 +1,299 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + Action, + defineAction, + getStreamingCallback, + runWithStreamingCallback, +} from '@genkit-ai/core'; +import { lookupAction } from '@genkit-ai/core/registry'; +import { + parseSchema, + toJsonSchema, + validateSchema, +} from '@genkit-ai/core/schema'; +import { z } from 'zod'; +import { DocumentDataSchema } from './document.js'; +import { + Candidate, + GenerateResponse, + GenerateResponseChunk, + NoValidCandidatesError, +} from './generate.js'; +import { + CandidateData, + GenerateRequest, + GenerateResponseChunkData, + GenerateResponseSchema, + MessageData, + MessageSchema, + ModelAction, + Part, + PartSchema, + Role, + ToolDefinitionSchema, + ToolResponsePart, +} from './model.js'; +import { ToolAction, toToolDefinition } from './tool.js'; + +export const GenerateUtilParamSchema = z.object({ + /** A model name (e.g. `vertexai/gemini-1.0-pro`). */ + model: z.string(), + /** The prompt for which to generate a response. Can be a string for a simple text prompt or one or more parts for multi-modal prompts. */ + prompt: z.union([z.string(), PartSchema, z.array(PartSchema)]), + /** Retrieved documents to be used as context for this generation. */ + context: z.array(DocumentDataSchema).optional(), + /** Conversation history for multi-turn prompting when supported by the underlying model. */ + history: z.array(MessageSchema).optional(), + /** List of registered tool names for this generation if supported by the underlying model. */ + tools: z.array(z.union([z.string(), ToolDefinitionSchema])).optional(), + /** Number of candidate messages to generate. */ + candidates: z.number().optional(), + /** Configuration for the generation request. */ + config: z.any().optional(), + /** Configuration for the desired output of the request. Defaults to the model's default output if unspecified. */ + output: z + .object({ + format: z + .union([z.literal('text'), z.literal('json'), z.literal('media')]) + .optional(), + jsonSchema: z.any().optional(), + }) + .optional(), + /** When true, return tool calls for manual processing instead of automatically resolving them. */ + returnToolRequests: z.boolean().optional(), +}); + +export const generateAction = defineAction( + { + actionType: 'util', + name: 'generate', + inputSchema: GenerateUtilParamSchema, + outputSchema: GenerateResponseSchema, + }, + async (input) => { + const model = (await lookupAction(`/model/${input.model}`)) as ModelAction; + if (!model) { + throw new Error(`Model ${input.model} not found`); + } + + let tools: ToolAction[] | undefined; + if (input.tools?.length) { + if (!model.__action.metadata?.model.supports?.tools) { + throw new Error( + `Model ${input.model} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.` + ); + } + tools = await Promise.all( + input.tools.map(async (toolRef) => { + if (typeof toolRef === 'string') { + const tool = (await lookupAction(toolRef)) as ToolAction; + if (!tool) { + throw new Error(`Tool ${toolRef} not found`); + } + return tool; + } + throw ''; + }) + ); + } + + const request = await actionToGenerateRequest(input, tools); + + const accumulatedChunks: GenerateResponseChunkData[] = []; + + const streamingCallback = getStreamingCallback(); + const response = await runWithStreamingCallback( + streamingCallback + ? (chunk: GenerateResponseChunkData) => { + // Store accumulated chunk data + accumulatedChunks.push(chunk); + if (streamingCallback) { + streamingCallback!( + new GenerateResponseChunk(chunk, accumulatedChunks) + ); + } + } + : undefined, + async () => new GenerateResponse(await model(request)) + ); + + // throw NoValidCandidates if all candidates are blocked or + if ( + !response.candidates.some((c) => + ['stop', 'length'].includes(c.finishReason) + ) + ) { + throw new NoValidCandidatesError({ + message: `All candidates returned finishReason issues: ${JSON.stringify(response.candidates.map((c) => c.finishReason))}`, + response, + }); + } + + if (input.output?.jsonSchema && !response.toolRequests()?.length) { + // find a candidate with valid output schema + const candidateErrors = response.candidates.map((c) => { + // don't validate messages that have no text or data + if (c.text() === '' && c.data() === null) return null; + + try { + parseSchema(c.output(), { + jsonSchema: input.output?.jsonSchema, + }); + return null; + } catch (e) { + return e as Error; + } + }); + // if all candidates have a non-null error... + if (candidateErrors.every((c) => !!c)) { + throw new NoValidCandidatesError({ + message: `Generation resulted in no candidates matching provided output schema.${candidateErrors.map((e, i) => `\n\nCandidate[${i}] ${e!.toString()}`)}`, + response, + detail: { + candidateErrors: candidateErrors, + }, + }); + } + } + + // Pick the first valid candidate. + let selected: Candidate | undefined; + for (const candidate of response.candidates) { + if (isValidCandidate(candidate, tools || [])) { + selected = candidate; + break; + } + } + + if (!selected) { + throw new Error('No valid candidates found'); + } + + const toolCalls = selected.message.content.filter( + (part) => !!part.toolRequest + ); + if (input.returnToolRequests || toolCalls.length === 0) { + return response.toJSON(); + } + const toolResponses: ToolResponsePart[] = await Promise.all( + toolCalls.map(async (part) => { + if (!part.toolRequest) { + throw Error( + 'Tool request expected but not provided in tool request part' + ); + } + const tool = tools?.find( + (tool) => tool.__action.name === part.toolRequest?.name + ); + if (!tool) { + throw Error('Tool not found'); + } + return { + toolResponse: { + name: part.toolRequest.name, + ref: part.toolRequest.ref, + output: await tool(part.toolRequest?.input), + }, + }; + }) + ); + input.history = request.messages; + input.history.push(selected.message); + input.prompt = toolResponses; + return await generateAction(input); + } +); + +async function actionToGenerateRequest( + options: z.infer, + resolvedTools?: ToolAction[] +): Promise { + const promptMessage: MessageData = { role: 'user', content: [] }; + if (typeof options.prompt === 'string') { + promptMessage.content.push({ text: options.prompt }); + } else if (Array.isArray(options.prompt)) { + promptMessage.role = inferRoleFromParts(options.prompt); + promptMessage.content.push(...(options.prompt as Part[])); + } else { + promptMessage.role = inferRoleFromParts([options.prompt]); + promptMessage.content.push(options.prompt); + } + const messages: MessageData[] = [...(options.history || []), promptMessage]; + + const out = { + messages, + candidates: options.candidates, + config: options.config, + context: options.context, + tools: resolvedTools?.map((tool) => toToolDefinition(tool)) || [], + output: { + format: + options.output?.format || + (options.output?.jsonSchema ? 'json' : 'text'), + schema: toJsonSchema({ + jsonSchema: options.output?.jsonSchema, + }), + }, + }; + if (!out.output.schema) delete out.output.schema; + return out; +} + +const isValidCandidate = ( + candidate: CandidateData, + tools: Action[] +): boolean => { + // Check if tool calls are vlaid + const toolCalls = candidate.message.content.filter( + (part) => !!part.toolRequest + ); + + // make sure every tool called exists and has valid input + return toolCalls.every((toolCall) => { + const tool = tools?.find( + (tool) => tool.__action.name === toolCall.toolRequest?.name + ); + if (!tool) return false; + const { valid } = validateSchema(toolCall.toolRequest?.input, { + schema: tool.__action.inputSchema, + jsonSchema: tool.__action.inputJsonSchema, + }); + return valid; + }); +}; + +export function inferRoleFromParts(parts: Part[]): Role { + const uniqueRoles = new Set(); + for (const part of parts) { + const role = getRoleFromPart(part); + uniqueRoles.add(role); + if (uniqueRoles.size > 1) { + throw new Error('Contents contain mixed roles'); + } + } + return Array.from(uniqueRoles)[0]; +} + +function getRoleFromPart(part: Part): Role { + if (part.toolRequest !== undefined) return 'model'; + if (part.toolResponse !== undefined) return 'tool'; + if (part.text !== undefined) return 'user'; + if (part.media !== undefined) return 'user'; + if (part.data !== undefined) return 'user'; + throw new Error('No recognized fields in content'); +} From cd21a7e524dea552decda4a11770bc39fb9f1a7b Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 12 Aug 2024 11:39:13 -0400 Subject: [PATCH 6/7] Update js/ai/src/generateAction.ts Co-authored-by: Michael Bleigh --- js/ai/src/generateAction.ts | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/js/ai/src/generateAction.ts b/js/ai/src/generateAction.ts index ecdf682dc..e58bee210 100644 --- a/js/ai/src/generateAction.ts +++ b/js/ai/src/generateAction.ts @@ -212,10 +212,12 @@ export const generateAction = defineAction( }; }) ); - input.history = request.messages; - input.history.push(selected.message); - input.prompt = toolResponses; - return await generateAction(input); + const nextRequest = { + ...input, + history: [...request.messages, selected.message], + prompt: toolResponses, + } + return await generateAction(nextRequest); } ); From c3ef78c1443ab1337b39cf64924b2933dbd9156c Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 12 Aug 2024 11:40:30 -0400 Subject: [PATCH 7/7] format --- js/ai/src/generateAction.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/ai/src/generateAction.ts b/js/ai/src/generateAction.ts index e58bee210..2938d1ab2 100644 --- a/js/ai/src/generateAction.ts +++ b/js/ai/src/generateAction.ts @@ -216,7 +216,7 @@ export const generateAction = defineAction( ...input, history: [...request.messages, selected.message], prompt: toolResponses, - } + }; return await generateAction(nextRequest); } );