diff --git a/js/ai/src/embedder.ts b/js/ai/src/embedder.ts index 5cd278130..7c24b4835 100644 --- a/js/ai/src/embedder.ts +++ b/js/ai/src/embedder.ts @@ -15,7 +15,7 @@ */ import { Action, defineAction, z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { Document, DocumentData, DocumentDataSchema } from './document.js'; export type EmbeddingBatch = { embedding: number[] }[]; @@ -68,6 +68,7 @@ function withMetadata( export function defineEmbedder< ConfigSchema extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, options: { name: string; configSchema?: ConfigSchema; @@ -76,6 +77,7 @@ export function defineEmbedder< runner: EmbedderFn ) { const embedder = defineAction( + registry, { actionType: 'embedder', name: options.name, @@ -111,13 +113,14 @@ export type EmbedderArgument< * A veneer for interacting with embedder models. */ export async function embed( + registry: Registry, params: EmbedderParams ): Promise { let embedder: EmbedderAction; if (typeof params.embedder === 'string') { - embedder = await lookupAction(`/embedder/${params.embedder}`); + embedder = await registry.lookupAction(`/embedder/${params.embedder}`); } else if (Object.hasOwnProperty.call(params.embedder, 'info')) { - embedder = await lookupAction( + embedder = await registry.lookupAction( `/embedder/${(params.embedder as EmbedderReference).name}` ); } else { @@ -141,17 +144,20 @@ export async function embed( */ export async function embedMany< ConfigSchema extends z.ZodTypeAny = z.ZodTypeAny, ->(params: { - embedder: EmbedderArgument; - content: string[] | DocumentData[]; - metadata?: Record; - options?: z.infer; -}): Promise { +>( + registry: Registry, + params: { + embedder: EmbedderArgument; + content: string[] | DocumentData[]; + metadata?: Record; + options?: z.infer; + } +): Promise { let embedder: EmbedderAction; if (typeof params.embedder === 'string') { - embedder = await lookupAction(`/embedder/${params.embedder}`); + embedder = await registry.lookupAction(`/embedder/${params.embedder}`); } else if (Object.hasOwnProperty.call(params.embedder, 'info')) { - embedder = await lookupAction( + embedder = await registry.lookupAction( `/embedder/${(params.embedder as EmbedderReference).name}` ); } else { diff --git a/js/ai/src/evaluator.ts b/js/ai/src/evaluator.ts index 042e069d9..02be11e48 100644 --- a/js/ai/src/evaluator.ts +++ b/js/ai/src/evaluator.ts @@ -16,7 +16,7 @@ import { Action, defineAction, z } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { SPAN_TYPE_ATTR, runInNewSpan } from '@genkit-ai/core/tracing'; import { randomUUID } from 'crypto'; @@ -127,6 +127,7 @@ export function defineEvaluator< typeof BaseEvalDataPointSchema = typeof BaseEvalDataPointSchema, EvaluatorOptions extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, options: { name: string; displayName: string; @@ -143,6 +144,7 @@ export function defineEvaluator< metadata[EVALUATOR_METADATA_KEY_DISPLAY_NAME] = options.displayName; metadata[EVALUATOR_METADATA_KEY_DEFINITION] = options.definition; const evaluator = defineAction( + registry, { actionType: 'evaluator', name: options.name, @@ -239,12 +241,17 @@ export type EvaluatorArgument< export async function evaluate< DataPoint extends typeof BaseDataPointSchema = typeof BaseDataPointSchema, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, ->(params: EvaluatorParams): Promise { +>( + registry: Registry, + params: EvaluatorParams +): Promise { let evaluator: EvaluatorAction; if (typeof params.evaluator === 'string') { - evaluator = await lookupAction(`/evaluator/${params.evaluator}`); + evaluator = await registry.lookupAction(`/evaluator/${params.evaluator}`); } else if (Object.hasOwnProperty.call(params.evaluator, 'info')) { - evaluator = await lookupAction(`/evaluator/${params.evaluator.name}`); + evaluator = await registry.lookupAction( + `/evaluator/${params.evaluator.name}` + ); } else { evaluator = params.evaluator as EvaluatorAction; } diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index b4e95716a..9527c86da 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -21,7 +21,7 @@ import { StreamingCallback, z, } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema'; import { DocumentData } from './document.js'; import { extractJson } from './extract.js'; @@ -365,6 +365,7 @@ export class GenerateResponseChunk } export async function toGenerateRequest( + registry: Registry, options: GenerateOptions ): Promise { const messages: MessageData[] = []; @@ -402,7 +403,7 @@ export async function toGenerateRequest( } let tools: Action[] | undefined; if (options.tools) { - tools = await resolveTools(options.tools); + tools = await resolveTools(registry, options.tools); } const out = { @@ -464,21 +465,28 @@ interface ResolvedModel { version?: string; } -async function resolveModel(options: GenerateOptions): Promise { +async function resolveModel( + registry: Registry, + options: GenerateOptions +): Promise { let model = options.model; if (!model) { throw new Error('Model is required.'); } if (typeof model === 'string') { return { - modelAction: (await lookupAction(`/model/${model}`)) as ModelAction, + modelAction: (await registry.lookupAction( + `/model/${model}` + )) as ModelAction, }; } else if (model.hasOwnProperty('__action')) { return { modelAction: model as ModelAction }; } else { const ref = model as ModelReference; return { - modelAction: (await lookupAction(`/model/${ref.name}`)) as ModelAction, + modelAction: (await registry.lookupAction( + `/model/${ref.name}` + )) as ModelAction, config: { ...ref.config, }, @@ -525,13 +533,14 @@ export async function generate< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, >( + registry: Registry, options: | GenerateOptions | PromiseLike> ): Promise>> { const resolvedOptions: GenerateOptions = await Promise.resolve(options); - const resolvedModel = await resolveModel(resolvedOptions); + const resolvedModel = await resolveModel(registry, resolvedOptions); const model = resolvedModel.modelAction; if (!model) { let modelId: string; @@ -623,8 +632,8 @@ export async function generate< resolvedOptions.streamingCallback, async () => new GenerateResponse( - await generateHelper(params, resolvedOptions.use), - await toGenerateRequest(resolvedOptions) + await generateHelper(registry, params, resolvedOptions.use), + await toGenerateRequest(registry, resolvedOptions) ) ); } @@ -653,6 +662,7 @@ export async function generateStream< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, >( + registry: Registry, options: | GenerateOptions | PromiseLike> @@ -678,7 +688,7 @@ export async function generateStream< } try { - generate({ + generate(registry, { ...options, streamingCallback: (chunk) => { firstChunkSent = true; diff --git a/js/ai/src/generateAction.ts b/js/ai/src/generateAction.ts index 996f77af3..7cb4f2d71 100644 --- a/js/ai/src/generateAction.ts +++ b/js/ai/src/generateAction.ts @@ -21,7 +21,7 @@ import { runWithStreamingCallback, z, } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema, validateSchema } from '@genkit-ai/core/schema'; import { runInNewSpan, SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing'; import * as clc from 'colorette'; @@ -70,6 +70,7 @@ export const GenerateUtilParamSchema = z.object({ * Encapsulates all generate logic. This is similar to `generateAction` except not an action and can take middleware. */ export async function generateHelper( + registry: Registry, input: z.infer, middleware?: Middleware[] ): Promise { @@ -86,7 +87,7 @@ export async function generateHelper( async (metadata) => { metadata.name = 'generate'; metadata.input = input; - const output = await generate(input, middleware); + const output = await generate(registry, input, middleware); metadata.output = JSON.stringify(output); return output; } @@ -94,10 +95,11 @@ export async function generateHelper( } async function generate( + registry: Registry, rawRequest: z.infer, middleware?: Middleware[] ): Promise { - const model = (await lookupAction( + const model = (await registry.lookupAction( `/model/${rawRequest.model}` )) as ModelAction; if (!model) { @@ -120,7 +122,7 @@ async function generate( tools = await Promise.all( rawRequest.tools.map(async (toolRef) => { if (typeof toolRef === 'string') { - const tool = (await lookupAction(toolRef)) as ToolAction; + const tool = (await registry.lookupAction(toolRef)) as ToolAction; if (!tool) { throw new Error(`Tool ${toolRef} not found`); } @@ -203,7 +205,7 @@ async function generate( messages: [...request.messages, message], prompt: toolResponses, }; - return await generateHelper(nextRequest, middleware); + return await generateHelper(registry, nextRequest, middleware); } async function actionToGenerateRequest( diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 70f4bd549..4881c3915 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -22,6 +22,7 @@ import { StreamingCallback, z, } from '@genkit-ai/core'; +import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { performance } from 'node:perf_hooks'; import { DocumentDataSchema } from './document.js'; @@ -330,6 +331,7 @@ export type DefineModelOptions< export function defineModel< CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, options: DefineModelOptions, runner: ( request: GenerateRequest, @@ -344,6 +346,7 @@ export function defineModel< if (!options?.supports?.context) middleware.push(augmentWithContext()); middleware.push(conformOutput()); const act = defineAction( + registry, { actionType: 'model', name: options.name, diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts index 0dda76e00..f497dca23 100644 --- a/js/ai/src/prompt.ts +++ b/js/ai/src/prompt.ts @@ -15,7 +15,7 @@ */ import { Action, defineAction, JSONSchema7, z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { DocumentData } from './document.js'; import { GenerateOptions, @@ -150,10 +150,12 @@ export interface ExecutablePrompt< * @returns The new `PromptAction`. */ export function definePrompt( + registry: Registry, config: PromptConfig, fn: PromptFn ): PromptAction { const a = defineAction( + registry, { ...config, actionType: 'prompt', @@ -177,16 +179,19 @@ export async function renderPrompt< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, ->(params: { - prompt: PromptArgument; - input: z.infer; - docs?: DocumentData[]; - model: ModelArgument; - config?: z.infer; -}): Promise> { +>( + registry: Registry, + params: { + prompt: PromptArgument; + input: z.infer; + docs?: DocumentData[]; + model: ModelArgument; + config?: z.infer; + } +): Promise> { let prompt: PromptAction; if (typeof params.prompt === 'string') { - prompt = await lookupAction(`/prompt/${params.prompt}`); + prompt = await registry.lookupAction(`/prompt/${params.prompt}`); } else { prompt = params.prompt as PromptAction; } diff --git a/js/ai/src/reranker.ts b/js/ai/src/reranker.ts index 54428d0cb..35d3b2505 100644 --- a/js/ai/src/reranker.ts +++ b/js/ai/src/reranker.ts @@ -15,7 +15,7 @@ */ import { Action, defineAction, z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { Part, PartSchema } from './document.js'; import { Document, DocumentData, DocumentDataSchema } from './retriever.js'; @@ -101,6 +101,7 @@ function rerankerWithMetadata< * Creates a reranker action for the provided {@link RerankerFn} implementation. */ export function defineReranker( + registry: Registry, options: { name: string; configSchema?: OptionsType; @@ -109,6 +110,7 @@ export function defineReranker( runner: RerankerFn ) { const reranker = defineAction( + registry, { actionType: 'reranker', name: options.name, @@ -157,13 +159,14 @@ export type RerankerArgument< * Reranks documents from a {@link RerankerArgument} based on the provided query. */ export async function rerank( + registry: Registry, params: RerankerParams ): Promise> { let reranker: RerankerAction; if (typeof params.reranker === 'string') { - reranker = await lookupAction(`/reranker/${params.reranker}`); + reranker = await registry.lookupAction(`/reranker/${params.reranker}`); } else if (Object.hasOwnProperty.call(params.reranker, 'info')) { - reranker = await lookupAction(`/reranker/${params.reranker.name}`); + reranker = await registry.lookupAction(`/reranker/${params.reranker.name}`); } else { reranker = params.reranker as RerankerAction; } diff --git a/js/ai/src/retriever.ts b/js/ai/src/retriever.ts index 0d3f23689..0623e297f 100644 --- a/js/ai/src/retriever.ts +++ b/js/ai/src/retriever.ts @@ -15,7 +15,7 @@ */ import { Action, GenkitError, defineAction, z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { Document, DocumentData, DocumentDataSchema } from './document.js'; import { EmbedderInfo } from './embedder.js'; @@ -111,6 +111,7 @@ function indexerWithMetadata< export function defineRetriever< OptionsType extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, options: { name: string; configSchema?: OptionsType; @@ -119,6 +120,7 @@ export function defineRetriever< runner: RetrieverFn ) { const retriever = defineAction( + registry, { actionType: 'retriever', name: options.name, @@ -149,6 +151,7 @@ export function defineRetriever< * Creates an indexer action for the provided {@link IndexerFn} implementation. */ export function defineIndexer( + registry: Registry, options: { name: string; embedderInfo?: EmbedderInfo; @@ -157,6 +160,7 @@ export function defineIndexer( runner: IndexerFn ) { const indexer = defineAction( + registry, { actionType: 'indexer', name: options.name, @@ -200,13 +204,16 @@ export type RetrieverArgument< * Retrieves documents from a {@link RetrieverArgument} based on the provided query. */ export async function retrieve( + registry: Registry, params: RetrieverParams ): Promise> { let retriever: RetrieverAction; if (typeof params.retriever === 'string') { - retriever = await lookupAction(`/retriever/${params.retriever}`); + retriever = await registry.lookupAction(`/retriever/${params.retriever}`); } else if (Object.hasOwnProperty.call(params.retriever, 'info')) { - retriever = await lookupAction(`/retriever/${params.retriever.name}`); + retriever = await registry.lookupAction( + `/retriever/${params.retriever.name}` + ); } else { retriever = params.retriever as RetrieverAction; } @@ -239,13 +246,14 @@ export interface IndexerParams< * Indexes documents using a {@link IndexerArgument}. */ export async function index( + registry: Registry, params: IndexerParams ): Promise { let indexer: IndexerAction; if (typeof params.indexer === 'string') { - indexer = await lookupAction(`/indexer/${params.indexer}`); + indexer = await registry.lookupAction(`/indexer/${params.indexer}`); } else if (Object.hasOwnProperty.call(params.indexer, 'info')) { - indexer = await lookupAction(`/indexer/${params.indexer.name}`); + indexer = await registry.lookupAction(`/indexer/${params.indexer.name}`); } else { indexer = params.indexer as IndexerAction; } @@ -381,10 +389,12 @@ export function defineSimpleRetriever< C extends z.ZodTypeAny = z.ZodTypeAny, R = any, >( + registry: Registry, options: SimpleRetrieverOptions, handler: (query: Document, config: z.infer) => Promise ) { return defineRetriever( + registry, { name: options.name, configSchema: options.configSchema, diff --git a/js/ai/src/testing/model-tester.ts b/js/ai/src/testing/model-tester.ts index 3cf041ae9..7caa4b0cc 100644 --- a/js/ai/src/testing/model-tester.ts +++ b/js/ai/src/testing/model-tester.ts @@ -15,7 +15,7 @@ */ import { z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { runInNewSpan } from '@genkit-ai/core/tracing'; import assert from 'node:assert'; import { generate } from '../generate'; @@ -23,8 +23,8 @@ import { ModelAction } from '../model'; import { defineTool } from '../tool'; const tests: Record = { - 'basic hi': async (model: string) => { - const response = await generate({ + 'basic hi': async (registry: Registry, model: string) => { + const response = await generate(registry, { model, prompt: 'just say "Hi", literally', }); @@ -32,14 +32,14 @@ const tests: Record = { const got = response.text.trim(); assert.match(got, /Hi/i); }, - multimodal: async (model: string) => { - const resolvedModel = (await lookupAction( + multimodal: async (registry: Registry, model: string) => { + const resolvedModel = (await registry.lookupAction( `/model/${model}` )) as ModelAction; if (!resolvedModel.__action.metadata?.model.supports?.media) { skip(); } - const response = await generate({ + const response = await generate(registry, { model, prompt: [ { @@ -57,18 +57,18 @@ const tests: Record = { const got = response.text.trim(); assert.match(got, /plus/i); }, - history: async (model: string) => { - const resolvedModel = (await lookupAction( + history: async (registry: Registry, model: string) => { + const resolvedModel = (await registry.lookupAction( `/model/${model}` )) as ModelAction; if (!resolvedModel.__action.metadata?.model.supports?.multiturn) { skip(); } - const response1 = await generate({ + const response1 = await generate(registry, { model, prompt: 'My name is Glorb', }); - const response = await generate({ + const response = await generate(registry, { model, prompt: "What's my name?", messages: response1.messages, @@ -77,8 +77,8 @@ const tests: Record = { const got = response.text.trim(); assert.match(got, /Glorb/); }, - 'system prompt': async (model: string) => { - const { text } = await generate({ + 'system prompt': async (registry: Registry, model: string) => { + const { text } = await generate(registry, { model, prompt: 'Hi', messages: [ @@ -97,8 +97,8 @@ const tests: Record = { const got = text.trim(); assert.equal(got, want); }, - 'structured output': async (model: string) => { - const response = await generate({ + 'structured output': async (registry: Registry, model: string) => { + const response = await generate(registry, { model, prompt: 'extract data as json from: Jack was a Lumberjack', output: { @@ -117,15 +117,15 @@ const tests: Record = { const got = response.output; assert.deepEqual(want, got); }, - 'tool calling': async (model: string) => { - const resolvedModel = (await lookupAction( + 'tool calling': async (registry: Registry, model: string) => { + const resolvedModel = (await registry.lookupAction( `/model/${model}` )) as ModelAction; if (!resolvedModel.__action.metadata?.model.supports?.tools) { skip(); } - const { text } = await generate({ + const { text } = await generate(registry, { model, prompt: 'what is a gablorken of 2? use provided tool', tools: ['gablorkenTool'], @@ -149,10 +149,14 @@ type TestReport = { }[]; }[]; -type TestCase = (model: string) => Promise; +type TestCase = (ai: Registry, model: string) => Promise; -export async function testModels(models: string[]): Promise { +export async function testModels( + registry: Registry, + models: string[] +): Promise { const gablorkenTool = defineTool( + registry, { name: 'gablorkenTool', description: 'use when need to calculate a gablorken', @@ -182,7 +186,7 @@ export async function testModels(models: string[]): Promise { }); const modelReport = caseReport.models[caseReport.models.length - 1]; try { - await tests[test](model); + await tests[test](registry, model); } catch (e) { modelReport.passed = false; if (e instanceof SkipTestError) { diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index 9dcb61c4f..a0d85340c 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -15,7 +15,7 @@ */ import { Action, defineAction, JSONSchema7, z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing'; import { ToolDefinition } from './model.js'; @@ -89,11 +89,11 @@ export function asTool( export async function resolveTools< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, ->(tools: ToolArgument[] = []): Promise { +>(registry: Registry, tools: ToolArgument[] = []): Promise { return await Promise.all( tools.map(async (ref): Promise => { if (typeof ref === 'string') { - const tool = await lookupAction(`/tool/${ref}`); + const tool = await registry.lookupAction(`/tool/${ref}`); if (!tool) { throw new Error(`Tool ${ref} not found`); } @@ -101,7 +101,7 @@ export async function resolveTools< } else if ((ref as Action).__action) { return asTool(ref as Action); } else if (ref.name) { - const tool = await lookupAction(`/tool/${ref.name}`); + const tool = await registry.lookupAction(`/tool/${ref.name}`); if (!tool) { throw new Error(`Tool ${ref} not found`); } @@ -137,10 +137,12 @@ export function toToolDefinition( * A tool is an action that can be passed to a model to be called automatically if it so chooses. */ export function defineTool( + registry: Registry, config: ToolConfig, fn: (input: z.infer) => Promise> ): ToolAction { const a = defineAction( + registry, { ...config, actionType: 'tool', diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index cae8e419c..9a02b0f6e 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -15,7 +15,7 @@ */ import { z } from '@genkit-ai/core'; -import { Registry, runWithRegistry } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; @@ -262,19 +262,18 @@ describe('GenerateResponse', () => { describe('toGenerateRequest', () => { const registry = new Registry(); // register tools - const tellAFunnyJoke = runWithRegistry(registry, () => - defineTool( - { - name: 'tellAFunnyJoke', - description: - 'Tells jokes about an input topic. Use this tool whenever user asks you to tell a joke.', - inputSchema: z.object({ topic: z.string() }), - outputSchema: z.string(), - }, - async (input) => { - return `Why did the ${input.topic} cross the road?`; - } - ) + const tellAFunnyJoke = defineTool( + registry, + { + name: 'tellAFunnyJoke', + description: + 'Tells jokes about an input topic. Use this tool whenever user asks you to tell a joke.', + inputSchema: z.object({ topic: z.string() }), + outputSchema: z.string(), + }, + async (input) => { + return `Why did the ${input.topic} cross the road?`; + } ); const testCases = [ @@ -442,9 +441,7 @@ describe('toGenerateRequest', () => { for (const test of testCases) { it(test.should, async () => { assert.deepStrictEqual( - await runWithRegistry(registry, () => - toGenerateRequest(test.prompt as GenerateOptions) - ), + await toGenerateRequest(registry, test.prompt as GenerateOptions), test.expectedOutput ); }); @@ -530,29 +527,28 @@ describe('generate', () => { beforeEach(() => { registry = new Registry(); - echoModel = runWithRegistry(registry, () => - defineModel( - { - name: 'echoModel', - }, - async (request) => { - return { - message: { - role: 'model', - content: [ - { - text: - 'Echo: ' + - request.messages - .map((m) => m.content.map((c) => c.text).join()) - .join(), - }, - ], - }, - finishReason: 'stop', - }; - } - ) + echoModel = defineModel( + registry, + { + name: 'echoModel', + }, + async (request) => { + return { + message: { + role: 'model', + content: [ + { + text: + 'Echo: ' + + request.messages + .map((m) => m.content.map((c) => c.text).join()) + .join(), + }, + ], + }, + finishReason: 'stop', + }; + } ); }); @@ -592,14 +588,11 @@ describe('generate', () => { }; }; - const response = await runWithRegistry(registry, () => - generate({ - prompt: 'banana', - model: echoModel, - use: [wrapRequest, wrapResponse], - }) - ); - + const response = await generate(registry, { + prompt: 'banana', + model: echoModel, + use: [wrapRequest, wrapResponse], + }); const want = '[Echo: (banana)]'; assert.deepStrictEqual(response.text, want); }); @@ -609,24 +602,21 @@ describe('generate', () => { let registry: Registry; beforeEach(() => { registry = new Registry(); - runWithRegistry(registry, () => - defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ) + + defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) ); }); it('should preserve the request in the returned response, enabling .messages', async () => { - const response = await runWithRegistry(registry, () => - generate({ - model: 'echo', - prompt: 'Testing messages', - }) - ); - + const response = await generate(registry, { + model: 'echo', + prompt: 'Testing messages', + }); assert.deepEqual( response.messages.map((m) => m.content[0].text), ['Testing messages', 'Testing messages'] diff --git a/js/ai/tests/model/middleware_test.ts b/js/ai/tests/model/middleware_test.ts index 9b3eb7054..3c9aaaffa 100644 --- a/js/ai/tests/model/middleware_test.ts +++ b/js/ai/tests/model/middleware_test.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { Registry, runWithRegistry } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; import { DocumentData } from '../../src/document.js'; @@ -147,24 +147,21 @@ describe('validateSupport', () => { }); const registry = new Registry(); -const echoModel = runWithRegistry(registry, () => - defineModel({ name: 'echo' }, async (req) => { - return { - finishReason: 'stop', - message: { - role: 'model', - content: [{ data: req }], - }, - }; - }) -); - +const echoModel = defineModel(registry, { name: 'echo' }, async (req) => { + return { + finishReason: 'stop', + message: { + role: 'model', + content: [{ data: req }], + }, + }; +}); describe('conformOutput (default middleware)', () => { const schema = { type: 'object', properties: { test: { type: 'boolean' } } }; // return the output tagged part from the request async function testRequest(req: GenerateRequest): Promise { - const response = await runWithRegistry(registry, () => echoModel(req)); + const response = await echoModel(req); const treq = response.message!.content[0].data as GenerateRequest; const lastUserMessage = treq.messages diff --git a/js/ai/tests/prompt/prompt_test.ts b/js/ai/tests/prompt/prompt_test.ts index c35c951c6..702f85444 100644 --- a/js/ai/tests/prompt/prompt_test.ts +++ b/js/ai/tests/prompt/prompt_test.ts @@ -15,7 +15,7 @@ */ import { z } from '@genkit-ai/core'; -import { Registry, runWithRegistry } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import assert from 'node:assert'; import { describe, it } from 'node:test'; import { definePrompt, renderPrompt } from '../../src/prompt.ts'; @@ -23,38 +23,37 @@ import { definePrompt, renderPrompt } from '../../src/prompt.ts'; describe('prompt', () => { let registry = new Registry(); describe('render()', () => { - runWithRegistry(registry, () => { - it('respects output schema in the definition', async () => { - const schema1 = z.object({ - puppyName: z.string({ description: 'A cute name for a puppy' }), - }); - const prompt1 = definePrompt( - { - name: 'prompt1', - inputSchema: z.string({ description: 'Dog breed' }), - }, - async (breed) => { - return { - messages: [ - { - role: 'user', - content: [{ text: `Pick a name for a ${breed} puppy` }], - }, - ], - output: { - format: 'json', - schema: schema1, + it('respects output schema in the definition', async () => { + const schema1 = z.object({ + puppyName: z.string({ description: 'A cute name for a puppy' }), + }); + const prompt1 = definePrompt( + registry, + { + name: 'prompt1', + inputSchema: z.string({ description: 'Dog breed' }), + }, + async (breed) => { + return { + messages: [ + { + role: 'user', + content: [{ text: `Pick a name for a ${breed} puppy` }], }, - }; - } - ); - const generateRequest = await renderPrompt({ - prompt: prompt1, - input: 'poodle', - model: 'geminiPro', - }); - assert.equal(generateRequest.output?.schema, schema1); + ], + output: { + format: 'json', + schema: schema1, + }, + }; + } + ); + const generateRequest = await renderPrompt(registry, { + prompt: prompt1, + input: 'poodle', + model: 'geminiPro', }); + assert.equal(generateRequest.output?.schema, schema1); }); }); }); diff --git a/js/ai/tests/reranker/reranker_test.ts b/js/ai/tests/reranker/reranker_test.ts index 4942e02b6..63a8b25e4 100644 --- a/js/ai/tests/reranker/reranker_test.ts +++ b/js/ai/tests/reranker/reranker_test.ts @@ -15,7 +15,7 @@ */ import { GenkitError, z } from '@genkit-ai/core'; -import { Registry, runWithRegistry } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; import { defineReranker, rerank } from '../../src/reranker'; @@ -28,34 +28,32 @@ describe('reranker', () => { registry = new Registry(); }); it('reranks documents based on custom logic', async () => { - const customReranker = runWithRegistry(registry, () => - defineReranker( - { - name: 'reranker', - configSchema: z.object({ - k: z.number().optional(), - }), - }, - async (query, documents, options) => { - // Custom reranking logic: score based on string length similarity to query - const queryLength = query.text.length; - const rerankedDocs = documents.map((doc) => { - const score = Math.abs(queryLength - doc.text.length); - return { - ...doc, - metadata: { ...doc.metadata, score }, - }; - }); - + const customReranker = defineReranker( + registry, + { + name: 'reranker', + configSchema: z.object({ + k: z.number().optional(), + }), + }, + async (query, documents, options) => { + // Custom reranking logic: score based on string length similarity to query + const queryLength = query.text.length; + const rerankedDocs = documents.map((doc) => { + const score = Math.abs(queryLength - doc.text.length); return { - documents: rerankedDocs - .sort((a, b) => a.metadata.score - b.metadata.score) - .slice(0, options.k || 3), + ...doc, + metadata: { ...doc.metadata, score }, }; - } - ) + }); + + return { + documents: rerankedDocs + .sort((a, b) => a.metadata.score - b.metadata.score) + .slice(0, options.k || 3), + }; + } ); - // Sample documents for testing const documents = [ Document.fromText('short'), @@ -64,15 +62,12 @@ describe('reranker', () => { ]; const query = Document.fromText('medium length'); - const rerankedDocuments = await runWithRegistry(registry, () => - rerank({ - reranker: customReranker, - query, - documents, - options: { k: 2 }, - }) - ); - + const rerankedDocuments = await rerank(registry, { + reranker: customReranker, + query, + documents, + options: { k: 2 }, + }); // Validate the reranked results assert.equal(rerankedDocuments.length, 2); assert(rerankedDocuments[0].text.includes('a bit longer')); @@ -80,85 +75,76 @@ describe('reranker', () => { }); it('handles missing options gracefully', async () => { - const customReranker = runWithRegistry(registry, () => - defineReranker( - { - name: 'reranker', - configSchema: z.object({ - k: z.number().optional(), - }), - }, - async (query, documents, options) => { - const rerankedDocs = documents.map((doc) => { - const score = Math.random(); // Simplified scoring for testing - return { - ...doc, - metadata: { ...doc.metadata, score }, - }; - }); - + const customReranker = defineReranker( + registry, + { + name: 'reranker', + configSchema: z.object({ + k: z.number().optional(), + }), + }, + async (query, documents, options) => { + const rerankedDocs = documents.map((doc) => { + const score = Math.random(); // Simplified scoring for testing return { - documents: rerankedDocs.sort( - (a, b) => b.metadata.score - a.metadata.score - ), + ...doc, + metadata: { ...doc.metadata, score }, }; - } - ) + }); + + return { + documents: rerankedDocs.sort( + (a, b) => b.metadata.score - a.metadata.score + ), + }; + } ); - const documents = [Document.fromText('doc1'), Document.fromText('doc2')]; const query = Document.fromText('test query'); - const rerankedDocuments = await runWithRegistry(registry, () => - rerank({ - reranker: customReranker, - query, - documents, - options: { k: 2 }, - }) - ); - + const rerankedDocuments = await rerank(registry, { + reranker: customReranker, + query, + documents, + options: { k: 2 }, + }); assert.equal(rerankedDocuments.length, 2); assert(typeof rerankedDocuments[0].metadata.score === 'number'); }); it('validates config schema and throws error on invalid input', async () => { - const customReranker = runWithRegistry(registry, () => - defineReranker( - { - name: 'reranker', - configSchema: z.object({ - k: z.number().min(1), - }), - }, - async (query, documents, options) => { - // Simplified scoring for testing - const rerankedDocs = documents.map((doc) => ({ - ...doc, - metadata: { score: Math.random() }, - })); - return { - documents: rerankedDocs.sort( - (a, b) => b.metadata.score - a.metadata.score - ), - }; - } - ) + const customReranker = defineReranker( + registry, + { + name: 'reranker', + configSchema: z.object({ + k: z.number().min(1), + }), + }, + async (query, documents, options) => { + // Simplified scoring for testing + const rerankedDocs = documents.map((doc) => ({ + ...doc, + metadata: { score: Math.random() }, + })); + return { + documents: rerankedDocs.sort( + (a, b) => b.metadata.score - a.metadata.score + ), + }; + } ); - const documents = [Document.fromText('doc1')]; const query = Document.fromText('test query'); try { - await runWithRegistry(registry, () => - rerank({ - reranker: customReranker, - query, - documents, - options: { k: 0 }, // Invalid input: k must be at least 1 - }) - ); + await rerank(registry, { + reranker: customReranker, + query, + documents, + options: { k: 0 }, // Invalid input: k must be at least 1 + }); assert.fail('Expected validation error'); } catch (err) { assert(err instanceof GenkitError); @@ -167,71 +153,62 @@ describe('reranker', () => { }); it('preserves document metadata after reranking', async () => { - const customReranker = runWithRegistry(registry, () => - defineReranker( - { - name: 'reranker', - }, - async (query, documents) => { - const rerankedDocs = documents.map((doc, i) => ({ - ...doc, - metadata: { ...doc.metadata, score: 2 - i }, - })); - - return { - documents: rerankedDocs.sort( - (a, b) => b.metadata.score - a.metadata.score - ), - }; - } - ) + const customReranker = defineReranker( + registry, + { + name: 'reranker', + }, + async (query, documents) => { + const rerankedDocs = documents.map((doc, i) => ({ + ...doc, + metadata: { ...doc.metadata, score: 2 - i }, + })); + + return { + documents: rerankedDocs.sort( + (a, b) => b.metadata.score - a.metadata.score + ), + }; + } ); - const documents = [ new Document({ content: [], metadata: { originalField: 'test1' } }), new Document({ content: [], metadata: { originalField: 'test2' } }), ]; const query = Document.fromText('test query'); - const rerankedDocuments = await runWithRegistry(registry, () => - rerank({ - reranker: customReranker, - query, - documents, - }) - ); - + const rerankedDocuments = await rerank(registry, { + reranker: customReranker, + query, + documents, + }); assert.equal(rerankedDocuments[0].metadata.originalField, 'test1'); assert.equal(rerankedDocuments[1].metadata.originalField, 'test2'); }); it('handles errors thrown by the reranker', async () => { - const customReranker = runWithRegistry(registry, () => - defineReranker( - { - name: 'reranker', - }, - async (query, documents) => { - // Simulate an error in the reranker logic - throw new GenkitError({ - message: 'Something went wrong during reranking', - status: 'INTERNAL', - }); - } - ) + const customReranker = defineReranker( + registry, + { + name: 'reranker', + }, + async (query, documents) => { + // Simulate an error in the reranker logic + throw new GenkitError({ + message: 'Something went wrong during reranking', + status: 'INTERNAL', + }); + } ); - const documents = [Document.fromText('doc1'), Document.fromText('doc2')]; const query = Document.fromText('test query'); try { - await runWithRegistry(registry, () => - rerank({ - reranker: customReranker, - query, - documents, - }) - ); + await rerank(registry, { + reranker: customReranker, + query, + documents, + }); assert.fail('Expected an error to be thrown'); } catch (err) { assert(err instanceof GenkitError); diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 1bf8f7a1b..382b80d14 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -17,12 +17,7 @@ import { JSONSchema7 } from 'json-schema'; import { AsyncLocalStorage } from 'node:async_hooks'; import * as z from 'zod'; -import { - ActionType, - initializeAllPlugins, - lookupPlugin, - registerAction, -} from './registry.js'; +import { ActionType, Registry } from './registry.js'; import { parseSchema } from './schema.js'; import { SPAN_TYPE_ATTR, @@ -122,8 +117,8 @@ export function action< ): Action { const actionName = typeof config.name === 'string' - ? validateActionName(config.name) - : `${config.name.pluginId}/${validateActionId(config.name.actionId)}`; + ? config.name + : `${config.name.pluginId}/${config.name.actionId}`; const actionFn = async (input: I) => { input = parseSchema(input, { schema: config.inputSchema, @@ -168,16 +163,16 @@ export function action< return actionFn; } -function validateActionName(name: string) { +function validateActionName(registry: Registry, name: string) { if (name.includes('/')) { - validatePluginName(name.split('/', 1)[0]); + validatePluginName(registry, name.split('/', 1)[0]); validateActionId(name.substring(name.indexOf('/') + 1)); } return name; } -function validatePluginName(pluginId: string) { - if (!lookupPlugin(pluginId)) { +function validatePluginName(registry: Registry, pluginId: string) { + if (!registry.lookupPlugin(pluginId)) { throw new Error( `Unable to find plugin name used in the action name: ${pluginId}` ); @@ -200,6 +195,7 @@ export function defineAction< O extends z.ZodTypeAny, M extends Record = Record, >( + registry: Registry, config: ActionParams & { actionType: ActionType; }, @@ -211,13 +207,18 @@ export function defineAction< 'See: https://github.com/firebase/genkit/blob/main/docs/errors/no_new_actions_at_runtime.md' ); } + if (typeof config.name === 'string') { + validateActionName(registry, config.name); + } else { + validateActionId(config.name.actionId); + } const act = action(config, async (i: I): Promise> => { setCustomMetadataAttributes({ subtype: config.actionType }); - await initializeAllPlugins(); + await registry.initializeAllPlugins(); return await runInActionRuntimeContext(() => fn(i)); }); act.__action.actionType = config.actionType; - registerAction(config.actionType, act); + registry.registerAction(config.actionType, act); return act; } diff --git a/js/core/src/flow.ts b/js/core/src/flow.ts index 0061e0cde..107459585 100644 --- a/js/core/src/flow.ts +++ b/js/core/src/flow.ts @@ -31,12 +31,7 @@ import { runWithAuthContext } from './auth.js'; import { getErrorMessage, getErrorStack } from './error.js'; import { FlowActionInputSchema } from './flowTypes.js'; import { logger } from './logging.js'; -import { - getRegistryInstance, - initializeAllPlugins, - Registry, - runWithRegistry, -} from './registry.js'; +import { Registry } from './registry.js'; import { toJsonSchema } from './schema.js'; import { newTrace, @@ -181,6 +176,7 @@ export class Flow< readonly flowFn: FlowFn; constructor( + private registry: Registry, config: FlowConfig | StreamingFlowConfig, flowFn: FlowFn ) { @@ -207,7 +203,7 @@ export class Flow< auth?: unknown; } ): Promise>> { - await initializeAllPlugins(); + await this.registry.initializeAllPlugins(); return await runWithAuthContext(opts.auth, () => newTrace( { @@ -336,84 +332,79 @@ export class Flow< } async expressHandler( - registry: Registry, request: __RequestWithAuth, response: express.Response ): Promise { - await runWithRegistry(registry, async () => { - const { stream } = request.query; - const auth = request.auth; - - let input = request.body.data; + const { stream } = request.query; + const auth = request.auth; + + let input = request.body.data; + + try { + await this.authPolicy?.(auth, input); + } catch (e: any) { + const respBody = { + error: { + status: 'PERMISSION_DENIED', + message: e.message || 'Permission denied to resource', + }, + }; + response.status(403).send(respBody).end(); + return; + } + if (stream === 'true') { + response.writeHead(200, { + 'Content-Type': 'text/plain', + 'Transfer-Encoding': 'chunked', + }); try { - await this.authPolicy?.(auth, input); - } catch (e: any) { - const respBody = { + const result = await this.invoke(input, { + streamingCallback: ((chunk: z.infer) => { + response.write(JSON.stringify(chunk) + streamDelimiter); + }) as S extends z.ZodVoid ? undefined : StreamingCallback>, + auth, + }); + response.write({ + result: result.result, // Need more results!!!! + }); + response.end(); + } catch (e) { + response.write({ error: { - status: 'PERMISSION_DENIED', - message: e.message || 'Permission denied to resource', + status: 'INTERNAL', + message: getErrorMessage(e), + details: getErrorStack(e), }, - }; - response.status(403).send(respBody).end(); - return; - } - - if (stream === 'true') { - response.writeHead(200, { - 'Content-Type': 'text/plain', - 'Transfer-Encoding': 'chunked', }); - try { - const result = await this.invoke(input, { - streamingCallback: ((chunk: z.infer) => { - response.write(JSON.stringify(chunk) + streamDelimiter); - }) as S extends z.ZodVoid - ? undefined - : StreamingCallback>, - auth, - }); - response.write({ - result: result.result, // Need more results!!!! - }); - response.end(); - } catch (e) { - response.write({ + response.end(); + } + } else { + try { + const result = await this.invoke(input, { auth }); + response.setHeader('x-genkit-trace-id', result.traceId); + response.setHeader('x-genkit-span-id', result.spanId); + // Responses for non-streaming flows are passed back with the flow result stored in a field called "result." + response + .status(200) + .send({ + result: result.result, + }) + .end(); + } catch (e) { + // Errors for non-streaming flows are passed back as standard API errors. + response + .status(500) + .send({ error: { status: 'INTERNAL', message: getErrorMessage(e), details: getErrorStack(e), }, - }); - response.end(); - } - } else { - try { - const result = await this.invoke(input, { auth }); - response.setHeader('x-genkit-trace-id', result.traceId); - response.setHeader('x-genkit-span-id', result.spanId); - // Responses for non-streaming flows are passed back with the flow result stored in a field called "result." - response - .status(200) - .send({ - result: result.result, - }) - .end(); - } catch (e) { - // Errors for non-streaming flows are passed back as standard API errors. - response - .status(500) - .send({ - error: { - status: 'INTERNAL', - message: getErrorMessage(e), - details: getErrorStack(e), - }, - }) - .end(); - } + }) + .end(); } - }); + } } } @@ -496,9 +487,7 @@ export class FlowServer { flow.middleware?.forEach((middleware) => server.post(flowPath, middleware) ); - server.post(flowPath, (req, res) => - flow.expressHandler(this.registry, req, res) - ); + server.post(flowPath, (req, res) => flow.expressHandler(req, res)); }); } else { logger.warn('No flows registered in flow server.'); @@ -557,17 +546,17 @@ export function defineFlow< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, config: FlowConfig | string, fn: FlowFn ): CallableFlow { const resolvedConfig: FlowConfig = typeof config === 'string' ? { name: config } : config; - const flow = new Flow(resolvedConfig, fn); - registerFlowAction(flow); - const registry = getRegistryInstance(); + const flow = new Flow(registry, resolvedConfig, fn); + registerFlowAction(registry, flow); const callableFlow: CallableFlow = async (input, opts) => { - return runWithRegistry(registry, () => flow.run(input, opts)); + return flow.run(input, opts); }; callableFlow.flow = flow; return callableFlow; @@ -581,14 +570,14 @@ export function defineStreamingFlow< O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, config: StreamingFlowConfig, fn: FlowFn ): StreamableFlow { - const flow = new Flow(config, fn); - registerFlowAction(flow); - const registry = getRegistryInstance(); + const flow = new Flow(registry, config, fn); + registerFlowAction(registry, flow); const streamableFlow: StreamableFlow = (input, opts) => { - return runWithRegistry(registry, () => flow.stream(input, opts)); + return flow.stream(input, opts); }; streamableFlow.flow = flow; return streamableFlow; @@ -601,8 +590,12 @@ function registerFlowAction< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, ->(flow: Flow): Action { +>( + registry: Registry, + flow: Flow +): Action { return defineAction( + registry, { actionType: 'flow', name: flow.name, diff --git a/js/core/src/reflection.ts b/js/core/src/reflection.ts index e74a7fa69..e66c73630 100644 --- a/js/core/src/reflection.ts +++ b/js/core/src/reflection.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import express, { NextFunction, Request, Response } from 'express'; +import express from 'express'; import fs from 'fs/promises'; import getPort, { makeRange } from 'get-port'; import { Server } from 'http'; @@ -23,7 +23,7 @@ import z from 'zod'; import { Status, StatusCodes, runWithStreamingCallback } from './action.js'; import { GENKIT_VERSION } from './index.js'; import { logger } from './logging.js'; -import { Registry, runWithRegistry } from './registry.js'; +import { Registry } from './registry.js'; import { toJsonSchema } from './schema.js'; import { flushTracing, @@ -113,16 +113,6 @@ export class ReflectionServer { next(); }); - server.use((req: Request, res: Response, next: NextFunction) => { - runWithRegistry(this.registry, async () => { - try { - next(); - } catch (err) { - next(err); - } - }); - }); - server.get('/api/__health', async (_, response) => { await this.registry.listActions(); response.status(200).send('OK'); diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index afa679ac6..f7cd0f532 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -14,7 +14,6 @@ * limitations under the License. */ -import { AsyncLocalStorage } from 'async_hooks'; import * as z from 'zod'; import { Action } from './action.js'; import { logger } from './logging.js'; @@ -47,17 +46,6 @@ export interface Schema { jsonSchema?: JSONSchema; } -/** - * Looks up a registry key (action type and key) in the registry. - */ -export function lookupAction< - I extends z.ZodTypeAny, - O extends z.ZodTypeAny, - R extends Action, ->(key: string): Promise { - return getRegistryInstance().lookupAction(key); -} - function parsePluginName(registryKey: string) { const tokens = registryKey.split('/'); if (tokens.length === 4) { @@ -66,99 +54,8 @@ function parsePluginName(registryKey: string) { return undefined; } -/** - * Registers an action in the registry. - */ -export function registerAction( - type: ActionType, - action: Action -) { - return getRegistryInstance().registerAction(type, action); -} - type ActionsRecord = Record>; -/** - * Initialize all plugins in the registry. - */ -export async function initializeAllPlugins() { - await getRegistryInstance().initializeAllPlugins(); -} - -/** - * Returns all actions in the registry. - */ -export function listActions(): Promise { - return getRegistryInstance().listActions(); -} - -/** - * Registers a plugin provider. - * @param name The name of the plugin to register. - * @param provider The plugin provider. - */ -export function registerPluginProvider(name: string, provider: PluginProvider) { - return getRegistryInstance().registerPluginProvider(name, provider); -} - -/** - * Looks up a plugin. - * @param name The name of the plugin to lookup. - * @returns The plugin. - */ -export function lookupPlugin(name: string) { - return getRegistryInstance().lookupPlugin(name); -} - -/** - * Initializes a plugin that has already been registered. - * @param name The name of the plugin to initialize. - * @returns The plugin. - */ -export async function initializePlugin(name: string) { - return getRegistryInstance().initializePlugin(name); -} - -/** - * Registers a schema. - * @param name The name of the schema to register. - * @param data The schema to register (either a Zod schema or a JSON schema). - */ -export function registerSchema(name: string, data: Schema) { - return getRegistryInstance().registerSchema(name, data); -} - -/** - * Looks up a schema. - * @param name The name of the schema to lookup. - * @returns The schema. - */ -export function lookupSchema(name: string) { - return getRegistryInstance().lookupSchema(name); -} - -const registryAls = new AsyncLocalStorage(); - -/** - * @returns The active registry instance. - */ -export function getRegistryInstance(): Registry { - const registry = registryAls.getStore(); - if (!registry) { - throw new Error('getRegistryInstance() called before runWithRegistry()'); - } - return registry; -} - -/** - * Runs a function with a specific registry instance. - * @param registry The registry instance to use. - * @param fn The function to run. - */ -export function runWithRegistry(registry: Registry, fn: () => R) { - return registryAls.run(registry, fn); -} - /** * The registry is used to store and lookup actions, trace stores, flow state stores, plugins, and schemas. */ @@ -170,14 +67,6 @@ export class Registry { constructor(public parent?: Registry) {} - /** - * Creates a new registry overlaid onto the currently active registry. - * @returns The new overlaid registry. - */ - static withCurrent() { - return new Registry(getRegistryInstance()); - } - /** * Creates a new registry overlaid onto the provided registry. * @param parent The parent registry. diff --git a/js/core/src/schema.ts b/js/core/src/schema.ts index 16a45160d..a53da8acb 100644 --- a/js/core/src/schema.ts +++ b/js/core/src/schema.ts @@ -19,7 +19,7 @@ import addFormats from 'ajv-formats'; import { z } from 'zod'; import zodToJsonSchema from 'zod-to-json-schema'; import { GenkitError } from './error.js'; -import { registerSchema } from './registry.js'; +import { Registry } from './registry.js'; const ajv = new Ajv(); addFormats(ajv); @@ -112,14 +112,19 @@ export function parseSchema( } export function defineSchema( + registry: Registry, name: string, schema: T ): T { - registerSchema(name, { schema }); + registry.registerSchema(name, { schema }); return schema; } -export function defineJsonSchema(name: string, jsonSchema: JSONSchema) { - registerSchema(name, { jsonSchema }); +export function defineJsonSchema( + registry: Registry, + name: string, + jsonSchema: JSONSchema +) { + registry.registerSchema(name, { jsonSchema }); return jsonSchema; } diff --git a/js/core/tests/flow_test.ts b/js/core/tests/flow_test.ts index 7d5b74646..cce14e2ee 100644 --- a/js/core/tests/flow_test.ts +++ b/js/core/tests/flow_test.ts @@ -18,10 +18,11 @@ import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; import { defineFlow, defineStreamingFlow } from '../src/flow.js'; import { z } from '../src/index.js'; -import { Registry, runWithRegistry } from '../src/registry.js'; +import { Registry } from '../src/registry.js'; -function createTestFlow() { +function createTestFlow(registry: Registry) { return defineFlow( + registry, { name: 'testFlow', inputSchema: z.string(), @@ -33,8 +34,9 @@ function createTestFlow() { ); } -function createTestStreamingFlow() { +function createTestStreamingFlow(registry: Registry) { return defineStreamingFlow( + registry, { name: 'testFlow', inputSchema: z.number(), @@ -63,7 +65,7 @@ describe('flow', () => { describe('runFlow', () => { it('should run the flow', async () => { - const testFlow = runWithRegistry(registry, createTestFlow); + const testFlow = createTestFlow(registry); const result = await testFlow('foo'); @@ -71,10 +73,8 @@ describe('flow', () => { }); it('should run simple sync flow', async () => { - const testFlow = runWithRegistry(registry, () => { - return defineFlow('testFlow', (input) => { - return `bar ${input}`; - }); + const testFlow = defineFlow(registry, 'testFlow', (input) => { + return `bar ${input}`; }); const result = await testFlow('foo'); @@ -83,17 +83,16 @@ describe('flow', () => { }); it('should rethrow the error', async () => { - const testFlow = runWithRegistry(registry, () => - defineFlow( - { - name: 'throwing', - inputSchema: z.string(), - outputSchema: z.string(), - }, - async (input) => { - throw new Error(`bad happened: ${input}`); - } - ) + const testFlow = defineFlow( + registry, + { + name: 'throwing', + inputSchema: z.string(), + outputSchema: z.string(), + }, + async (input) => { + throw new Error(`bad happened: ${input}`); + } ); await assert.rejects(() => testFlow('foo'), { @@ -103,17 +102,16 @@ describe('flow', () => { }); it('should validate input', async () => { - const testFlow = runWithRegistry(registry, () => - defineFlow( - { - name: 'validating', - inputSchema: z.object({ foo: z.string(), bar: z.number() }), - outputSchema: z.string(), - }, - async (input) => { - return `ok ${input}`; - } - ) + const testFlow = defineFlow( + registry, + { + name: 'validating', + inputSchema: z.object({ foo: z.string(), bar: z.number() }), + outputSchema: z.string(), + }, + async (input) => { + return `ok ${input}`; + } ); await assert.rejects( @@ -132,7 +130,7 @@ describe('flow', () => { describe('streamFlow', () => { it('should run the flow', async () => { - const testFlow = runWithRegistry(registry, createTestStreamingFlow); + const testFlow = createTestStreamingFlow(registry); const response = testFlow(3); @@ -146,16 +144,15 @@ describe('flow', () => { }); it('should rethrow the error', async () => { - const testFlow = runWithRegistry(registry, () => - defineStreamingFlow( - { - name: 'throwing', - inputSchema: z.string(), - }, - async (input) => { - throw new Error(`stream bad happened: ${input}`); - } - ) + const testFlow = defineStreamingFlow( + registry, + { + name: 'throwing', + inputSchema: z.string(), + }, + async (input) => { + throw new Error(`stream bad happened: ${input}`); + } ); const response = testFlow('foo'); diff --git a/js/core/tests/registry_test.ts b/js/core/tests/registry_test.ts index 9542cf779..d54fdd415 100644 --- a/js/core/tests/registry_test.ts +++ b/js/core/tests/registry_test.ts @@ -17,175 +17,7 @@ import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; import { action } from '../src/action.js'; -import { - Registry, - listActions, - lookupAction, - registerAction, - registerPluginProvider, - runWithRegistry, -} from '../src/registry.js'; - -describe('global registry', () => { - let registry: Registry; - - beforeEach(() => { - registry = new Registry(); - }); - - describe('listActions', () => { - it('returns all registered actions', async () => { - await runWithRegistry(registry, async () => { - const fooSomethingAction = action( - { name: 'foo_something' }, - async () => null - ); - registerAction('model', fooSomethingAction); - const barSomethingAction = action( - { name: 'bar_something' }, - async () => null - ); - registerAction('model', barSomethingAction); - - assert.deepEqual(await listActions(), { - '/model/foo_something': fooSomethingAction, - '/model/bar_something': barSomethingAction, - }); - }); - }); - - it('returns all registered actions by plugins', async () => { - await runWithRegistry(registry, async () => { - registerPluginProvider('foo', { - name: 'foo', - async initializer() { - registerAction('model', fooSomethingAction); - return {}; - }, - }); - const fooSomethingAction = action( - { - name: { - pluginId: 'foo', - actionId: 'something', - }, - }, - async () => null - ); - registerPluginProvider('bar', { - name: 'bar', - async initializer() { - registerAction('model', barSomethingAction); - return {}; - }, - }); - const barSomethingAction = action( - { - name: { - pluginId: 'bar', - actionId: 'something', - }, - }, - async () => null - ); - - assert.deepEqual(await listActions(), { - '/model/foo/something': fooSomethingAction, - '/model/bar/something': barSomethingAction, - }); - }); - }); - }); - - describe('lookupAction', () => { - it('initializes plugin for action first', async () => { - await runWithRegistry(registry, async () => { - let fooInitialized = false; - registerPluginProvider('foo', { - name: 'foo', - async initializer() { - fooInitialized = true; - return {}; - }, - }); - let barInitialized = false; - registerPluginProvider('bar', { - name: 'bar', - async initializer() { - barInitialized = true; - return {}; - }, - }); - - await lookupAction('/model/foo/something'); - - assert.strictEqual(fooInitialized, true); - assert.strictEqual(barInitialized, false); - - await lookupAction('/model/bar/something'); - - assert.strictEqual(fooInitialized, true); - assert.strictEqual(barInitialized, true); - }); - }); - }); - - it('returns registered action', async () => { - await runWithRegistry(registry, async () => { - const fooSomethingAction = action( - { name: 'foo_something' }, - async () => null - ); - registerAction('model', fooSomethingAction); - const barSomethingAction = action( - { name: 'bar_something' }, - async () => null - ); - registerAction('model', barSomethingAction); - - assert.strictEqual( - await lookupAction('/model/foo_something'), - fooSomethingAction - ); - assert.strictEqual( - await lookupAction('/model/bar_something'), - barSomethingAction - ); - }); - }); - - it('returns action registered by plugin', async () => { - await runWithRegistry(registry, async () => { - registerPluginProvider('foo', { - name: 'foo', - async initializer() { - registerAction('model', somethingAction); - return {}; - }, - }); - const somethingAction = action( - { - name: { - pluginId: 'foo', - actionId: 'something', - }, - }, - async () => null - ); - - assert.strictEqual( - await lookupAction('/model/foo/something'), - somethingAction - ); - }); - }); - - it('returns undefined for unknown action', async () => { - await runWithRegistry(registry, async () => { - assert.strictEqual(await lookupAction('/model/foo/something'), undefined); - }); - }); -}); +import { Registry } from '../src/registry.js'; describe('registry class', () => { var registry: Registry; diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index cd5c0ad73..c9ff26a75 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -36,7 +36,6 @@ import { GenerateStreamOptions, GenerateStreamResponse, GenerationCommonConfigSchema, - index, IndexerParams, ModelArgument, ModelReference, @@ -83,6 +82,7 @@ import { defineRetriever, defineSimpleRetriever, DocumentData, + index, IndexerAction, IndexerFn, RetrieverFn, @@ -119,7 +119,7 @@ import { Chat, ChatOptions } from './chat.js'; import { BaseEvalDataPointSchema } from './evaluator.js'; import { logger } from './logging.js'; import { GenkitPlugin, genkitPlugin } from './plugin.js'; -import { lookupAction, Registry, runWithRegistry } from './registry.js'; +import { Registry } from './registry.js'; import { getCurrentSession, Session, @@ -193,7 +193,7 @@ export class Genkit { I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, >(config: FlowConfig | string, fn: FlowFn): CallableFlow { - const flow = runWithRegistry(this.registry, () => defineFlow(config, fn)); + const flow = defineFlow(this.registry, config, fn); this.registeredFlows.push(flow.flow); return flow; } @@ -211,9 +211,7 @@ export class Genkit { config: StreamingFlowConfig, fn: FlowFn ): StreamableFlow { - const flow = runWithRegistry(this.registry, () => - defineStreamingFlow(config, fn) - ); + const flow = defineStreamingFlow(this.registry, config, fn); this.registeredFlows.push(flow.flow); return flow; } @@ -227,7 +225,7 @@ export class Genkit { config: ToolConfig, fn: (input: z.infer) => Promise> ): ToolAction { - return runWithRegistry(this.registry, () => defineTool(config, fn)); + return defineTool(this.registry, config, fn); } /** @@ -236,7 +234,7 @@ export class Genkit { * Defined schemas can be referenced by `name` in prompts in place of inline schemas. */ defineSchema(name: string, schema: T): T { - return runWithRegistry(this.registry, () => defineSchema(name, schema)); + return defineSchema(this.registry, name, schema); } /** @@ -245,9 +243,7 @@ export class Genkit { * Defined schemas can be referenced by `name` in prompts in place of inline schemas. */ defineJsonSchema(name: string, jsonSchema: JSONSchema) { - return runWithRegistry(this.registry, () => - defineJsonSchema(name, jsonSchema) - ); + return defineJsonSchema(this.registry, name, jsonSchema); } /** @@ -260,7 +256,7 @@ export class Genkit { streamingCallback?: StreamingCallback ) => Promise ): ModelAction { - return runWithRegistry(this.registry, () => defineModel(options, runner)); + return defineModel(this.registry, options, runner); } /** @@ -268,7 +264,7 @@ export class Genkit { * * @todo TODO: Show an example of a name and variant. */ - prompt< + async prompt< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, @@ -276,13 +272,13 @@ export class Genkit { name: string, options?: { variant?: string } ): Promise, O, CustomOptions>> { - return runWithRegistry(this.registry, async () => { - const action = (await lookupAction(`/prompt/${name}`)) as PromptAction; - return this.wrapPromptActionInExecutablePrompt( - action, - {} - ) as ExecutablePrompt; - }); + const action = (await this.registry.lookupAction( + `/prompt/${name}` + )) as PromptAction; + return this.wrapPromptActionInExecutablePrompt( + action, + {} + ) as ExecutablePrompt; } /** @@ -368,28 +364,31 @@ export class Genkit { if (!options.name) { throw new Error('options.name is required'); } - return runWithRegistry(this.registry, () => { - if (!options.name) { - throw new Error('options.name is required'); - } - if (typeof templateOrFn === 'string') { - const dotprompt = defineDotprompt(options, templateOrFn as string); - return this.wrapPromptActionInExecutablePrompt( - dotprompt.promptAction! as PromptAction, - options - ); - } else { - const p = definePrompt( - { - name: options.name!, - inputJsonSchema: options.input?.jsonSchema, - inputSchema: options.input?.schema, - }, - templateOrFn as PromptFn - ); - return this.wrapPromptActionInExecutablePrompt(p, options); - } - }); + if (!options.name) { + throw new Error('options.name is required'); + } + if (typeof templateOrFn === 'string') { + const dotprompt = defineDotprompt( + this.registry, + options, + templateOrFn as string + ); + return this.wrapPromptActionInExecutablePrompt( + dotprompt.promptAction! as PromptAction, + options + ); + } else { + const p = definePrompt( + this.registry, + { + name: options.name!, + inputJsonSchema: options.input?.jsonSchema, + inputSchema: options.input?.schema, + }, + templateOrFn as PromptFn + ); + return this.wrapPromptActionInExecutablePrompt(p, options); + } } private wrapPromptActionInExecutablePrompt< @@ -400,87 +399,78 @@ export class Genkit { p: PromptAction, options: PromptMetadata ): ExecutablePrompt { - const executablePrompt = ( + const executablePrompt = async ( input?: z.infer, opts?: PromptGenerateOptions ): Promise => { - return runWithRegistry(this.registry, async () => { - const renderedOpts = await ( - executablePrompt as ExecutablePrompt - ).render({ - ...opts, - input, - }); - return this.generate(renderedOpts); + const renderedOpts = await ( + executablePrompt as ExecutablePrompt + ).render({ + ...opts, + input, }); + return this.generate(renderedOpts); }; - (executablePrompt as ExecutablePrompt).stream = ( + (executablePrompt as ExecutablePrompt).stream = async ( input?: z.infer, opts?: z.infer ): Promise> => { - return runWithRegistry(this.registry, async () => { - const renderedOpts = await ( - executablePrompt as ExecutablePrompt - ).render({ - ...opts, - input, - }); - return this.generateStream(renderedOpts); + const renderedOpts = await ( + executablePrompt as ExecutablePrompt + ).render({ + ...opts, + input, }); + return this.generateStream(renderedOpts); }; - (executablePrompt as ExecutablePrompt).generate = ( - opt: PromptGenerateOptions - ): Promise> => { - return runWithRegistry(this.registry, async () => { + (executablePrompt as ExecutablePrompt).generate = + async ( + opt: PromptGenerateOptions + ): Promise> => { const renderedOpts = await ( executablePrompt as ExecutablePrompt ).render(opt); return this.generate(renderedOpts); - }); - }; + }; (executablePrompt as ExecutablePrompt).generateStream = - ( + async ( opt: PromptGenerateOptions ): Promise> => { - return runWithRegistry(this.registry, async () => { - const renderedOpts = await ( - executablePrompt as ExecutablePrompt - ).render(opt); - return this.generateStream(renderedOpts); - }); + const renderedOpts = await ( + executablePrompt as ExecutablePrompt + ).render(opt); + return this.generateStream(renderedOpts); }; - (executablePrompt as ExecutablePrompt).render = < + (executablePrompt as ExecutablePrompt).render = async < Out extends O, >( opt: PromptGenerateOptions ): Promise> => { - return runWithRegistry(this.registry, async () => { - let model: ModelAction | undefined; - try { - model = await this.resolveModel(opt?.model ?? options.model); - } catch (e) { - // ignore, no model on a render is OK. - } - - const promptResult = await p(opt.input); - const resultOptions = { - messages: promptResult.messages, - docs: promptResult.docs, - tools: promptResult.tools, - output: { - format: promptResult.output?.format, - jsonSchema: promptResult.output?.schema, - }, - config: { - ...options.config, - ...promptResult.config, - ...opt.config, - }, - model, - } as GenerateOptions; - delete (resultOptions as PromptGenerateOptions).input; - return resultOptions; - }); + let model: ModelAction | undefined; + try { + model = await this.resolveModel(opt?.model ?? options.model); + } catch (e) { + // ignore, no model on a render is OK. + } + + const promptResult = await p(opt.input); + const resultOptions = { + messages: promptResult.messages, + docs: promptResult.docs, + tools: promptResult.tools, + output: { + format: promptResult.output?.format, + jsonSchema: promptResult.output?.schema, + }, + config: { + ...options.config, + ...promptResult.config, + ...opt.config, + }, + model, + } as GenerateOptions; + delete (resultOptions as PromptGenerateOptions).input; + return resultOptions; }; (executablePrompt as ExecutablePrompt).asTool = (): ToolAction => { @@ -500,9 +490,7 @@ export class Genkit { }, runner: RetrieverFn ): RetrieverAction { - return runWithRegistry(this.registry, () => - defineRetriever(options, runner) - ); + return defineRetriever(this.registry, options, runner); } /** @@ -517,9 +505,7 @@ export class Genkit { options: SimpleRetrieverOptions, handler: (query: Document, config: z.infer) => Promise ): RetrieverAction { - return runWithRegistry(this.registry, () => - defineSimpleRetriever(options, handler) - ); + return defineSimpleRetriever(this.registry, options, handler); } /** @@ -533,7 +519,7 @@ export class Genkit { }, runner: IndexerFn ): IndexerAction { - return runWithRegistry(this.registry, () => defineIndexer(options, runner)); + return defineIndexer(this.registry, options, runner); } /** @@ -555,9 +541,7 @@ export class Genkit { }, runner: EvaluatorFn ): EvaluatorAction { - return runWithRegistry(this.registry, () => - defineEvaluator(options, runner) - ); + return defineEvaluator(this.registry, options, runner); } /** @@ -571,23 +555,21 @@ export class Genkit { }, runner: EmbedderFn ): EmbedderAction { - return runWithRegistry(this.registry, () => - defineEmbedder(options, runner) - ); + return defineEmbedder(this.registry, options, runner); } /** * create a handlebards helper (https://handlebarsjs.com/guide/block-helpers.html) to be used in dotpormpt templates. */ defineHelper(name: string, fn: Handlebars.HelperDelegate) { - return runWithRegistry(this.registry, () => defineHelper(name, fn)); + return defineHelper(name, fn); } /** * Creates a handlebars partial (https://handlebarsjs.com/guide/partials.html) to be used in dotpormpt templates. */ definePartial(name: string, source: string) { - return runWithRegistry(this.registry, () => definePartial(name, source)); + return definePartial(name, source); } /** @@ -601,9 +583,7 @@ export class Genkit { }, runner: RerankerFn ) { - return runWithRegistry(this.registry, () => - defineReranker(options, runner) - ); + return defineReranker(this.registry, options, runner); } /** @@ -612,7 +592,7 @@ export class Genkit { embed( params: EmbedderParams ): Promise { - return runWithRegistry(this.registry, () => embed(params)); + return embed(this.registry, params); } /** @@ -624,7 +604,7 @@ export class Genkit { metadata?: Record; options?: z.infer; }): Promise { - return runWithRegistry(this.registry, () => embedMany(params)); + return embedMany(this.registry, params); } /** @@ -634,7 +614,7 @@ export class Genkit { DataPoint extends typeof BaseDataPointSchema = typeof BaseDataPointSchema, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, >(params: EvaluatorParams): Promise { - return runWithRegistry(this.registry, () => evaluate(params)); + return evaluate(this.registry, params); } /** @@ -643,7 +623,7 @@ export class Genkit { rerank( params: RerankerParams ): Promise> { - return runWithRegistry(this.registry, () => rerank(params)); + return rerank(this.registry, params); } /** @@ -652,7 +632,7 @@ export class Genkit { index( params: IndexerParams ): Promise { - return runWithRegistry(this.registry, () => index(params)); + return index(this.registry, params); } /** @@ -661,7 +641,7 @@ export class Genkit { retrieve( params: RetrieverParams ): Promise> { - return runWithRegistry(this.registry, () => retrieve(params)); + return retrieve(this.registry, params); } /** @@ -756,7 +736,7 @@ export class Genkit { if (!resolvedOptions.model) { resolvedOptions.model = this.options.model; } - return runWithRegistry(this.registry, () => generate(resolvedOptions)); + return generate(this.registry, resolvedOptions); } /** @@ -863,9 +843,7 @@ export class Genkit { if (!resolvedOptions.model) { resolvedOptions.model = this.options.model; } - return runWithRegistry(this.registry, () => - generateStream(resolvedOptions) - ); + return generateStream(this.registry, resolvedOptions); } /** @@ -940,9 +918,7 @@ export class Genkit { const plugins = [...(this.options.plugins ?? [])]; if (this.options.promptDir !== null) { const dotprompt = genkitPlugin('dotprompt', async (ai) => { - runWithRegistry(ai.registry, async () => - loadPromptFolder(this.options.promptDir ?? './prompts') - ); + loadPromptFolder(this.registry, this.options.promptDir ?? './prompts'); }); plugins.push(dotprompt); } @@ -953,9 +929,7 @@ export class Genkit { name: loadedPlugin.name, async initializer() { logger.debug(`Initializing plugin ${loadedPlugin.name}:`); - return runWithRegistry(activeRegistry, () => - loadedPlugin.initializer() - ); + loadedPlugin.initializer(); }, }); }); @@ -980,12 +954,16 @@ export class Genkit { return this.resolveModel(this.options.model); } if (typeof modelArg === 'string') { - return (await lookupAction(`/model/${modelArg}`)) as ModelAction; + return (await this.registry.lookupAction( + `/model/${modelArg}` + )) as ModelAction; } else if ((modelArg as ModelAction).__action) { return modelArg as ModelAction; } else { const ref = modelArg as ModelReference; - return (await lookupAction(`/model/${ref.name}`)) as ModelAction; + return (await this.registry.lookupAction( + `/model/${ref.name}` + )) as ModelAction; } } } diff --git a/js/genkit/src/registry.ts b/js/genkit/src/registry.ts index 0dab20e68..8c45c10d4 100644 --- a/js/genkit/src/registry.ts +++ b/js/genkit/src/registry.ts @@ -19,15 +19,4 @@ export { AsyncProvider, Registry, Schema, - getRegistryInstance, - initializeAllPlugins, - initializePlugin, - listActions, - lookupAction, - lookupPlugin, - lookupSchema, - registerAction, - registerPluginProvider, - registerSchema, - runWithRegistry, } from '@genkit-ai/core/registry'; diff --git a/js/plugins/dotprompt/src/index.ts b/js/plugins/dotprompt/src/index.ts index 094ae8397..4f3f60f53 100644 --- a/js/plugins/dotprompt/src/index.ts +++ b/js/plugins/dotprompt/src/index.ts @@ -14,6 +14,7 @@ * limitations under the License. */ +import { Registry } from '@genkit-ai/core/registry'; import { readFileSync } from 'fs'; import { basename } from 'path'; import { @@ -38,10 +39,15 @@ export interface DotpromptPluginOptions { } export async function prompt( + registry: Registry, name: string, options?: { variant?: string } ): Promise> { - return (await lookupPrompt(name, options?.variant)) as Dotprompt; + return (await lookupPrompt( + registry, + name, + options?.variant + )) as Dotprompt; } export function promptRef( @@ -51,19 +57,22 @@ export function promptRef( return new DotpromptRef(name, options); } -export function loadPromptFile(path: string): Dotprompt { +export function loadPromptFile(registry: Registry, path: string): Dotprompt { return Dotprompt.parse( + registry, basename(path).split('.')[0], readFileSync(path, 'utf-8') ); } export async function loadPromptUrl( + registry: Registry, + name: string, url: string ): Promise { const fetch = (await import('node-fetch')).default; const response = await fetch(url); const text = await response.text(); - return Dotprompt.parse(name, text); + return Dotprompt.parse(registry, name, text); } diff --git a/js/plugins/dotprompt/src/metadata.ts b/js/plugins/dotprompt/src/metadata.ts index 03947a8a7..165176919 100644 --- a/js/plugins/dotprompt/src/metadata.ts +++ b/js/plugins/dotprompt/src/metadata.ts @@ -25,7 +25,7 @@ import { } from '@genkit-ai/ai/model'; import { ToolArgument } from '@genkit-ai/ai/tool'; import { z } from '@genkit-ai/core'; -import { lookupSchema } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { JSONSchema, parseSchema, toJsonSchema } from '@genkit-ai/core/schema'; import { picoschema } from './picoschema.js'; @@ -122,27 +122,33 @@ function stripUndefinedOrNull(obj: any) { return obj; } -function fmSchemaToSchema(fmSchema: any) { +function fmSchemaToSchema(registry: Registry, fmSchema: any) { if (!fmSchema) return {}; - if (typeof fmSchema === 'string') return lookupSchema(fmSchema); + if (typeof fmSchema === 'string') return registry.lookupSchema(fmSchema); return { jsonSchema: picoschema(fmSchema) }; } -export function toMetadata(attributes: unknown): Partial { +export function toMetadata( + registry: Registry, + attributes: unknown +): Partial { const fm = parseSchema>(attributes, { schema: PromptFrontmatterSchema, }); let input: PromptMetadata['input'] | undefined; if (fm.input) { - input = { default: fm.input.default, ...fmSchemaToSchema(fm.input.schema) }; + input = { + default: fm.input.default, + ...fmSchemaToSchema(registry, fm.input.schema), + }; } let output: PromptMetadata['output'] | undefined; if (fm.output) { output = { format: fm.output.format, - ...fmSchemaToSchema(fm.output.schema), + ...fmSchemaToSchema(registry, fm.output.schema), }; } diff --git a/js/plugins/dotprompt/src/prompt.ts b/js/plugins/dotprompt/src/prompt.ts index baf7a1388..b216c21a8 100644 --- a/js/plugins/dotprompt/src/prompt.ts +++ b/js/plugins/dotprompt/src/prompt.ts @@ -27,6 +27,7 @@ import { import { MessageData, ModelArgument } from '@genkit-ai/ai/model'; import { DocumentData } from '@genkit-ai/ai/retriever'; import { GenkitError, z } from '@genkit-ai/core'; +import { Registry } from '@genkit-ai/core/registry'; import { parseSchema } from '@genkit-ai/core/schema'; import { runInNewSpan, @@ -80,14 +81,18 @@ export class Dotprompt implements PromptMetadata { private _render: (input: I, options?: RenderMetadata) => MessageData[]; - static parse(name: string, source: string) { + static parse(registry: Registry, name: string, source: string) { try { const fmResult = (fm as any)(source.trimStart(), { allowUnsafe: false, }) as FrontMatterResult; return new Dotprompt( - { ...toMetadata(fmResult.attributes), name } as PromptMetadata, + registry, + { + ...toMetadata(registry, fmResult.attributes), + name, + } as PromptMetadata, fmResult.body ); } catch (e: any) { @@ -99,7 +104,7 @@ export class Dotprompt implements PromptMetadata { } } - static fromAction(action: PromptAction): Dotprompt { + static fromAction(registry: Registry, action: PromptAction): Dotprompt { const { template, ...options } = action.__action.metadata!.prompt; const pm = options as PromptMetadata; if (pm.input?.schema) { @@ -109,11 +114,15 @@ export class Dotprompt implements PromptMetadata { if (pm.output?.schema) { pm.output.jsonSchema = options.output?.schema; } - const prompt = new Dotprompt(options as PromptMetadata, template); + const prompt = new Dotprompt(registry, options as PromptMetadata, template); return prompt; } - constructor(options: PromptMetadata, template: string) { + constructor( + private registry: Registry, + options: PromptMetadata, + template: string + ) { this.name = options.name || 'untitledPrompt'; this.variant = options.variant; this.model = options.model; @@ -171,6 +180,7 @@ export class Dotprompt implements PromptMetadata { define(options?: { ns?: string; description?: string }): void { this._promptAction = definePrompt( + this.registry, { name: registryDefinitionKey(this.name, this.variant, options?.ns), description: options?.description ?? 'Defined by Dotprompt', @@ -181,7 +191,8 @@ export class Dotprompt implements PromptMetadata { prompt: this.toJSON(), }, }, - async (input?: I) => toGenerateRequest(this.render({ input })) + async (input?: I) => + toGenerateRequest(this.registry, this.render({ input })) ); } @@ -276,7 +287,7 @@ export class Dotprompt implements PromptMetadata { opt: PromptGenerateOptions ): Promise>> { const renderedOpts = this.renderInNewSpan(opt); - return generate(renderedOpts); + return generate(this.registry, renderedOpts); } /** @@ -289,7 +300,7 @@ export class Dotprompt implements PromptMetadata { opt: PromptGenerateOptions ): Promise { const renderedOpts = await this.renderInNewSpan(opt); - return generateStream(renderedOpts); + return generateStream(this.registry, renderedOpts); } } @@ -312,9 +323,10 @@ export class DotpromptRef { } /** Loads the prompt which is referenced. */ - async loadPrompt(): Promise> { + async loadPrompt(registry: Registry): Promise> { if (this._prompt) return this._prompt; this._prompt = (await lookupPrompt( + registry, this.name, this.variant, this.dir @@ -333,9 +345,10 @@ export class DotpromptRef { CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, opt: PromptGenerateOptions ): Promise>> { - const prompt = await this.loadPrompt(); + const prompt = await this.loadPrompt(registry); return prompt.generate(opt); } @@ -349,9 +362,11 @@ export class DotpromptRef { CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, + opt: PromptGenerateOptions ): Promise> { - const prompt = await this.loadPrompt(); + const prompt = await this.loadPrompt(registry); return prompt.render(opt); } } @@ -367,10 +382,11 @@ export function defineDotprompt< I extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, options: PromptMetadata, template: string ): Dotprompt> { - const prompt = new Dotprompt(options, template); + const prompt = new Dotprompt(registry, options, template); prompt.define({ description: options.description }); return prompt; } diff --git a/js/plugins/dotprompt/src/registry.ts b/js/plugins/dotprompt/src/registry.ts index 3397f6b56..f0af18eec 100644 --- a/js/plugins/dotprompt/src/registry.ts +++ b/js/plugins/dotprompt/src/registry.ts @@ -17,7 +17,7 @@ import { PromptAction } from '@genkit-ai/ai'; import { GenkitError } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { existsSync, readdir, readFileSync } from 'fs'; import { basename, join, resolve } from 'path'; import { Dotprompt } from './prompt.js'; @@ -37,23 +37,27 @@ export function registryLookupKey(name: string, variant?: string, ns?: string) { } export async function lookupPrompt( + registry: Registry, name: string, variant?: string, dir: string = './prompts' ): Promise { let registryPrompt = - (await lookupAction(registryLookupKey(name, variant))) || - (await lookupAction(registryLookupKey(name, variant, 'dotprompt'))); + (await registry.lookupAction(registryLookupKey(name, variant))) || + (await registry.lookupAction( + registryLookupKey(name, variant, 'dotprompt') + )); if (registryPrompt) { - return Dotprompt.fromAction(registryPrompt as PromptAction); + return Dotprompt.fromAction(registry, registryPrompt as PromptAction); } else { // Handle the case where initialization isn't complete // or a file was added after the prompt folder was loaded. - return maybeLoadPrompt(dir, name, variant); + return maybeLoadPrompt(registry, dir, name, variant); } } async function maybeLoadPrompt( + registry: Registry, dir: string, name: string, variant?: string @@ -62,7 +66,7 @@ async function maybeLoadPrompt( const promptFolder = resolve(dir); const promptExists = existsSync(join(promptFolder, expectedFileName)); if (promptExists) { - return loadPrompt(promptFolder, expectedFileName); + return loadPrompt(registry, promptFolder, expectedFileName); } else { throw new GenkitError({ source: 'dotprompt', @@ -73,6 +77,8 @@ async function maybeLoadPrompt( } export async function loadPromptFolder( + registry: Registry, + dir: string = './prompts' ): Promise { const promptsPath = resolve(dir); @@ -114,7 +120,7 @@ export async function loadPromptFolder( .replace(`${promptsPath}/`, '') .replace(/\//g, '-'); } - loadPrompt(dirEnt.path, dirEnt.name, prefix); + loadPrompt(registry, dirEnt.path, dirEnt.name, prefix); } } }); @@ -129,6 +135,7 @@ export async function loadPromptFolder( } export function loadPrompt( + registry: Registry, path: string, filename: string, prefix = '' @@ -141,7 +148,7 @@ export function loadPrompt( variant = parts[1]; } const source = readFileSync(join(path, filename), 'utf8'); - const prompt = Dotprompt.parse(name, source); + const prompt = Dotprompt.parse(registry, name, source); if (variant) { prompt.variant = variant; } diff --git a/js/plugins/dotprompt/tests/prompt_test.ts b/js/plugins/dotprompt/tests/prompt_test.ts index dfdcc21ad..0e6d31e1a 100644 --- a/js/plugins/dotprompt/tests/prompt_test.ts +++ b/js/plugins/dotprompt/tests/prompt_test.ts @@ -16,7 +16,7 @@ import { defineModel, ModelAction } from '@genkit-ai/ai/model'; import { z } from '@genkit-ai/core'; -import { Registry, runWithRegistry } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { defineJsonSchema, defineSchema, @@ -29,11 +29,12 @@ import { defineDotprompt, Dotprompt, prompt, promptRef } from '../src/index.js'; import { PromptMetadata } from '../src/metadata.js'; function testPrompt( + registry: Registry, model: ModelAction, template: string, options?: Partial ): Dotprompt { - return new Dotprompt({ name: 'test', model, ...options }, template); + return new Dotprompt(registry, { name: 'test', model, ...options }, template); } describe('Prompt', () => { @@ -44,184 +45,194 @@ describe('Prompt', () => { describe('#render', () => { it('should render variables', () => { - runWithRegistry(registry, () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt(model, `Hello {{name}}, how are you?`); - - const rendered = prompt.render({ input: { name: 'Michael' } }); - assert.deepStrictEqual(rendered.prompt, [ - { text: 'Hello Michael, how are you?' }, - ]); - }); + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt( + registry, + model, + `Hello {{name}}, how are you?` + ); + + const rendered = prompt.render({ input: { name: 'Michael' } }); + assert.deepStrictEqual(rendered.prompt, [ + { text: 'Hello Michael, how are you?' }, + ]); }); it('should render default variables', () => { - runWithRegistry(registry, () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt(model, `Hello {{name}}, how are you?`, { + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt( + registry, + model, + `Hello {{name}}, how are you?`, + { input: { default: { name: 'Fellow Human' } }, - }); + } + ); - const rendered = prompt.render({ input: {} }); - assert.deepStrictEqual(rendered.prompt, [ - { - text: 'Hello Fellow Human, how are you?', - }, - ]); - }); + const rendered = prompt.render({ input: {} }); + assert.deepStrictEqual(rendered.prompt, [ + { + text: 'Hello Fellow Human, how are you?', + }, + ]); }); it('rejects input not matching the schema', async () => { - await runWithRegistry(registry, async () => { - const invalidSchemaPrompt = defineDotprompt( - { - name: 'invalidInput', - model: 'echo', - input: { - jsonSchema: { - properties: { foo: { type: 'boolean' } }, - required: ['foo'], - }, + const invalidSchemaPrompt = defineDotprompt( + registry, + { + name: 'invalidInput', + model: 'echo', + input: { + jsonSchema: { + properties: { foo: { type: 'boolean' } }, + required: ['foo'], }, }, - `You asked for {{foo}}.` - ); + }, + `You asked for {{foo}}.` + ); - await assert.rejects(async () => { - invalidSchemaPrompt.render({ input: { foo: 'baz' } }); - }, ValidationError); - }); + await assert.rejects(async () => { + invalidSchemaPrompt.render({ input: { foo: 'baz' } }); + }, ValidationError); }); it('should render with overridden fields', () => { - runWithRegistry(registry, () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt(model, `Hello {{name}}, how are you?`); - - const streamingCallback = (c) => console.log(c); - const middleware = []; - - const rendered = prompt.render({ - input: { name: 'Michael' }, - streamingCallback, - returnToolRequests: true, - use: middleware, - }); - assert.strictEqual(rendered.streamingCallback, streamingCallback); - assert.strictEqual(rendered.returnToolRequests, true); - assert.strictEqual(rendered.use, middleware); + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt( + registry, + model, + `Hello {{name}}, how are you?` + ); + + const streamingCallback = (c) => console.log(c); + const middleware = []; + + const rendered = prompt.render({ + input: { name: 'Michael' }, + streamingCallback, + returnToolRequests: true, + use: middleware, }); + assert.strictEqual(rendered.streamingCallback, streamingCallback); + assert.strictEqual(rendered.returnToolRequests, true); + assert.strictEqual(rendered.use, middleware); }); it('should support system prompt with history', () => { - runWithRegistry(registry, () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt( - model, - `{{ role "system" }}Testing system {{name}}` - ); - - const rendered = prompt.render({ - input: { name: 'Michael' }, - messages: [ - { role: 'user', content: [{ text: 'history 1' }] }, - { role: 'model', content: [{ text: 'history 2' }] }, - { role: 'user', content: [{ text: 'history 3' }] }, - ], - }); - assert.deepStrictEqual(rendered.messages, [ - { role: 'system', content: [{ text: 'Testing system Michael' }] }, + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt( + registry, + model, + `{{ role "system" }}Testing system {{name}}` + ); + + const rendered = prompt.render({ + input: { name: 'Michael' }, + messages: [ { role: 'user', content: [{ text: 'history 1' }] }, { role: 'model', content: [{ text: 'history 2' }] }, - ]); - assert.deepStrictEqual(rendered.prompt, [{ text: 'history 3' }]); + { role: 'user', content: [{ text: 'history 3' }] }, + ], }); + assert.deepStrictEqual(rendered.messages, [ + { role: 'system', content: [{ text: 'Testing system Michael' }] }, + { role: 'user', content: [{ text: 'history 1' }] }, + { role: 'model', content: [{ text: 'history 2' }] }, + ]); + assert.deepStrictEqual(rendered.prompt, [{ text: 'history 3' }]); }); }); describe('#generate', () => { it('renders and calls the model', async () => { - await runWithRegistry(registry, async () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt(model, `Hello {{name}}, how are you?`); - const response = await prompt.generate({ input: { name: 'Bob' } }); - assert.equal(response.text, `Hello Bob, how are you?`); - }); + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt( + registry, + model, + `Hello {{name}}, how are you?` + ); + const response = await prompt.generate({ input: { name: 'Bob' } }); + assert.equal(response.text, `Hello Bob, how are you?`); }); it('rejects input not matching the schema', async () => { - await runWithRegistry(registry, async () => { - const invalidSchemaPrompt = defineDotprompt( - { - name: 'invalidInput', - model: 'echo', - input: { - jsonSchema: { - properties: { foo: { type: 'boolean' } }, - required: ['foo'], - }, + const invalidSchemaPrompt = defineDotprompt( + registry, + { + name: 'invalidInput', + model: 'echo', + input: { + jsonSchema: { + properties: { foo: { type: 'boolean' } }, + required: ['foo'], }, }, - `You asked for {{foo}}.` - ); + }, + `You asked for {{foo}}.` + ); - await assert.rejects(async () => { - await invalidSchemaPrompt.generate({ input: { foo: 'baz' } }); - }, ValidationError); - }); + await assert.rejects(async () => { + await invalidSchemaPrompt.generate({ input: { foo: 'baz' } }); + }, ValidationError); }); }); describe('#toJSON', () => { it('should convert zod to json schema', () => { - runWithRegistry(registry, () => { - const schema = z.object({ name: z.string() }); - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt(model, `hello {{name}}`, { - input: { schema }, - }); - - assert.deepStrictEqual( - prompt.toJSON().input?.schema, - toJsonSchema({ schema }) - ); + const schema = z.object({ name: z.string() }); + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt(registry, model, `hello {{name}}`, { + input: { schema }, }); + + assert.deepStrictEqual( + prompt.toJSON().input?.schema, + toJsonSchema({ schema }) + ); }); }); @@ -230,6 +241,7 @@ describe('Prompt', () => { assert.throws( () => { Dotprompt.parse( + registry, 'example', `--- input: { @@ -247,6 +259,7 @@ This is the rest of the prompt` it('should parse picoschema', () => { const p = Dotprompt.parse( + registry, 'example', `--- input: @@ -277,54 +290,53 @@ output: }); it('should use registered schemas', () => { - runWithRegistry(registry, () => { - const MyInput = defineSchema('MyInput', z.number()); - defineJsonSchema('MyOutput', { type: 'boolean' }); + const MyInput = defineSchema(registry, 'MyInput', z.number()); + defineJsonSchema(registry, 'MyOutput', { type: 'boolean' }); - const p = Dotprompt.parse( - 'example2', - `--- + const p = Dotprompt.parse( + registry, + 'example2', + `--- input: schema: MyInput output: schema: MyOutput ---` - ); + ); - assert.deepEqual(p.input, { schema: MyInput }); - assert.deepEqual(p.output, { jsonSchema: { type: 'boolean' } }); - }); + assert.deepEqual(p.input, { schema: MyInput }); + assert.deepEqual(p.output, { jsonSchema: { type: 'boolean' } }); }); }); describe('defineDotprompt', () => { it('registers a prompt and its variant', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'promptName', - model: 'echo', - }, - `This is a prompt.` - ); - - defineDotprompt( - { - name: 'promptName', - variant: 'variantName', - model: 'echo', - }, - `And this is its variant.` - ); - - const basePrompt = await prompt('promptName'); - assert.equal('This is a prompt.', basePrompt.template); + defineDotprompt( + registry, + { + name: 'promptName', + model: 'echo', + }, + `This is a prompt.` + ); - const variantPrompt = await prompt('promptName', { + defineDotprompt( + registry, + { + name: 'promptName', variant: 'variantName', - }); - assert.equal('And this is its variant.', variantPrompt.template); + model: 'echo', + }, + `And this is its variant.` + ); + + const basePrompt = await prompt(registry, 'promptName'); + assert.equal('This is a prompt.', basePrompt.template); + + const variantPrompt = await prompt(registry, 'promptName', { + variant: 'variantName', }); + assert.equal('And this is its variant.', variantPrompt.template); }); }); }); @@ -336,159 +348,153 @@ describe('DotpromptRef', () => { }); it('Should load a prompt correctly', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'promptName', - model: 'echo', - }, - `This is a prompt.` - ); + defineDotprompt( + registry, + { + name: 'promptName', + model: 'echo', + }, + `This is a prompt.` + ); - const ref = promptRef('promptName'); + const ref = promptRef('promptName'); - const p = await ref.loadPrompt(); + const p = await ref.loadPrompt(registry); - const isDotprompt = p instanceof Dotprompt; + const isDotprompt = p instanceof Dotprompt; - assert.equal(isDotprompt, true); - assert.equal(p.template, 'This is a prompt.'); - }); + assert.equal(isDotprompt, true); + assert.equal(p.template, 'This is a prompt.'); }); it('Should generate output correctly using DotpromptRef', async () => { - await runWithRegistry(registry, async () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - defineDotprompt( - { - name: 'generatePrompt', - model: 'echo', - }, - `Hello {{name}}, this is a test prompt.` - ); - - const ref = promptRef('generatePrompt'); - const response = await ref.generate({ input: { name: 'Alice' } }); - - assert.equal(response.text, 'Hello Alice, this is a test prompt.'); - }); + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + defineDotprompt( + registry, + { + name: 'generatePrompt', + model: 'echo', + }, + `Hello {{name}}, this is a test prompt.` + ); + + const ref = promptRef('generatePrompt'); + const response = await ref.generate(registry, { input: { name: 'Alice' } }); + + assert.equal(response.text, 'Hello Alice, this is a test prompt.'); }); it('Should render correctly using DotpromptRef', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'renderPrompt', - model: 'echo', - }, - `Hi {{name}}, welcome to the system.` - ); - - const ref = promptRef('renderPrompt'); - const rendered = await ref.render({ input: { name: 'Bob' } }); - - assert.deepStrictEqual(rendered.prompt, [ - { text: 'Hi Bob, welcome to the system.' }, - ]); - }); + defineDotprompt( + registry, + { + name: 'renderPrompt', + model: 'echo', + }, + `Hi {{name}}, welcome to the system.` + ); + + const ref = promptRef('renderPrompt'); + const rendered = await ref.render(registry, { input: { name: 'Bob' } }); + + assert.deepStrictEqual(rendered.prompt, [ + { text: 'Hi Bob, welcome to the system.' }, + ]); }); it('Should handle invalid schema input in DotpromptRef', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'invalidSchemaPromptRef', - model: 'echo', - input: { - jsonSchema: { - properties: { foo: { type: 'boolean' } }, - required: ['foo'], - }, + defineDotprompt( + registry, + { + name: 'invalidSchemaPromptRef', + model: 'echo', + input: { + jsonSchema: { + properties: { foo: { type: 'boolean' } }, + required: ['foo'], }, }, - `This is the prompt with foo={{foo}}.` - ); + }, + `This is the prompt with foo={{foo}}.` + ); - const ref = promptRef('invalidSchemaPromptRef'); + const ref = promptRef('invalidSchemaPromptRef'); - await assert.rejects(async () => { - await ref.generate({ input: { foo: 'not_a_boolean' } }); - }, ValidationError); - }); + await assert.rejects(async () => { + await ref.generate(registry, { input: { foo: 'not_a_boolean' } }); + }, ValidationError); }); it('Should support streamingCallback in DotpromptRef', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'streamingCallbackPrompt', - model: 'echo', - }, - `Hello {{name}}, streaming test.` - ); - - const ref = promptRef('streamingCallbackPrompt'); - - const streamingCallback = (chunk) => console.log(chunk); - const options = { - input: { name: 'Charlie' }, - streamingCallback, - returnToolRequests: true, - }; - - const rendered = await ref.render(options); - - assert.strictEqual(rendered.streamingCallback, streamingCallback); - assert.strictEqual(rendered.returnToolRequests, true); - }); + defineDotprompt( + registry, + { + name: 'streamingCallbackPrompt', + model: 'echo', + }, + `Hello {{name}}, streaming test.` + ); + + const ref = promptRef('streamingCallbackPrompt'); + + const streamingCallback = (chunk) => console.log(chunk); + const options = { + input: { name: 'Charlie' }, + streamingCallback, + returnToolRequests: true, + }; + + const rendered = await ref.render(registry, options); + + assert.strictEqual(rendered.streamingCallback, streamingCallback); + assert.strictEqual(rendered.returnToolRequests, true); }); it('Should cache loaded prompt in DotpromptRef', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'cacheTestPrompt', - model: 'echo', - }, - `This is a prompt for cache test.` - ); - - const ref = promptRef('cacheTestPrompt'); - const firstLoad = await ref.loadPrompt(); - const secondLoad = await ref.loadPrompt(); - - assert.strictEqual( - firstLoad, - secondLoad, - 'Loaded prompts should be identical (cached).' - ); - }); + defineDotprompt( + registry, + { + name: 'cacheTestPrompt', + model: 'echo', + }, + `This is a prompt for cache test.` + ); + + const ref = promptRef('cacheTestPrompt'); + const firstLoad = await ref.loadPrompt(registry); + const secondLoad = await ref.loadPrompt(registry); + + assert.strictEqual( + firstLoad, + secondLoad, + 'Loaded prompts should be identical (cached).' + ); }); it('should render system prompt', () => { - runWithRegistry(registry, () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt(model, `{{ role "system"}} hi`); - - const rendered = prompt.render({ input: {} }); - assert.deepStrictEqual(rendered.messages, [ - { - content: [{ text: ' hi' }], - role: 'system', - }, - ]); - }); + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt(registry, model, `{{ role "system"}} hi`); + + const rendered = prompt.render({ input: {} }); + assert.deepStrictEqual(rendered.messages, [ + { + content: [{ text: ' hi' }], + role: 'system', + }, + ]); }); }); diff --git a/js/plugins/evaluators/src/metrics/answer_relevancy.ts b/js/plugins/evaluators/src/metrics/answer_relevancy.ts index 74e91e743..833a5ce33 100644 --- a/js/plugins/evaluators/src/metrics/answer_relevancy.ts +++ b/js/plugins/evaluators/src/metrics/answer_relevancy.ts @@ -47,6 +47,7 @@ export async function answerRelevancyScore< throw new Error('Output was not provided'); } const prompt = await loadPromptFile( + ai.registry, path.resolve(getDirName(), '../../prompts/answer_relevancy.prompt') ); const response = await ai.generate({ diff --git a/js/plugins/evaluators/src/metrics/faithfulness.ts b/js/plugins/evaluators/src/metrics/faithfulness.ts index 3b3ed9e0e..244d0f10a 100644 --- a/js/plugins/evaluators/src/metrics/faithfulness.ts +++ b/js/plugins/evaluators/src/metrics/faithfulness.ts @@ -54,6 +54,7 @@ export async function faithfulnessScore< throw new Error('Output was not provided'); } const longFormPrompt = await loadPromptFile( + ai.registry, path.resolve(getDirName(), '../../prompts/faithfulness_long_form.prompt') ); const longFormResponse = await ai.generate({ @@ -75,6 +76,7 @@ export async function faithfulnessScore< const allStatements = statements.map((s) => `statement: ${s}`).join('\n'); const allContext = context.join('\n'); const nliPrompt = await loadPromptFile( + ai.registry, path.resolve(getDirName(), '../../prompts/faithfulness_nli.prompt') ); const response = await ai.generate({ diff --git a/js/plugins/evaluators/src/metrics/maliciousness.ts b/js/plugins/evaluators/src/metrics/maliciousness.ts index 048b7a9bb..5538cbc25 100644 --- a/js/plugins/evaluators/src/metrics/maliciousness.ts +++ b/js/plugins/evaluators/src/metrics/maliciousness.ts @@ -39,6 +39,7 @@ export async function maliciousnessScore< } const prompt = await loadPromptFile( + ai.registry, path.resolve(getDirName(), '../../prompts/maliciousness.prompt') ); //TODO: safetySettings are gemini specific - pull these out so they are tied to the LLM diff --git a/js/plugins/firebase/src/functions.ts b/js/plugins/firebase/src/functions.ts index 1e0e64870..89248274f 100644 --- a/js/plugins/firebase/src/functions.ts +++ b/js/plugins/firebase/src/functions.ts @@ -131,7 +131,7 @@ function wrapHttpsFlow< } await config.authPolicy.provider(req, res, () => - flow.expressHandler(genkit.registry, req, res) + flow.expressHandler(req, res) ); } ); diff --git a/js/plugins/google-cloud/tests/metrics_test.ts b/js/plugins/google-cloud/tests/metrics_test.ts index f25bbd9da..f9ba5c2cc 100644 --- a/js/plugins/google-cloud/tests/metrics_test.ts +++ b/js/plugins/google-cloud/tests/metrics_test.ts @@ -30,7 +30,6 @@ import { } from '@opentelemetry/sdk-metrics'; import { ReadableSpan } from '@opentelemetry/sdk-trace-base'; import { GenerateResponseData, Genkit, genkit, run, z } from 'genkit'; -import { runWithRegistry } from 'genkit/registry'; import { SPAN_TYPE_ATTR, appendSpan } from 'genkit/tracing'; import assert from 'node:assert'; import { after, before, beforeEach, describe, it } from 'node:test'; @@ -188,10 +187,8 @@ describe('GoogleCloudMetrics', () => { it('writes feature metrics for an action', async () => { const testAction = createAction(ai, 'featureAction'); - await runWithRegistry(ai.registry, async () => { - await testAction(null); - await testAction(null); - }); + await testAction(null); + await testAction(null); await getExportedSpans(); @@ -213,11 +210,9 @@ describe('GoogleCloudMetrics', () => { // after PR #1029 it('writes feature metrics for generate', async () => { - await runWithRegistry(ai.registry, async () => { - const testModel = createTestModel(ai, 'helloModel'); - await ai.generate({ model: testModel, prompt: 'Hi' }); - await ai.generate({ model: testModel, prompt: 'Yo' }); - }); + const testModel = createTestModel(ai, 'helloModel'); + await ai.generate({ model: testModel, prompt: 'Hi' }); + await ai.generate({ model: testModel, prompt: 'Yo' }); const spans = await getExportedSpans(); @@ -263,9 +258,7 @@ describe('GoogleCloudMetrics', () => { }); assert.rejects(async () => { - return await runWithRegistry(ai.registry, async () => { - return testAction(null); - }); + return testAction(null); }); await getExportedSpans(); @@ -416,9 +409,7 @@ describe('GoogleCloudMetrics', () => { }); }); - await runWithRegistry(ai.registry, async () => { - testAction(null); - }); + testAction(null); await getExportedSpans(); @@ -906,13 +897,11 @@ describe('GoogleCloudMetrics', () => { name: string, fn: () => Promise = async () => {} ) { - return runWithRegistry(ai.registry, () => - ai.defineFlow( - { - name, - }, - fn - ) + return ai.defineFlow( + { + name, + }, + fn ); } @@ -923,9 +912,7 @@ describe('GoogleCloudMetrics', () => { name: string, respFn: () => Promise ) { - return runWithRegistry(ai.registry, () => - ai.defineModel({ name }, (req) => respFn()) - ); + return ai.defineModel({ name }, (req) => respFn()); } function createTestModel(ai: Genkit, name: string) { diff --git a/js/plugins/google-cloud/tests/traces_test.ts b/js/plugins/google-cloud/tests/traces_test.ts index b4b687ed4..298002acb 100644 --- a/js/plugins/google-cloud/tests/traces_test.ts +++ b/js/plugins/google-cloud/tests/traces_test.ts @@ -16,7 +16,6 @@ import { ReadableSpan } from '@opentelemetry/sdk-trace-base'; import { Genkit, genkit, run, z } from 'genkit'; -import { runWithRegistry } from 'genkit/registry'; import { appendSpan } from 'genkit/tracing'; import assert from 'node:assert'; import { after, before, beforeEach, describe, it } from 'node:test'; @@ -135,29 +134,27 @@ describe('GoogleCloudTracing', () => { }); it('adds the genkit/model label for model actions', async () => { - const echoModel = runWithRegistry(ai.registry, () => - ai.defineModel( - { - name: 'echoModel', - }, - async (request) => { - return { - message: { - role: 'model', - content: [ - { - text: - 'Echo: ' + - request.messages - .map((m) => m.content.map((c) => c.text).join()) - .join(), - }, - ], - }, - finishReason: 'stop', - }; - } - ) + const echoModel = ai.defineModel( + { + name: 'echoModel', + }, + async (request) => { + return { + message: { + role: 'model', + content: [ + { + text: + 'Echo: ' + + request.messages + .map((m) => m.content.map((c) => c.text).join()) + .join(), + }, + ], + }, + finishReason: 'stop', + }; + } ); const testFlow = createFlow(ai, 'modelFlow', async () => { return run('runFlow', async () => { diff --git a/js/testapps/byo-evaluator/src/deliciousness/deliciousness.ts b/js/testapps/byo-evaluator/src/deliciousness/deliciousness.ts index d3750cb84..1c887d50d 100644 --- a/js/testapps/byo-evaluator/src/deliciousness/deliciousness.ts +++ b/js/testapps/byo-evaluator/src/deliciousness/deliciousness.ts @@ -40,6 +40,7 @@ export async function deliciousnessScore< throw new Error('Output is required for Funniness detection'); } const finalPrompt = await loadPromptFile( + ai.registry, path.resolve(__dirname, '../../prompts/deliciousness.prompt') ); const response = await ai.generate({ diff --git a/js/testapps/byo-evaluator/src/funniness/funniness.ts b/js/testapps/byo-evaluator/src/funniness/funniness.ts index 3f38f0e1e..e1a1df5cf 100644 --- a/js/testapps/byo-evaluator/src/funniness/funniness.ts +++ b/js/testapps/byo-evaluator/src/funniness/funniness.ts @@ -42,6 +42,7 @@ export async function funninessScore( throw new Error('Output is required for Funniness detection'); } const finalPrompt = await loadPromptFile( + ai.registry, path.resolve(__dirname, '../../prompts/funniness.prompt') ); diff --git a/js/testapps/byo-evaluator/src/pii/pii_detection.ts b/js/testapps/byo-evaluator/src/pii/pii_detection.ts index b9d296f5d..d0079fdd1 100644 --- a/js/testapps/byo-evaluator/src/pii/pii_detection.ts +++ b/js/testapps/byo-evaluator/src/pii/pii_detection.ts @@ -37,6 +37,7 @@ export async function piiDetectionScore< throw new Error('Output is required for PII detection'); } const finalPrompt = await loadPromptFile( + ai.registry, path.resolve(__dirname, '../../prompts/pii_detection.prompt') ); diff --git a/js/testapps/cat-eval/src/pdf_rag_firebase.ts b/js/testapps/cat-eval/src/pdf_rag_firebase.ts index 9b42954e6..469b1e925 100644 --- a/js/testapps/cat-eval/src/pdf_rag_firebase.ts +++ b/js/testapps/cat-eval/src/pdf_rag_firebase.ts @@ -22,7 +22,6 @@ import { initializeApp } from 'firebase-admin/app'; import { getFirestore } from 'firebase-admin/firestore'; import { readFile } from 'fs/promises'; import { run, z } from 'genkit'; -import { runWithRegistry } from 'genkit/registry'; import { chunk } from 'llm-chunk'; import path from 'path'; import pdf from 'pdf-parse'; @@ -58,17 +57,15 @@ Question: ${question} Helpful Answer:`; } -export const pdfChatRetrieverFirebase = runWithRegistry(ai.registry, () => - defineFirestoreRetriever(ai, { - name: 'pdfChatRetrieverFirebase', - firestore, - collection: 'pdf-qa', - contentField: 'facts', - vectorField: 'embedding', - embedder: textEmbeddingGecko, - distanceMeasure: 'COSINE', - }) -); +export const pdfChatRetrieverFirebase = defineFirestoreRetriever(ai, { + name: 'pdfChatRetrieverFirebase', + firestore, + collection: 'pdf-qa', + contentField: 'facts', + vectorField: 'embedding', + embedder: textEmbeddingGecko, + distanceMeasure: 'COSINE', +}); // Define a simple RAG flow, we will evaluate this flow export const pdfQAFirebase = ai.defineFlow( diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index f5d58718c..29ff3cf45 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -32,7 +32,6 @@ import { AlwaysOnSampler } from '@opentelemetry/sdk-trace-base'; import { initializeApp } from 'firebase-admin/app'; import { getFirestore } from 'firebase-admin/firestore'; import { MessageSchema, genkit, run, z } from 'genkit'; -import { runWithRegistry } from 'genkit/registry'; import { Allow, parse } from 'partial-json'; enableGoogleCloudTelemetry({ @@ -274,16 +273,14 @@ export const multimodalFlow = ai.defineFlow( } ); -const destinationsRetriever = runWithRegistry(ai.registry, () => - defineFirestoreRetriever(ai, { - name: 'destinationsRetriever', - firestore: getFirestore(app), - collection: 'destinations', - contentField: 'knownFor', - embedder: textEmbeddingGecko, - vectorField: 'embedding', - }) -); +const destinationsRetriever = defineFirestoreRetriever(ai, { + name: 'destinationsRetriever', + firestore: getFirestore(app), + collection: 'destinations', + contentField: 'knownFor', + embedder: textEmbeddingGecko, + vectorField: 'embedding', +}); export const searchDestinations = ai.defineFlow( { diff --git a/js/testapps/model-tester/src/index.ts b/js/testapps/model-tester/src/index.ts index 5d874c785..cc9fb5361 100644 --- a/js/testapps/model-tester/src/index.ts +++ b/js/testapps/model-tester/src/index.ts @@ -44,7 +44,7 @@ export const ai = genkit({ ], }); -testModels([ +testModels(ai.registry, [ 'googleai/gemini-1.5-pro-latest', 'googleai/gemini-1.5-flash-latest', 'vertexai/gemini-1.5-pro',