Skip to content

Commit

Permalink
refactor: [JS] introduced a generate utility action to make generate …
Browse files Browse the repository at this point in the history
…veneer logic reusable (#759)

Co-authored-by: Michael Bleigh <[email protected]>
  • Loading branch information
pavelgj and mbleigh authored Aug 12, 2024
1 parent 1dba5f0 commit b1e96f1
Show file tree
Hide file tree
Showing 4 changed files with 349 additions and 169 deletions.
210 changes: 44 additions & 166 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@

import {
Action,
config as genkitConfig,
GenkitError,
runWithStreamingCallback,
StreamingCallback,
config as genkitConfig,
runWithStreamingCallback,
} from '@genkit-ai/core';
import { lookupAction } from '@genkit-ai/core/registry';
import {
parseSchema,
toJsonSchema,
validateSchema,
} from '@genkit-ai/core/schema';
import { toJsonSchema, validateSchema } from '@genkit-ai/core/schema';
import { z } from 'zod';
import { DocumentData } from './document.js';
import { extractJson } from './extract.js';
import {
GenerateUtilParamSchema,
generateAction,
inferRoleFromParts,
} from './generateAction.js';
import {
CandidateData,
GenerateRequest,
Expand All @@ -42,16 +43,11 @@ import {
ModelArgument,
ModelReference,
Part,
Role,
ToolDefinition,
ToolRequestPart,
ToolResponsePart,
} from './model.js';
import {
resolveTools,
ToolAction,
ToolArgument,
toToolDefinition,
} from './tool.js';
import { ToolArgument, resolveTools, toToolDefinition } from './tool.js';

/**
* Message represents a single role's contribution to a generation. Each message
Expand Down Expand Up @@ -425,27 +421,6 @@ export class GenerateResponseChunk<T = unknown>
}
}

function getRoleFromPart(part: Part): Role {
if (part.toolRequest !== undefined) return 'model';
if (part.toolResponse !== undefined) return 'tool';
if (part.text !== undefined) return 'user';
if (part.media !== undefined) return 'user';
if (part.data !== undefined) return 'user';
throw new Error('No recognized fields in content');
}

function inferRoleFromParts(parts: Part[]): Role {
const uniqueRoles = new Set<Role>();
for (const part of parts) {
const role = getRoleFromPart(part);
uniqueRoles.add(role);
if (uniqueRoles.size > 1) {
throw new Error('Contents contain mixed roles');
}
}
return Array.from(uniqueRoles)[0];
}

export async function toGenerateRequest(
options: GenerateOptions
): Promise<GenerateRequest> {
Expand Down Expand Up @@ -517,29 +492,6 @@ export interface GenerateOptions<
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
}

const isValidCandidate = (
candidate: CandidateData,
tools: Action<any, any>[]
): boolean => {
// Check if tool calls are vlaid
const toolCalls = candidate.message.content.filter(
(part) => !!part.toolRequest
);

// make sure every tool called exists and has valid input
return toolCalls.every((toolCall) => {
const tool = tools?.find(
(tool) => tool.__action.name === toolCall.toolRequest?.name
);
if (!tool) return false;
const { valid } = validateSchema(toolCall.toolRequest?.input, {
schema: tool.__action.inputSchema,
jsonSchema: tool.__action.inputJsonSchema,
});
return valid;
});
};

async function resolveModel(options: GenerateOptions): Promise<ModelAction> {
let model = options.model;
if (!model) {
Expand Down Expand Up @@ -604,7 +556,6 @@ export class NoValidCandidatesError extends GenkitError {
* @param options The options for this generation request.
* @returns The generated response based on the provided parameters.
*/

export async function generate<
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema,
Expand All @@ -620,120 +571,47 @@ export async function generate<
throw new Error(`Model ${JSON.stringify(resolvedOptions.model)} not found`);
}

let tools: ToolAction[] | undefined;
if (resolvedOptions.tools?.length) {
if (!model.__action.metadata?.model.supports?.tools) {
// convert tools to action refs (strings).
let tools: (string | ToolDefinition)[] | undefined;
if (resolvedOptions.tools) {
tools = resolvedOptions.tools.map((t) => {
if (typeof t === 'string') {
return `/tool/${t}`;
} else if ((t as Action).__action) {
return `/${(t as Action).__action.metadata?.type}/${(t as Action).__action.name}`;
} else if (t.name) {
return `/tool/${t.name}`;
}
throw new Error(
`Model ${JSON.stringify(resolvedOptions.model)} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.`
`Unable to determine type of of tool: ${JSON.stringify(t)}`
);
}
tools = await resolveTools(resolvedOptions.tools);
}

const request = await toGenerateRequest(resolvedOptions);

const accumulatedChunks: GenerateResponseChunkData[] = [];

const response = await runWithStreamingCallback(
resolvedOptions.streamingCallback
? (chunk: GenerateResponseChunkData) => {
// Store accumulated chunk data
accumulatedChunks.push(chunk);
if (resolvedOptions.streamingCallback) {
resolvedOptions.streamingCallback!(
new GenerateResponseChunk(chunk, accumulatedChunks)
);
}
}
: undefined,
async () => new GenerateResponse<z.infer<O>>(await model(request), request)
);

// throw NoValidCandidates if all candidates are blocked or
if (
!response.candidates.some((c) =>
['stop', 'length'].includes(c.finishReason)
)
) {
throw new NoValidCandidatesError({
message: `All candidates returned finishReason issues: ${JSON.stringify(response.candidates.map((c) => c.finishReason))}`,
response,
});
}

if (resolvedOptions.output?.schema || resolvedOptions.output?.jsonSchema) {
// find a candidate with valid output schema
const candidateErrors = response.candidates.map((c) => {
// don't validate messages that have no text or data
if (c.text() === '' && c.data() === null) return null;

try {
parseSchema(c.output(), {
jsonSchema: resolvedOptions.output?.jsonSchema,
schema: resolvedOptions.output?.schema,
});
return null;
} catch (e) {
return e as Error;
}
});
// if all candidates have a non-null error...
if (candidateErrors.every((c) => !!c)) {
throw new NoValidCandidatesError({
message: `Generation resulted in no candidates matching provided output schema.${candidateErrors.map((e, i) => `\n\nCandidate[${i}] ${e!.toString()}`)}`,
response,
detail: {
candidateErrors: candidateErrors,
},
});
}
}

// Pick the first valid candidate.
let selected: Candidate<z.TypeOf<O>> | undefined;
for (const candidate of response.candidates) {
if (isValidCandidate(candidate, tools || [])) {
selected = candidate;
break;
}
}

if (!selected) {
throw new Error('No valid candidates found');
}
const params: z.infer<typeof GenerateUtilParamSchema> = {
model: model.__action.name,
prompt: resolvedOptions.prompt,
context: resolvedOptions.context,
history: resolvedOptions.history,
tools,
candidates: resolvedOptions.candidates,
config: resolvedOptions.config,
output: resolvedOptions.output && {
format: resolvedOptions.output.format,
jsonSchema: resolvedOptions.output.schema
? toJsonSchema({
schema: resolvedOptions.output.schema,
jsonSchema: resolvedOptions.output.jsonSchema,
})
: resolvedOptions.output.jsonSchema,
},
returnToolRequests: resolvedOptions.returnToolRequests,
};

const toolCalls = selected.message.content.filter(
(part) => !!part.toolRequest
);
if (resolvedOptions.returnToolRequests || toolCalls.length === 0) {
return response;
}
const toolResponses: ToolResponsePart[] = await Promise.all(
toolCalls.map(async (part) => {
if (!part.toolRequest) {
throw Error(
'Tool request expected but not provided in tool request part'
);
}
const tool = tools?.find(
(tool) => tool.__action.name === part.toolRequest?.name
);
if (!tool) {
throw Error('Tool not found');
}
return {
toolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output: await tool(part.toolRequest?.input),
},
};
})
return await runWithStreamingCallback(
resolvedOptions.streamingCallback,
async () => new GenerateResponse<O>(await generateAction(params))
);
resolvedOptions.history = request.messages;
resolvedOptions.history.push(selected.message);
resolvedOptions.prompt = toolResponses;
return await generate(resolvedOptions);
}

export type GenerateStreamOptions<
Expand Down
Loading

0 comments on commit b1e96f1

Please sign in to comment.