Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: [JS] introduced a generate utility action to make generate veneer logic reusable #759

Merged
merged 7 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 44 additions & 166 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@

import {
Action,
config as genkitConfig,
GenkitError,
runWithStreamingCallback,
StreamingCallback,
config as genkitConfig,
runWithStreamingCallback,
} from '@genkit-ai/core';
import { lookupAction } from '@genkit-ai/core/registry';
import {
parseSchema,
toJsonSchema,
validateSchema,
} from '@genkit-ai/core/schema';
import { toJsonSchema, validateSchema } from '@genkit-ai/core/schema';
import { z } from 'zod';
import { DocumentData } from './document.js';
import { extractJson } from './extract.js';
import {
GenerateUtilParamSchema,
generateAction,
inferRoleFromParts,
} from './generateAction.js';
import {
CandidateData,
GenerateRequest,
Expand All @@ -42,16 +43,11 @@ import {
ModelArgument,
ModelReference,
Part,
Role,
ToolDefinition,
ToolRequestPart,
ToolResponsePart,
} from './model.js';
import {
resolveTools,
ToolAction,
ToolArgument,
toToolDefinition,
} from './tool.js';
import { ToolArgument, resolveTools, toToolDefinition } from './tool.js';

/**
* Message represents a single role's contribution to a generation. Each message
Expand Down Expand Up @@ -425,27 +421,6 @@ export class GenerateResponseChunk<T = unknown>
}
}

function getRoleFromPart(part: Part): Role {
if (part.toolRequest !== undefined) return 'model';
if (part.toolResponse !== undefined) return 'tool';
if (part.text !== undefined) return 'user';
if (part.media !== undefined) return 'user';
if (part.data !== undefined) return 'user';
throw new Error('No recognized fields in content');
}

function inferRoleFromParts(parts: Part[]): Role {
const uniqueRoles = new Set<Role>();
for (const part of parts) {
const role = getRoleFromPart(part);
uniqueRoles.add(role);
if (uniqueRoles.size > 1) {
throw new Error('Contents contain mixed roles');
}
}
return Array.from(uniqueRoles)[0];
}

export async function toGenerateRequest(
options: GenerateOptions
): Promise<GenerateRequest> {
Expand Down Expand Up @@ -517,29 +492,6 @@ export interface GenerateOptions<
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
}

const isValidCandidate = (
candidate: CandidateData,
tools: Action<any, any>[]
): boolean => {
// Check if tool calls are vlaid
const toolCalls = candidate.message.content.filter(
(part) => !!part.toolRequest
);

// make sure every tool called exists and has valid input
return toolCalls.every((toolCall) => {
const tool = tools?.find(
(tool) => tool.__action.name === toolCall.toolRequest?.name
);
if (!tool) return false;
const { valid } = validateSchema(toolCall.toolRequest?.input, {
schema: tool.__action.inputSchema,
jsonSchema: tool.__action.inputJsonSchema,
});
return valid;
});
};

async function resolveModel(options: GenerateOptions): Promise<ModelAction> {
let model = options.model;
if (!model) {
Expand Down Expand Up @@ -604,7 +556,6 @@ export class NoValidCandidatesError extends GenkitError {
* @param options The options for this generation request.
* @returns The generated response based on the provided parameters.
*/

export async function generate<
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema,
Expand All @@ -620,120 +571,47 @@ export async function generate<
throw new Error(`Model ${JSON.stringify(resolvedOptions.model)} not found`);
}

let tools: ToolAction[] | undefined;
if (resolvedOptions.tools?.length) {
if (!model.__action.metadata?.model.supports?.tools) {
// convert tools to action refs (strings).
let tools: (string | ToolDefinition)[] | undefined;
if (resolvedOptions.tools) {
tools = resolvedOptions.tools.map((t) => {
if (typeof t === 'string') {
return `/tool/${t}`;
} else if ((t as Action).__action) {
return `/${(t as Action).__action.metadata?.type}/${(t as Action).__action.name}`;
} else if (t.name) {
return `/tool/${t.name}`;
}
throw new Error(
`Model ${JSON.stringify(resolvedOptions.model)} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.`
`Unable to determine type of of tool: ${JSON.stringify(t)}`
);
}
tools = await resolveTools(resolvedOptions.tools);
}

const request = await toGenerateRequest(resolvedOptions);

const accumulatedChunks: GenerateResponseChunkData[] = [];

const response = await runWithStreamingCallback(
resolvedOptions.streamingCallback
? (chunk: GenerateResponseChunkData) => {
// Store accumulated chunk data
accumulatedChunks.push(chunk);
if (resolvedOptions.streamingCallback) {
resolvedOptions.streamingCallback!(
new GenerateResponseChunk(chunk, accumulatedChunks)
);
}
}
: undefined,
async () => new GenerateResponse<z.infer<O>>(await model(request), request)
);

// throw NoValidCandidates if all candidates are blocked or
if (
!response.candidates.some((c) =>
['stop', 'length'].includes(c.finishReason)
)
) {
throw new NoValidCandidatesError({
message: `All candidates returned finishReason issues: ${JSON.stringify(response.candidates.map((c) => c.finishReason))}`,
response,
});
}

if (resolvedOptions.output?.schema || resolvedOptions.output?.jsonSchema) {
// find a candidate with valid output schema
const candidateErrors = response.candidates.map((c) => {
// don't validate messages that have no text or data
if (c.text() === '' && c.data() === null) return null;

try {
parseSchema(c.output(), {
jsonSchema: resolvedOptions.output?.jsonSchema,
schema: resolvedOptions.output?.schema,
});
return null;
} catch (e) {
return e as Error;
}
});
// if all candidates have a non-null error...
if (candidateErrors.every((c) => !!c)) {
throw new NoValidCandidatesError({
message: `Generation resulted in no candidates matching provided output schema.${candidateErrors.map((e, i) => `\n\nCandidate[${i}] ${e!.toString()}`)}`,
response,
detail: {
candidateErrors: candidateErrors,
},
});
}
}

// Pick the first valid candidate.
let selected: Candidate<z.TypeOf<O>> | undefined;
for (const candidate of response.candidates) {
if (isValidCandidate(candidate, tools || [])) {
selected = candidate;
break;
}
}

if (!selected) {
throw new Error('No valid candidates found');
}
const params: z.infer<typeof GenerateUtilParamSchema> = {
model: model.__action.name,
prompt: resolvedOptions.prompt,
context: resolvedOptions.context,
history: resolvedOptions.history,
tools,
candidates: resolvedOptions.candidates,
config: resolvedOptions.config,
output: resolvedOptions.output && {
format: resolvedOptions.output.format,
jsonSchema: resolvedOptions.output.schema
? toJsonSchema({
schema: resolvedOptions.output.schema,
jsonSchema: resolvedOptions.output.jsonSchema,
})
: resolvedOptions.output.jsonSchema,
},
returnToolRequests: resolvedOptions.returnToolRequests,
};

const toolCalls = selected.message.content.filter(
(part) => !!part.toolRequest
);
if (resolvedOptions.returnToolRequests || toolCalls.length === 0) {
return response;
}
const toolResponses: ToolResponsePart[] = await Promise.all(
toolCalls.map(async (part) => {
if (!part.toolRequest) {
throw Error(
'Tool request expected but not provided in tool request part'
);
}
const tool = tools?.find(
(tool) => tool.__action.name === part.toolRequest?.name
);
if (!tool) {
throw Error('Tool not found');
}
return {
toolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output: await tool(part.toolRequest?.input),
},
};
})
return await runWithStreamingCallback(
resolvedOptions.streamingCallback,
async () => new GenerateResponse<O>(await generateAction(params))
);
resolvedOptions.history = request.messages;
resolvedOptions.history.push(selected.message);
resolvedOptions.prompt = toolResponses;
return await generate(resolvedOptions);
}

export type GenerateStreamOptions<
Expand Down
Loading
Loading