Skip to content

Commit

Permalink
breaking: got rid of runWithRegistry and most vestiges of global regi…
Browse files Browse the repository at this point in the history
…stry
  • Loading branch information
pavelgj committed Oct 22, 2024
1 parent eab4bb1 commit d822d00
Show file tree
Hide file tree
Showing 40 changed files with 1,030 additions and 1,312 deletions.
28 changes: 17 additions & 11 deletions js/ai/src/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/

import { Action, defineAction, z } from '@genkit-ai/core';
import { lookupAction } from '@genkit-ai/core/registry';
import { Registry } from '@genkit-ai/core/registry';
import { Document, DocumentData, DocumentDataSchema } from './document.js';

export type EmbeddingBatch = { embedding: number[] }[];
Expand Down Expand Up @@ -68,6 +68,7 @@ function withMetadata<CustomOptions extends z.ZodTypeAny>(
export function defineEmbedder<
ConfigSchema extends z.ZodTypeAny = z.ZodTypeAny,
>(
registry: Registry,
options: {
name: string;
configSchema?: ConfigSchema;
Expand All @@ -76,6 +77,7 @@ export function defineEmbedder<
runner: EmbedderFn<ConfigSchema>
) {
const embedder = defineAction(
registry,
{
actionType: 'embedder',
name: options.name,
Expand Down Expand Up @@ -111,13 +113,14 @@ export type EmbedderArgument<
* A veneer for interacting with embedder models.
*/
export async function embed<CustomOptions extends z.ZodTypeAny = z.ZodTypeAny>(
registry: Registry,
params: EmbedderParams<CustomOptions>
): Promise<Embedding> {
let embedder: EmbedderAction<CustomOptions>;
if (typeof params.embedder === 'string') {
embedder = await lookupAction(`/embedder/${params.embedder}`);
embedder = await registry.lookupAction(`/embedder/${params.embedder}`);
} else if (Object.hasOwnProperty.call(params.embedder, 'info')) {
embedder = await lookupAction(
embedder = await registry.lookupAction(
`/embedder/${(params.embedder as EmbedderReference).name}`
);
} else {
Expand All @@ -141,17 +144,20 @@ export async function embed<CustomOptions extends z.ZodTypeAny = z.ZodTypeAny>(
*/
export async function embedMany<
ConfigSchema extends z.ZodTypeAny = z.ZodTypeAny,
>(params: {
embedder: EmbedderArgument<ConfigSchema>;
content: string[] | DocumentData[];
metadata?: Record<string, unknown>;
options?: z.infer<ConfigSchema>;
}): Promise<EmbeddingBatch> {
>(
registry: Registry,
params: {
embedder: EmbedderArgument<ConfigSchema>;
content: string[] | DocumentData[];
metadata?: Record<string, unknown>;
options?: z.infer<ConfigSchema>;
}
): Promise<EmbeddingBatch> {
let embedder: EmbedderAction<ConfigSchema>;
if (typeof params.embedder === 'string') {
embedder = await lookupAction(`/embedder/${params.embedder}`);
embedder = await registry.lookupAction(`/embedder/${params.embedder}`);
} else if (Object.hasOwnProperty.call(params.embedder, 'info')) {
embedder = await lookupAction(
embedder = await registry.lookupAction(
`/embedder/${(params.embedder as EmbedderReference).name}`
);
} else {
Expand Down
15 changes: 11 additions & 4 deletions js/ai/src/evaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import { Action, defineAction, z } from '@genkit-ai/core';
import { logger } from '@genkit-ai/core/logging';
import { lookupAction } from '@genkit-ai/core/registry';
import { Registry } from '@genkit-ai/core/registry';
import { SPAN_TYPE_ATTR, runInNewSpan } from '@genkit-ai/core/tracing';
import { randomUUID } from 'crypto';

Expand Down Expand Up @@ -127,6 +127,7 @@ export function defineEvaluator<
typeof BaseEvalDataPointSchema = typeof BaseEvalDataPointSchema,
EvaluatorOptions extends z.ZodTypeAny = z.ZodTypeAny,
>(
registry: Registry,
options: {
name: string;
displayName: string;
Expand All @@ -143,6 +144,7 @@ export function defineEvaluator<
metadata[EVALUATOR_METADATA_KEY_DISPLAY_NAME] = options.displayName;
metadata[EVALUATOR_METADATA_KEY_DEFINITION] = options.definition;
const evaluator = defineAction(
registry,
{
actionType: 'evaluator',
name: options.name,
Expand Down Expand Up @@ -239,12 +241,17 @@ export type EvaluatorArgument<
export async function evaluate<
DataPoint extends typeof BaseDataPointSchema = typeof BaseDataPointSchema,
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
>(params: EvaluatorParams<DataPoint, CustomOptions>): Promise<EvalResponses> {
>(
registry: Registry,
params: EvaluatorParams<DataPoint, CustomOptions>
): Promise<EvalResponses> {
let evaluator: EvaluatorAction<DataPoint, CustomOptions>;
if (typeof params.evaluator === 'string') {
evaluator = await lookupAction(`/evaluator/${params.evaluator}`);
evaluator = await registry.lookupAction(`/evaluator/${params.evaluator}`);
} else if (Object.hasOwnProperty.call(params.evaluator, 'info')) {
evaluator = await lookupAction(`/evaluator/${params.evaluator.name}`);
evaluator = await registry.lookupAction(
`/evaluator/${params.evaluator.name}`
);
} else {
evaluator = params.evaluator as EvaluatorAction<DataPoint, CustomOptions>;
}
Expand Down
28 changes: 19 additions & 9 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import {
StreamingCallback,
z,
} from '@genkit-ai/core';
import { lookupAction } from '@genkit-ai/core/registry';
import { Registry } from '@genkit-ai/core/registry';
import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema';
import { DocumentData } from './document.js';
import { extractJson } from './extract.js';
Expand Down Expand Up @@ -365,6 +365,7 @@ export class GenerateResponseChunk<T = unknown>
}

export async function toGenerateRequest(
registry: Registry,
options: GenerateOptions
): Promise<GenerateRequest> {
const messages: MessageData[] = [];
Expand Down Expand Up @@ -402,7 +403,7 @@ export async function toGenerateRequest(
}
let tools: Action<any, any>[] | undefined;
if (options.tools) {
tools = await resolveTools(options.tools);
tools = await resolveTools(registry, options.tools);
}

const out = {
Expand Down Expand Up @@ -464,21 +465,28 @@ interface ResolvedModel<CustomOptions extends z.ZodTypeAny = z.ZodTypeAny> {
version?: string;
}

async function resolveModel(options: GenerateOptions): Promise<ResolvedModel> {
async function resolveModel(
registry: Registry,
options: GenerateOptions
): Promise<ResolvedModel> {
let model = options.model;
if (!model) {
throw new Error('Model is required.');
}
if (typeof model === 'string') {
return {
modelAction: (await lookupAction(`/model/${model}`)) as ModelAction,
modelAction: (await registry.lookupAction(
`/model/${model}`
)) as ModelAction,
};
} else if (model.hasOwnProperty('__action')) {
return { modelAction: model as ModelAction };
} else {
const ref = model as ModelReference<any>;
return {
modelAction: (await lookupAction(`/model/${ref.name}`)) as ModelAction,
modelAction: (await registry.lookupAction(
`/model/${ref.name}`
)) as ModelAction,
config: {
...ref.config,
},
Expand Down Expand Up @@ -525,13 +533,14 @@ export async function generate<
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema,
>(
registry: Registry,
options:
| GenerateOptions<O, CustomOptions>
| PromiseLike<GenerateOptions<O, CustomOptions>>
): Promise<GenerateResponse<z.infer<O>>> {
const resolvedOptions: GenerateOptions<O, CustomOptions> =
await Promise.resolve(options);
const resolvedModel = await resolveModel(resolvedOptions);
const resolvedModel = await resolveModel(registry, resolvedOptions);
const model = resolvedModel.modelAction;
if (!model) {
let modelId: string;
Expand Down Expand Up @@ -623,8 +632,8 @@ export async function generate<
resolvedOptions.streamingCallback,
async () =>
new GenerateResponse<O>(
await generateHelper(params, resolvedOptions.use),
await toGenerateRequest(resolvedOptions)
await generateHelper(registry, params, resolvedOptions.use),
await toGenerateRequest(registry, resolvedOptions)
)
);
}
Expand Down Expand Up @@ -653,6 +662,7 @@ export async function generateStream<
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema,
>(
registry: Registry,
options:
| GenerateOptions<O, CustomOptions>
| PromiseLike<GenerateOptions<O, CustomOptions>>
Expand All @@ -678,7 +688,7 @@ export async function generateStream<
}

try {
generate<O, CustomOptions>({
generate<O, CustomOptions>(registry, {
...options,
streamingCallback: (chunk) => {
firstChunkSent = true;
Expand Down
12 changes: 7 additions & 5 deletions js/ai/src/generateAction.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import {
runWithStreamingCallback,
z,
} from '@genkit-ai/core';
import { lookupAction } from '@genkit-ai/core/registry';
import { Registry } from '@genkit-ai/core/registry';
import { toJsonSchema, validateSchema } from '@genkit-ai/core/schema';
import { runInNewSpan, SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing';
import * as clc from 'colorette';
Expand Down Expand Up @@ -70,6 +70,7 @@ export const GenerateUtilParamSchema = z.object({
* Encapsulates all generate logic. This is similar to `generateAction` except not an action and can take middleware.
*/
export async function generateHelper(
registry: Registry,
input: z.infer<typeof GenerateUtilParamSchema>,
middleware?: Middleware[]
): Promise<GenerateResponseData> {
Expand All @@ -86,18 +87,19 @@ export async function generateHelper(
async (metadata) => {
metadata.name = 'generate';
metadata.input = input;
const output = await generate(input, middleware);
const output = await generate(registry, input, middleware);
metadata.output = JSON.stringify(output);
return output;
}
);
}

async function generate(
registry: Registry,
rawRequest: z.infer<typeof GenerateUtilParamSchema>,
middleware?: Middleware[]
): Promise<GenerateResponseData> {
const model = (await lookupAction(
const model = (await registry.lookupAction(
`/model/${rawRequest.model}`
)) as ModelAction;
if (!model) {
Expand All @@ -120,7 +122,7 @@ async function generate(
tools = await Promise.all(
rawRequest.tools.map(async (toolRef) => {
if (typeof toolRef === 'string') {
const tool = (await lookupAction(toolRef)) as ToolAction;
const tool = (await registry.lookupAction(toolRef)) as ToolAction;
if (!tool) {
throw new Error(`Tool ${toolRef} not found`);
}
Expand Down Expand Up @@ -203,7 +205,7 @@ async function generate(
messages: [...request.messages, message],
prompt: toolResponses,
};
return await generateHelper(nextRequest, middleware);
return await generateHelper(registry, nextRequest, middleware);
}

async function actionToGenerateRequest(
Expand Down
3 changes: 3 additions & 0 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
StreamingCallback,
z,
} from '@genkit-ai/core';
import { Registry } from '@genkit-ai/core/registry';
import { toJsonSchema } from '@genkit-ai/core/schema';
import { performance } from 'node:perf_hooks';
import { DocumentDataSchema } from './document.js';
Expand Down Expand Up @@ -330,6 +331,7 @@ export type DefineModelOptions<
export function defineModel<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
>(
registry: Registry,
options: DefineModelOptions<CustomOptionsSchema>,
runner: (
request: GenerateRequest<CustomOptionsSchema>,
Expand All @@ -344,6 +346,7 @@ export function defineModel<
if (!options?.supports?.context) middleware.push(augmentWithContext());
middleware.push(conformOutput());
const act = defineAction(
registry,
{
actionType: 'model',
name: options.name,
Expand Down
23 changes: 14 additions & 9 deletions js/ai/src/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/

import { Action, defineAction, JSONSchema7, z } from '@genkit-ai/core';
import { lookupAction } from '@genkit-ai/core/registry';
import { Registry } from '@genkit-ai/core/registry';
import { DocumentData } from './document.js';
import {
GenerateOptions,
Expand Down Expand Up @@ -150,10 +150,12 @@ export interface ExecutablePrompt<
* @returns The new `PromptAction`.
*/
export function definePrompt<I extends z.ZodTypeAny>(
registry: Registry,
config: PromptConfig<I>,
fn: PromptFn<I>
): PromptAction<I> {
const a = defineAction(
registry,
{
...config,
actionType: 'prompt',
Expand All @@ -177,16 +179,19 @@ export async function renderPrompt<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
>(params: {
prompt: PromptArgument<I>;
input: z.infer<I>;
docs?: DocumentData[];
model: ModelArgument<CustomOptions>;
config?: z.infer<CustomOptions>;
}): Promise<GenerateOptions<O, CustomOptions>> {
>(
registry: Registry,
params: {
prompt: PromptArgument<I>;
input: z.infer<I>;
docs?: DocumentData[];
model: ModelArgument<CustomOptions>;
config?: z.infer<CustomOptions>;
}
): Promise<GenerateOptions<O, CustomOptions>> {
let prompt: PromptAction<I>;
if (typeof params.prompt === 'string') {
prompt = await lookupAction(`/prompt/${params.prompt}`);
prompt = await registry.lookupAction(`/prompt/${params.prompt}`);
} else {
prompt = params.prompt as PromptAction<I>;
}
Expand Down
9 changes: 6 additions & 3 deletions js/ai/src/reranker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/

import { Action, defineAction, z } from '@genkit-ai/core';
import { lookupAction } from '@genkit-ai/core/registry';
import { Registry } from '@genkit-ai/core/registry';
import { Part, PartSchema } from './document.js';
import { Document, DocumentData, DocumentDataSchema } from './retriever.js';

Expand Down Expand Up @@ -101,6 +101,7 @@ function rerankerWithMetadata<
* Creates a reranker action for the provided {@link RerankerFn} implementation.
*/
export function defineReranker<OptionsType extends z.ZodTypeAny = z.ZodTypeAny>(
registry: Registry,
options: {
name: string;
configSchema?: OptionsType;
Expand All @@ -109,6 +110,7 @@ export function defineReranker<OptionsType extends z.ZodTypeAny = z.ZodTypeAny>(
runner: RerankerFn<OptionsType>
) {
const reranker = defineAction(
registry,
{
actionType: 'reranker',
name: options.name,
Expand Down Expand Up @@ -157,13 +159,14 @@ export type RerankerArgument<
* Reranks documents from a {@link RerankerArgument} based on the provided query.
*/
export async function rerank<CustomOptions extends z.ZodTypeAny>(
registry: Registry,
params: RerankerParams<CustomOptions>
): Promise<Array<RankedDocument>> {
let reranker: RerankerAction<CustomOptions>;
if (typeof params.reranker === 'string') {
reranker = await lookupAction(`/reranker/${params.reranker}`);
reranker = await registry.lookupAction(`/reranker/${params.reranker}`);
} else if (Object.hasOwnProperty.call(params.reranker, 'info')) {
reranker = await lookupAction(`/reranker/${params.reranker.name}`);
reranker = await registry.lookupAction(`/reranker/${params.reranker.name}`);
} else {
reranker = params.reranker as RerankerAction<CustomOptions>;
}
Expand Down
Loading

0 comments on commit d822d00

Please sign in to comment.