Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: [JS] introduced a generate utility action to make generate veneer logic reusable #759

Merged
merged 7 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 8 additions & 267 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,36 @@ import {
Action,
GenkitError,
StreamingCallback,
defineAction,
config as genkitConfig,
getStreamingCallback,
runWithStreamingCallback,
} from '@genkit-ai/core';
import { lookupAction } from '@genkit-ai/core/registry';
import {
parseSchema,
toJsonSchema,
validateSchema,
} from '@genkit-ai/core/schema';
import { toJsonSchema, validateSchema } from '@genkit-ai/core/schema';
import { z } from 'zod';
import { DocumentData, DocumentDataSchema } from './document.js';
import { DocumentData } from './document.js';
import { extractJson } from './extract.js';
import {
GenerateUtilParamSchema,
generateAction,
inferRoleFromParts,
} from './generateAction.js';
import {
CandidateData,
GenerateRequest,
GenerateResponseChunkData,
GenerateResponseData,
GenerateResponseSchema,
GenerationCommonConfigSchema,
GenerationUsage,
MessageData,
MessageSchema,
ModelAction,
ModelArgument,
ModelReference,
Part,
PartSchema,
Role,
ToolDefinition,
ToolDefinitionSchema,
ToolRequestPart,
ToolResponsePart,
} from './model.js';
import {
ToolAction,
ToolArgument,
resolveTools,
toToolDefinition,
} from './tool.js';
import { ToolArgument, resolveTools, toToolDefinition } from './tool.js';

/**
* Message represents a single role's contribution to a generation. Each message
Expand Down Expand Up @@ -432,62 +421,6 @@ export class GenerateResponseChunk<T = unknown>
}
}

function getRoleFromPart(part: Part): Role {
if (part.toolRequest !== undefined) return 'model';
if (part.toolResponse !== undefined) return 'tool';
if (part.text !== undefined) return 'user';
if (part.media !== undefined) return 'user';
if (part.data !== undefined) return 'user';
throw new Error('No recognized fields in content');
}

function inferRoleFromParts(parts: Part[]): Role {
const uniqueRoles = new Set<Role>();
for (const part of parts) {
const role = getRoleFromPart(part);
uniqueRoles.add(role);
if (uniqueRoles.size > 1) {
throw new Error('Contents contain mixed roles');
}
}
return Array.from(uniqueRoles)[0];
}

async function actionToGenerateRequest(
options: z.infer<typeof GenerateUtilParamSchema>,
resolvedTools?: ToolAction[]
): Promise<GenerateRequest> {
const promptMessage: MessageData = { role: 'user', content: [] };
if (typeof options.prompt === 'string') {
promptMessage.content.push({ text: options.prompt });
} else if (Array.isArray(options.prompt)) {
promptMessage.role = inferRoleFromParts(options.prompt);
promptMessage.content.push(...(options.prompt as Part[]));
} else {
promptMessage.role = inferRoleFromParts([options.prompt]);
promptMessage.content.push(options.prompt);
}
const messages: MessageData[] = [...(options.history || []), promptMessage];

const out = {
messages,
candidates: options.candidates,
config: options.config,
context: options.context,
tools: resolvedTools?.map((tool) => toToolDefinition(tool)) || [],
output: {
format:
options.output?.format ||
(options.output?.jsonSchema ? 'json' : 'text'),
schema: toJsonSchema({
jsonSchema: options.output?.jsonSchema,
}),
},
};
if (!out.output.schema) delete out.output.schema;
return out;
}

export async function toGenerateRequest(
options: GenerateOptions
): Promise<GenerateRequest> {
Expand Down Expand Up @@ -559,29 +492,6 @@ export interface GenerateOptions<
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
}

const isValidCandidate = (
candidate: CandidateData,
tools: Action<any, any>[]
): boolean => {
// Check if tool calls are vlaid
const toolCalls = candidate.message.content.filter(
(part) => !!part.toolRequest
);

// make sure every tool called exists and has valid input
return toolCalls.every((toolCall) => {
const tool = tools?.find(
(tool) => tool.__action.name === toolCall.toolRequest?.name
);
if (!tool) return false;
const { valid } = validateSchema(toolCall.toolRequest?.input, {
schema: tool.__action.inputSchema,
jsonSchema: tool.__action.inputJsonSchema,
});
return valid;
});
};

async function resolveModel(options: GenerateOptions): Promise<ModelAction> {
let model = options.model;
if (!model) {
Expand Down Expand Up @@ -635,175 +545,6 @@ export class NoValidCandidatesError extends GenkitError {
}
}

export const GenerateUtilParamSchema = z.object({
/** A model name (e.g. `vertexai/gemini-1.0-pro`). */
model: z.string(),
/** The prompt for which to generate a response. Can be a string for a simple text prompt or one or more parts for multi-modal prompts. */
prompt: z.union([z.string(), PartSchema, z.array(PartSchema)]),
/** Retrieved documents to be used as context for this generation. */
context: z.array(DocumentDataSchema).optional(),
/** Conversation history for multi-turn prompting when supported by the underlying model. */
history: z.array(MessageSchema).optional(),
/** List of registered tool names for this generation if supported by the underlying model. */
tools: z.array(z.union([z.string(), ToolDefinitionSchema])).optional(),
/** Number of candidate messages to generate. */
candidates: z.number().optional(),
/** Configuration for the generation request. */
config: z.any().optional(),
/** Configuration for the desired output of the request. Defaults to the model's default output if unspecified. */
output: z
.object({
format: z
.union([z.literal('text'), z.literal('json'), z.literal('media')])
.optional(),
jsonSchema: z.any().optional(),
})
.optional(),
/** When true, return tool calls for manual processing instead of automatically resolving them. */
returnToolRequests: z.boolean().optional(),
});

const generateAction = defineAction(
{
actionType: 'util',
name: 'generate',
inputSchema: GenerateUtilParamSchema,
outputSchema: GenerateResponseSchema,
},
async (input) => {
const model = (await lookupAction(`/model/${input.model}`)) as ModelAction;
if (!model) {
throw new Error(`Model ${input.model} not found`);
}

let tools: ToolAction[] | undefined;
if (input.tools?.length) {
if (!model.__action.metadata?.model.supports?.tools) {
throw new Error(
`Model ${input.model} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.`
);
}
tools = await Promise.all(
input.tools.map(async (toolRef) => {
if (typeof toolRef === 'string') {
const tool = (await lookupAction(toolRef)) as ToolAction;
if (!tool) {
throw new Error(`Tool ${toolRef} not found`);
}
return tool;
}
throw '';
})
);
}

const request = await actionToGenerateRequest(input, tools);

const accumulatedChunks: GenerateResponseChunkData[] = [];

const streamingCallback = getStreamingCallback();
const response = await runWithStreamingCallback(
streamingCallback
? (chunk: GenerateResponseChunkData) => {
// Store accumulated chunk data
accumulatedChunks.push(chunk);
if (streamingCallback) {
streamingCallback!(
new GenerateResponseChunk(chunk, accumulatedChunks)
);
}
}
: undefined,
async () => new GenerateResponse(await model(request))
);

// throw NoValidCandidates if all candidates are blocked or
if (
!response.candidates.some((c) =>
['stop', 'length'].includes(c.finishReason)
)
) {
throw new NoValidCandidatesError({
message: `All candidates returned finishReason issues: ${JSON.stringify(response.candidates.map((c) => c.finishReason))}`,
response,
});
}

if (input.output?.jsonSchema && !response.toolRequests()?.length) {
// find a candidate with valid output schema
const candidateErrors = response.candidates.map((c) => {
// don't validate messages that have no text or data
if (c.text() === '' && c.data() === null) return null;

try {
parseSchema(c.output(), {
jsonSchema: input.output?.jsonSchema,
});
return null;
} catch (e) {
return e as Error;
}
});
// if all candidates have a non-null error...
if (candidateErrors.every((c) => !!c)) {
throw new NoValidCandidatesError({
message: `Generation resulted in no candidates matching provided output schema.${candidateErrors.map((e, i) => `\n\nCandidate[${i}] ${e!.toString()}`)}`,
response,
detail: {
candidateErrors: candidateErrors,
},
});
}
}

// Pick the first valid candidate.
let selected: Candidate<any> | undefined;
for (const candidate of response.candidates) {
if (isValidCandidate(candidate, tools || [])) {
selected = candidate;
break;
}
}

if (!selected) {
throw new Error('No valid candidates found');
}

const toolCalls = selected.message.content.filter(
(part) => !!part.toolRequest
);
if (input.returnToolRequests || toolCalls.length === 0) {
return response.toJSON();
}
const toolResponses: ToolResponsePart[] = await Promise.all(
toolCalls.map(async (part) => {
if (!part.toolRequest) {
throw Error(
'Tool request expected but not provided in tool request part'
);
}
const tool = tools?.find(
(tool) => tool.__action.name === part.toolRequest?.name
);
if (!tool) {
throw Error('Tool not found');
}
return {
toolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output: await tool(part.toolRequest?.input),
},
};
})
);
input.history = request.messages;
input.history.push(selected.message);
input.prompt = toolResponses;
return await generateAction(input);
}
);

/**
* Generate calls a generative model based on the provided prompt and configuration. If
* `history` is provided, the generation will include a conversation history in its
Expand Down
Loading
Loading