From aa4cc954c2304fbb77c618bd1ac0eacff25abdad Mon Sep 17 00:00:00 2001 From: Michael Doyle Date: Wed, 23 Oct 2024 14:21:36 -0400 Subject: [PATCH] Improve error handling and messaging for cloud-logging --- .github/workflows/builder.yml | 4 +- .github/workflows/e2e-tests.yml | 2 +- .github/workflows/formatter.yml | 4 +- .github/workflows/tests.yml | 4 +- README.md | 11 +- docs/dotprompt.md | 8 +- docs/evaluation.md | 2 +- docs/get-started.md | 2 +- docs/models.md | 14 +- docs/plugin-authoring-evaluator.md | 4 +- docs/plugins/google-cloud.md | 90 +- docs/prompts.md | 2 +- docs/rag.md | 6 +- docs/tool-calling.md | 2 +- .../cli/config/firebase.index.ts.template | 2 +- .../cli/config/nextjs.genkit.ts.template | 2 +- .../cli/config/nodejs.index.ts.template | 2 +- genkit-tools/cli/package.json | 4 +- genkit-tools/cli/src/commands/flow-run.ts | 10 +- genkit-tools/cli/src/commands/ui-start.ts | 19 +- genkit-tools/cli/src/utils/server-harness.ts | 25 + genkit-tools/common/package.json | 4 +- genkit-tools/common/src/eval/evaluate.ts | 22 +- .../common/src/eval/localFileDatasetStore.ts | 27 +- .../common/src/eval/localFileEvalStore.ts | 8 +- genkit-tools/common/src/manager/manager.ts | 4 + genkit-tools/common/src/server/server.ts | 11 +- genkit-tools/common/src/types/apis.ts | 12 +- genkit-tools/common/src/types/eval.ts | 13 + genkit-tools/common/src/utils/analytics.ts | 2 +- genkit-tools/common/src/utils/utils.ts | 17 +- .../tests/eval/localFileDatasetStore_test.ts | 85 +- .../tests/eval/localFileEvalStore_test.ts | 22 +- genkit-tools/package.json | 2 +- genkit-tools/pnpm-lock.yaml | 14 +- genkit-tools/telemetry-server/package.json | 2 +- js/ai/package.json | 2 +- js/ai/src/document.ts | 4 +- js/ai/src/embedder.ts | 96 +- js/ai/src/evaluator.ts | 17 +- js/ai/src/generate.ts | 190 +++- js/ai/src/generateAction.ts | 42 +- js/ai/src/index.ts | 2 + js/ai/src/model.ts | 59 +- js/ai/src/model/middleware.ts | 6 +- js/ai/src/prompt.ts | 110 ++- js/ai/src/reranker.ts | 11 +- js/ai/src/retriever.ts | 24 +- js/ai/src/testing/model-tester.ts | 58 +- js/ai/src/tool.ts | 10 +- js/ai/tests/generate/generate_test.ts | 148 ++- js/ai/tests/model/document_test.ts | 6 +- js/ai/tests/model/middleware_test.ts | 31 +- js/ai/tests/prompt/prompt_test.ts | 61 +- js/ai/tests/reranker/reranker_test.ts | 271 +++--- js/core/package.json | 2 +- js/core/src/action.ts | 29 +- js/core/src/flow-client/client.ts | 4 +- js/core/src/flow.ts | 159 ++-- js/core/src/reflection.ts | 15 +- js/core/src/registry.ts | 111 --- js/core/src/schema.ts | 13 +- js/core/src/tracing/exporter.ts | 2 +- js/core/tests/flow_test.ts | 79 +- js/core/tests/registry_test.ts | 170 +--- js/genkit/package.json | 17 +- js/genkit/src/chat.ts | 183 ++++ js/genkit/src/embedder.ts | 13 +- js/genkit/src/evaluator.ts | 21 +- js/genkit/src/extract.ts | 2 +- js/genkit/src/genkit.ts | 889 +++++++++++------- js/genkit/src/index.ts | 140 ++- js/genkit/src/logging.ts | 2 +- js/genkit/src/middleware.ts | 9 +- js/genkit/src/model.ts | 54 +- js/genkit/src/plugin.ts | 41 + js/genkit/src/registry.ts | 7 +- js/genkit/src/reranker.ts | 14 +- js/genkit/src/retriever.ts | 27 +- js/genkit/src/schema.ts | 11 +- js/genkit/src/session.ts | 241 +++++ js/genkit/src/testing.ts | 2 +- js/genkit/src/tool.ts | 8 +- js/genkit/src/tracing.ts | 34 +- js/genkit/tests/chat_test.ts | 153 +++ js/genkit/tests/embed_test.ts | 140 +++ js/genkit/tests/generate_test.ts | 158 ++++ js/genkit/tests/helpers.ts | 30 +- js/genkit/tests/models_test.ts | 75 -- js/genkit/tests/prompts_test.ts | 85 +- js/genkit/tests/session_test.ts | 326 +++++++ js/package.json | 2 +- js/plugins/chroma/package.json | 2 +- js/plugins/chroma/src/index.ts | 120 ++- js/plugins/dev-local-vectorstore/package.json | 2 +- js/plugins/dev-local-vectorstore/src/index.ts | 72 +- js/plugins/dotprompt/package.json | 2 +- js/plugins/dotprompt/src/index.ts | 36 +- js/plugins/dotprompt/src/metadata.ts | 21 +- js/plugins/dotprompt/src/prompt.ts | 78 +- js/plugins/dotprompt/src/registry.ts | 23 +- js/plugins/dotprompt/tests/prompt_test.ts | 595 ++++++------ js/plugins/evaluators/package.json | 3 +- js/plugins/evaluators/src/index.ts | 41 +- .../src/metrics/answer_relevancy.ts | 19 +- .../evaluators/src/metrics/faithfulness.ts | 14 +- .../evaluators/src/metrics/maliciousness.ts | 11 +- js/plugins/firebase/jest.config.ts | 2 +- js/plugins/firebase/package.json | 6 +- js/plugins/firebase/src/firestoreRetriever.ts | 63 +- js/plugins/firebase/src/functions.ts | 2 +- js/plugins/google-cloud/jest.config.ts | 48 + js/plugins/google-cloud/package.json | 12 +- js/plugins/google-cloud/src/auth.ts | 42 +- js/plugins/google-cloud/src/gcpLogger.ts | 25 + .../google-cloud/src/gcpOpenTelemetry.ts | 120 ++- js/plugins/google-cloud/src/index.ts | 2 +- .../google-cloud/src/telemetry/action.ts | 11 +- js/plugins/google-cloud/src/types.ts | 5 + js/plugins/google-cloud/src/utils.ts | 61 ++ .../google-cloud/tests/logs_no_io_test.ts | 48 +- js/plugins/google-cloud/tests/logs_test.ts | 50 +- js/plugins/google-cloud/tests/metrics_test.ts | 132 +-- js/plugins/google-cloud/tests/traces_test.ts | 86 +- js/plugins/googleai/package.json | 2 +- js/plugins/googleai/src/embedder.ts | 49 +- js/plugins/googleai/src/gemini.ts | 151 +-- js/plugins/googleai/src/index.ts | 84 +- js/plugins/langchain/package.json | 2 +- js/plugins/langchain/src/evaluators.ts | 9 +- js/plugins/langchain/src/index.ts | 23 +- js/plugins/langchain/src/model.ts | 15 +- js/plugins/ollama/package.json | 2 +- js/plugins/ollama/src/embeddings.ts | 35 +- js/plugins/ollama/src/index.ts | 40 +- .../ollama/tests/embedding_live_test.ts | 7 +- js/plugins/ollama/tests/embeddings_test.ts | 124 +-- js/plugins/pinecone/package.json | 2 +- js/plugins/pinecone/src/index.ts | 74 +- js/plugins/vertexai/package.json | 4 +- js/plugins/vertexai/src/anthropic.ts | 5 +- js/plugins/vertexai/src/embedder.ts | 125 +-- js/plugins/vertexai/src/evaluation.ts | 35 +- js/plugins/vertexai/src/evaluator_factory.ts | 7 +- js/plugins/vertexai/src/gemini.ts | 74 +- js/plugins/vertexai/src/imagen.ts | 8 +- js/plugins/vertexai/src/index.ts | 104 +- js/plugins/vertexai/src/model_garden.ts | 7 +- .../vertexai/src/openai_compatibility.ts | 10 +- js/plugins/vertexai/src/reranker.ts | 16 +- .../vertexai/src/vector-search/indexers.ts | 10 +- .../vertexai/src/vector-search/retrievers.ts | 8 +- .../query_public_endpoint_test.ts | 3 +- .../vector-search/upsert_datapoints_test.ts | 3 +- js/pnpm-lock.yaml | 424 +++------ js/testapps/anthropic-models/src/index.ts | 2 +- js/testapps/basic-gemini/src/index.ts | 25 +- .../src/deliciousness/deliciousness.ts | 5 +- .../deliciousness/deliciousness_evaluator.ts | 11 +- .../byo-evaluator/src/funniness/funniness.ts | 5 +- .../src/funniness/funniness_evaluator.ts | 11 +- js/testapps/byo-evaluator/src/index.ts | 72 +- .../byo-evaluator/src/pii/pii_detection.ts | 5 +- .../byo-evaluator/src/pii/pii_evaluator.ts | 11 +- .../src/regex/regex_evaluator.ts | 5 +- js/testapps/cat-eval/package.json | 6 +- js/testapps/cat-eval/src/genkit.ts | 65 ++ js/testapps/cat-eval/src/index.ts | 51 - js/testapps/cat-eval/src/pdf_rag.ts | 14 +- js/testapps/cat-eval/src/pdf_rag_firebase.ts | 35 +- js/testapps/cat-eval/src/setup.ts | 2 +- js/testapps/dev-ui-gallery/package.json | 3 +- js/testapps/dev-ui-gallery/src/genkit.ts | 19 +- .../src/main/flows-firebase-functions.ts | 6 +- .../dev-ui-gallery/src/main/prompts.ts | 30 +- js/testapps/dev-ui-gallery/src/main/tools.ts | 7 +- js/testapps/docs-menu-basic/src/index.ts | 6 +- js/testapps/docs-menu-rag/src/index.ts | 4 +- js/testapps/docs-menu-rag/src/menuQA.ts | 8 +- js/testapps/eval/src/index.ts | 6 +- js/testapps/evaluator-gut-check/src/index.ts | 10 +- js/testapps/express/src/index.ts | 2 +- .../functions/package.json | 4 +- .../functions/src/index.ts | 2 +- js/testapps/flow-simple-ai/package.json | 2 +- js/testapps/flow-simple-ai/src/index.ts | 80 +- .../google-ai-code-execution/src/index.ts | 2 +- js/testapps/menu/src/01/prompts.ts | 8 +- js/testapps/menu/src/02/flows.ts | 2 +- js/testapps/menu/src/02/prompts.ts | 4 +- js/testapps/menu/src/03/flows.ts | 6 +- js/testapps/menu/src/04/flows.ts | 2 +- js/testapps/menu/src/04/prompts.ts | 4 +- js/testapps/menu/src/05/flows.ts | 4 +- js/testapps/menu/src/05/prompts.ts | 6 +- js/testapps/menu/src/index.ts | 4 +- js/testapps/model-tester/src/index.ts | 10 +- js/testapps/prompt-file/src/index.ts | 17 +- js/testapps/rag/src/genkit.ts | 21 +- js/testapps/rag/src/pdf_rag.ts | 10 +- js/testapps/rag/src/prompt.ts | 10 +- js/testapps/rag/src/simple_rag.ts | 8 +- js/testapps/vertexai-reranker/README.md | 2 +- js/testapps/vertexai-reranker/src/index.ts | 2 +- .../package.json | 2 +- package.json | 2 +- samples/chatbot/server/src/index.ts | 6 +- samples/js-angular/server/src/agent.ts | 4 +- .../js-angular/server/src/jsonStreaming.ts | 2 +- samples/js-coffee-shop/src/index.ts | 5 +- samples/js-menu/src/02/flows.ts | 2 +- samples/js-menu/src/03/flows.ts | 2 +- samples/js-menu/src/04/flows.ts | 2 +- samples/js-menu/src/05/flows.ts | 4 +- samples/prompts/src/index.ts | 2 +- scripts/release_main.sh | 78 ++ scripts/release_next.sh | 83 ++ tests/test_js_app/src/index.ts | 4 +- 218 files changed, 5489 insertions(+), 3441 deletions(-) create mode 100644 js/genkit/src/chat.ts create mode 100644 js/genkit/src/plugin.ts create mode 100644 js/genkit/src/session.ts create mode 100644 js/genkit/tests/chat_test.ts create mode 100644 js/genkit/tests/embed_test.ts create mode 100644 js/genkit/tests/generate_test.ts delete mode 100644 js/genkit/tests/models_test.ts create mode 100644 js/genkit/tests/session_test.ts create mode 100644 js/plugins/google-cloud/jest.config.ts create mode 100644 js/testapps/cat-eval/src/genkit.ts create mode 100755 scripts/release_main.sh create mode 100755 scripts/release_next.sh diff --git a/.github/workflows/builder.yml b/.github/workflows/builder.yml index 5eb75618c..af790a05f 100644 --- a/.github/workflows/builder.yml +++ b/.github/workflows/builder.yml @@ -27,10 +27,10 @@ jobs: steps: - uses: actions/checkout@v3 - uses: pnpm/action-setup@v3 - - name: Set up node v20 + - name: Set up node v21 uses: actions/setup-node@v4 with: - node-version: 20.x + node-version: 21.x cache: 'pnpm' - name: Install dependencies run: pnpm install diff --git a/.github/workflows/e2e-tests.yml b/.github/workflows/e2e-tests.yml index f2a25852c..92f369d92 100644 --- a/.github/workflows/e2e-tests.yml +++ b/.github/workflows/e2e-tests.yml @@ -29,7 +29,7 @@ jobs: - name: Set up node v20 uses: actions/setup-node@v4 with: - node-version: 20.x + node-version: 21.x cache: 'pnpm' - name: Install dependencies run: pnpm install diff --git a/.github/workflows/formatter.yml b/.github/workflows/formatter.yml index 852b085dc..23b012902 100644 --- a/.github/workflows/formatter.yml +++ b/.github/workflows/formatter.yml @@ -27,10 +27,10 @@ jobs: steps: - uses: actions/checkout@v3 - uses: pnpm/action-setup@v3 - - name: Set up node v20 + - name: Set up node v21 uses: actions/setup-node@v4 with: - node-version: 20.x + node-version: 21.x cache: 'pnpm' - name: Install dependencies run: pnpm install diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9e0589f79..41241a936 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,10 +27,10 @@ jobs: steps: - uses: actions/checkout@v3 - uses: pnpm/action-setup@v3 - - name: Set up node v20 + - name: Set up node v21 uses: actions/setup-node@v4 with: - node-version: 20.x + node-version: 21.x cache: 'pnpm' - name: Install dependencies and build run: pnpm build:genkit-tools diff --git a/README.md b/README.md index 1af64b3d3..8f076de7a 100644 --- a/README.md +++ b/README.md @@ -88,9 +88,14 @@ Find excellent examples of community-built plugins for OpenAI, Anthropic, Cohere ## Try Genkit on IDX -Project IDX logo - -Want to try Genkit without a local setup? [Explore it on Project IDX](https://idx.google.com/new/genkit), Google's AI-assisted workspace for full-stack app development in the cloud. +Want to skip the local setup? Click below to try out Genkit using [Project IDX](https://idx.dev), Google's AI-assisted workspace for full-stack app development in the cloud. + + + Try in IDX + ## Sample apps diff --git a/docs/dotprompt.md b/docs/dotprompt.md index c0a286092..706ec906e 100644 --- a/docs/dotprompt.md +++ b/docs/dotprompt.md @@ -53,7 +53,7 @@ const result = await greetingPrompt.generate({ }, }); -console.log(result.text()); +console.log(result.text); ``` Dotprompt's syntax is based on the [Handlebars](https://handlebarsjs.com/guide/) @@ -183,7 +183,7 @@ const myPrompt = promptRef("myPrompt"); const result = await myPrompt.generate({...}); // now strongly typed as MySchema -result.output(); +result.output; ``` ## Overriding Prompt Metadata @@ -237,7 +237,7 @@ const menu = await createMenuPrompt.generate({ }, }); -console.log(menu.output()); +console.log(menu.output); ``` Output conformance is achieved by inserting additional instructions into the @@ -340,7 +340,7 @@ const result = await describeImagePrompt.generate({ }, }); -console.log(result.text()); +console.log(result.text); ``` ## Partials diff --git a/docs/evaluation.md b/docs/evaluation.md index aa750f04d..613a9f77f 100644 --- a/docs/evaluation.md +++ b/docs/evaluation.md @@ -204,7 +204,7 @@ export const synthesizeQuestions = defineFlow( text: `Generate one question about the text below: ${chunks[i]}`, }, }); - questions.push(qResponse.text()); + questions.push(qResponse.text); } return questions; } diff --git a/docs/get-started.md b/docs/get-started.md index 0b7c02b89..694f89c02 100644 --- a/docs/get-started.md +++ b/docs/get-started.md @@ -131,7 +131,7 @@ so that it can be used outside of a Node project. // Handle the response from the model API. In this sample, we just convert // it to a string, but more complicated flows might coerce the response into // structured output or chain the response into another LLM call, etc. - return llmResponse.text(); + return llmResponse.text; } ); diff --git a/docs/models.md b/docs/models.md index ccffe308a..c1731469d 100644 --- a/docs/models.md +++ b/docs/models.md @@ -109,7 +109,7 @@ configureGenkit(/* ... */); prompt: 'Invent a menu item for a pirate themed restaurant.', }); - console.log(await llmResponse.text()); + console.log(await llmResponse.text); })(); ``` @@ -339,7 +339,7 @@ object's `output()` method: ```ts type MenuItem = z.infer; -const output: MenuItem | null = llmResponse.output(); +const output: MenuItem | null = llmResponse.output; ``` #### Handling errors @@ -425,7 +425,7 @@ Handle each of these chunks as they become available: ```ts for await (const responseChunkData of llmResponseStream.stream()) { const responseChunk = responseChunkData as GenerateResponseChunk; - console.log(responseChunk.text()); + console.log(responseChunk.text); } ``` @@ -454,7 +454,7 @@ const llmResponseStream = await generateStream({ for await (const responseChunkData of llmResponseStream.stream()) { const responseChunk = responseChunkData as GenerateResponseChunk; // output() returns an object representing the entire output so far - const output: Menu | null = responseChunk.output(); + const output: Menu | null = responseChunk.output; console.log(output); } ``` @@ -605,7 +605,7 @@ your users will not be interacting directly with the model in this way, the conversational style of prompting is a powerful way to influence the output generated by an AI model. -To generate message history from a model response, call the `toHistory()` +To generate message history from a model response, call the `.messages` method: ```ts @@ -613,7 +613,7 @@ let response = await generate({ model: gemini15Flash, prompt: "How do you say 'dog' in French?", }); -let history = response.toHistory(); +let history = response.messages; ``` You can serialize this history and persist it in a database or session storage. @@ -625,7 +625,7 @@ response = await generate({ prompt: 'How about in Spanish?', history, }); -history = response.toHistory(); +history = response.messages; ``` If the model you're using supports the `system` role, you can use the initial diff --git a/docs/plugin-authoring-evaluator.md b/docs/plugin-authoring-evaluator.md index bba9d3211..b5604229b 100644 --- a/docs/plugin-authoring-evaluator.md +++ b/docs/plugin-authoring-evaluator.md @@ -110,9 +110,9 @@ export async function deliciousnessScore< }); // Parse the output - const parsedResponse = response.output(); + const parsedResponse = response.output; if (!parsedResponse) { - throw new Error(`Unable to parse evaluator response: ${response.text()}`); + throw new Error(`Unable to parse evaluator response: ${response.text}`); } // Return a scored response diff --git a/docs/plugins/google-cloud.md b/docs/plugins/google-cloud.md index 5314558bc..6135eb8f2 100644 --- a/docs/plugins/google-cloud.md +++ b/docs/plugins/google-cloud.md @@ -221,32 +221,84 @@ Common dimensions include: - `topK` - the inference topK [value](https://ai.google.dev/docs/concepts#model-parameters). - `topP` - the inference topP [value](https://ai.google.dev/docs/concepts#model-parameters). -### Flow-level metrics +### Feature-level metrics + +Features are the top-level entry-point to your Genkit code. In most cases, this +will be a flow, but if you do not use flows, this will be the top-most span in a trace. + +| Name | Type | Description | +| ----------------------- | --------- | ----------------------- | +| genkit/feature/requests | Counter | Number of requests | +| genkit/feature/latency | Histogram | Execution latency in ms | + +Each feature-level metric contains the following dimensions: + +| Name | Description | +| ------------- | -------------------------------------------------------------------------------- | +| name | The name of the feature. In most cases, this is the top-level Genkit flow | +| status | 'success' or 'failure' depending on whether or not the feature request succeeded | +| error | Only set when `status=failure`. Contains the error type that caused the failure | +| source | The Genkit source language. Eg. 'ts' | +| sourceVersion | The Genkit framework version | -| Name | Dimensions | -| -------------------- | ------------------------------------ | -| genkit/flow/requests | flow_name, error_code, error_message | -| genkit/flow/latency | flow_name | ### Action-level metrics -| Name | Dimensions | -| ---------------------- | ------------------------------------ | -| genkit/action/requests | flow_name, error_code, error_message | -| genkit/action/latency | flow_name | +Actions represent a generic step of execution within Genkit. Each of these steps +will have the following metrics tracked: + +| Name | Type | Description | +| ----------------------- | --------- | --------------------------------------------- | +| genkit/action/requests | Counter | Number of times this action has been executed | +| genkit/action/latency | Histogram | Execution latency in ms | + +Each action-level metric contains the following dimensions: + +| Name | Description | +| ------------- | ---------------------------------------------------------------------------------------------------- | +| name | The name of the action | +| featureName | The name of the parent feature being executed | +| path | The path of execution from the feature root to this action. eg. '/myFeature/parentAction/thisAction' | +| status | 'success' or 'failure' depending on whether or not the action succeeded | +| error | Only set when `status=failure`. Contains the error type that caused the failure | +| source | The Genkit source language. Eg. 'ts' | +| sourceVersion | The Genkit framework version | ### Generate-level metrics -| Name | Dimensions | -| ------------------------------------ | -------------------------------------------------------------------- | -| genkit/ai/generate | flow_path, model, temperature, topK, topP, error_code, error_message | -| genkit/ai/generate/input_tokens | flow_path, model, temperature, topK, topP | -| genkit/ai/generate/output_tokens | flow_path, model, temperature, topK, topP | -| genkit/ai/generate/input_characters | flow_path, model, temperature, topK, topP | -| genkit/ai/generate/output_characters | flow_path, model, temperature, topK, topP | -| genkit/ai/generate/input_images | flow_path, model, temperature, topK, topP | -| genkit/ai/generate/output_images | flow_path, model, temperature, topK, topP | -| genkit/ai/generate/latency | flow_path, model, temperature, topK, topP, error_code, error_message | +These are special action metrics relating to actions that interact with a model. +In addition to requests and latency, input and output are also tracked, with model +specific dimensions that make debugging and configuration tuning easier. + +| Name | Type | Description | +| ------------------------------------ | --------- | ------------------------------------------ | +| genkit/ai/generate/requests | Counter | Number of times this model has been called | +| genkit/ai/generate/latency | Histogram | Execution latency in ms | +| genkit/ai/generate/input/tokens | Counter | Input tokens | +| genkit/ai/generate/output/tokens | Counter | Output tokens | +| genkit/ai/generate/input/characters | Counter | Input characters | +| genkit/ai/generate/output/characters | Counter | Output characters | +| genkit/ai/generate/input/images | Counter | Input images | +| genkit/ai/generate/output/images | Counter | Output images | +| genkit/ai/generate/input/audio | Counter | Input audio files | +| genkit/ai/generate/output/audio | Counter | Output audio files | + +Each generate-level metric contains the following dimensions: + +| Name | Description | +| --------------- | ---------------------------------------------------------------------------------------------------- | +| modelName | The name of the model | +| featureName | The name of the parent feature being executed | +| path | The path of execution from the feature root to this action. eg. '/myFeature/parentAction/thisAction' | +| temperature | The temperature parameter passed to the model | +| maxOutputTokens | The maxOutputTokens parameter passed to the model | +| topK | The topK parameter passed to the model | +| topP | The topP parameter passed to the model | +| latencyMs | The response time taken by the model | +| status | 'success' or 'failure' depending on whether or not the feature request succeeded | +| error | Only set when `status=failure`. Contains the error type that caused the failure | +| source | The Genkit source language. Eg. 'ts' | +| sourceVersion | The Genkit framework version | Visualizing metrics can be done through the Metrics Explorer. Using the side menu, select 'Logging' and click 'Metrics explorer' diff --git a/docs/prompts.md b/docs/prompts.md index 1ecf5715a..f3e915570 100644 --- a/docs/prompts.md +++ b/docs/prompts.md @@ -100,7 +100,7 @@ const response = await (threeGreetingsPrompt.generate( { input: { name: 'Fred' } } )); -response.output()?.likeAPirate +response.output?.likeAPirate // "Ahoy there, Fred! May the winds be ever in your favor!" ``` diff --git a/docs/rag.md b/docs/rag.md index 0adff7691..23b459ed0 100644 --- a/docs/rag.md +++ b/docs/rag.md @@ -296,7 +296,7 @@ export const menuQAFlow = defineFlow( context: docs, }); - const output = llmResponse.text(); + const output = llmResponse.text; return output; } ); @@ -333,7 +333,7 @@ defineSimpleRetriever({ // and several keys to use as metadata metadata: ['from', 'to', 'subject'], } async (query, config) => { - const result = await searchEmails(query.text(), {limit: config.limit}); + const result = await searchEmails(query.text, {limit: config.limit}); return result.data.emails; }); ``` @@ -433,7 +433,7 @@ export const rerankFlow = defineFlow( }); return rerankedDocuments.map((doc) => ({ - text: doc.text(), + text: doc.text, score: doc.metadata.score, })); } diff --git a/docs/tool-calling.md b/docs/tool-calling.md index 1b475efb9..0a846718c 100644 --- a/docs/tool-calling.md +++ b/docs/tool-calling.md @@ -169,7 +169,7 @@ while (true) { throw Error('Tool not found'); } })); - generateOptions.history = llmResponse.toHistory(); + generateOptions.history = llmResponse.messages; generateOptions.prompt = toolResponses; } ``` diff --git a/genkit-tools/cli/config/firebase.index.ts.template b/genkit-tools/cli/config/firebase.index.ts.template index c7030f9f1..f8c8217bf 100644 --- a/genkit-tools/cli/config/firebase.index.ts.template +++ b/genkit-tools/cli/config/firebase.index.ts.template @@ -59,6 +59,6 @@ export const menuSuggestionFlow = onFlow( // convert it to a string, but more complicated flows might coerce the // response into structured output or chain the response into another // LLM call, etc. - return llmResponse.text(); + return llmResponse.text; } ); diff --git a/genkit-tools/cli/config/nextjs.genkit.ts.template b/genkit-tools/cli/config/nextjs.genkit.ts.template index ed619a3be..c1c05d8f8 100644 --- a/genkit-tools/cli/config/nextjs.genkit.ts.template +++ b/genkit-tools/cli/config/nextjs.genkit.ts.template @@ -39,7 +39,7 @@ const menuSuggestionFlow = ai.defineFlow( // convert it to a string, but more complicated flows might coerce the // response into structured output or chain the response into another // LLM call, etc. - return llmResponse.text(); + return llmResponse.text; } ); diff --git a/genkit-tools/cli/config/nodejs.index.ts.template b/genkit-tools/cli/config/nodejs.index.ts.template index 5f480bc47..28388b239 100644 --- a/genkit-tools/cli/config/nodejs.index.ts.template +++ b/genkit-tools/cli/config/nodejs.index.ts.template @@ -39,6 +39,6 @@ export const menuSuggestionFlow = ai.defineFlow( // Handle the response from the model API. In this sample, we just convert // it to a string, but more complicated flows might coerce the response into // structured output or chain the response into another LLM call, etc. - return llmResponse.text(); + return llmResponse.text; } ); diff --git a/genkit-tools/cli/package.json b/genkit-tools/cli/package.json index 335225733..de2ecd1cb 100644 --- a/genkit-tools/cli/package.json +++ b/genkit-tools/cli/package.json @@ -1,6 +1,6 @@ { "name": "genkit-cli", - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "description": "CLI for interacting with the Google Genkit AI framework", "license": "Apache-2.0", "keywords": [ @@ -28,7 +28,7 @@ "dependencies": { "@genkit-ai/tools-common": "workspace:*", "@genkit-ai/telemetry-server": "workspace:*", - "axios": "^1.6.7", + "axios": "^1.7.7", "colorette": "^2.0.20", "commander": "^11.1.0", "extract-zip": "^2.0.1", diff --git a/genkit-tools/cli/src/commands/flow-run.ts b/genkit-tools/cli/src/commands/flow-run.ts index bf5ca8555..ca9a4aac0 100644 --- a/genkit-tools/cli/src/commands/flow-run.ts +++ b/genkit-tools/cli/src/commands/flow-run.ts @@ -18,7 +18,7 @@ import { FlowInvokeEnvelopeMessage, FlowState } from '@genkit-ai/tools-common'; import { logger } from '@genkit-ai/tools-common/utils'; import { Command } from 'commander'; import { writeFile } from 'fs/promises'; -import { runWithManager, waitForFlowToComplete } from '../utils/manager-utils'; +import { runWithManager } from '../utils/manager-utils'; interface FlowRunOptions { wait?: boolean; @@ -63,13 +63,7 @@ export const flowRun = new Command('flow:run') ) ).result as FlowState; - if (!state.operation.done && options.wait) { - logger.info('Started flow run, waiting for it to complete...'); - state = await waitForFlowToComplete(manager, flowName, state.flowId); - } - logger.info( - 'Flow operation:\n' + JSON.stringify(state.operation, undefined, ' ') - ); + logger.info('Flow response:\n' + JSON.stringify(state, undefined, ' ')); if (options.output && state.operation.result?.response) { await writeFile( diff --git a/genkit-tools/cli/src/commands/ui-start.ts b/genkit-tools/cli/src/commands/ui-start.ts index 8cbb613cb..79c74be06 100644 --- a/genkit-tools/cli/src/commands/ui-start.ts +++ b/genkit-tools/cli/src/commands/ui-start.ts @@ -88,7 +88,7 @@ export const uiStart = new Command('ui:start') logger.debug('No UI running. Starting a new one...'); } logger.info('Starting...'); - await startAndWaitUntilHealthy(port).catch((error) => { + await startAndWaitUntilHealthy(port, serversDir).catch((error) => { logger.error(`Failed to start Genkit Developer UI: ${error}`); return; }); @@ -122,15 +122,20 @@ export const uiStart = new Command('ui:start') /** * Starts the UI server in a child process and waits until it is healthy. Once it's healthy, the child process is detached. */ -async function startAndWaitUntilHealthy(port: number): Promise { +async function startAndWaitUntilHealthy( + port: number, + serversDir: string +): Promise { return new Promise((resolve, reject) => { const serverPath = path.join(__dirname, '../utils/server-harness.js'); - const child = spawn('node', [serverPath, port.toString()], { - stdio: ['ignore', 'pipe', 'pipe'], - }); + const child = spawn( + 'node', + [serverPath, port.toString(), serversDir + '/devui.log'], + { + stdio: ['ignore', 'ignore', 'ignore'], + } + ); // Only print out logs from the child process to debug output. - child.stdout.on('data', (data) => logger.debug(data)); - child.stderr.on('data', (data) => logger.debug(data)); child.on('error', (error) => reject(error)); child.on('exit', (code) => reject(new Error(`UI process exited (code ${code}) unexpectedly`)) diff --git a/genkit-tools/cli/src/utils/server-harness.ts b/genkit-tools/cli/src/utils/server-harness.ts index 247866ab7..24de70ad8 100644 --- a/genkit-tools/cli/src/utils/server-harness.ts +++ b/genkit-tools/cli/src/utils/server-harness.ts @@ -15,14 +15,39 @@ */ import { startServer } from '@genkit-ai/tools-common/server'; +import fs from 'fs'; import { startManager } from './manager-utils'; const args = process.argv.slice(2); const port = parseInt(args[0]) || 4100; +redirectStdoutToFile(args[1]); async function start() { const manager = await startManager(true); await startServer(manager, port); } +function redirectStdoutToFile(logFile: string) { + var myLogFileStream = fs.createWriteStream(logFile); + + var originalStdout = process.stdout.write; + function writeStdout() { + originalStdout.apply(process.stdout, arguments as any); + myLogFileStream.write.apply(myLogFileStream, arguments as any); + } + + process.stdout.write = writeStdout as any; + process.stderr.write = process.stdout.write; +} + +process.on('error', (error): void => { + console.log(`Error in tools process: ${error}`); +}); +process.on('uncaughtException', (err, somethingelse) => { + console.log(`Uncaught error in tools process: ${err} ${somethingelse}`); +}); +process.on('unhandledRejection', function (reason, p) { + console.log(`Unhandled rejection in tools process: ${reason}`); +}); + start(); diff --git a/genkit-tools/common/package.json b/genkit-tools/common/package.json index f77bbbca7..97fedae33 100644 --- a/genkit-tools/common/package.json +++ b/genkit-tools/common/package.json @@ -1,6 +1,6 @@ { "name": "@genkit-ai/tools-common", - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "scripts": { "compile": "tsc -b ./tsconfig.cjs.json ./tsconfig.esm.json ./tsconfig.types.json", "build:clean": "rimraf ./lib", @@ -12,7 +12,7 @@ "@asteasolutions/zod-to-openapi": "^7.0.0", "@trpc/server": "10.45.0", "adm-zip": "^0.5.12", - "axios": "^1.6.7", + "axios": "^1.7.7", "body-parser": "^1.20.2", "chokidar": "^3.5.3", "colorette": "^2.0.20", diff --git a/genkit-tools/common/src/eval/evaluate.ts b/genkit-tools/common/src/eval/evaluate.ts index 4e3f6088a..a0727a95b 100644 --- a/genkit-tools/common/src/eval/evaluate.ts +++ b/genkit-tools/common/src/eval/evaluate.ts @@ -186,7 +186,7 @@ export async function getMatchingEvaluatorActions( } const allEvaluatorActions = await getAllEvaluatorActions(manager); const filteredEvaluatorActions = allEvaluatorActions.filter((action) => - evaluators.includes(action.name) + evaluators.includes(action.key) ); if (filteredEvaluatorActions.length === 0) { if (allEvaluatorActions.length == 0) { @@ -216,11 +216,12 @@ async function bulkRunAction(params: { testCaseId: c.testCaseId ?? generateTestCaseId(), })); + let states: InferenceRunState[] = []; let evalInputs: EvalInput[] = []; for (const testCase of testCases) { - logger.info(`Running '${actionRef}' ...`); + logger.info(`Running inference '${actionRef}' ...`); if (isModelAction) { - evalInputs.push( + states.push( await runModelAction({ manager, actionRef, @@ -229,7 +230,7 @@ async function bulkRunAction(params: { }) ); } else { - evalInputs.push( + states.push( await runFlowAction({ manager, actionRef, @@ -239,6 +240,11 @@ async function bulkRunAction(params: { ); } } + + logger.info(`Gathering evalInputs...`); + for (const state of states) { + evalInputs.push(await gatherEvalInput({ manager, actionRef, state })); + } return evalInputs; } @@ -247,7 +253,7 @@ async function runFlowAction(params: { actionRef: string; testCase: TestCase; auth?: any; -}): Promise { +}): Promise { const { manager, actionRef, testCase, auth } = { ...params }; let state: InferenceRunState; try { @@ -274,7 +280,7 @@ async function runFlowAction(params: { evalError: `Error when running inference. Details: ${e?.message ?? e}`, }; } - return gatherEvalInput({ manager, actionRef, state }); + return state; } async function runModelAction(params: { @@ -282,7 +288,7 @@ async function runModelAction(params: { actionRef: string; testCase: TestCase; modelConfig?: any; -}): Promise { +}): Promise { const { manager, actionRef, modelConfig, testCase } = { ...params }; let state: InferenceRunState; try { @@ -304,7 +310,7 @@ async function runModelAction(params: { evalError: `Error when running inference. Details: ${e?.message ?? e}`, }; } - return gatherEvalInput({ manager, actionRef, state }); + return state; } async function gatherEvalInput(params: { diff --git a/genkit-tools/common/src/eval/localFileDatasetStore.ts b/genkit-tools/common/src/eval/localFileDatasetStore.ts index a0dbe2559..79f160873 100644 --- a/genkit-tools/common/src/eval/localFileDatasetStore.ts +++ b/genkit-tools/common/src/eval/localFileDatasetStore.ts @@ -14,7 +14,6 @@ * limitations under the License. */ -import crypto from 'crypto'; import fs from 'fs'; import { readFile, rm, writeFile } from 'fs/promises'; import path from 'path'; @@ -66,13 +65,13 @@ export class LocalFileDatasetStore implements DatasetStore { } async createDataset(req: CreateDatasetRequest): Promise { - return this.createDatasetInternal(req.data, req.datasetId); + return this.createDatasetInternal(req); } private async createDatasetInternal( - data: EvalInferenceInput, - datasetId?: string + req: CreateDatasetRequest ): Promise { + const { data, datasetId, schema, targetAction } = req; const id = await this.generateDatasetId(datasetId); const filePath = path.resolve(this.storeRoot, this.generateFileName(id)); @@ -87,8 +86,10 @@ export class LocalFileDatasetStore implements DatasetStore { await writeFile(filePath, JSON.stringify(dataset)); const now = new Date().toString(); - const metadata = { + const metadata: DatasetMetadata = { datasetId: id, + schema, + targetAction, size: dataset.length, version: 1, createTime: now, @@ -107,7 +108,7 @@ export class LocalFileDatasetStore implements DatasetStore { } async updateDataset(req: UpdateDatasetRequest): Promise { - const datasetId = req.datasetId; + const { datasetId, data, schema, targetAction } = req; const filePath = path.resolve( this.storeRoot, this.generateFileName(datasetId) @@ -121,7 +122,7 @@ export class LocalFileDatasetStore implements DatasetStore { if (!prevMetadata) { throw new Error(`Update dataset failed: dataset metadata not found`); } - const patch = this.getDatasetFromInferenceInput(req.data ?? []); + const patch = this.getDatasetFromInferenceInput(data ?? []); let newSize = prevMetadata.size; if (patch.length > 0) { logger.info(`Updating Dataset at ` + filePath); @@ -132,7 +133,9 @@ export class LocalFileDatasetStore implements DatasetStore { const newMetadata = { datasetId: datasetId, size: newSize, - version: prevMetadata.version + 1, + schema: schema ? schema : prevMetadata.schema, + targetAction: targetAction ? targetAction : prevMetadata.targetAction, + version: data ? prevMetadata.version + 1 : prevMetadata.version, createTime: prevMetadata.createTime, updateTime: now, }; @@ -190,11 +193,7 @@ export class LocalFileDatasetStore implements DatasetStore { } private static generateRootPath(): string { - const rootHash = crypto - .createHash('md5') - .update(process.cwd() || 'unknown') - .digest('hex'); - return path.resolve(process.cwd(), `.genkit/${rootHash}/datasets`); + return path.resolve(process.cwd(), `.genkit/datasets`); } /** Visible for testing */ @@ -230,7 +229,7 @@ export class LocalFileDatasetStore implements DatasetStore { return path.resolve(this.storeRoot, 'index.json'); } - private async getMetadataMap(): Promise { + private async getMetadataMap(): Promise> { if (!fs.existsSync(this.indexFile)) { return Promise.resolve({} as any); } diff --git a/genkit-tools/common/src/eval/localFileEvalStore.ts b/genkit-tools/common/src/eval/localFileEvalStore.ts index e9a1c44ea..bf4a543dc 100644 --- a/genkit-tools/common/src/eval/localFileEvalStore.ts +++ b/genkit-tools/common/src/eval/localFileEvalStore.ts @@ -14,10 +14,8 @@ * limitations under the License. */ -import crypto from 'crypto'; import fs from 'fs'; import { appendFile, readFile, writeFile } from 'fs/promises'; -import os from 'os'; import path from 'path'; import { logger } from '../utils/logger'; @@ -134,10 +132,6 @@ export class LocalFileEvalStore implements EvalStore { } private generateRootPath(): string { - const rootHash = crypto - .createHash('md5') - .update(process.cwd() || 'unknown') - .digest('hex'); - return path.resolve(os.tmpdir(), `.genkit/${rootHash}/evals`); + return path.resolve(process.cwd(), `.genkit/evals`); } } diff --git a/genkit-tools/common/src/manager/manager.ts b/genkit-tools/common/src/manager/manager.ts index 20c3c7117..3a33457ad 100644 --- a/genkit-tools/common/src/manager/manager.ts +++ b/genkit-tools/common/src/manager/manager.ts @@ -284,6 +284,10 @@ export class RuntimeManager { if (this.manageHealth) { watcher.on('unlink', (filePath) => this.handleRemovedRuntime(filePath)); } + // eagerly check existing runtimes on first load. + for (const runtime of await fs.readdir(runtimesDir)) { + await this.handleNewRuntime(path.resolve(runtimesDir, runtime)); + } } catch (error) { logger.error('Failed to set up runtimes watcher:', error); } diff --git a/genkit-tools/common/src/server/server.ts b/genkit-tools/common/src/server/server.ts index 8feb52b87..9863c4e1b 100644 --- a/genkit-tools/common/src/server/server.ts +++ b/genkit-tools/common/src/server/server.ts @@ -84,6 +84,7 @@ export async function startServer(manager: RuntimeManager, port: number) { }); app.post('/api/__quitquitquit', (_, res) => { + logger.info('Shutting down tools API'); res.status(200).send('Server is shutting down'); server.close(() => { process.exit(0); @@ -104,6 +105,10 @@ export async function startServer(manager: RuntimeManager, port: number) { }) ); + app.all('*', (_, res) => { + res.status(200).sendFile('/', { root: UI_ASSETS_SERVE_PATH }); + }); + const errorHandler: ErrorRequestHandler = ( error, request, @@ -119,11 +124,7 @@ export async function startServer(manager: RuntimeManager, port: number) { }; app.use(errorHandler); - app.all('*', (_, res) => { - res.status(200).sendFile('/', { root: UI_ASSETS_SERVE_PATH }); - }); - - app.listen(port, () => { + server = app.listen(port, () => { const uiUrl = 'http://localhost:' + port; logger.info(`${clc.green(clc.bold('Genkit Developer UI:'))} ${uiUrl}`); }); diff --git a/genkit-tools/common/src/types/apis.ts b/genkit-tools/common/src/types/apis.ts index 9e280e972..f1889b050 100644 --- a/genkit-tools/common/src/types/apis.ts +++ b/genkit-tools/common/src/types/apis.ts @@ -15,7 +15,11 @@ */ import { z } from 'zod'; -import { EvalInferenceInputSchema, EvalRunKeySchema } from './eval'; +import { + DatasetSchemaSchema, + EvalInferenceInputSchema, + EvalRunKeySchema, +} from './eval'; import { FlowStateSchema } from './flow'; import { GenerationCommonConfigSchema, @@ -135,13 +139,17 @@ export type GetEvalRunRequest = z.infer; export const CreateDatasetRequestSchema = z.object({ data: EvalInferenceInputSchema, datasetId: z.string().optional(), + schema: DatasetSchemaSchema.optional(), + targetAction: z.string().optional(), }); export type CreateDatasetRequest = z.infer; export const UpdateDatasetRequestSchema = z.object({ - data: EvalInferenceInputSchema.optional(), datasetId: z.string(), + data: EvalInferenceInputSchema.optional(), + schema: DatasetSchemaSchema.optional(), + targetAction: z.string().optional(), }); export type UpdateDatasetRequest = z.infer; diff --git a/genkit-tools/common/src/types/eval.ts b/genkit-tools/common/src/types/eval.ts index 0d0a0e89e..bcf3ff1fb 100644 --- a/genkit-tools/common/src/types/eval.ts +++ b/genkit-tools/common/src/types/eval.ts @@ -160,6 +160,17 @@ export interface EvalStore { list(query?: ListEvalKeysRequest): Promise; } +export const DatasetSchemaSchema = z.object({ + inputSchema: z + .record(z.any()) + .describe('Valid JSON Schema for the `input` field of dataset entry.') + .optional(), + referenceSchema: z + .record(z.any()) + .describe('Valid JSON Schema for the `reference` field of dataset entry.') + .optional(), +}); + /** * Metadata for Dataset objects containing version, create and update time, etc. */ @@ -167,6 +178,8 @@ export const DatasetMetadataSchema = z.object({ /** unique. user-provided or auto-generated */ datasetId: z.string(), size: z.number(), + schema: DatasetSchemaSchema.optional(), + targetAction: z.string().optional(), /** 1 for v1, 2 for v2, etc */ version: z.number(), createTime: z.string(), diff --git a/genkit-tools/common/src/utils/analytics.ts b/genkit-tools/common/src/utils/analytics.ts index 67190b88a..5cc1d34ed 100644 --- a/genkit-tools/common/src/utils/analytics.ts +++ b/genkit-tools/common/src/utils/analytics.ts @@ -292,7 +292,7 @@ async function recordInternal( if (!response.ok) { logger.warn(`Analytics validation HTTP error: ${response.status}`); } - const respBody = await response.text(); + const respBody = await response.text; logger.info(`Analytics validation result: ${respBody}`); } // response.ok / response.status intentionally ignored, see comment below. diff --git a/genkit-tools/common/src/utils/utils.ts b/genkit-tools/common/src/utils/utils.ts index eab8ed631..b36a3b58e 100644 --- a/genkit-tools/common/src/utils/utils.ts +++ b/genkit-tools/common/src/utils/utils.ts @@ -14,7 +14,6 @@ * limitations under the License. */ -import axios from 'axios'; import * as fs from 'fs/promises'; import * as path from 'path'; import { Runtime } from '../manager/types'; @@ -78,10 +77,13 @@ export async function detectRuntime(directory: string): Promise { */ export async function checkServerHealth(url: string): Promise { try { - const response = await axios.get(`${url}/api/__health`); + const response = await fetch(`${url}/api/__health`); return response.status === 200; } catch (error) { - if (axios.isAxiosError(error) && error.code === 'ECONNREFUSED') { + if ( + error instanceof Error && + (error.cause as any).code === 'ECONNREFUSED' + ) { return false; } } @@ -98,7 +100,7 @@ export async function waitUntilHealthy( const startTime = Date.now(); while (Date.now() - startTime < maxTimeout) { try { - const response = await axios.get(`${url}/api/__health`); + const response = await fetch(`${url}/api/__health`); if (response.status === 200) { return true; } @@ -120,9 +122,12 @@ export async function waitUntilUnresponsive( const startTime = Date.now(); while (Date.now() - startTime < maxTimeout) { try { - await axios.get(`${url}/api/__health`); + const health = await fetch(`${url}/api/__health`); } catch (error) { - if (axios.isAxiosError(error) && error.code === 'ECONNREFUSED') { + if ( + error instanceof Error && + (error.cause as any).code === 'ECONNREFUSED' + ) { return true; } } diff --git a/genkit-tools/common/tests/eval/localFileDatasetStore_test.ts b/genkit-tools/common/tests/eval/localFileDatasetStore_test.ts index fa404cc71..c658208c8 100644 --- a/genkit-tools/common/tests/eval/localFileDatasetStore_test.ts +++ b/genkit-tools/common/tests/eval/localFileDatasetStore_test.ts @@ -89,6 +89,21 @@ const CREATE_DATASET_REQUEST = CreateDatasetRequestSchema.parse({ data: { samples: SAMPLE_DATASET_1_V1 }, }); +const CREATE_DATASET_REQUEST_WITH_SCHEMA = CreateDatasetRequestSchema.parse({ + data: { samples: SAMPLE_DATASET_1_V1 }, + schema: { + inputSchema: { + type: 'string', + $schema: 'http://json-schema.org/draft-07/schema#', + }, + referenceSchema: { + type: 'number', + $schema: 'http://json-schema.org/draft-07/schema#', + }, + }, + targetAction: '/flow/my-flow', +}); + const UPDATE_DATASET_REQUEST = UpdateDatasetRequestSchema.parse({ data: { samples: SAMPLE_DATASET_1_V2 }, datasetId: SAMPLE_DATASET_ID_1, @@ -104,11 +119,9 @@ const SAMPLE_DATASET_METADATA_2 = { updateTime: FAKE_TIME.toString(), }; -jest.mock('crypto', () => { +jest.mock('process', () => { return { - createHash: jest.fn().mockReturnThis(), - update: jest.fn().mockReturnThis(), - digest: jest.fn(() => 'store-root'), + cwd: jest.fn(() => 'store-root'), }; }); @@ -209,6 +222,37 @@ describe('localFileDatasetStore', () => { expect(datasetMetadata).toMatchObject(SAMPLE_DATASET_METADATA_1_V1); }); + it('creates new dataset, with schema', async () => { + fs.promises.writeFile = jest.fn(async () => Promise.resolve(undefined)); + fs.promises.appendFile = jest.fn(async () => Promise.resolve(undefined)); + // For index file reads + fs.promises.readFile = jest.fn(async () => + Promise.resolve(JSON.stringify({}) as any) + ); + fs.existsSync = jest.fn(() => false); + const dataset: Dataset = SAMPLE_DATASET_1_V1.map((s) => ({ + testCaseId: TEST_CASE_ID, + ...s, + })); + + const datasetMetadata = await DatasetStore.createDataset({ + ...CREATE_DATASET_REQUEST_WITH_SCHEMA, + datasetId: SAMPLE_DATASET_ID_1, + }); + + expect(datasetMetadata.schema).toMatchObject({ + inputSchema: { + type: 'string', + $schema: 'http://json-schema.org/draft-07/schema#', + }, + referenceSchema: { + type: 'number', + $schema: 'http://json-schema.org/draft-07/schema#', + }, + }); + expect(datasetMetadata.targetAction).toEqual('/flow/my-flow'); + }); + it('fails request if dataset already exists', async () => { fs.existsSync = jest.fn(() => true); @@ -355,6 +399,39 @@ describe('localFileDatasetStore', () => { expect(datasetMetadata).toMatchObject(SAMPLE_DATASET_METADATA_1_V2); }); + it('succeeds for existing dataset -- with schema', async () => { + fs.existsSync = jest.fn(() => true); + let metadataMap = { + [SAMPLE_DATASET_ID_1]: SAMPLE_DATASET_METADATA_1_V1, + [SAMPLE_DATASET_ID_2]: SAMPLE_DATASET_METADATA_2, + }; + // For index file reads + fs.promises.readFile = jest.fn(async () => + Promise.resolve(JSON.stringify(metadataMap) as any) + ); + fs.promises.writeFile = jest.fn(async () => Promise.resolve(undefined)); + fs.promises.appendFile = jest.fn(async () => Promise.resolve(undefined)); + + const datasetMetadata = await DatasetStore.updateDataset({ + datasetId: SAMPLE_DATASET_ID_1, + schema: { + inputSchema: { + type: 'string', + $schema: 'http://json-schema.org/draft-07/schema#', + }, + }, + targetAction: '/flow/my-flow-2', + }); + + expect(datasetMetadata.schema).toMatchObject({ + inputSchema: { + type: 'string', + $schema: 'http://json-schema.org/draft-07/schema#', + }, + }); + expect(datasetMetadata.targetAction).toEqual('/flow/my-flow-2'); + }); + it('fails for non existing dataset', async () => { fs.existsSync = jest.fn(() => false); diff --git a/genkit-tools/common/tests/eval/localFileEvalStore_test.ts b/genkit-tools/common/tests/eval/localFileEvalStore_test.ts index 72fb8758e..4fc403dc1 100644 --- a/genkit-tools/common/tests/eval/localFileEvalStore_test.ts +++ b/genkit-tools/common/tests/eval/localFileEvalStore_test.ts @@ -26,18 +26,6 @@ import fs from 'fs'; import { LocalFileEvalStore } from '../../src/eval/localFileEvalStore'; import { EvalResult, EvalRunSchema, EvalStore } from '../../src/types/eval'; -jest.mock('crypto', () => { - return { - createHash: jest.fn().mockReturnThis(), - update: jest.fn().mockReturnThis(), - digest: jest.fn(() => 'store-root'), - }; -}); - -jest.mock('os', () => { - return { tmpdir: jest.fn(() => '/tmp/') }; -}); - const EVAL_RESULTS: EvalResult[] = [ { testCaseId: 'alakjdshfalsdkjh', @@ -107,6 +95,8 @@ describe('localFileEvalStore', () => { let evalStore: EvalStore; beforeEach(() => { + // For storeRoot setup + fs.existsSync = jest.fn(() => true); LocalFileEvalStore.reset(); evalStore = LocalFileEvalStore.getEvalStore() as EvalStore; }); @@ -125,11 +115,11 @@ describe('localFileEvalStore', () => { await evalStore.save(EVAL_RUN_WITH_ACTION); expect(fs.promises.writeFile).toHaveBeenCalledWith( - `/tmp/.genkit/store-root/evals/abc1234.json`, + expect.stringContaining(`evals/abc1234.json`), JSON.stringify(EVAL_RUN_WITH_ACTION) ); expect(fs.promises.appendFile).toHaveBeenCalledWith( - `/tmp/.genkit/store-root/evals/index.txt`, + expect.stringContaining(`evals/index.txt`), JSON.stringify(EVAL_RUN_WITH_ACTION.key) + '\n' ); }); @@ -138,11 +128,11 @@ describe('localFileEvalStore', () => { await evalStore.save(EVAL_RUN_WITHOUT_ACTION); expect(fs.promises.writeFile).toHaveBeenCalledWith( - `/tmp/.genkit/store-root/evals/def456.json`, + expect.stringContaining(`evals/def456.json`), JSON.stringify(EVAL_RUN_WITHOUT_ACTION) ); expect(fs.promises.appendFile).toHaveBeenCalledWith( - `/tmp/.genkit/store-root/evals/index.txt`, + expect.stringContaining(`evals/index.txt`), JSON.stringify(EVAL_RUN_WITHOUT_ACTION.key) + '\n' ); }); diff --git a/genkit-tools/package.json b/genkit-tools/package.json index 09833d453..641914d51 100644 --- a/genkit-tools/package.json +++ b/genkit-tools/package.json @@ -23,5 +23,5 @@ "zod": "^3.22.4", "zod-to-json-schema": "^3.22.4" }, - "packageManager": "pnpm@9.12.0+sha256.a61b67ff6cc97af864564f4442556c22a04f2e5a7714fbee76a1011361d9b726" + "packageManager": "pnpm@9.12.2+sha256.2ef6e547b0b07d841d605240dce4d635677831148cd30f6d564b8f4f928f73d2" } diff --git a/genkit-tools/pnpm-lock.yaml b/genkit-tools/pnpm-lock.yaml index 7c873c3a8..c8fe635f0 100644 --- a/genkit-tools/pnpm-lock.yaml +++ b/genkit-tools/pnpm-lock.yaml @@ -36,8 +36,8 @@ importers: specifier: workspace:* version: link:../common axios: - specifier: ^1.6.7 - version: 1.6.8 + specifier: ^1.7.7 + version: 1.7.7 colorette: specifier: ^2.0.20 version: 2.0.20 @@ -97,8 +97,8 @@ importers: specifier: ^0.5.12 version: 0.5.12 axios: - specifier: ^1.6.7 - version: 1.6.8 + specifier: ^1.7.7 + version: 1.7.7 body-parser: specifier: ^1.20.2 version: 1.20.2 @@ -1179,8 +1179,8 @@ packages: resolution: {integrity: sha512-wvUjBtSGN7+7SjNpq/9M2Tg350UZD3q62IFZLbRAR1bSMlCo1ZaeW+BJ+D090e4hIIZLBcTDWe4Mh4jvUDajzQ==} engines: {node: '>= 0.4'} - axios@1.6.8: - resolution: {integrity: sha512-v/ZHtJDU39mDpyBoFVkETcd/uNdxrWRrg3bKpOKzXFA6Bvqopts6ALSMU3y6ijYxbw2B+wPrIv46egTzJXCLGQ==} + axios@1.7.7: + resolution: {integrity: sha512-S4kL7XrjgBmvdGut0sN3yJxqYzrDOnivkBiN0OFs6hLiUam3UPvswUo0kqGyhqUZGEOytHyumEdXsAkgCOUf3Q==} babel-jest@29.7.0: resolution: {integrity: sha512-BrvGY3xZSwEcCzKvKsCi2GgHqDqsYkOP4/by5xCgIwGXQxIEh+8ew3gmrE1y7XRR6LHZIj6yLYnUi/mm2KXKBg==} @@ -4201,7 +4201,7 @@ snapshots: dependencies: possible-typed-array-names: 1.0.0 - axios@1.6.8: + axios@1.7.7: dependencies: follow-redirects: 1.15.6 form-data: 4.0.0 diff --git a/genkit-tools/telemetry-server/package.json b/genkit-tools/telemetry-server/package.json index 3665f657a..a49050aa1 100644 --- a/genkit-tools/telemetry-server/package.json +++ b/genkit-tools/telemetry-server/package.json @@ -7,7 +7,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "compile": "tsc -b ./tsconfig.cjs.json ./tsconfig.esm.json ./tsconfig.types.json", diff --git a/js/ai/package.json b/js/ai/package.json index 047817be9..e9e6d2cde 100644 --- a/js/ai/package.json +++ b/js/ai/package.json @@ -7,7 +7,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/ai/src/document.ts b/js/ai/src/document.ts index d9ff71f11..70ed10f17 100644 --- a/js/ai/src/document.ts +++ b/js/ai/src/document.ts @@ -70,7 +70,7 @@ export class Document implements DocumentData { * Concatenates all `text` parts present in the document with no delimiter. * @returns A string of all concatenated text parts. */ - text(): string { + get text(): string { return this.content.map((part) => part.text || '').join(''); } @@ -79,7 +79,7 @@ export class Document implements DocumentData { * (for example) an image. * @returns The first detected `media` part in the document. */ - media(): { url: string; contentType?: string } | null { + get media(): { url: string; contentType?: string } | null { return this.content.find((part) => part.media)?.media || null; } diff --git a/js/ai/src/embedder.ts b/js/ai/src/embedder.ts index 962ac24dd..89b050253 100644 --- a/js/ai/src/embedder.ts +++ b/js/ai/src/embedder.ts @@ -15,7 +15,7 @@ */ import { Action, defineAction, z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { Document, DocumentData, DocumentDataSchema } from './document.js'; export type EmbeddingBatch = { embedding: number[] }[]; @@ -23,7 +23,7 @@ export type EmbeddingBatch = { embedding: number[] }[]; export const EmbeddingSchema = z.array(z.number()); export type Embedding = z.infer; -type EmbedderFn = ( +export type EmbedderFn = ( input: Document[], embedderOpts?: z.infer ) => Promise; @@ -68,6 +68,7 @@ function withMetadata( export function defineEmbedder< ConfigSchema extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, options: { name: string; configSchema?: ConfigSchema; @@ -76,6 +77,7 @@ export function defineEmbedder< runner: EmbedderFn ) { const embedder = defineAction( + registry, { actionType: 'embedder', name: options.name, @@ -111,47 +113,91 @@ export type EmbedderArgument< * A veneer for interacting with embedder models. */ export async function embed( + registry: Registry, params: EmbedderParams ): Promise { - let embedder: EmbedderAction; - if (typeof params.embedder === 'string') { - embedder = await lookupAction(`/embedder/${params.embedder}`); - } else if (Object.hasOwnProperty.call(params.embedder, 'info')) { - embedder = await lookupAction( - `/embedder/${(params.embedder as EmbedderReference).name}` - ); - } else { - embedder = params.embedder as EmbedderAction; - } - if (!embedder) { - throw new Error('Unable to utilize the provided embedder'); + let embedder = await resolveEmbedder(registry, params); + if (!embedder.embedderAction) { + let embedderId: string; + if (typeof params.embedder === 'string') { + embedderId = params.embedder; + } else if ((params.embedder as EmbedderAction)?.__action?.name) { + embedderId = (params.embedder as EmbedderAction).__action.name; + } else { + embedderId = (params.embedder as EmbedderReference).name; + } + throw new Error(`Unable to resolve embedder ${embedderId}`); } - const response = await embedder({ + const response = await embedder.embedderAction({ input: typeof params.content === 'string' ? [Document.fromText(params.content, params.metadata)] : [params.content], - options: params.options, + options: { + version: embedder.version, + ...embedder.config, + ...params.options, + }, }); return response.embeddings[0].embedding; } +interface ResolvedEmbedder { + embedderAction: EmbedderAction; + config?: z.infer; + version?: string; +} + +async function resolveEmbedder< + CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, +>( + registry: Registry, + params: EmbedderParams +): Promise> { + if (typeof params.embedder === 'string') { + return { + embedderAction: await registry.lookupAction( + `/embedder/${params.embedder}` + ), + }; + } else if (Object.hasOwnProperty.call(params.embedder, '__action')) { + return { + embedderAction: params.embedder as EmbedderAction, + }; + } else if (Object.hasOwnProperty.call(params.embedder, 'name')) { + const ref = params.embedder as EmbedderReference; + return { + embedderAction: await registry.lookupAction( + `/embedder/${(params.embedder as EmbedderReference).name}` + ), + config: { + ...ref.config, + }, + version: ref.version, + }; + } + throw new Error(`failed to resolve embedder ${params.embedder}`); +} + /** * A veneer for interacting with embedder models in bulk. */ export async function embedMany< ConfigSchema extends z.ZodTypeAny = z.ZodTypeAny, ->(params: { - embedder: EmbedderArgument; - content: string[] | DocumentData[]; - metadata?: Record; - options?: z.infer; -}): Promise { +>( + registry: Registry, + params: { + embedder: EmbedderArgument; + content: string[] | DocumentData[]; + metadata?: Record; + options?: z.infer; + } +): Promise { let embedder: EmbedderAction; if (typeof params.embedder === 'string') { - embedder = await lookupAction(`/embedder/${params.embedder}`); + embedder = await registry.lookupAction(`/embedder/${params.embedder}`); } else if (Object.hasOwnProperty.call(params.embedder, 'info')) { - embedder = await lookupAction( + embedder = await registry.lookupAction( `/embedder/${(params.embedder as EmbedderReference).name}` ); } else { @@ -192,6 +238,8 @@ export interface EmbedderReference< name: string; configSchema?: CustomOptions; info?: EmbedderInfo; + config?: z.infer; + version?: string; } /** diff --git a/js/ai/src/evaluator.ts b/js/ai/src/evaluator.ts index 1e2db1084..02be11e48 100644 --- a/js/ai/src/evaluator.ts +++ b/js/ai/src/evaluator.ts @@ -16,7 +16,7 @@ import { Action, defineAction, z } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { SPAN_TYPE_ATTR, runInNewSpan } from '@genkit-ai/core/tracing'; import { randomUUID } from 'crypto'; @@ -73,7 +73,7 @@ export type EvalResponse = z.infer; export const EvalResponsesSchema = z.array(EvalResponseSchema); export type EvalResponses = z.infer; -type EvaluatorFn< +export type EvaluatorFn< EvalDataPoint extends typeof BaseEvalDataPointSchema = typeof BaseEvalDataPointSchema, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, @@ -127,6 +127,7 @@ export function defineEvaluator< typeof BaseEvalDataPointSchema = typeof BaseEvalDataPointSchema, EvaluatorOptions extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, options: { name: string; displayName: string; @@ -143,6 +144,7 @@ export function defineEvaluator< metadata[EVALUATOR_METADATA_KEY_DISPLAY_NAME] = options.displayName; metadata[EVALUATOR_METADATA_KEY_DEFINITION] = options.definition; const evaluator = defineAction( + registry, { actionType: 'evaluator', name: options.name, @@ -239,12 +241,17 @@ export type EvaluatorArgument< export async function evaluate< DataPoint extends typeof BaseDataPointSchema = typeof BaseDataPointSchema, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, ->(params: EvaluatorParams): Promise { +>( + registry: Registry, + params: EvaluatorParams +): Promise { let evaluator: EvaluatorAction; if (typeof params.evaluator === 'string') { - evaluator = await lookupAction(`/evaluator/${params.evaluator}`); + evaluator = await registry.lookupAction(`/evaluator/${params.evaluator}`); } else if (Object.hasOwnProperty.call(params.evaluator, 'info')) { - evaluator = await lookupAction(`/evaluator/${params.evaluator.name}`); + evaluator = await registry.lookupAction( + `/evaluator/${params.evaluator.name}` + ); } else { evaluator = params.evaluator as EvaluatorAction; } diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 44af678e3..cf5231520 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -21,7 +21,7 @@ import { StreamingCallback, z, } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema'; import { DocumentData } from './document.js'; import { extractJson } from './extract.js'; @@ -70,8 +70,8 @@ export class Message implements MessageData { * * @returns The structured output contained in the message. */ - output(): T { - return this.data() || extractJson(this.text()); + get output(): T { + return this.data || extractJson(this.text); } toolResponseParts(): ToolResponsePart[] { @@ -83,7 +83,7 @@ export class Message implements MessageData { * Concatenates all `text` parts present in the message with no delimiter. * @returns A string of all concatenated text parts. */ - text(): string { + get text(): string { return this.content.map((part) => part.text || '').join(''); } @@ -92,7 +92,7 @@ export class Message implements MessageData { * (for example) an image from a generation expected to create one. * @returns The first detected `media` part in the message. */ - media(): { url: string; contentType?: string } | null { + get media(): { url: string; contentType?: string } | null { return this.content.find((part) => part.media)?.media || null; } @@ -100,7 +100,7 @@ export class Message implements MessageData { * Returns the first detected `data` part of a message. * @returns The first `data` part detected in the message (if any). */ - data(): T | null { + get data(): T | null { return this.content.find((part) => part.data)?.data as T | null; } @@ -108,7 +108,7 @@ export class Message implements MessageData { * Returns all tool request found in this message. * @returns Array of all tool request found in this message. */ - toolRequests(): ToolRequestPart[] { + get toolRequests(): ToolRequestPart[] { return this.content.filter( (part) => !!part.toolRequest ) as ToolRequestPart[]; @@ -187,7 +187,7 @@ export class GenerateResponse implements ModelResponseData { } if (request?.output?.schema || this.request?.output?.schema) { - const o = this.output(); + const o = this.output; parseSchema(o, { jsonSchema: request?.output?.schema || this.request?.output?.schema, }); @@ -211,8 +211,8 @@ export class GenerateResponse implements ModelResponseData { * @param index The candidate index from which to extract output. If not provided, finds first candidate that conforms to output schema. * @returns The structured output contained in the selected candidate. */ - output(): O | null { - return this.message?.output() || null; + get output(): O | null { + return this.message?.output || null; } /** @@ -220,8 +220,8 @@ export class GenerateResponse implements ModelResponseData { * @param index The candidate index from which to extract text, defaults to first candidate. * @returns A string of all concatenated text parts. */ - text(): string { - return this.message?.text() || ''; + get text(): string { + return this.message?.text || ''; } /** @@ -230,8 +230,8 @@ export class GenerateResponse implements ModelResponseData { * @param index The candidate index from which to extract media, defaults to first candidate. * @returns The first detected `media` part in the candidate. */ - media(): { url: string; contentType?: string } | null { - return this.message?.media() || null; + get media(): { url: string; contentType?: string } | null { + return this.message?.media || null; } /** @@ -239,8 +239,8 @@ export class GenerateResponse implements ModelResponseData { * @param index The candidate index from which to extract data, defaults to first candidate. * @returns The first `data` part detected in the candidate (if any). */ - data(): O | null { - return this.message?.data() || null; + get data(): O | null { + return this.message?.data || null; } /** @@ -248,8 +248,8 @@ export class GenerateResponse implements ModelResponseData { * @param index The candidate index from which to extract tool requests, defaults to first candidate. * @returns Array of all tool request found in the candidate. */ - toolRequests(): ToolRequestPart[] { - return this.message?.toolRequests() || []; + get toolRequests(): ToolRequestPart[] { + return this.message?.toolRequests || []; } /** @@ -259,7 +259,7 @@ export class GenerateResponse implements ModelResponseData { * @param index The candidate index to utilize during conversion, defaults to first candidate. * @returns A serializable list of messages compatible with `generate({history})`. */ - toHistory(): MessageData[] { + get messages(): MessageData[] { if (!this.request) throw new Error( "Can't construct history for response without request reference." @@ -271,6 +271,10 @@ export class GenerateResponse implements ModelResponseData { return [...this.request?.messages, this.message.toJSON()]; } + get raw(): unknown { + return this.raw ?? this.custom; + } + toJSON(): ModelResponseData { const out = { message: this.message?.toJSON(), @@ -312,7 +316,7 @@ export class GenerateResponseChunk * Concatenates all `text` parts present in the chunk with no delimiter. * @returns A string of all concatenated text parts. */ - text(): string { + get text(): string { return this.content.map((part) => part.text || '').join(''); } @@ -321,7 +325,7 @@ export class GenerateResponseChunk * (for example) an image from a generation expected to create one. * @returns The first detected `media` part in the chunk. */ - media(): { url: string; contentType?: string } | null { + get media(): { url: string; contentType?: string } | null { return this.content.find((part) => part.media)?.media || null; } @@ -329,7 +333,7 @@ export class GenerateResponseChunk * Returns the first detected `data` part of a chunk. * @returns The first `data` part detected in the chunk (if any). */ - data(): T | null { + get data(): T | null { return this.content.find((part) => part.data)?.data as T | null; } @@ -337,7 +341,7 @@ export class GenerateResponseChunk * Returns all tool request found in this chunk. * @returns Array of all tool request found in this chunk. */ - toolRequests(): ToolRequestPart[] { + get toolRequests(): ToolRequestPart[] { return this.content.filter( (part) => !!part.toolRequest ) as ToolRequestPart[]; @@ -347,7 +351,7 @@ export class GenerateResponseChunk * Attempts to extract the longest valid JSON substring from the accumulated chunks. * @returns The longest valid JSON substring found in the accumulated chunks. */ - output(): T | null { + get output(): T | null { if (!this.accumulatedChunks) return null; const accumulatedText = this.accumulatedChunks .map((chunk) => chunk.content.map((part) => part.text || '').join('')) @@ -361,9 +365,26 @@ export class GenerateResponseChunk } export async function toGenerateRequest( + registry: Registry, options: GenerateOptions ): Promise { - const messages: MessageData[] = [...(options.messages || [])]; + const messages: MessageData[] = []; + if (options.system) { + const systemMessage: MessageData = { role: 'system', content: [] }; + if (typeof options.system === 'string') { + systemMessage.content.push({ text: options.system }); + } else if (Array.isArray(options.system)) { + systemMessage.role = inferRoleFromParts(options.system); + systemMessage.content.push(...(options.system as Part[])); + } else { + systemMessage.role = inferRoleFromParts([options.system]); + systemMessage.content.push(options.system); + } + messages.push(systemMessage); + } + if (options.messages) { + messages.push(...options.messages); + } if (options.prompt) { const promptMessage: MessageData = { role: 'user', content: [] }; if (typeof options.prompt === 'string') { @@ -382,13 +403,13 @@ export async function toGenerateRequest( } let tools: Action[] | undefined; if (options.tools) { - tools = await resolveTools(options.tools); + tools = await resolveTools(registry, options.tools); } const out = { messages, config: options.config, - context: options.context, + docs: options.docs, tools: tools?.map((tool) => toToolDefinition(tool)) || [], output: { format: @@ -412,11 +433,13 @@ export interface GenerateOptions< > { /** A model name (e.g. `vertexai/gemini-1.0-pro`) or reference. */ model?: ModelArgument; + /** The system prompt to be included in the generate request. Can be a string for a simple text prompt or one or more parts for multi-modal prompts (subject to model support). */ + system?: string | Part | Part[]; /** The prompt for which to generate a response. Can be a string for a simple text prompt or one or more parts for multi-modal prompts. */ prompt?: string | Part | Part[]; /** Retrieved documents to be used as context for this generation. */ - context?: DocumentData[]; - /** Conversation history for multi-turn prompting when supported by the underlying model. */ + docs?: DocumentData[]; + /** Conversation messages (history) for multi-turn prompting when supported by the underlying model. */ messages?: MessageData[]; /** List of registered tool names or actions to treat as a tool for this generation if supported by the underlying model. */ tools?: ToolArgument[]; @@ -436,18 +459,39 @@ export interface GenerateOptions< use?: ModelMiddleware[]; } -async function resolveModel(options: GenerateOptions): Promise { +interface ResolvedModel { + modelAction: ModelAction; + config?: z.infer; + version?: string; +} + +async function resolveModel( + registry: Registry, + options: GenerateOptions +): Promise { let model = options.model; if (!model) { - throw new Error('Unable to resolve model.'); + throw new Error('Model is required.'); } if (typeof model === 'string') { - return (await lookupAction(`/model/${model}`)) as ModelAction; - } else if (model.hasOwnProperty('info')) { - const ref = model as ModelReference; - return (await lookupAction(`/model/${ref.name}`)) as ModelAction; + return { + modelAction: (await registry.lookupAction( + `/model/${model}` + )) as ModelAction, + }; + } else if (model.hasOwnProperty('__action')) { + return { modelAction: model as ModelAction }; } else { - return model as ModelAction; + const ref = model as ModelReference; + return { + modelAction: (await registry.lookupAction( + `/model/${ref.name}` + )) as ModelAction, + config: { + ...ref.config, + }, + version: ref.version, + }; } } @@ -489,21 +533,23 @@ export async function generate< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, >( + registry: Registry, options: | GenerateOptions | PromiseLike> ): Promise>> { const resolvedOptions: GenerateOptions = await Promise.resolve(options); - const model = await resolveModel(resolvedOptions); + const resolvedModel = await resolveModel(registry, resolvedOptions); + const model = resolvedModel.modelAction; if (!model) { let modelId: string; if (typeof resolvedOptions.model === 'string') { modelId = resolvedOptions.model; - } else if ((resolvedOptions.model as ModelReference).name) { - modelId = (resolvedOptions.model as ModelReference).name; - } else { + } else if ((resolvedOptions.model as ModelAction)?.__action?.name) { modelId = (resolvedOptions.model as ModelAction).__action.name; + } else { + modelId = (resolvedOptions.model as ModelReference).name; } throw new Error(`Model ${modelId} not found`); } @@ -525,13 +571,51 @@ export async function generate< }); } + const messages: MessageData[] = []; + if (resolvedOptions.system) { + const systemMessage: MessageData = { role: 'system', content: [] }; + if (typeof resolvedOptions.system === 'string') { + systemMessage.content.push({ text: resolvedOptions.system }); + } else if (Array.isArray(resolvedOptions.system)) { + systemMessage.role = inferRoleFromParts(resolvedOptions.system); + systemMessage.content.push(...(resolvedOptions.system as Part[])); + } else { + systemMessage.role = inferRoleFromParts([resolvedOptions.system]); + systemMessage.content.push(resolvedOptions.system); + } + messages.push(systemMessage); + } + if (resolvedOptions.messages) { + messages.push(...resolvedOptions.messages); + } + if (resolvedOptions.prompt) { + const promptMessage: MessageData = { role: 'user', content: [] }; + if (typeof resolvedOptions.prompt === 'string') { + promptMessage.content.push({ text: resolvedOptions.prompt }); + } else if (Array.isArray(resolvedOptions.prompt)) { + promptMessage.role = inferRoleFromParts(resolvedOptions.prompt); + promptMessage.content.push(...(resolvedOptions.prompt as Part[])); + } else { + promptMessage.role = inferRoleFromParts([resolvedOptions.prompt]); + promptMessage.content.push(resolvedOptions.prompt); + } + messages.push(promptMessage); + } + + if (messages.length === 0) { + throw new Error('at least one message is required in generate request'); + } + const params: z.infer = { model: model.__action.name, - prompt: resolvedOptions.prompt, - context: resolvedOptions.context, - messages: resolvedOptions.messages, + docs: resolvedOptions.docs, + messages, tools, - config: resolvedOptions.config, + config: { + version: resolvedModel.version, + ...stripUndefinedOptions(resolvedModel.config), + ...stripUndefinedOptions(resolvedOptions.config), + }, output: resolvedOptions.output && { format: resolvedOptions.output.format, jsonSchema: resolvedOptions.output.schema @@ -548,12 +632,23 @@ export async function generate< resolvedOptions.streamingCallback, async () => new GenerateResponse( - await generateHelper(params, resolvedOptions.use), - await toGenerateRequest(resolvedOptions) + await generateHelper(registry, params, resolvedOptions.use), + await toGenerateRequest(registry, resolvedOptions) ) ); } +function stripUndefinedOptions(input?: any): any { + if (!input) return input; + const copy = { ...input }; + Object.keys(input).forEach((key) => { + if (copy[key] === undefined) { + delete copy[key]; + } + }); + return copy; +} + export type GenerateStreamOptions< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, @@ -578,6 +673,7 @@ export async function generateStream< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, >( + registry: Registry, options: | GenerateOptions | PromiseLike> @@ -603,7 +699,7 @@ export async function generateStream< } try { - generate({ + generate(registry, { ...options, streamingCallback: (chunk) => { firstChunkSent = true; diff --git a/js/ai/src/generateAction.ts b/js/ai/src/generateAction.ts index dfe2eda22..7cb4f2d71 100644 --- a/js/ai/src/generateAction.ts +++ b/js/ai/src/generateAction.ts @@ -21,7 +21,7 @@ import { runWithStreamingCallback, z, } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema, validateSchema } from '@genkit-ai/core/schema'; import { runInNewSpan, SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing'; import * as clc from 'colorette'; @@ -33,11 +33,9 @@ import { GenerateRequestSchema, GenerateResponseChunkData, GenerateResponseData, - MessageData, MessageSchema, ModelAction, Part, - PartSchema, Role, ToolDefinitionSchema, ToolResponsePart, @@ -47,12 +45,10 @@ import { ToolAction, toToolDefinition } from './tool.js'; export const GenerateUtilParamSchema = z.object({ /** A model name (e.g. `vertexai/gemini-1.0-pro`). */ model: z.string(), - /** The prompt for which to generate a response. Can be a string for a simple text prompt or one or more parts for multi-modal prompts. */ - prompt: z.union([z.string(), PartSchema, z.array(PartSchema)]).optional(), /** Retrieved documents to be used as context for this generation. */ - context: z.array(DocumentDataSchema).optional(), + docs: z.array(DocumentDataSchema).optional(), /** Conversation history for multi-turn prompting when supported by the underlying model. */ - messages: z.array(MessageSchema).optional(), + messages: z.array(MessageSchema), /** List of registered tool names for this generation if supported by the underlying model. */ tools: z.array(z.union([z.string(), ToolDefinitionSchema])).optional(), /** Configuration for the generation request. */ @@ -74,6 +70,7 @@ export const GenerateUtilParamSchema = z.object({ * Encapsulates all generate logic. This is similar to `generateAction` except not an action and can take middleware. */ export async function generateHelper( + registry: Registry, input: z.infer, middleware?: Middleware[] ): Promise { @@ -90,7 +87,7 @@ export async function generateHelper( async (metadata) => { metadata.name = 'generate'; metadata.input = input; - const output = await generate(input, middleware); + const output = await generate(registry, input, middleware); metadata.output = JSON.stringify(output); return output; } @@ -98,10 +95,11 @@ export async function generateHelper( } async function generate( + registry: Registry, rawRequest: z.infer, middleware?: Middleware[] ): Promise { - const model = (await lookupAction( + const model = (await registry.lookupAction( `/model/${rawRequest.model}` )) as ModelAction; if (!model) { @@ -124,7 +122,7 @@ async function generate( tools = await Promise.all( rawRequest.tools.map(async (toolRef) => { if (typeof toolRef === 'string') { - const tool = (await lookupAction(toolRef)) as ToolAction; + const tool = (await registry.lookupAction(toolRef)) as ToolAction; if (!tool) { throw new Error(`Tool ${toolRef} not found`); } @@ -207,35 +205,17 @@ async function generate( messages: [...request.messages, message], prompt: toolResponses, }; - return await generateHelper(nextRequest, middleware); + return await generateHelper(registry, nextRequest, middleware); } async function actionToGenerateRequest( options: z.infer, resolvedTools?: ToolAction[] ): Promise { - const messages: MessageData[] = [...(options.messages || [])]; - if (options.prompt) { - const promptMessage: MessageData = { role: 'user', content: [] }; - if (typeof options.prompt === 'string') { - promptMessage.content.push({ text: options.prompt }); - } else if (Array.isArray(options.prompt)) { - promptMessage.role = inferRoleFromParts(options.prompt); - promptMessage.content.push(...(options.prompt as Part[])); - } else { - promptMessage.role = inferRoleFromParts([options.prompt]); - promptMessage.content.push(options.prompt); - } - messages.push(promptMessage); - } - if (messages.length === 0) { - throw new Error('at least one message is required in generate request'); - } - const out = { - messages, + messages: options.messages, config: options.config, - context: options.context, + docs: options.docs, tools: resolvedTools?.map((tool) => toToolDefinition(tool)) || [], output: { format: diff --git a/js/ai/src/index.ts b/js/ai/src/index.ts index 06dbbc98a..b8236e3b4 100644 --- a/js/ai/src/index.ts +++ b/js/ai/src/index.ts @@ -72,9 +72,11 @@ export { export { definePrompt, renderPrompt, + type ExecutablePrompt, type PromptAction, type PromptConfig, type PromptFn, + type PromptGenerateOptions, } from './prompt.js'; export { rerank, diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 0e1769421..4a176e70b 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -22,6 +22,7 @@ import { StreamingCallback, z, } from '@genkit-ai/core'; +import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { performance } from 'node:perf_hooks'; import { DocumentDataSchema } from './document.js'; @@ -194,7 +195,7 @@ export const ModelRequestSchema = z.object({ config: z.any().optional(), tools: z.array(ToolDefinitionSchema).optional(), output: OutputConfigSchema.optional(), - context: z.array(DocumentDataSchema).optional(), + docs: z.array(DocumentDataSchema).optional(), }); /** ModelRequest represents the parameters that are passed to a model when generating content. */ export interface ModelRequest< @@ -204,8 +205,6 @@ export interface ModelRequest< } export const GenerateRequestSchema = ModelRequestSchema.extend({ - /** @deprecated Use `docs` instead. */ - context: z.array(DocumentDataSchema).optional(), /** @deprecated All responses now return a single candidate. This will always be `undefined`. */ candidates: z.number().optional(), }); @@ -260,7 +259,9 @@ export const ModelResponseSchema = z.object({ finishMessage: z.string().optional(), latencyMs: z.number().optional(), usage: GenerationUsageSchema.optional(), + /** @deprecated use `raw` instead */ custom: z.unknown(), + raw: z.unknown(), request: GenerateRequestSchema.optional(), }); export type ModelResponseData = z.infer; @@ -330,6 +331,7 @@ export type DefineModelOptions< export function defineModel< CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, options: DefineModelOptions, runner: ( request: GenerateRequest, @@ -344,6 +346,7 @@ export function defineModel< if (!options?.supports?.context) middleware.push(augmentWithContext()); middleware.push(conformOutput()); const act = defineAction( + registry, { actionType: 'model', name: options.name, @@ -385,17 +388,37 @@ export interface ModelReference { configSchema?: CustomOptions; info?: ModelInfo; version?: string; + config?: z.infer; + + withConfig(cfg: z.infer): ModelReference; + withVersion(version: string): ModelReference; } -/** - * - */ +/** Cretes a model reference. */ export function modelRef< CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny, >( - options: ModelReference + options: Omit< + ModelReference, + 'withConfig' | 'withVersion' + > ): ModelReference { - return { ...options }; + const ref: Partial> = { ...options }; + ref.withConfig = ( + cfg: z.infer + ): ModelReference => { + return modelRef({ + ...options, + config: cfg, + }); + }; + ref.withVersion = (version: string): ModelReference => { + return modelRef({ + ...options, + version, + }); + }; + return ref as ModelReference; } /** Container for counting usage stats for a single input/output {Part} */ @@ -434,16 +457,20 @@ export function getBasicUsageStats( function getPartCounts(parts: Part[]): PartCounts { return parts.reduce( (counts, part) => { + const isImage = + part.media?.contentType?.startsWith('image') || + part.media?.url?.startsWith('data:image'); + const isVideo = + part.media?.contentType?.startsWith('video') || + part.media?.url?.startsWith('data:video'); + const isAudio = + part.media?.contentType?.startsWith('audio') || + part.media?.url?.startsWith('data:audio'); return { characters: counts.characters + (part.text?.length || 0), - images: - counts.images + - (part.media?.contentType?.startsWith('image') ? 1 : 0), - videos: - counts.videos + - (part.media?.contentType?.startsWith('video') ? 1 : 0), - audio: - counts.audio + (part.media?.contentType?.startsWith('audio') ? 1 : 0), + images: counts.images + (isImage ? 1 : 0), + videos: counts.videos + (isVideo ? 1 : 0), + audio: counts.audio + (isAudio ? 1 : 0), }; }, { characters: 0, images: 0, videos: 0, audio: 0 } diff --git a/js/ai/src/model/middleware.ts b/js/ai/src/model/middleware.ts index 5a16cc1f9..11c320cb9 100644 --- a/js/ai/src/model/middleware.ts +++ b/js/ai/src/model/middleware.ts @@ -230,7 +230,7 @@ const CONTEXT_ITEM_TEMPLATE = ( } else if (options?.citationKey === undefined) { out += `[${d.metadata?.['ref'] || d.metadata?.['id'] || index}]: `; } - out += d.text() + '\n'; + out += d.text + '\n'; return out; }; @@ -242,7 +242,7 @@ export function augmentWithContext( const itemTemplate = options?.itemTemplate || CONTEXT_ITEM_TEMPLATE; return (req, next) => { // if there is no context in the request, no-op - if (!req.context?.length) return next(req); + if (!req.docs?.length) return next(req); const userMessage = lastUserMessage(req.messages); // if there are no messages, no-op if (!userMessage) return next(req); @@ -257,7 +257,7 @@ export function augmentWithContext( return next(req); } let out = `${preface || ''}`; - req.context?.forEach((d, i) => { + req.docs?.forEach((d, i) => { out += itemTemplate(new Document(d), i, options); }); out += '\n'; diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts index fb0dc06bb..f497dca23 100644 --- a/js/ai/src/prompt.ts +++ b/js/ai/src/prompt.ts @@ -15,14 +15,19 @@ */ import { Action, defineAction, JSONSchema7, z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { DocumentData } from './document.js'; -import { GenerateOptions } from './generate.js'; +import { + GenerateOptions, + GenerateResponse, + GenerateStreamResponse, +} from './generate.js'; import { GenerateRequest, GenerateRequestSchema, ModelArgument, } from './model.js'; +import { ToolAction } from './tool.js'; export type PromptFn< I extends z.ZodTypeAny = z.ZodTypeAny, @@ -58,6 +63,84 @@ export function isPrompt(arg: any): boolean { ); } +export type PromptGenerateOptions< + I = undefined, + CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, +> = Omit< + GenerateOptions, + 'prompt' | 'input' | 'model' +> & { + model?: ModelArgument; + input?: I; +}; + +/** + * A prompt that can be executed as a function. + */ +export interface ExecutablePrompt< + I = undefined, + O extends z.ZodTypeAny = z.ZodTypeAny, + CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, +> { + /** + * Generates a response by rendering the prompt template with given user input and then calling the model. + * + * @param input Prompt inputs. + * @param opt Options for the prompt template, including user input variables and custom model configuration options. + * @returns the model response as a promise of `GenerateStreamResponse`. + */ + ( + input?: I, + opts?: PromptGenerateOptions + ): Promise>>; + + /** + * Generates a response by rendering the prompt template with given user input and then calling the model. + * @param input Prompt inputs. + * @param opt Options for the prompt template, including user input variables and custom model configuration options. + * @returns the model response as a promise of `GenerateStreamResponse`. + */ + stream( + input?: I, + opts?: PromptGenerateOptions + ): Promise>>; + + /** + * Generates a response by rendering the prompt template with given user input and additional generate options and then calling the model. + * + * @param opt Options for the prompt template, including user input variables and custom model configuration options. + * @returns the model response as a promise of `GenerateResponse`. + */ + generate( + opt: PromptGenerateOptions + ): Promise>>; + + /** + * Generates a streaming response by rendering the prompt template with given user input and additional generate options and then calling the model. + * + * @param opt Options for the prompt template, including user input variables and custom model configuration options. + * @returns the model response as a promise of `GenerateStreamResponse`. + */ + generateStream( + opt: PromptGenerateOptions + ): Promise>>; + + /** + * Renders the prompt template based on user input. + * + * @param opt Options for the prompt template, including user input variables and custom model configuration options. + * @returns a `GenerateOptions` object to be used with the `generate()` function from @genkit-ai/ai. + */ + render( + opt: PromptGenerateOptions + ): Promise>; + + /** + * Returns the prompt usable as a tool. + */ + asTool(): ToolAction; +} + /** * Defines and registers a prompt action. The action can be called to obtain * a `GenerateRequest` which can be passed to a model action. The given @@ -67,10 +150,12 @@ export function isPrompt(arg: any): boolean { * @returns The new `PromptAction`. */ export function definePrompt( + registry: Registry, config: PromptConfig, fn: PromptFn ): PromptAction { const a = defineAction( + registry, { ...config, actionType: 'prompt', @@ -94,16 +179,19 @@ export async function renderPrompt< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, ->(params: { - prompt: PromptArgument; - input: z.infer; - context?: DocumentData[]; - model: ModelArgument; - config?: z.infer; -}): Promise> { +>( + registry: Registry, + params: { + prompt: PromptArgument; + input: z.infer; + docs?: DocumentData[]; + model: ModelArgument; + config?: z.infer; + } +): Promise> { let prompt: PromptAction; if (typeof params.prompt === 'string') { - prompt = await lookupAction(`/prompt/${params.prompt}`); + prompt = await registry.lookupAction(`/prompt/${params.prompt}`); } else { prompt = params.prompt as PromptAction; } @@ -115,7 +203,7 @@ export async function renderPrompt< config: { ...(rendered.config || {}), ...params.config }, messages: rendered.messages.slice(0, rendered.messages.length - 1), prompt: rendered.messages[rendered.messages.length - 1].content, - context: params.context, + docs: params.docs, output: { format: rendered.output?.format, schema: rendered.output?.schema, diff --git a/js/ai/src/reranker.ts b/js/ai/src/reranker.ts index fab6de633..35d3b2505 100644 --- a/js/ai/src/reranker.ts +++ b/js/ai/src/reranker.ts @@ -15,11 +15,11 @@ */ import { Action, defineAction, z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { Part, PartSchema } from './document.js'; import { Document, DocumentData, DocumentDataSchema } from './retriever.js'; -type RerankerFn = ( +export type RerankerFn = ( query: Document, documents: Document[], queryOpts: z.infer @@ -101,6 +101,7 @@ function rerankerWithMetadata< * Creates a reranker action for the provided {@link RerankerFn} implementation. */ export function defineReranker( + registry: Registry, options: { name: string; configSchema?: OptionsType; @@ -109,6 +110,7 @@ export function defineReranker( runner: RerankerFn ) { const reranker = defineAction( + registry, { actionType: 'reranker', name: options.name, @@ -157,13 +159,14 @@ export type RerankerArgument< * Reranks documents from a {@link RerankerArgument} based on the provided query. */ export async function rerank( + registry: Registry, params: RerankerParams ): Promise> { let reranker: RerankerAction; if (typeof params.reranker === 'string') { - reranker = await lookupAction(`/reranker/${params.reranker}`); + reranker = await registry.lookupAction(`/reranker/${params.reranker}`); } else if (Object.hasOwnProperty.call(params.reranker, 'info')) { - reranker = await lookupAction(`/reranker/${params.reranker.name}`); + reranker = await registry.lookupAction(`/reranker/${params.reranker.name}`); } else { reranker = params.reranker as RerankerAction; } diff --git a/js/ai/src/retriever.ts b/js/ai/src/retriever.ts index 5444672f2..0623e297f 100644 --- a/js/ai/src/retriever.ts +++ b/js/ai/src/retriever.ts @@ -15,7 +15,7 @@ */ import { Action, GenkitError, defineAction, z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { Document, DocumentData, DocumentDataSchema } from './document.js'; import { EmbedderInfo } from './embedder.js'; @@ -28,12 +28,12 @@ export { type TextPart, } from './document.js'; -type RetrieverFn = ( +export type RetrieverFn = ( query: Document, queryOpts: z.infer ) => Promise; -type IndexerFn = ( +export type IndexerFn = ( docs: Array, indexerOpts: z.infer ) => Promise; @@ -111,6 +111,7 @@ function indexerWithMetadata< export function defineRetriever< OptionsType extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, options: { name: string; configSchema?: OptionsType; @@ -119,6 +120,7 @@ export function defineRetriever< runner: RetrieverFn ) { const retriever = defineAction( + registry, { actionType: 'retriever', name: options.name, @@ -149,6 +151,7 @@ export function defineRetriever< * Creates an indexer action for the provided {@link IndexerFn} implementation. */ export function defineIndexer( + registry: Registry, options: { name: string; embedderInfo?: EmbedderInfo; @@ -157,6 +160,7 @@ export function defineIndexer( runner: IndexerFn ) { const indexer = defineAction( + registry, { actionType: 'indexer', name: options.name, @@ -200,13 +204,16 @@ export type RetrieverArgument< * Retrieves documents from a {@link RetrieverArgument} based on the provided query. */ export async function retrieve( + registry: Registry, params: RetrieverParams ): Promise> { let retriever: RetrieverAction; if (typeof params.retriever === 'string') { - retriever = await lookupAction(`/retriever/${params.retriever}`); + retriever = await registry.lookupAction(`/retriever/${params.retriever}`); } else if (Object.hasOwnProperty.call(params.retriever, 'info')) { - retriever = await lookupAction(`/retriever/${params.retriever.name}`); + retriever = await registry.lookupAction( + `/retriever/${params.retriever.name}` + ); } else { retriever = params.retriever as RetrieverAction; } @@ -239,13 +246,14 @@ export interface IndexerParams< * Indexes documents using a {@link IndexerArgument}. */ export async function index( + registry: Registry, params: IndexerParams ): Promise { let indexer: IndexerAction; if (typeof params.indexer === 'string') { - indexer = await lookupAction(`/indexer/${params.indexer}`); + indexer = await registry.lookupAction(`/indexer/${params.indexer}`); } else if (Object.hasOwnProperty.call(params.indexer, 'info')) { - indexer = await lookupAction(`/indexer/${params.indexer.name}`); + indexer = await registry.lookupAction(`/indexer/${params.indexer.name}`); } else { indexer = params.indexer as IndexerAction; } @@ -381,10 +389,12 @@ export function defineSimpleRetriever< C extends z.ZodTypeAny = z.ZodTypeAny, R = any, >( + registry: Registry, options: SimpleRetrieverOptions, handler: (query: Document, config: z.infer) => Promise ) { return defineRetriever( + registry, { name: options.name, configSchema: options.configSchema, diff --git a/js/ai/src/testing/model-tester.ts b/js/ai/src/testing/model-tester.ts index 81b0b5c5c..7caa4b0cc 100644 --- a/js/ai/src/testing/model-tester.ts +++ b/js/ai/src/testing/model-tester.ts @@ -15,7 +15,7 @@ */ import { z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { runInNewSpan } from '@genkit-ai/core/tracing'; import assert from 'node:assert'; import { generate } from '../generate'; @@ -23,23 +23,23 @@ import { ModelAction } from '../model'; import { defineTool } from '../tool'; const tests: Record = { - 'basic hi': async (model: string) => { - const response = await generate({ + 'basic hi': async (registry: Registry, model: string) => { + const response = await generate(registry, { model, prompt: 'just say "Hi", literally', }); - const got = response.text().trim(); + const got = response.text.trim(); assert.match(got, /Hi/i); }, - multimodal: async (model: string) => { - const resolvedModel = (await lookupAction( + multimodal: async (registry: Registry, model: string) => { + const resolvedModel = (await registry.lookupAction( `/model/${model}` )) as ModelAction; if (!resolvedModel.__action.metadata?.model.supports?.media) { skip(); } - const response = await generate({ + const response = await generate(registry, { model, prompt: [ { @@ -54,31 +54,31 @@ const tests: Record = { }); const want = ''; - const got = response.text().trim(); + const got = response.text.trim(); assert.match(got, /plus/i); }, - history: async (model: string) => { - const resolvedModel = (await lookupAction( + history: async (registry: Registry, model: string) => { + const resolvedModel = (await registry.lookupAction( `/model/${model}` )) as ModelAction; if (!resolvedModel.__action.metadata?.model.supports?.multiturn) { skip(); } - const response1 = await generate({ + const response1 = await generate(registry, { model, prompt: 'My name is Glorb', }); - const response = await generate({ + const response = await generate(registry, { model, prompt: "What's my name?", - messages: response1.toHistory(), + messages: response1.messages, }); - const got = response.text().trim(); + const got = response.text.trim(); assert.match(got, /Glorb/); }, - 'system prompt': async (model: string) => { - const response = await generate({ + 'system prompt': async (registry: Registry, model: string) => { + const { text } = await generate(registry, { model, prompt: 'Hi', messages: [ @@ -94,11 +94,11 @@ const tests: Record = { }); const want = 'Bye'; - const got = response.text().trim(); + const got = text.trim(); assert.equal(got, want); }, - 'structured output': async (model: string) => { - const response = await generate({ + 'structured output': async (registry: Registry, model: string) => { + const response = await generate(registry, { model, prompt: 'extract data as json from: Jack was a Lumberjack', output: { @@ -114,24 +114,24 @@ const tests: Record = { name: 'Jack', occupation: 'Lumberjack', }; - const got = response.output(); + const got = response.output; assert.deepEqual(want, got); }, - 'tool calling': async (model: string) => { - const resolvedModel = (await lookupAction( + 'tool calling': async (registry: Registry, model: string) => { + const resolvedModel = (await registry.lookupAction( `/model/${model}` )) as ModelAction; if (!resolvedModel.__action.metadata?.model.supports?.tools) { skip(); } - const response = await generate({ + const { text } = await generate(registry, { model, prompt: 'what is a gablorken of 2? use provided tool', tools: ['gablorkenTool'], }); - const got = response.text().trim(); + const got = text.trim(); assert.match(got, /9.407/); }, }; @@ -149,10 +149,14 @@ type TestReport = { }[]; }[]; -type TestCase = (model: string) => Promise; +type TestCase = (ai: Registry, model: string) => Promise; -export async function testModels(models: string[]): Promise { +export async function testModels( + registry: Registry, + models: string[] +): Promise { const gablorkenTool = defineTool( + registry, { name: 'gablorkenTool', description: 'use when need to calculate a gablorken', @@ -182,7 +186,7 @@ export async function testModels(models: string[]): Promise { }); const modelReport = caseReport.models[caseReport.models.length - 1]; try { - await tests[test](model); + await tests[test](registry, model); } catch (e) { modelReport.passed = false; if (e instanceof SkipTestError) { diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index 9dcb61c4f..a0d85340c 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -15,7 +15,7 @@ */ import { Action, defineAction, JSONSchema7, z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing'; import { ToolDefinition } from './model.js'; @@ -89,11 +89,11 @@ export function asTool( export async function resolveTools< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, ->(tools: ToolArgument[] = []): Promise { +>(registry: Registry, tools: ToolArgument[] = []): Promise { return await Promise.all( tools.map(async (ref): Promise => { if (typeof ref === 'string') { - const tool = await lookupAction(`/tool/${ref}`); + const tool = await registry.lookupAction(`/tool/${ref}`); if (!tool) { throw new Error(`Tool ${ref} not found`); } @@ -101,7 +101,7 @@ export async function resolveTools< } else if ((ref as Action).__action) { return asTool(ref as Action); } else if (ref.name) { - const tool = await lookupAction(`/tool/${ref.name}`); + const tool = await registry.lookupAction(`/tool/${ref.name}`); if (!tool) { throw new Error(`Tool ${ref} not found`); } @@ -137,10 +137,12 @@ export function toToolDefinition( * A tool is an action that can be passed to a model to be called automatically if it so chooses. */ export function defineTool( + registry: Registry, config: ToolConfig, fn: (input: z.infer) => Promise> ): ToolAction { const a = defineAction( + registry, { ...config, actionType: 'tool', diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index be7a4a079..9a02b0f6e 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -15,7 +15,7 @@ */ import { z } from '@genkit-ai/core'; -import { Registry, runWithRegistry } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; @@ -108,7 +108,7 @@ describe('GenerateResponse', () => { const response = new GenerateResponse( test.responseData as GenerateResponseData ); - assert.deepStrictEqual(response.output(), test.expectedOutput); + assert.deepStrictEqual(response.output, test.expectedOutput); }); } }); @@ -213,7 +213,7 @@ describe('GenerateResponse', () => { }), finishReason: 'stop', }); - assert.deepStrictEqual(response.toolRequests(), []); + assert.deepStrictEqual(response.toolRequests, []); }); it('returns tool call if present', () => { const toolCall = { @@ -230,7 +230,7 @@ describe('GenerateResponse', () => { }), finishReason: 'stop', }); - assert.deepStrictEqual(response.toolRequests(), [toolCall]); + assert.deepStrictEqual(response.toolRequests, [toolCall]); }); it('returns all tool calls', () => { const toolCall1 = { @@ -254,7 +254,7 @@ describe('GenerateResponse', () => { }), finishReason: 'stop', }); - assert.deepStrictEqual(response.toolRequests(), [toolCall1, toolCall2]); + assert.deepStrictEqual(response.toolRequests, [toolCall1, toolCall2]); }); }); }); @@ -262,19 +262,18 @@ describe('GenerateResponse', () => { describe('toGenerateRequest', () => { const registry = new Registry(); // register tools - const tellAFunnyJoke = runWithRegistry(registry, () => - defineTool( - { - name: 'tellAFunnyJoke', - description: - 'Tells jokes about an input topic. Use this tool whenever user asks you to tell a joke.', - inputSchema: z.object({ topic: z.string() }), - outputSchema: z.string(), - }, - async (input) => { - return `Why did the ${input.topic} cross the road?`; - } - ) + const tellAFunnyJoke = defineTool( + registry, + { + name: 'tellAFunnyJoke', + description: + 'Tells jokes about an input topic. Use this tool whenever user asks you to tell a joke.', + inputSchema: z.object({ topic: z.string() }), + outputSchema: z.string(), + }, + async (input) => { + return `Why did the ${input.topic} cross the road?`; + } ); const testCases = [ @@ -289,7 +288,7 @@ describe('toGenerateRequest', () => { { role: 'user', content: [{ text: 'Tell a joke about dogs.' }] }, ], config: undefined, - context: undefined, + docs: undefined, tools: [], output: { format: 'text' }, }, @@ -307,7 +306,7 @@ describe('toGenerateRequest', () => { { role: 'user', content: [{ text: 'Tell a joke about dogs.' }] }, ], config: undefined, - context: undefined, + docs: undefined, tools: [ { name: 'tellAFunnyJoke', @@ -342,7 +341,7 @@ describe('toGenerateRequest', () => { { role: 'user', content: [{ text: 'Tell a joke about dogs.' }] }, ], config: undefined, - context: undefined, + docs: undefined, tools: [ { name: 'tellAFunnyJoke', @@ -394,7 +393,7 @@ describe('toGenerateRequest', () => { }, ], config: undefined, - context: undefined, + docs: undefined, tools: [], output: { format: 'text' }, }, @@ -416,7 +415,7 @@ describe('toGenerateRequest', () => { { role: 'user', content: [{ text: 'Tell a joke about dogs.' }] }, ], config: undefined, - context: undefined, + docs: undefined, tools: [], output: { format: 'text' }, }, @@ -426,14 +425,14 @@ describe('toGenerateRequest', () => { prompt: { model: 'vertexai/gemini-1.0-pro', prompt: 'Tell a joke with context.', - context: [{ content: [{ text: 'context here' }] }], + docs: [{ content: [{ text: 'context here' }] }], }, expectedOutput: { messages: [ { content: [{ text: 'Tell a joke with context.' }], role: 'user' }, ], config: undefined, - context: [{ content: [{ text: 'context here' }] }], + docs: [{ content: [{ text: 'context here' }] }], tools: [], output: { format: 'text' }, }, @@ -442,9 +441,7 @@ describe('toGenerateRequest', () => { for (const test of testCases) { it(test.should, async () => { assert.deepStrictEqual( - await runWithRegistry(registry, () => - toGenerateRequest(test.prompt as GenerateOptions) - ), + await toGenerateRequest(registry, test.prompt as GenerateOptions), test.expectedOutput ); }); @@ -515,7 +512,7 @@ describe('GenerateResponseChunk', () => { const responseChunk: GenerateResponseChunk = new GenerateResponseChunk(chunkData, accumulatedChunks); - const output = responseChunk.output(); + const output = responseChunk.output; assert.deepStrictEqual(output, test.correctJson); }); @@ -530,29 +527,28 @@ describe('generate', () => { beforeEach(() => { registry = new Registry(); - echoModel = runWithRegistry(registry, () => - defineModel( - { - name: 'echoModel', - }, - async (request) => { - return { - message: { - role: 'model', - content: [ - { - text: - 'Echo: ' + - request.messages - .map((m) => m.content.map((c) => c.text).join()) - .join(), - }, - ], - }, - finishReason: 'stop', - }; - } - ) + echoModel = defineModel( + registry, + { + name: 'echoModel', + }, + async (request) => { + return { + message: { + role: 'model', + content: [ + { + text: + 'Echo: ' + + request.messages + .map((m) => m.content.map((c) => c.text).join()) + .join(), + }, + ], + }, + finishReason: 'stop', + }; + } ); }); @@ -592,16 +588,13 @@ describe('generate', () => { }; }; - const response = await runWithRegistry(registry, () => - generate({ - prompt: 'banana', - model: echoModel, - use: [wrapRequest, wrapResponse], - }) - ); - + const response = await generate(registry, { + prompt: 'banana', + model: echoModel, + use: [wrapRequest, wrapResponse], + }); const want = '[Echo: (banana)]'; - assert.deepStrictEqual(response.text(), want); + assert.deepStrictEqual(response.text, want); }); }); @@ -609,27 +602,24 @@ describe('generate', () => { let registry: Registry; beforeEach(() => { registry = new Registry(); - runWithRegistry(registry, () => - defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ) - ); - }); - it('should preserve the request in the returned response, enabling toHistory()', async () => { - const response = await runWithRegistry(registry, () => - generate({ - model: 'echo', - prompt: 'Testing toHistory', + + defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', }) ); - + }); + it('should preserve the request in the returned response, enabling .messages', async () => { + const response = await generate(registry, { + model: 'echo', + prompt: 'Testing messages', + }); assert.deepEqual( - response.toHistory().map((m) => m.content[0].text), - ['Testing toHistory', 'Testing toHistory'] + response.messages.map((m) => m.content[0].text), + ['Testing messages', 'Testing messages'] ); }); }); diff --git a/js/ai/tests/model/document_test.ts b/js/ai/tests/model/document_test.ts index d91bdb657..c1adaa97f 100644 --- a/js/ai/tests/model/document_test.ts +++ b/js/ai/tests/model/document_test.ts @@ -23,13 +23,13 @@ describe('document', () => { it('retuns single text part', () => { const doc = new Document({ content: [{ text: 'foo' }] }); - assert.equal(doc.text(), 'foo'); + assert.equal(doc.text, 'foo'); }); it('retuns concatenated text part', () => { const doc = new Document({ content: [{ text: 'foo' }, { text: 'bar' }] }); - assert.equal(doc.text(), 'foobar'); + assert.equal(doc.text, 'foobar'); }); }); @@ -42,7 +42,7 @@ describe('document', () => { ], }); - assert.deepEqual(doc.media(), { url: 'data:foo' }); + assert.deepEqual(doc.media, { url: 'data:foo' }); }); }); diff --git a/js/ai/tests/model/middleware_test.ts b/js/ai/tests/model/middleware_test.ts index 94fb048db..3c9aaaffa 100644 --- a/js/ai/tests/model/middleware_test.ts +++ b/js/ai/tests/model/middleware_test.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { Registry, runWithRegistry } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; import { DocumentData } from '../../src/document.js'; @@ -147,24 +147,21 @@ describe('validateSupport', () => { }); const registry = new Registry(); -const echoModel = runWithRegistry(registry, () => - defineModel({ name: 'echo' }, async (req) => { - return { - finishReason: 'stop', - message: { - role: 'model', - content: [{ data: req }], - }, - }; - }) -); - +const echoModel = defineModel(registry, { name: 'echo' }, async (req) => { + return { + finishReason: 'stop', + message: { + role: 'model', + content: [{ data: req }], + }, + }; +}); describe('conformOutput (default middleware)', () => { const schema = { type: 'object', properties: { test: { type: 'boolean' } } }; // return the output tagged part from the request async function testRequest(req: GenerateRequest): Promise { - const response = await runWithRegistry(registry, () => echoModel(req)); + const response = await echoModel(req); const treq = response.message!.content[0].data as GenerateRequest; const lastUserMessage = treq.messages @@ -191,7 +188,7 @@ describe('conformOutput (default middleware)', () => { { role: 'user', content: [{ text: 'hello again' }] }, ], output: { format: 'json', schema }, - context: [{ content: [{ text: 'hi' }] }], + docs: [{ content: [{ text: 'hi' }] }], }); assert( part?.text?.includes(JSON.stringify(schema)), @@ -341,7 +338,7 @@ describe('augmentWithContext', () => { augmentWithContext(options)( { messages, - context, + docs: context, }, resolve as any ); @@ -521,7 +518,7 @@ describe('augmentWithContext', () => { metadata: { uid: 'second' }, }, ], - { itemTemplate: (d) => `* (${d.metadata!.uid}) -- ${d.text()}\n` } + { itemTemplate: (d) => `* (${d.metadata!.uid}) -- ${d.text}\n` } ); assert.deepEqual(result[0].content.at(-1), { text: `${CONTEXT_PREFACE}* (first) -- i am context\n* (second) -- i am more context\n\n`, diff --git a/js/ai/tests/prompt/prompt_test.ts b/js/ai/tests/prompt/prompt_test.ts index c35c951c6..702f85444 100644 --- a/js/ai/tests/prompt/prompt_test.ts +++ b/js/ai/tests/prompt/prompt_test.ts @@ -15,7 +15,7 @@ */ import { z } from '@genkit-ai/core'; -import { Registry, runWithRegistry } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import assert from 'node:assert'; import { describe, it } from 'node:test'; import { definePrompt, renderPrompt } from '../../src/prompt.ts'; @@ -23,38 +23,37 @@ import { definePrompt, renderPrompt } from '../../src/prompt.ts'; describe('prompt', () => { let registry = new Registry(); describe('render()', () => { - runWithRegistry(registry, () => { - it('respects output schema in the definition', async () => { - const schema1 = z.object({ - puppyName: z.string({ description: 'A cute name for a puppy' }), - }); - const prompt1 = definePrompt( - { - name: 'prompt1', - inputSchema: z.string({ description: 'Dog breed' }), - }, - async (breed) => { - return { - messages: [ - { - role: 'user', - content: [{ text: `Pick a name for a ${breed} puppy` }], - }, - ], - output: { - format: 'json', - schema: schema1, + it('respects output schema in the definition', async () => { + const schema1 = z.object({ + puppyName: z.string({ description: 'A cute name for a puppy' }), + }); + const prompt1 = definePrompt( + registry, + { + name: 'prompt1', + inputSchema: z.string({ description: 'Dog breed' }), + }, + async (breed) => { + return { + messages: [ + { + role: 'user', + content: [{ text: `Pick a name for a ${breed} puppy` }], }, - }; - } - ); - const generateRequest = await renderPrompt({ - prompt: prompt1, - input: 'poodle', - model: 'geminiPro', - }); - assert.equal(generateRequest.output?.schema, schema1); + ], + output: { + format: 'json', + schema: schema1, + }, + }; + } + ); + const generateRequest = await renderPrompt(registry, { + prompt: prompt1, + input: 'poodle', + model: 'geminiPro', }); + assert.equal(generateRequest.output?.schema, schema1); }); }); }); diff --git a/js/ai/tests/reranker/reranker_test.ts b/js/ai/tests/reranker/reranker_test.ts index 1b67a663f..63a8b25e4 100644 --- a/js/ai/tests/reranker/reranker_test.ts +++ b/js/ai/tests/reranker/reranker_test.ts @@ -15,7 +15,7 @@ */ import { GenkitError, z } from '@genkit-ai/core'; -import { Registry, runWithRegistry } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; import { defineReranker, rerank } from '../../src/reranker'; @@ -28,34 +28,32 @@ describe('reranker', () => { registry = new Registry(); }); it('reranks documents based on custom logic', async () => { - const customReranker = runWithRegistry(registry, () => - defineReranker( - { - name: 'reranker', - configSchema: z.object({ - k: z.number().optional(), - }), - }, - async (query, documents, options) => { - // Custom reranking logic: score based on string length similarity to query - const queryLength = query.text().length; - const rerankedDocs = documents.map((doc) => { - const score = Math.abs(queryLength - doc.text().length); - return { - ...doc, - metadata: { ...doc.metadata, score }, - }; - }); - + const customReranker = defineReranker( + registry, + { + name: 'reranker', + configSchema: z.object({ + k: z.number().optional(), + }), + }, + async (query, documents, options) => { + // Custom reranking logic: score based on string length similarity to query + const queryLength = query.text.length; + const rerankedDocs = documents.map((doc) => { + const score = Math.abs(queryLength - doc.text.length); return { - documents: rerankedDocs - .sort((a, b) => a.metadata.score - b.metadata.score) - .slice(0, options.k || 3), + ...doc, + metadata: { ...doc.metadata, score }, }; - } - ) + }); + + return { + documents: rerankedDocs + .sort((a, b) => a.metadata.score - b.metadata.score) + .slice(0, options.k || 3), + }; + } ); - // Sample documents for testing const documents = [ Document.fromText('short'), @@ -64,101 +62,89 @@ describe('reranker', () => { ]; const query = Document.fromText('medium length'); - const rerankedDocuments = await runWithRegistry(registry, () => - rerank({ - reranker: customReranker, - query, - documents, - options: { k: 2 }, - }) - ); - + const rerankedDocuments = await rerank(registry, { + reranker: customReranker, + query, + documents, + options: { k: 2 }, + }); // Validate the reranked results assert.equal(rerankedDocuments.length, 2); - assert(rerankedDocuments[0].text().includes('a bit longer')); - assert(rerankedDocuments[1].text().includes('short')); + assert(rerankedDocuments[0].text.includes('a bit longer')); + assert(rerankedDocuments[1].text.includes('short')); }); it('handles missing options gracefully', async () => { - const customReranker = runWithRegistry(registry, () => - defineReranker( - { - name: 'reranker', - configSchema: z.object({ - k: z.number().optional(), - }), - }, - async (query, documents, options) => { - const rerankedDocs = documents.map((doc) => { - const score = Math.random(); // Simplified scoring for testing - return { - ...doc, - metadata: { ...doc.metadata, score }, - }; - }); - + const customReranker = defineReranker( + registry, + { + name: 'reranker', + configSchema: z.object({ + k: z.number().optional(), + }), + }, + async (query, documents, options) => { + const rerankedDocs = documents.map((doc) => { + const score = Math.random(); // Simplified scoring for testing return { - documents: rerankedDocs.sort( - (a, b) => b.metadata.score - a.metadata.score - ), + ...doc, + metadata: { ...doc.metadata, score }, }; - } - ) + }); + + return { + documents: rerankedDocs.sort( + (a, b) => b.metadata.score - a.metadata.score + ), + }; + } ); - const documents = [Document.fromText('doc1'), Document.fromText('doc2')]; const query = Document.fromText('test query'); - const rerankedDocuments = await runWithRegistry(registry, () => - rerank({ - reranker: customReranker, - query, - documents, - options: { k: 2 }, - }) - ); - + const rerankedDocuments = await rerank(registry, { + reranker: customReranker, + query, + documents, + options: { k: 2 }, + }); assert.equal(rerankedDocuments.length, 2); assert(typeof rerankedDocuments[0].metadata.score === 'number'); }); it('validates config schema and throws error on invalid input', async () => { - const customReranker = runWithRegistry(registry, () => - defineReranker( - { - name: 'reranker', - configSchema: z.object({ - k: z.number().min(1), - }), - }, - async (query, documents, options) => { - // Simplified scoring for testing - const rerankedDocs = documents.map((doc) => ({ - ...doc, - metadata: { score: Math.random() }, - })); - return { - documents: rerankedDocs.sort( - (a, b) => b.metadata.score - a.metadata.score - ), - }; - } - ) + const customReranker = defineReranker( + registry, + { + name: 'reranker', + configSchema: z.object({ + k: z.number().min(1), + }), + }, + async (query, documents, options) => { + // Simplified scoring for testing + const rerankedDocs = documents.map((doc) => ({ + ...doc, + metadata: { score: Math.random() }, + })); + return { + documents: rerankedDocs.sort( + (a, b) => b.metadata.score - a.metadata.score + ), + }; + } ); - const documents = [Document.fromText('doc1')]; const query = Document.fromText('test query'); try { - await runWithRegistry(registry, () => - rerank({ - reranker: customReranker, - query, - documents, - options: { k: 0 }, // Invalid input: k must be at least 1 - }) - ); + await rerank(registry, { + reranker: customReranker, + query, + documents, + options: { k: 0 }, // Invalid input: k must be at least 1 + }); assert.fail('Expected validation error'); } catch (err) { assert(err instanceof GenkitError); @@ -167,71 +153,62 @@ describe('reranker', () => { }); it('preserves document metadata after reranking', async () => { - const customReranker = runWithRegistry(registry, () => - defineReranker( - { - name: 'reranker', - }, - async (query, documents) => { - const rerankedDocs = documents.map((doc, i) => ({ - ...doc, - metadata: { ...doc.metadata, score: 2 - i }, - })); - - return { - documents: rerankedDocs.sort( - (a, b) => b.metadata.score - a.metadata.score - ), - }; - } - ) + const customReranker = defineReranker( + registry, + { + name: 'reranker', + }, + async (query, documents) => { + const rerankedDocs = documents.map((doc, i) => ({ + ...doc, + metadata: { ...doc.metadata, score: 2 - i }, + })); + + return { + documents: rerankedDocs.sort( + (a, b) => b.metadata.score - a.metadata.score + ), + }; + } ); - const documents = [ new Document({ content: [], metadata: { originalField: 'test1' } }), new Document({ content: [], metadata: { originalField: 'test2' } }), ]; const query = Document.fromText('test query'); - const rerankedDocuments = await runWithRegistry(registry, () => - rerank({ - reranker: customReranker, - query, - documents, - }) - ); - + const rerankedDocuments = await rerank(registry, { + reranker: customReranker, + query, + documents, + }); assert.equal(rerankedDocuments[0].metadata.originalField, 'test1'); assert.equal(rerankedDocuments[1].metadata.originalField, 'test2'); }); it('handles errors thrown by the reranker', async () => { - const customReranker = runWithRegistry(registry, () => - defineReranker( - { - name: 'reranker', - }, - async (query, documents) => { - // Simulate an error in the reranker logic - throw new GenkitError({ - message: 'Something went wrong during reranking', - status: 'INTERNAL', - }); - } - ) + const customReranker = defineReranker( + registry, + { + name: 'reranker', + }, + async (query, documents) => { + // Simulate an error in the reranker logic + throw new GenkitError({ + message: 'Something went wrong during reranking', + status: 'INTERNAL', + }); + } ); - const documents = [Document.fromText('doc1'), Document.fromText('doc2')]; const query = Document.fromText('test query'); try { - await runWithRegistry(registry, () => - rerank({ - reranker: customReranker, - query, - documents, - }) - ); + await rerank(registry, { + reranker: customReranker, + query, + documents, + }); assert.fail('Expected an error to be thrown'); } catch (err) { assert(err instanceof GenkitError); diff --git a/js/core/package.json b/js/core/package.json index 3a26e2536..03fdcd451 100644 --- a/js/core/package.json +++ b/js/core/package.json @@ -7,7 +7,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 1bf8f7a1b..382b80d14 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -17,12 +17,7 @@ import { JSONSchema7 } from 'json-schema'; import { AsyncLocalStorage } from 'node:async_hooks'; import * as z from 'zod'; -import { - ActionType, - initializeAllPlugins, - lookupPlugin, - registerAction, -} from './registry.js'; +import { ActionType, Registry } from './registry.js'; import { parseSchema } from './schema.js'; import { SPAN_TYPE_ATTR, @@ -122,8 +117,8 @@ export function action< ): Action { const actionName = typeof config.name === 'string' - ? validateActionName(config.name) - : `${config.name.pluginId}/${validateActionId(config.name.actionId)}`; + ? config.name + : `${config.name.pluginId}/${config.name.actionId}`; const actionFn = async (input: I) => { input = parseSchema(input, { schema: config.inputSchema, @@ -168,16 +163,16 @@ export function action< return actionFn; } -function validateActionName(name: string) { +function validateActionName(registry: Registry, name: string) { if (name.includes('/')) { - validatePluginName(name.split('/', 1)[0]); + validatePluginName(registry, name.split('/', 1)[0]); validateActionId(name.substring(name.indexOf('/') + 1)); } return name; } -function validatePluginName(pluginId: string) { - if (!lookupPlugin(pluginId)) { +function validatePluginName(registry: Registry, pluginId: string) { + if (!registry.lookupPlugin(pluginId)) { throw new Error( `Unable to find plugin name used in the action name: ${pluginId}` ); @@ -200,6 +195,7 @@ export function defineAction< O extends z.ZodTypeAny, M extends Record = Record, >( + registry: Registry, config: ActionParams & { actionType: ActionType; }, @@ -211,13 +207,18 @@ export function defineAction< 'See: https://github.com/firebase/genkit/blob/main/docs/errors/no_new_actions_at_runtime.md' ); } + if (typeof config.name === 'string') { + validateActionName(registry, config.name); + } else { + validateActionId(config.name.actionId); + } const act = action(config, async (i: I): Promise> => { setCustomMetadataAttributes({ subtype: config.actionType }); - await initializeAllPlugins(); + await registry.initializeAllPlugins(); return await runInActionRuntimeContext(() => fn(i)); }); act.__action.actionType = config.actionType; - registerAction(config.actionType, act); + registry.registerAction(config.actionType, act); return act; } diff --git a/js/core/src/flow-client/client.ts b/js/core/src/flow-client/client.ts index 46656ce77..111e4e479 100644 --- a/js/core/src/flow-client/client.ts +++ b/js/core/src/flow-client/client.ts @@ -28,10 +28,10 @@ const __flowStreamDelimiter = '\n'; * url: 'https://my-flow-deployed-url', * input: 'foo', * }); - * for await (const chunk of response.stream()) { + * for await (const chunk of response.stream) { * console.log(chunk); * } - * console.log(await response.output()); + * console.log(await response.output); * ``` */ export function streamFlow({ diff --git a/js/core/src/flow.ts b/js/core/src/flow.ts index 0061e0cde..107459585 100644 --- a/js/core/src/flow.ts +++ b/js/core/src/flow.ts @@ -31,12 +31,7 @@ import { runWithAuthContext } from './auth.js'; import { getErrorMessage, getErrorStack } from './error.js'; import { FlowActionInputSchema } from './flowTypes.js'; import { logger } from './logging.js'; -import { - getRegistryInstance, - initializeAllPlugins, - Registry, - runWithRegistry, -} from './registry.js'; +import { Registry } from './registry.js'; import { toJsonSchema } from './schema.js'; import { newTrace, @@ -181,6 +176,7 @@ export class Flow< readonly flowFn: FlowFn; constructor( + private registry: Registry, config: FlowConfig | StreamingFlowConfig, flowFn: FlowFn ) { @@ -207,7 +203,7 @@ export class Flow< auth?: unknown; } ): Promise>> { - await initializeAllPlugins(); + await this.registry.initializeAllPlugins(); return await runWithAuthContext(opts.auth, () => newTrace( { @@ -336,84 +332,79 @@ export class Flow< } async expressHandler( - registry: Registry, request: __RequestWithAuth, response: express.Response ): Promise { - await runWithRegistry(registry, async () => { - const { stream } = request.query; - const auth = request.auth; - - let input = request.body.data; + const { stream } = request.query; + const auth = request.auth; + + let input = request.body.data; + + try { + await this.authPolicy?.(auth, input); + } catch (e: any) { + const respBody = { + error: { + status: 'PERMISSION_DENIED', + message: e.message || 'Permission denied to resource', + }, + }; + response.status(403).send(respBody).end(); + return; + } + if (stream === 'true') { + response.writeHead(200, { + 'Content-Type': 'text/plain', + 'Transfer-Encoding': 'chunked', + }); try { - await this.authPolicy?.(auth, input); - } catch (e: any) { - const respBody = { + const result = await this.invoke(input, { + streamingCallback: ((chunk: z.infer) => { + response.write(JSON.stringify(chunk) + streamDelimiter); + }) as S extends z.ZodVoid ? undefined : StreamingCallback>, + auth, + }); + response.write({ + result: result.result, // Need more results!!!! + }); + response.end(); + } catch (e) { + response.write({ error: { - status: 'PERMISSION_DENIED', - message: e.message || 'Permission denied to resource', + status: 'INTERNAL', + message: getErrorMessage(e), + details: getErrorStack(e), }, - }; - response.status(403).send(respBody).end(); - return; - } - - if (stream === 'true') { - response.writeHead(200, { - 'Content-Type': 'text/plain', - 'Transfer-Encoding': 'chunked', }); - try { - const result = await this.invoke(input, { - streamingCallback: ((chunk: z.infer) => { - response.write(JSON.stringify(chunk) + streamDelimiter); - }) as S extends z.ZodVoid - ? undefined - : StreamingCallback>, - auth, - }); - response.write({ - result: result.result, // Need more results!!!! - }); - response.end(); - } catch (e) { - response.write({ + response.end(); + } + } else { + try { + const result = await this.invoke(input, { auth }); + response.setHeader('x-genkit-trace-id', result.traceId); + response.setHeader('x-genkit-span-id', result.spanId); + // Responses for non-streaming flows are passed back with the flow result stored in a field called "result." + response + .status(200) + .send({ + result: result.result, + }) + .end(); + } catch (e) { + // Errors for non-streaming flows are passed back as standard API errors. + response + .status(500) + .send({ error: { status: 'INTERNAL', message: getErrorMessage(e), details: getErrorStack(e), }, - }); - response.end(); - } - } else { - try { - const result = await this.invoke(input, { auth }); - response.setHeader('x-genkit-trace-id', result.traceId); - response.setHeader('x-genkit-span-id', result.spanId); - // Responses for non-streaming flows are passed back with the flow result stored in a field called "result." - response - .status(200) - .send({ - result: result.result, - }) - .end(); - } catch (e) { - // Errors for non-streaming flows are passed back as standard API errors. - response - .status(500) - .send({ - error: { - status: 'INTERNAL', - message: getErrorMessage(e), - details: getErrorStack(e), - }, - }) - .end(); - } + }) + .end(); } - }); + } } } @@ -496,9 +487,7 @@ export class FlowServer { flow.middleware?.forEach((middleware) => server.post(flowPath, middleware) ); - server.post(flowPath, (req, res) => - flow.expressHandler(this.registry, req, res) - ); + server.post(flowPath, (req, res) => flow.expressHandler(req, res)); }); } else { logger.warn('No flows registered in flow server.'); @@ -557,17 +546,17 @@ export function defineFlow< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, config: FlowConfig | string, fn: FlowFn ): CallableFlow { const resolvedConfig: FlowConfig = typeof config === 'string' ? { name: config } : config; - const flow = new Flow(resolvedConfig, fn); - registerFlowAction(flow); - const registry = getRegistryInstance(); + const flow = new Flow(registry, resolvedConfig, fn); + registerFlowAction(registry, flow); const callableFlow: CallableFlow = async (input, opts) => { - return runWithRegistry(registry, () => flow.run(input, opts)); + return flow.run(input, opts); }; callableFlow.flow = flow; return callableFlow; @@ -581,14 +570,14 @@ export function defineStreamingFlow< O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, config: StreamingFlowConfig, fn: FlowFn ): StreamableFlow { - const flow = new Flow(config, fn); - registerFlowAction(flow); - const registry = getRegistryInstance(); + const flow = new Flow(registry, config, fn); + registerFlowAction(registry, flow); const streamableFlow: StreamableFlow = (input, opts) => { - return runWithRegistry(registry, () => flow.stream(input, opts)); + return flow.stream(input, opts); }; streamableFlow.flow = flow; return streamableFlow; @@ -601,8 +590,12 @@ function registerFlowAction< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, ->(flow: Flow): Action { +>( + registry: Registry, + flow: Flow +): Action { return defineAction( + registry, { actionType: 'flow', name: flow.name, diff --git a/js/core/src/reflection.ts b/js/core/src/reflection.ts index 7a2b1bf24..e66c73630 100644 --- a/js/core/src/reflection.ts +++ b/js/core/src/reflection.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import express, { NextFunction, Request, Response } from 'express'; +import express from 'express'; import fs from 'fs/promises'; import getPort, { makeRange } from 'get-port'; import { Server } from 'http'; @@ -23,7 +23,7 @@ import z from 'zod'; import { Status, StatusCodes, runWithStreamingCallback } from './action.js'; import { GENKIT_VERSION } from './index.js'; import { logger } from './logging.js'; -import { Registry, runWithRegistry } from './registry.js'; +import { Registry } from './registry.js'; import { toJsonSchema } from './schema.js'; import { flushTracing, @@ -113,16 +113,6 @@ export class ReflectionServer { next(); }); - server.use((req: Request, res: Response, next: NextFunction) => { - runWithRegistry(this.registry, async () => { - try { - next(); - } catch (err) { - next(err); - } - }); - }); - server.get('/api/__health', async (_, response) => { await this.registry.listActions(); response.status(200).send('OK'); @@ -213,6 +203,7 @@ export class ReflectionServer { return await action(input); } ); + await flushTracing(); response.send({ result, telemetry: traceId diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index afa679ac6..f7cd0f532 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -14,7 +14,6 @@ * limitations under the License. */ -import { AsyncLocalStorage } from 'async_hooks'; import * as z from 'zod'; import { Action } from './action.js'; import { logger } from './logging.js'; @@ -47,17 +46,6 @@ export interface Schema { jsonSchema?: JSONSchema; } -/** - * Looks up a registry key (action type and key) in the registry. - */ -export function lookupAction< - I extends z.ZodTypeAny, - O extends z.ZodTypeAny, - R extends Action, ->(key: string): Promise { - return getRegistryInstance().lookupAction(key); -} - function parsePluginName(registryKey: string) { const tokens = registryKey.split('/'); if (tokens.length === 4) { @@ -66,99 +54,8 @@ function parsePluginName(registryKey: string) { return undefined; } -/** - * Registers an action in the registry. - */ -export function registerAction( - type: ActionType, - action: Action -) { - return getRegistryInstance().registerAction(type, action); -} - type ActionsRecord = Record>; -/** - * Initialize all plugins in the registry. - */ -export async function initializeAllPlugins() { - await getRegistryInstance().initializeAllPlugins(); -} - -/** - * Returns all actions in the registry. - */ -export function listActions(): Promise { - return getRegistryInstance().listActions(); -} - -/** - * Registers a plugin provider. - * @param name The name of the plugin to register. - * @param provider The plugin provider. - */ -export function registerPluginProvider(name: string, provider: PluginProvider) { - return getRegistryInstance().registerPluginProvider(name, provider); -} - -/** - * Looks up a plugin. - * @param name The name of the plugin to lookup. - * @returns The plugin. - */ -export function lookupPlugin(name: string) { - return getRegistryInstance().lookupPlugin(name); -} - -/** - * Initializes a plugin that has already been registered. - * @param name The name of the plugin to initialize. - * @returns The plugin. - */ -export async function initializePlugin(name: string) { - return getRegistryInstance().initializePlugin(name); -} - -/** - * Registers a schema. - * @param name The name of the schema to register. - * @param data The schema to register (either a Zod schema or a JSON schema). - */ -export function registerSchema(name: string, data: Schema) { - return getRegistryInstance().registerSchema(name, data); -} - -/** - * Looks up a schema. - * @param name The name of the schema to lookup. - * @returns The schema. - */ -export function lookupSchema(name: string) { - return getRegistryInstance().lookupSchema(name); -} - -const registryAls = new AsyncLocalStorage(); - -/** - * @returns The active registry instance. - */ -export function getRegistryInstance(): Registry { - const registry = registryAls.getStore(); - if (!registry) { - throw new Error('getRegistryInstance() called before runWithRegistry()'); - } - return registry; -} - -/** - * Runs a function with a specific registry instance. - * @param registry The registry instance to use. - * @param fn The function to run. - */ -export function runWithRegistry(registry: Registry, fn: () => R) { - return registryAls.run(registry, fn); -} - /** * The registry is used to store and lookup actions, trace stores, flow state stores, plugins, and schemas. */ @@ -170,14 +67,6 @@ export class Registry { constructor(public parent?: Registry) {} - /** - * Creates a new registry overlaid onto the currently active registry. - * @returns The new overlaid registry. - */ - static withCurrent() { - return new Registry(getRegistryInstance()); - } - /** * Creates a new registry overlaid onto the provided registry. * @param parent The parent registry. diff --git a/js/core/src/schema.ts b/js/core/src/schema.ts index 16a45160d..a53da8acb 100644 --- a/js/core/src/schema.ts +++ b/js/core/src/schema.ts @@ -19,7 +19,7 @@ import addFormats from 'ajv-formats'; import { z } from 'zod'; import zodToJsonSchema from 'zod-to-json-schema'; import { GenkitError } from './error.js'; -import { registerSchema } from './registry.js'; +import { Registry } from './registry.js'; const ajv = new Ajv(); addFormats(ajv); @@ -112,14 +112,19 @@ export function parseSchema( } export function defineSchema( + registry: Registry, name: string, schema: T ): T { - registerSchema(name, { schema }); + registry.registerSchema(name, { schema }); return schema; } -export function defineJsonSchema(name: string, jsonSchema: JSONSchema) { - registerSchema(name, { jsonSchema }); +export function defineJsonSchema( + registry: Registry, + name: string, + jsonSchema: JSONSchema +) { + registry.registerSchema(name, { jsonSchema }); return jsonSchema; } diff --git a/js/core/src/tracing/exporter.ts b/js/core/src/tracing/exporter.ts index 4693c5843..a267c5360 100644 --- a/js/core/src/tracing/exporter.ts +++ b/js/core/src/tracing/exporter.ts @@ -123,7 +123,7 @@ export class TraceServerExporter implements SpanExporter { await this.save(traceId, traces[traceId]); } catch (e) { error = true; - logger.error('Failed to save trace ${traceId}', e); + logger.error(`Failed to save trace ${traceId}`, e); } if (done) { return done({ diff --git a/js/core/tests/flow_test.ts b/js/core/tests/flow_test.ts index 7d5b74646..cce14e2ee 100644 --- a/js/core/tests/flow_test.ts +++ b/js/core/tests/flow_test.ts @@ -18,10 +18,11 @@ import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; import { defineFlow, defineStreamingFlow } from '../src/flow.js'; import { z } from '../src/index.js'; -import { Registry, runWithRegistry } from '../src/registry.js'; +import { Registry } from '../src/registry.js'; -function createTestFlow() { +function createTestFlow(registry: Registry) { return defineFlow( + registry, { name: 'testFlow', inputSchema: z.string(), @@ -33,8 +34,9 @@ function createTestFlow() { ); } -function createTestStreamingFlow() { +function createTestStreamingFlow(registry: Registry) { return defineStreamingFlow( + registry, { name: 'testFlow', inputSchema: z.number(), @@ -63,7 +65,7 @@ describe('flow', () => { describe('runFlow', () => { it('should run the flow', async () => { - const testFlow = runWithRegistry(registry, createTestFlow); + const testFlow = createTestFlow(registry); const result = await testFlow('foo'); @@ -71,10 +73,8 @@ describe('flow', () => { }); it('should run simple sync flow', async () => { - const testFlow = runWithRegistry(registry, () => { - return defineFlow('testFlow', (input) => { - return `bar ${input}`; - }); + const testFlow = defineFlow(registry, 'testFlow', (input) => { + return `bar ${input}`; }); const result = await testFlow('foo'); @@ -83,17 +83,16 @@ describe('flow', () => { }); it('should rethrow the error', async () => { - const testFlow = runWithRegistry(registry, () => - defineFlow( - { - name: 'throwing', - inputSchema: z.string(), - outputSchema: z.string(), - }, - async (input) => { - throw new Error(`bad happened: ${input}`); - } - ) + const testFlow = defineFlow( + registry, + { + name: 'throwing', + inputSchema: z.string(), + outputSchema: z.string(), + }, + async (input) => { + throw new Error(`bad happened: ${input}`); + } ); await assert.rejects(() => testFlow('foo'), { @@ -103,17 +102,16 @@ describe('flow', () => { }); it('should validate input', async () => { - const testFlow = runWithRegistry(registry, () => - defineFlow( - { - name: 'validating', - inputSchema: z.object({ foo: z.string(), bar: z.number() }), - outputSchema: z.string(), - }, - async (input) => { - return `ok ${input}`; - } - ) + const testFlow = defineFlow( + registry, + { + name: 'validating', + inputSchema: z.object({ foo: z.string(), bar: z.number() }), + outputSchema: z.string(), + }, + async (input) => { + return `ok ${input}`; + } ); await assert.rejects( @@ -132,7 +130,7 @@ describe('flow', () => { describe('streamFlow', () => { it('should run the flow', async () => { - const testFlow = runWithRegistry(registry, createTestStreamingFlow); + const testFlow = createTestStreamingFlow(registry); const response = testFlow(3); @@ -146,16 +144,15 @@ describe('flow', () => { }); it('should rethrow the error', async () => { - const testFlow = runWithRegistry(registry, () => - defineStreamingFlow( - { - name: 'throwing', - inputSchema: z.string(), - }, - async (input) => { - throw new Error(`stream bad happened: ${input}`); - } - ) + const testFlow = defineStreamingFlow( + registry, + { + name: 'throwing', + inputSchema: z.string(), + }, + async (input) => { + throw new Error(`stream bad happened: ${input}`); + } ); const response = testFlow('foo'); diff --git a/js/core/tests/registry_test.ts b/js/core/tests/registry_test.ts index 9542cf779..d54fdd415 100644 --- a/js/core/tests/registry_test.ts +++ b/js/core/tests/registry_test.ts @@ -17,175 +17,7 @@ import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; import { action } from '../src/action.js'; -import { - Registry, - listActions, - lookupAction, - registerAction, - registerPluginProvider, - runWithRegistry, -} from '../src/registry.js'; - -describe('global registry', () => { - let registry: Registry; - - beforeEach(() => { - registry = new Registry(); - }); - - describe('listActions', () => { - it('returns all registered actions', async () => { - await runWithRegistry(registry, async () => { - const fooSomethingAction = action( - { name: 'foo_something' }, - async () => null - ); - registerAction('model', fooSomethingAction); - const barSomethingAction = action( - { name: 'bar_something' }, - async () => null - ); - registerAction('model', barSomethingAction); - - assert.deepEqual(await listActions(), { - '/model/foo_something': fooSomethingAction, - '/model/bar_something': barSomethingAction, - }); - }); - }); - - it('returns all registered actions by plugins', async () => { - await runWithRegistry(registry, async () => { - registerPluginProvider('foo', { - name: 'foo', - async initializer() { - registerAction('model', fooSomethingAction); - return {}; - }, - }); - const fooSomethingAction = action( - { - name: { - pluginId: 'foo', - actionId: 'something', - }, - }, - async () => null - ); - registerPluginProvider('bar', { - name: 'bar', - async initializer() { - registerAction('model', barSomethingAction); - return {}; - }, - }); - const barSomethingAction = action( - { - name: { - pluginId: 'bar', - actionId: 'something', - }, - }, - async () => null - ); - - assert.deepEqual(await listActions(), { - '/model/foo/something': fooSomethingAction, - '/model/bar/something': barSomethingAction, - }); - }); - }); - }); - - describe('lookupAction', () => { - it('initializes plugin for action first', async () => { - await runWithRegistry(registry, async () => { - let fooInitialized = false; - registerPluginProvider('foo', { - name: 'foo', - async initializer() { - fooInitialized = true; - return {}; - }, - }); - let barInitialized = false; - registerPluginProvider('bar', { - name: 'bar', - async initializer() { - barInitialized = true; - return {}; - }, - }); - - await lookupAction('/model/foo/something'); - - assert.strictEqual(fooInitialized, true); - assert.strictEqual(barInitialized, false); - - await lookupAction('/model/bar/something'); - - assert.strictEqual(fooInitialized, true); - assert.strictEqual(barInitialized, true); - }); - }); - }); - - it('returns registered action', async () => { - await runWithRegistry(registry, async () => { - const fooSomethingAction = action( - { name: 'foo_something' }, - async () => null - ); - registerAction('model', fooSomethingAction); - const barSomethingAction = action( - { name: 'bar_something' }, - async () => null - ); - registerAction('model', barSomethingAction); - - assert.strictEqual( - await lookupAction('/model/foo_something'), - fooSomethingAction - ); - assert.strictEqual( - await lookupAction('/model/bar_something'), - barSomethingAction - ); - }); - }); - - it('returns action registered by plugin', async () => { - await runWithRegistry(registry, async () => { - registerPluginProvider('foo', { - name: 'foo', - async initializer() { - registerAction('model', somethingAction); - return {}; - }, - }); - const somethingAction = action( - { - name: { - pluginId: 'foo', - actionId: 'something', - }, - }, - async () => null - ); - - assert.strictEqual( - await lookupAction('/model/foo/something'), - somethingAction - ); - }); - }); - - it('returns undefined for unknown action', async () => { - await runWithRegistry(registry, async () => { - assert.strictEqual(await lookupAction('/model/foo/something'), undefined); - }); - }); -}); +import { Registry } from '../src/registry.js'; describe('registry class', () => { var registry: Registry; diff --git a/js/genkit/package.json b/js/genkit/package.json index 80e4b1026..c5d18f544 100644 --- a/js/genkit/package.json +++ b/js/genkit/package.json @@ -7,7 +7,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "main": "./lib/cjs/index.js", "scripts": { @@ -16,7 +16,8 @@ "build:clean": "rimraf ./lib", "build": "npm-run-all build:clean check compile", "build:watch": "tsup-node --watch", - "test": "node --import tsx --test tests/*_test.ts" + "test": "node --import tsx --test tests/*_test.ts", + "test:watch": "node --watch --import tsx --test tests/*_test.ts" }, "repository": { "type": "git", @@ -38,7 +39,8 @@ "tsup": "^8.0.2", "typescript": "^4.9.0", "tsx": "^4.7.1", - "@types/body-parser": "^1.19.5" + "@types/body-parser": "^1.19.5", + "uuid": "^10.0.0" }, "files": [ "genkit-ui", @@ -131,6 +133,12 @@ "require": "./lib/tool.js", "import": "./lib/tool.mjs", "default": "./lib/tool.js" + }, + "./plugin": { + "types": "./lib/plugin.d.ts", + "require": "./lib/plugin.js", + "import": "./lib/plugin.mjs", + "default": "./lib/plugin.js" } }, "typesVersions": { @@ -179,6 +187,9 @@ ], "tool": [ "lib/tool" + ], + "plugin": [ + "lib/plugin" ] } } diff --git a/js/genkit/src/chat.ts b/js/genkit/src/chat.ts new file mode 100644 index 000000000..7587c941a --- /dev/null +++ b/js/genkit/src/chat.ts @@ -0,0 +1,183 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + ExecutablePrompt, + GenerateOptions, + GenerateResponse, + GenerateStreamOptions, + GenerateStreamResponse, + GenerationCommonConfigSchema, + MessageData, + Part, +} from '@genkit-ai/ai'; +import { z } from '@genkit-ai/core'; +import { Genkit } from './genkit'; +import { Session, SessionStore } from './session'; + +export const MAIN_THREAD = 'main'; + +export type BaseGenerateOptions = Omit; + +export interface PromptRenderOptions { + prompt: ExecutablePrompt; + input?: I; +} + +export type ChatOptions< + I = undefined, + S extends z.ZodTypeAny = z.ZodTypeAny, +> = (PromptRenderOptions | BaseGenerateOptions) & { + store?: SessionStore; + sessionId?: string; +}; + +/** + * Chat encapsulates a statful execution environment for chat. + * Chat session executed within a session in this environment will have acesss to + * session convesation history. + * + * ```ts + * const ai = genkit({...}); + * const chat = ai.chat(); // create a Chat + * let response = await chat.send('hi, my name is Genkit'); + * response = await chat.send('what is my name?'); // chat history aware conversation + * ``` + */ +export class Chat { + readonly requestBase?: Promise; + readonly sessionId: string; + readonly schema?: S; + private _messages?: MessageData[]; + private threadName: string; + + constructor( + readonly session: Session, + requestBase: Promise, + options: { + id: string; + thread: string; + messages?: MessageData[]; + } + ) { + this.sessionId = options.id; + this.threadName = options.thread; + this.requestBase = requestBase?.then((rb) => { + const requestBase = { ...rb }; + // this is handling dotprompt render case + if (requestBase && requestBase['prompt']) { + const basePrompt = requestBase['prompt'] as string | Part | Part[]; + let promptMessage: MessageData; + if (typeof basePrompt === 'string') { + promptMessage = { + role: 'user', + content: [{ text: basePrompt }], + }; + } else if (Array.isArray(basePrompt)) { + promptMessage = { + role: 'user', + content: basePrompt, + }; + } else { + promptMessage = { + role: 'user', + content: [basePrompt], + }; + } + requestBase.messages = [...(requestBase.messages ?? []), promptMessage]; + } + requestBase.messages = [ + ...(options.messages ?? []), + ...(requestBase.messages ?? []), + ]; + this._messages = requestBase.messages; + return requestBase; + }); + this._messages = options.messages; + } + + async send< + O extends z.ZodTypeAny = z.ZodTypeAny, + CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, + >( + options: string | Part[] | GenerateOptions + ): Promise>> { + // string + if (typeof options === 'string') { + options = { + prompt: options, + } as GenerateOptions; + } + // Part[] + if (Array.isArray(options)) { + options = { + prompt: options, + } as GenerateOptions; + } + const response = await this.genkit.generate({ + ...(await this.requestBase), + messages: this.messages, + ...options, + }); + await this.updateMessages(response.messages); + return response; + } + + async sendStream< + O extends z.ZodTypeAny = z.ZodTypeAny, + CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, + >( + options: string | Part[] | GenerateStreamOptions + ): Promise>> { + // string + if (typeof options === 'string') { + options = { + prompt: options, + } as GenerateOptions; + } + // Part[] + if (Array.isArray(options)) { + options = { + prompt: options, + } as GenerateOptions; + } + const { response, stream } = await this.genkit.generateStream({ + ...(await this.requestBase), + messages: this.messages, + ...options, + }); + + return { + response: response.finally(async () => { + this.updateMessages((await response).messages); + }), + stream, + }; + } + + private get genkit(): Genkit { + return this.session.genkit; + } + + get messages(): MessageData[] { + return this._messages ?? []; + } + + async updateMessages(messages: MessageData[]): Promise { + this._messages = messages; + await this.session.updateMessages(this.threadName, messages); + } +} diff --git a/js/genkit/src/embedder.ts b/js/genkit/src/embedder.ts index 9f6ac86c2..620a1e7db 100644 --- a/js/genkit/src/embedder.ts +++ b/js/genkit/src/embedder.ts @@ -14,4 +14,15 @@ * limitations under the License. */ -export * from '@genkit-ai/ai/embedder'; +export { + EmbedderInfoSchema, + EmbeddingSchema, + embedderRef, + type EmbedderAction, + type EmbedderArgument, + type EmbedderInfo, + type EmbedderParams, + type EmbedderReference, + type Embedding, + type EmbeddingBatch, +} from '@genkit-ai/ai/embedder'; diff --git a/js/genkit/src/evaluator.ts b/js/genkit/src/evaluator.ts index 2c4523124..fc905ae01 100644 --- a/js/genkit/src/evaluator.ts +++ b/js/genkit/src/evaluator.ts @@ -14,4 +14,23 @@ * limitations under the License. */ -export * from '@genkit-ai/ai/evaluator'; +export { + BaseDataPointSchema, + BaseEvalDataPointSchema, + EvalResponseSchema, + EvalResponsesSchema, + EvaluatorInfoSchema, + ScoreSchema, + evaluatorRef, + type BaseDataPoint, + type BaseEvalDataPoint, + type Dataset, + type EvalResponse, + type EvalResponses, + type EvaluatorAction, + type EvaluatorArgument, + type EvaluatorInfo, + type EvaluatorParams, + type EvaluatorReference, + type Score, +} from '@genkit-ai/ai/evaluator'; diff --git a/js/genkit/src/extract.ts b/js/genkit/src/extract.ts index 4d687bcb4..1dc634e13 100644 --- a/js/genkit/src/extract.ts +++ b/js/genkit/src/extract.ts @@ -14,4 +14,4 @@ * limitations under the License. */ -export * from '@genkit-ai/ai/extract'; +export { extractJson, parsePartialJson } from '@genkit-ai/ai/extract'; diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 012e42b71..c9ff26a75 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -20,11 +20,13 @@ import { defineTool, Document, embed, + EmbedderInfo, EmbedderParams, Embedding, EvalResponses, evaluate, EvaluatorParams, + ExecutablePrompt, generate, GenerateOptions, GenerateRequest, @@ -34,20 +36,58 @@ import { GenerateStreamOptions, GenerateStreamResponse, GenerationCommonConfigSchema, - index, IndexerParams, ModelArgument, ModelReference, + Part, PromptAction, PromptFn, + PromptGenerateOptions, RankedDocument, rerank, RerankerParams, retrieve, + RetrieverAction, + RetrieverInfo, RetrieverParams, ToolAction, ToolConfig, } from '@genkit-ai/ai'; +import { + defineEmbedder, + EmbedderAction, + EmbedderArgument, + EmbedderFn, + EmbeddingBatch, + embedMany, +} from '@genkit-ai/ai/embedder'; +import { + defineEvaluator, + EvaluatorAction, + EvaluatorFn, +} from '@genkit-ai/ai/evaluator'; +import { + defineModel, + DefineModelOptions, + GenerateResponseChunkData, + ModelAction, +} from '@genkit-ai/ai/model'; +import { + defineReranker, + RerankerFn, + RerankerInfo, +} from '@genkit-ai/ai/reranker'; +import { + defineIndexer, + defineRetriever, + defineSimpleRetriever, + DocumentData, + index, + IndexerAction, + IndexerFn, + RetrieverFn, + SimpleRetrieverOptions, +} from '@genkit-ai/ai/retriever'; import { CallableFlow, defineFlow, @@ -61,7 +101,6 @@ import { FlowServerOptions, isDevEnv, JSONSchema, - PluginProvider, ReflectionServer, StreamableFlow, StreamingCallback, @@ -70,26 +109,31 @@ import { } from '@genkit-ai/core'; import { defineDotprompt, - Dotprompt, - prompt, - PromptGenerateOptions, + defineHelper, + definePartial, + loadPromptFolder, PromptMetadata, } from '@genkit-ai/dotprompt'; +import { v4 as uuidv4 } from 'uuid'; +import { Chat, ChatOptions } from './chat.js'; +import { BaseEvalDataPointSchema } from './evaluator.js'; import { logger } from './logging.js'; +import { GenkitPlugin, genkitPlugin } from './plugin.js'; +import { Registry } from './registry.js'; import { - defineModel, - DefineModelOptions, - GenerateResponseChunkData, - ModelAction, -} from './model.js'; -import { lookupAction, Registry, runWithRegistry } from './registry.js'; + getCurrentSession, + Session, + SessionData, + SessionError, + SessionOptions, +} from './session.js'; /** * Options for initializing Genkit. */ export interface GenkitOptions { /** List of plugins to load. */ - plugins?: PluginProvider[]; + plugins?: GenkitPlugin[]; /** Directory where dotprompts are stored. */ promptDir?: string; /** Default model to use if no model is specified. */ @@ -99,66 +143,6 @@ export interface GenkitOptions { flowServer?: FlowServerOptions | boolean; } -export interface ExecutablePrompt< - I extends z.ZodTypeAny = z.ZodTypeAny, - O extends z.ZodTypeAny = z.ZodTypeAny, - CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, -> { - /** - * Generates a response by rendering the prompt template with given user input and then calling the model. - * - * @param input Prompt inputs. - * @param opt Options for the prompt template, including user input variables and custom model configuration options. - * @returns the model response as a promise of `GenerateStreamResponse`. - */ - ( - input?: z.infer, - opts?: z.infer - ): Promise>>; - - /** - * Generates a streaming response by rendering the prompt template with given user input and then calling the model. - * - * @param input Prompt inputs. - * @param opt Options for the prompt template, including user input variables and custom model configuration options. - * @returns the model response as a promise of `GenerateStreamResponse`. - */ - stream( - input?: z.infer, - opts?: z.infer - ): Promise>>; - - /** - * Generates a response by rendering the prompt template with given user input and additional generate options and then calling the model. - * - * @param opt Options for the prompt template, including user input variables and custom model configuration options. - * @returns the model response as a promise of `GenerateResponse`. - */ - generate( - opt: PromptGenerateOptions, CustomOptions> - ): Promise>>; - - /** - * Generates a streaming response by rendering the prompt template with given user input and additional generate options and then calling the model. - * - * @param opt Options for the prompt template, including user input variables and custom model configuration options. - * @returns the model response as a promise of `GenerateStreamResponse`. - */ - generateStream( - opt: PromptGenerateOptions, CustomOptions> - ): Promise>>; - - /** - * Renders the prompt template based on user input. - * - * @param opt Options for the prompt template, including user input variables and custom model configuration options. - * @returns a `GenerateOptions` object to be used with the `generate()` function from @genkit-ai/ai. - */ - render( - opt: PromptGenerateOptions, CustomOptions> - ): Promise>; -} - /** * `Genkit` encapsulates a single Genkit instance including the {@link Registry}, {@link ReflectionServer}, {@link FlowServer}, and configuration. * @@ -209,7 +193,7 @@ export class Genkit { I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, >(config: FlowConfig | string, fn: FlowFn): CallableFlow { - const flow = runWithRegistry(this.registry, () => defineFlow(config, fn)); + const flow = defineFlow(this.registry, config, fn); this.registeredFlows.push(flow.flow); return flow; } @@ -227,9 +211,7 @@ export class Genkit { config: StreamingFlowConfig, fn: FlowFn ): StreamableFlow { - const flow = runWithRegistry(this.registry, () => - defineStreamingFlow(config, fn) - ); + const flow = defineStreamingFlow(this.registry, config, fn); this.registeredFlows.push(flow.flow); return flow; } @@ -243,7 +225,7 @@ export class Genkit { config: ToolConfig, fn: (input: z.infer) => Promise> ): ToolAction { - return runWithRegistry(this.registry, () => defineTool(config, fn)); + return defineTool(this.registry, config, fn); } /** @@ -252,7 +234,7 @@ export class Genkit { * Defined schemas can be referenced by `name` in prompts in place of inline schemas. */ defineSchema(name: string, schema: T): T { - return runWithRegistry(this.registry, () => defineSchema(name, schema)); + return defineSchema(this.registry, name, schema); } /** @@ -261,9 +243,7 @@ export class Genkit { * Defined schemas can be referenced by `name` in prompts in place of inline schemas. */ defineJsonSchema(name: string, jsonSchema: JSONSchema) { - return runWithRegistry(this.registry, () => - defineJsonSchema(name, jsonSchema) - ); + return defineJsonSchema(this.registry, name, jsonSchema); } /** @@ -276,7 +256,7 @@ export class Genkit { streamingCallback?: StreamingCallback ) => Promise ): ModelAction { - return runWithRegistry(this.registry, () => defineModel(options, runner)); + return defineModel(this.registry, options, runner); } /** @@ -284,41 +264,42 @@ export class Genkit { * * @todo TODO: Show an example of a name and variant. */ - prompt< + async prompt< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, >( name: string, options?: { variant?: string } - ): Promise> { - return runWithRegistry(this.registry, async () => { - const action = (await lookupAction(`/prompt/${name}`)) as PromptAction; - if ( - action.__action?.metadata?.prompt && - Object.keys(action.__action.metadata.prompt).length > 0 - ) { - const p = await prompt(name, options); - return this.wrapDotpromptInExecutablePrompt(p, {}) as ExecutablePrompt< - I, - O, - CustomOptions - >; - } else { - return this.wrapPromptActionInExecutablePrompt( - action, - {} - ) as ExecutablePrompt; - } - }); + ): Promise, O, CustomOptions>> { + const action = (await this.registry.lookupAction( + `/prompt/${name}` + )) as PromptAction; + return this.wrapPromptActionInExecutablePrompt( + action, + {} + ) as ExecutablePrompt; } /** * Defines and registers a dotprompt. * - * This replaces defining and importing a .dotprompt file. + * This is an alternative to defining and importing a .prompt file. * - * @todo TODO: Improve this documentation (show an example, etc). + * ```ts + * const hi = ai.definePrompt( + * { + * name: 'hi', + * input: { + * schema: z.object({ + * name: z.string(), + * }), + * }, + * }, + * 'hi {{ name }}' + * ); + * const { text } = await hi({ name: 'Genkit' }); + * ``` */ definePrompt< I extends z.ZodTypeAny = z.ZodTypeAny, @@ -330,8 +311,33 @@ export class Genkit { name: string; }, template: string - ): ExecutablePrompt; + ): ExecutablePrompt, O, CustomOptions>; + /** + * Defines and registers a function-based prompt. + * + * ```ts + * const hi = ai.definePrompt( + * { + * name: 'hi', + * input: { + * schema: z.object({ + * name: z.string(), + * }), + * }, + * config: { + * temperature: 1, + * }, + * }, + * async (input) => { + * return { + * messages: [ { role: 'user', content: [{ text: `hi ${input.name}` }] } ], + * }; + * } + * ); + * const { text } = await hi({ name: 'Genkit' }); + * ``` + */ definePrompt< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, @@ -342,7 +348,7 @@ export class Genkit { name: string; }, fn: PromptFn - ): ExecutablePrompt; + ): ExecutablePrompt, O, CustomOptions>; definePrompt< I extends z.ZodTypeAny = z.ZodTypeAny, @@ -354,106 +360,35 @@ export class Genkit { name: string; }, templateOrFn: string | PromptFn - ): ExecutablePrompt { + ): ExecutablePrompt, O, CustomOptions> { if (!options.name) { throw new Error('options.name is required'); } - return runWithRegistry(this.registry, () => { - if (!options.name) { - throw new Error('options.name is required'); - } - if (typeof templateOrFn === 'string') { - const dotprompt = defineDotprompt(options, templateOrFn as string); - return this.wrapDotpromptInExecutablePrompt(dotprompt, options); - } else { - const p = definePrompt( - { - name: options.name!, - inputJsonSchema: options.input?.jsonSchema, - inputSchema: options.input?.schema, - }, - templateOrFn as PromptFn - ); - return this.wrapPromptActionInExecutablePrompt(p, options); - } - }); - } - - private wrapDotpromptInExecutablePrompt< - I extends z.ZodTypeAny = z.ZodTypeAny, - O extends z.ZodTypeAny = z.ZodTypeAny, - CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, - >( - dotprompt: Dotprompt>, - options: PromptMetadata - ): ExecutablePrompt { - const executablePrompt = ( - input?: z.infer, - opts?: z.infer - ): Promise> => { - return runWithRegistry(this.registry, async () => { - const model = await this.resolveModel(options.model); - return dotprompt.generate({ - model, - input, - config: opts, - }); - }); - }; - (executablePrompt as ExecutablePrompt).stream = ( - input?: z.infer, - opts?: z.infer - ): Promise> => { - return runWithRegistry(this.registry, async () => { - const model = await this.resolveModel(options.model); - return dotprompt.generateStream({ - model, - input, - config: opts, - }) as Promise>; - }); - }; - (executablePrompt as ExecutablePrompt).generate = ( - opt: PromptGenerateOptions - ): Promise> => { - return runWithRegistry(this.registry, async () => { - const model = !opt.model - ? await this.resolveModel(options.model) - : undefined; - return dotprompt.generate({ - model, - ...opt, - }); - }); - }; - (executablePrompt as ExecutablePrompt).generateStream = - ( - opt: PromptGenerateOptions - ): Promise> => { - return runWithRegistry(this.registry, async () => { - const model = !opt.model - ? await this.resolveModel(options.model) - : undefined; - return dotprompt.generateStream({ - model, - ...opt, - }) as Promise>; - }); - }; - (executablePrompt as ExecutablePrompt).render = < - Out extends O, - >( - opt: PromptGenerateOptions - ): Promise> => { - return runWithRegistry( + if (!options.name) { + throw new Error('options.name is required'); + } + if (typeof templateOrFn === 'string') { + const dotprompt = defineDotprompt( this.registry, - async () => - dotprompt.render({ - ...opt, - }) as GenerateOptions + options, + templateOrFn as string ); - }; - return executablePrompt as ExecutablePrompt; + return this.wrapPromptActionInExecutablePrompt( + dotprompt.promptAction! as PromptAction, + options + ); + } else { + const p = definePrompt( + this.registry, + { + name: options.name!, + inputJsonSchema: options.input?.jsonSchema, + inputSchema: options.input?.schema, + }, + templateOrFn as PromptFn + ); + return this.wrapPromptActionInExecutablePrompt(p, options); + } } private wrapPromptActionInExecutablePrompt< @@ -464,145 +399,212 @@ export class Genkit { p: PromptAction, options: PromptMetadata ): ExecutablePrompt { - const executablePrompt = ( + const executablePrompt = async ( input?: z.infer, - opts?: z.infer + opts?: PromptGenerateOptions ): Promise => { - return runWithRegistry(this.registry, async () => { - const model = await this.resolveModel(options.model); - const promptResult = await p(input); - return this.generate({ - model, - messages: promptResult.messages, - context: promptResult.context, - tools: promptResult.tools, - output: { - format: promptResult.output?.format, - jsonSchema: promptResult.output?.schema, - }, - config: { - ...options.config, - ...opts, - ...promptResult.config, - }, - }); + const renderedOpts = await ( + executablePrompt as ExecutablePrompt + ).render({ + ...opts, + input, }); + return this.generate(renderedOpts); }; - (executablePrompt as ExecutablePrompt).stream = ( + (executablePrompt as ExecutablePrompt).stream = async ( input?: z.infer, opts?: z.infer ): Promise> => { - return runWithRegistry(this.registry, async () => { - const model = await this.resolveModel(options.model); - const promptResult = await p(input); - return this.generateStream({ - model, - messages: promptResult.messages, - context: promptResult.context, - tools: promptResult.tools, - output: { - format: promptResult.output?.format, - jsonSchema: promptResult.output?.schema, - }, - config: { - ...options.config, - ...promptResult.config, - ...opts, - }, - }); - }); - }; - (executablePrompt as ExecutablePrompt).generate = ( - opt: PromptGenerateOptions - ): Promise> => { - return runWithRegistry(this.registry, async () => { - const model = !opt.model - ? await this.resolveModel(options.model) - : undefined; - const promptResult = await p(opt.input); - return this.generate({ - model, - messages: promptResult.messages, - context: promptResult.context, - tools: promptResult.tools, - output: { - format: promptResult.output?.format, - jsonSchema: promptResult.output?.schema, - }, - ...opt, - config: { - ...options.config, - ...promptResult.config, - ...opt.config, - }, - }); + const renderedOpts = await ( + executablePrompt as ExecutablePrompt + ).render({ + ...opts, + input, }); + return this.generateStream(renderedOpts); }; + (executablePrompt as ExecutablePrompt).generate = + async ( + opt: PromptGenerateOptions + ): Promise> => { + const renderedOpts = await ( + executablePrompt as ExecutablePrompt + ).render(opt); + return this.generate(renderedOpts); + }; (executablePrompt as ExecutablePrompt).generateStream = - ( + async ( opt: PromptGenerateOptions ): Promise> => { - return runWithRegistry(this.registry, async () => { - const model = !opt.model - ? await this.resolveModel(options.model) - : undefined; - const promptResult = await p(opt.input); - return this.generateStream({ - model, - messages: promptResult.messages, - context: promptResult.context, - tools: promptResult.tools, - output: { - format: promptResult.output?.format, - jsonSchema: promptResult.output?.schema, - } as any /* FIXME - schema type inference is borken */, - ...opt, - config: { - ...options.config, - ...promptResult.config, - ...opt.config, - }, - }); - }); + const renderedOpts = await ( + executablePrompt as ExecutablePrompt + ).render(opt); + return this.generateStream(renderedOpts); }; - (executablePrompt as ExecutablePrompt).render = < + (executablePrompt as ExecutablePrompt).render = async < Out extends O, >( opt: PromptGenerateOptions ): Promise> => { - return runWithRegistry(this.registry, async () => { - const model = !opt.model - ? await this.resolveModel(options.model) - : undefined; - const promptResult = await p(opt.input); - return { - model, - messages: promptResult.messages, - context: promptResult.context, - tools: promptResult.tools, - output: { - format: promptResult.output?.format, - jsonSchema: promptResult.output?.schema, - }, - ...opt, - config: { - ...options.config, - ...promptResult.config, - ...opt.config, - }, - } as GenerateOptions; - }); + let model: ModelAction | undefined; + try { + model = await this.resolveModel(opt?.model ?? options.model); + } catch (e) { + // ignore, no model on a render is OK. + } + + const promptResult = await p(opt.input); + const resultOptions = { + messages: promptResult.messages, + docs: promptResult.docs, + tools: promptResult.tools, + output: { + format: promptResult.output?.format, + jsonSchema: promptResult.output?.schema, + }, + config: { + ...options.config, + ...promptResult.config, + ...opt.config, + }, + model, + } as GenerateOptions; + delete (resultOptions as PromptGenerateOptions).input; + return resultOptions; }; + (executablePrompt as ExecutablePrompt).asTool = + (): ToolAction => { + return p as unknown as ToolAction; + }; return executablePrompt as ExecutablePrompt; } + /** + * Creates a retriever action for the provided {@link RetrieverFn} implementation. + */ + defineRetriever( + options: { + name: string; + configSchema?: OptionsType; + info?: RetrieverInfo; + }, + runner: RetrieverFn + ): RetrieverAction { + return defineRetriever(this.registry, options, runner); + } + + /** + * defineSimpleRetriever makes it easy to map existing data into documents that + * can be used for prompt augmentation. + * + * @param options Configuration options for the retriever. + * @param handler A function that queries a datastore and returns items from which to extract documents. + * @returns A Genkit retriever. + */ + defineSimpleRetriever( + options: SimpleRetrieverOptions, + handler: (query: Document, config: z.infer) => Promise + ): RetrieverAction { + return defineSimpleRetriever(this.registry, options, handler); + } + + /** + * Creates an indexer action for the provided {@link IndexerFn} implementation. + */ + defineIndexer( + options: { + name: string; + embedderInfo?: EmbedderInfo; + configSchema?: IndexerOptions; + }, + runner: IndexerFn + ): IndexerAction { + return defineIndexer(this.registry, options, runner); + } + + /** + * Creates evaluator action for the provided {@link EvaluatorFn} implementation. + */ + defineEvaluator< + DataPoint extends typeof BaseDataPointSchema = typeof BaseDataPointSchema, + EvalDataPoint extends + typeof BaseEvalDataPointSchema = typeof BaseEvalDataPointSchema, + EvaluatorOptions extends z.ZodTypeAny = z.ZodTypeAny, + >( + options: { + name: string; + displayName: string; + definition: string; + dataPointType?: DataPoint; + configSchema?: EvaluatorOptions; + isBilled?: boolean; + }, + runner: EvaluatorFn + ): EvaluatorAction { + return defineEvaluator(this.registry, options, runner); + } + + /** + * Creates embedder model for the provided {@link EmbedderFn} model implementation. + */ + defineEmbedder( + options: { + name: string; + configSchema?: ConfigSchema; + info?: EmbedderInfo; + }, + runner: EmbedderFn + ): EmbedderAction { + return defineEmbedder(this.registry, options, runner); + } + + /** + * create a handlebards helper (https://handlebarsjs.com/guide/block-helpers.html) to be used in dotpormpt templates. + */ + defineHelper(name: string, fn: Handlebars.HelperDelegate) { + return defineHelper(name, fn); + } + + /** + * Creates a handlebars partial (https://handlebarsjs.com/guide/partials.html) to be used in dotpormpt templates. + */ + definePartial(name: string, source: string) { + return definePartial(name, source); + } + + /** + * Creates a reranker action for the provided {@link RerankerFn} implementation. + */ + defineReranker( + options: { + name: string; + configSchema?: OptionsType; + info?: RerankerInfo; + }, + runner: RerankerFn + ) { + return defineReranker(this.registry, options, runner); + } + /** * Embeds the given `content` using the specified `embedder`. */ embed( params: EmbedderParams ): Promise { - return runWithRegistry(this.registry, () => embed(params)); + return embed(this.registry, params); + } + + /** + * A veneer for interacting with embedder models in bulk. + */ + embedMany(params: { + embedder: EmbedderArgument; + content: string[] | DocumentData[]; + metadata?: Record; + options?: z.infer; + }): Promise { + return embedMany(this.registry, params); } /** @@ -612,7 +614,7 @@ export class Genkit { DataPoint extends typeof BaseDataPointSchema = typeof BaseDataPointSchema, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, >(params: EvaluatorParams): Promise { - return runWithRegistry(this.registry, () => evaluate(params)); + return evaluate(this.registry, params); } /** @@ -621,7 +623,7 @@ export class Genkit { rerank( params: RerankerParams ): Promise> { - return runWithRegistry(this.registry, () => rerank(params)); + return rerank(this.registry, params); } /** @@ -630,7 +632,7 @@ export class Genkit { index( params: IndexerParams ): Promise { - return runWithRegistry(this.registry, () => index(params)); + return index(this.registry, params); } /** @@ -639,9 +641,44 @@ export class Genkit { retrieve( params: RetrieverParams ): Promise> { - return runWithRegistry(this.registry, () => retrieve(params)); + return retrieve(this.registry, params); } + /** + * Make a generate call to the default model with a simple text prompt. + * + * ```ts + * const ai = genkit({ + * plugins: [googleAI()], + * model: gemini15Flash, // default model + * }) + * + * const { text } = await ai.generate('hi'); + * ``` + */ + generate( + strPrompt: string + ): Promise>>; + + /** + * Make a generate call to the default model with a multipart request. + * + * ```ts + * const ai = genkit({ + * plugins: [googleAI()], + * model: gemini15Flash, // default model + * }) + * + * const { text } = await ai.generate([ + * { media: {url: 'http://....'} }, + * { text: 'describe this image' } + * ]); + * ``` + */ + generate( + parts: Part[] + ): Promise>>; + /** * Generate calls a generative model based on the provided prompt and configuration. If * `messages` is provided, the generation will include a conversation history in its @@ -649,12 +686,40 @@ export class Genkit { * tool calls returned from the model unless `returnToolRequests` is set to `true`. * * See {@link GenerateOptions} for detailed information about available options. + * + * ```ts + * const ai = genkit({ + * plugins: [googleAI()], + * }) + * + * const { text } = await ai.generate({ + * system: 'talk like a pirate', + * prompt: [ + * { media: { url: 'http://....' } }, + * { text: 'describe this image' } + * ], + * messages: conversationHistory, + * tools: [ userInfoLookup ], + * model: gemini15Flash, + * }); + * ``` */ + generate< + O extends z.ZodTypeAny = z.ZodTypeAny, + CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, + >( + opts: + | GenerateOptions + | PromiseLike> + ): Promise>>; + async generate< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, >( options: + | string + | Part[] | GenerateOptions | PromiseLike> ): Promise>> { @@ -671,23 +736,97 @@ export class Genkit { if (!resolvedOptions.model) { resolvedOptions.model = this.options.model; } - return runWithRegistry(this.registry, () => generate(options)); + return generate(this.registry, resolvedOptions); } /** - * Generates a stream of responses from a generative model based on the provided prompt - * and configuration. If `history` is provided, the generation will include a conversation - * history in its request. If `tools` are provided, the generate method will automatically - * resolve tool calls returned from the model unless `returnToolRequests` is set to `true`. + * Make a streaming generate call to the default model with a simple text prompt. + * + * ```ts + * const ai = genkit({ + * plugins: [googleAI()], + * model: gemini15Flash, // default model + * }) + * + * const { response, stream } = await ai.generateStream('hi'); + * for await (const chunk of stream) { + * console.log(chunk.text); + * } + * console.log((await response).text); + * ``` + */ + generateStream( + strPrompt: string + ): Promise>>; + + /** + * Make a streaming generate call to the default model with a multipart request. + * + * ```ts + * const ai = genkit({ + * plugins: [googleAI()], + * model: gemini15Flash, // default model + * }) + * + * const { response, stream } = await ai.generateStream([ + * { media: {url: 'http://....'} }, + * { text: 'describe this image' } + * ]); + * for await (const chunk of stream) { + * console.log(chunk.text); + * } + * console.log((await response).text); + * ``` + */ + generateStream( + parts: Part[] + ): Promise>>; + + /** + * Streaming generate calls a generative model based on the provided prompt and configuration. If + * `messages` is provided, the generation will include a conversation history in its + * request. If `tools` are provided, the generate method will automatically resolve * tool calls returned from the model unless `returnToolRequests` is set to `true`. * - * See {@link GenerateStreamOptions} for detailed information about available options. + * See {@link GenerateOptions} for detailed information about available options. + * + * ```ts + * const ai = genkit({ + * plugins: [googleAI()], + * }) + * + * const { response, stream } = await ai.generateStream({ + * system: 'talk like a pirate', + * prompt: [ + * { media: { url: 'http://....' } }, + * { text: 'describe this image' } + * ], + * messages: conversationHistory, + * tools: [ userInfoLookup ], + * model: gemini15Flash, + * }); + * for await (const chunk of stream) { + * console.log(chunk.text); + * } + * console.log((await response).text); + * ``` */ + generateStream< + O extends z.ZodTypeAny = z.ZodTypeAny, + CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, + >( + parts: + | GenerateOptions + | PromiseLike> + ): Promise>>; + async generateStream< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, >( options: + | string + | Part[] | GenerateStreamOptions | PromiseLike> ): Promise>> { @@ -704,21 +843,93 @@ export class Genkit { if (!resolvedOptions.model) { resolvedOptions.model = this.options.model; } - return runWithRegistry(this.registry, () => generateStream(options)); + return generateStream(this.registry, resolvedOptions); + } + + /** + * Create a chat session with the provided options. + * + * ```ts + * const chat = ai.chat({ + * system: 'talk like a pirate', + * }) + * let response = await chat.send('tell me a joke') + * response = await chat.send('another one') + * ``` + */ + chat(options?: ChatOptions): Chat { + const session = this.createSession(); + return session.chat(options); + } + + /** + * Create a session for this environment. + */ + createSession(options?: SessionOptions): Session { + const sessionId = uuidv4(); + const sessionData: SessionData = { + id: sessionId, + state: options?.initialState, + }; + return new Session(this, { + id: sessionId, + sessionData, + stateSchema: options?.stateSchema, + store: options?.store, + }); + } + + /** + * Loads a session from the store. + */ + async loadSession( + sessionId: string, + options: SessionOptions + ): Promise { + if (!options.store) { + throw new Error('options.store is required'); + } + const sessionData = await options.store.get(sessionId); + + return new Session(this, { + id: sessionId, + sessionData, + stateSchema: options?.stateSchema, + store: options.store, + }); + } + + /** + * Gets the current session from async local storage. + */ + get currentSession(): Session { + const currentSession = getCurrentSession(); + if (!currentSession) { + throw new SessionError('not running within a session'); + } + return currentSession as Session; } /** * Configures the Genkit instance. */ private configure() { - this.options.plugins?.forEach((plugin) => { - logger.debug(`Registering plugin ${plugin.name}...`); - const activeRegistry = this.registry; - activeRegistry.registerPluginProvider(plugin.name, { - name: plugin.name, + const activeRegistry = this.registry; + const plugins = [...(this.options.plugins ?? [])]; + if (this.options.promptDir !== null) { + const dotprompt = genkitPlugin('dotprompt', async (ai) => { + loadPromptFolder(this.registry, this.options.promptDir ?? './prompts'); + }); + plugins.push(dotprompt); + } + plugins.forEach((plugin) => { + const loadedPlugin = plugin(this); + logger.debug(`Registering plugin ${loadedPlugin.name}...`); + activeRegistry.registerPluginProvider(loadedPlugin.name, { + name: loadedPlugin.name, async initializer() { - logger.debug(`Initializing plugin ${plugin.name}:`); - return runWithRegistry(activeRegistry, () => plugin.initializer()); + logger.debug(`Initializing plugin ${loadedPlugin.name}:`); + loadedPlugin.initializer(); }, }); }); @@ -743,12 +954,16 @@ export class Genkit { return this.resolveModel(this.options.model); } if (typeof modelArg === 'string') { - return (await lookupAction(`/model/${modelArg}`)) as ModelAction; - } else if (modelArg.hasOwnProperty('name')) { - const ref = modelArg as ModelReference; - return (await lookupAction(`/model/${ref.name}`)) as ModelAction; - } else { + return (await this.registry.lookupAction( + `/model/${modelArg}` + )) as ModelAction; + } else if ((modelArg as ModelAction).__action) { return modelArg as ModelAction; + } else { + const ref = modelArg as ModelReference; + return (await this.registry.lookupAction( + `/model/${ref.name}` + )) as ModelAction; } } } diff --git a/js/genkit/src/index.ts b/js/genkit/src/index.ts index d92d864d9..45670276c 100644 --- a/js/genkit/src/index.ts +++ b/js/genkit/src/index.ts @@ -14,7 +14,141 @@ * limitations under the License. */ -export * from '@genkit-ai/ai'; -export * from '@genkit-ai/core'; -export * from '@genkit-ai/dotprompt'; +export { + BaseDataPointSchema, + CommonLlmOptions, + Document, + DocumentData, + DocumentDataSchema, + EmbedderAction, + EmbedderArgument, + EmbedderInfo, + EmbedderParams, + EmbedderReference, + Embedding, + EvalResponses, + EvaluatorAction, + EvaluatorInfo, + EvaluatorParams, + EvaluatorReference, + GenerateOptions, + GenerateRequest, + GenerateRequestData, + GenerateResponse, + GenerateResponseData, + GenerateStreamOptions, + GenerateStreamResponse, + GenerationBlockedError, + GenerationCommonConfigSchema, + GenerationResponseError, + GenerationUsage, + IndexerAction, + IndexerArgument, + IndexerInfo, + IndexerParams, + IndexerReference, + LlmResponse, + LlmResponseSchema, + LlmStats, + LlmStatsSchema, + MediaPart, + Message, + MessageData, + MessageSchema, + ModelArgument, + ModelId, + ModelIdSchema, + ModelReference, + ModelRequest, + ModelRequestSchema, + ModelResponseData, + ModelResponseSchema, + Part, + PartSchema, + PromptAction, + PromptConfig, + PromptFn, + RankedDocument, + RerankerAction, + RerankerArgument, + RerankerInfo, + RerankerParams, + RerankerReference, + RetrieverAction, + RetrieverArgument, + RetrieverInfo, + RetrieverParams, + RetrieverReference, + Role, + RoleSchema, + Tool, + ToolAction, + ToolArgument, + ToolCall, + ToolCallSchema, + ToolConfig, + ToolRequestPart, + ToolResponsePart, + ToolSchema, + asTool, + embedderRef, + evaluatorRef, + indexerRef, + rerankerRef, + retrieverRef, + toGenerateRequest, + toToolWireFormat, +} from '@genkit-ai/ai'; +export { + Action, + ActionMetadata, + CallableFlow, + Flow, + FlowActionInput, + FlowActionInputSchema, + FlowAuthPolicy, + FlowConfig, + FlowError, + FlowErrorSchema, + FlowFn, + FlowInvokeEnvelopeMessage, + FlowInvokeEnvelopeMessageSchema, + FlowResponseSchema, + FlowResultSchema, + FlowServer, + FlowServerOptions, + GENKIT_CLIENT_HEADER, + GENKIT_VERSION, + GenkitError, + JSONSchema, + JSONSchema7, + Middleware, + ReflectionServer, + ReflectionServerOptions, + RunActionResponse, + RunActionResponseSchema, + SideChannelData, + Status, + StatusCodes, + StatusSchema, + StreamableFlow, + StreamingCallback, + StreamingFlowConfig, + TelemetryConfig, + __RequestWithAuth, + defineFlow, + defineJsonSchema, + defineSchema, + defineStreamingFlow, + deleteUndefinedProps, + flowMetadataPrefix, + getCurrentEnv, + getFlowAuth, + getStreamingCallback, + isDevEnv, + run, + runWithStreamingCallback, + z, +} from '@genkit-ai/core'; +export { loadPromptFile } from '@genkit-ai/dotprompt'; export * from './genkit.js'; diff --git a/js/genkit/src/logging.ts b/js/genkit/src/logging.ts index 9f22a94d6..5334a58f7 100644 --- a/js/genkit/src/logging.ts +++ b/js/genkit/src/logging.ts @@ -14,4 +14,4 @@ * limitations under the License. */ -export * from '@genkit-ai/core/logging'; +export { logger } from '@genkit-ai/core/logging'; diff --git a/js/genkit/src/middleware.ts b/js/genkit/src/middleware.ts index 3778af112..443ac450e 100644 --- a/js/genkit/src/middleware.ts +++ b/js/genkit/src/middleware.ts @@ -14,4 +14,11 @@ * limitations under the License. */ -export * from '@genkit-ai/ai/model/middleware'; +export { + augmentWithContext, + conformOutput, + downloadRequestMedia, + simulateSystemPrompt, + validateSupport, + type AugmentWithContextOptions, +} from '@genkit-ai/ai/model/middleware'; diff --git a/js/genkit/src/model.ts b/js/genkit/src/model.ts index 6ff7c5855..87a29f0fa 100644 --- a/js/genkit/src/model.ts +++ b/js/genkit/src/model.ts @@ -14,4 +14,56 @@ * limitations under the License. */ -export * from '@genkit-ai/ai/model'; +export { + CandidateErrorSchema, + CandidateSchema, + CustomPartSchema, + DataPartSchema, + GenerateRequestSchema, + GenerateResponseChunkSchema, + GenerateResponseSchema, + GenerationCommonConfigSchema, + GenerationUsageSchema, + MediaPartSchema, + MessageSchema, + ModelInfoSchema, + ModelRequestSchema, + ModelResponseChunkSchema, + ModelResponseSchema, + PartSchema, + RoleSchema, + TextPartSchema, + ToolDefinitionSchema, + ToolRequestPartSchema, + ToolResponsePartSchema, + getBasicUsageStats, + modelRef, + type CandidateData, + type CandidateError, + type CustomPart, + type DataPart, + type DefineModelOptions, + type GenerateRequest, + type GenerateRequestData, + type GenerateResponseChunkData, + type GenerateResponseData, + type GenerationCommonConfig, + type GenerationUsage, + type MediaPart, + type MessageData, + type ModelAction, + type ModelArgument, + type ModelInfo, + type ModelMiddleware, + type ModelReference, + type ModelRequest, + type ModelResponseChunkData, + type ModelResponseData, + type OutputConfig, + type Part, + type Role, + type TextPart, + type ToolDefinition, + type ToolRequestPart, + type ToolResponsePart, +} from '@genkit-ai/ai/model'; diff --git a/js/genkit/src/plugin.ts b/js/genkit/src/plugin.ts new file mode 100644 index 000000000..d0baf3318 --- /dev/null +++ b/js/genkit/src/plugin.ts @@ -0,0 +1,41 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Genkit } from './genkit.js'; + +export interface PluginProvider { + name: string; + initializer: () => void | Promise; +} + +type PluginInit = (genkit: Genkit) => void | Promise; + +export type GenkitPlugin = (genkit: Genkit) => PluginProvider; + +/** + * Defines a Genkit plugin. + */ +export function genkitPlugin( + pluginName: string, + initFn: T +): GenkitPlugin { + return (genkit: Genkit) => ({ + name: pluginName, + initializer: async () => { + await initFn(genkit); + }, + }); +} diff --git a/js/genkit/src/registry.ts b/js/genkit/src/registry.ts index 2f62006d2..8c45c10d4 100644 --- a/js/genkit/src/registry.ts +++ b/js/genkit/src/registry.ts @@ -14,4 +14,9 @@ * limitations under the License. */ -export * from '@genkit-ai/core/registry'; +export { + ActionType, + AsyncProvider, + Registry, + Schema, +} from '@genkit-ai/core/registry'; diff --git a/js/genkit/src/reranker.ts b/js/genkit/src/reranker.ts index 5aa131891..52b7f8f1f 100644 --- a/js/genkit/src/reranker.ts +++ b/js/genkit/src/reranker.ts @@ -14,4 +14,16 @@ * limitations under the License. */ -export * from '@genkit-ai/ai/reranker'; +export { + CommonRerankerOptionsSchema, + RankedDocument, + RankedDocumentDataSchema, + RerankerInfoSchema, + rerankerRef, + type RankedDocumentData, + type RerankerAction, + type RerankerArgument, + type RerankerInfo, + type RerankerParams, + type RerankerReference, +} from '@genkit-ai/ai/reranker'; diff --git a/js/genkit/src/retriever.ts b/js/genkit/src/retriever.ts index c5d708e8c..2b020f40e 100644 --- a/js/genkit/src/retriever.ts +++ b/js/genkit/src/retriever.ts @@ -14,4 +14,29 @@ * limitations under the License. */ -export * from '@genkit-ai/ai/retriever'; +export { + CommonRetrieverOptionsSchema, + Document, + DocumentData, + DocumentDataSchema, + IndexerInfoSchema, + MediaPart, + Part, + RetrieverInfoSchema, + TextPart, + indexerRef, + retrieverRef, + type IndexerAction, + type IndexerArgument, + type IndexerFn, + type IndexerInfo, + type IndexerParams, + type IndexerReference, + type RetrieverAction, + type RetrieverArgument, + type RetrieverFn, + type RetrieverInfo, + type RetrieverParams, + type RetrieverReference, + type SimpleRetrieverOptions, +} from '@genkit-ai/ai/retriever'; diff --git a/js/genkit/src/schema.ts b/js/genkit/src/schema.ts index 9bcf89551..efeb4b728 100644 --- a/js/genkit/src/schema.ts +++ b/js/genkit/src/schema.ts @@ -14,4 +14,13 @@ * limitations under the License. */ -export * from '@genkit-ai/core/schema'; +export { + ValidationError, + parseSchema, + toJsonSchema, + validateSchema, + type JSONSchema, + type ProvidedSchema, + type ValidationErrorDetail, + type ValidationResponse, +} from '@genkit-ai/core/schema'; diff --git a/js/genkit/src/session.ts b/js/genkit/src/session.ts new file mode 100644 index 000000000..78824e15b --- /dev/null +++ b/js/genkit/src/session.ts @@ -0,0 +1,241 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { GenerateOptions, MessageData } from '@genkit-ai/ai'; +import { z } from '@genkit-ai/core'; +import { AsyncLocalStorage } from 'node:async_hooks'; +import { v4 as uuidv4 } from 'uuid'; +import { Chat, ChatOptions, MAIN_THREAD, PromptRenderOptions } from './chat'; +import { Genkit } from './genkit'; + +export type BaseGenerateOptions = Omit; + +export interface SessionOptions { + /** Schema describing the state. */ + stateSchema?: S; + /** Session store implementation for persisting the session state. */ + store?: SessionStore; + /** Initial state of the session. */ + initialState?: z.infer; + /** Custom session Id. */ + sessionId?: string; +} + +/** + * Session encapsulates a statful execution environment for chat. + * Chat session executed within a session in this environment will have acesss to + * session session convesation history. + * + * ```ts + * const ai = genkit({...}); + * const chat = ai.chat(); // create a Session + * let response = await chat.send('hi'); // session/history aware conversation + * response = await chat.send('tell me a story'); + * ``` + */ +export class Session { + readonly id: string; + readonly schema?: S; + private sessionData?: SessionData; + private store: SessionStore; + + constructor( + readonly genkit: Genkit, + options?: { + id?: string; + stateSchema?: S; + sessionData?: SessionData; + store?: SessionStore; + } + ) { + this.id = options?.id ?? uuidv4(); + this.schema = options?.stateSchema; + this.sessionData = options?.sessionData ?? { + id: this.id, + }; + if (!this.sessionData) { + this.sessionData = { id: this.id }; + } + if (!this.sessionData.threads) { + this.sessionData!.threads = {}; + } + this.store = options?.store ?? new InMemorySessionStore(); + } + + get state(): z.infer { + // We always get state from the parent. Parent session is the source of truth. + if (this.genkit instanceof Session) { + return this.genkit.state; + } + return this.sessionData!.state; + } + + /** + * Update session state data. + */ + async updateState(data: z.infer): Promise { + let sessionData = this.sessionData; + if (!sessionData) { + sessionData = {} as SessionData; + } + sessionData.state = data; + this.sessionData = sessionData; + + await this.store.save(this.id, sessionData); + } + + /** + * Update messages for a given thread. + */ + async updateMessages( + thread: string, + messasges: MessageData[] + ): Promise { + let sessionData = this.sessionData; + if (!sessionData) { + sessionData = {} as SessionData; + } + if (!sessionData.threads) { + sessionData.threads = {}; + } + sessionData.threads[thread] = messasges; + this.sessionData = sessionData; + + await this.store.save(this.id, sessionData); + } + + /** + * Create a chat session with the provided options. + * + * ```ts + * const chat = ai.chat({ + * system: 'talk like a pirate', + * }) + * let response = await chat.send('tell me a joke') + * response = await chat.send('another one') + * ``` + */ + chat(options?: ChatOptions): Chat; + + /** + * Craete a separaete chat conversation ("thread") within the same session state. + * + * ```ts + * const lawyerChat = ai.chat('lawyerThread', { + * system: 'talk like a lawyer', + * }) + * const pirateChat = ai.chat('pirateThread', { + * system: 'talk like a pirate', + * }) + * await lawyerChat.send('tell me a joke') + * await pirateChat.send('tell me a joke') + * ``` + */ + chat(threadName: string, options?: ChatOptions): Chat; + + chat( + optionsOrThreadName?: ChatOptions | string, + maybeOptions?: ChatOptions + ): Chat { + let options: ChatOptions | undefined; + let threadName = MAIN_THREAD; + if (maybeOptions) { + threadName = optionsOrThreadName as string; + options = maybeOptions as ChatOptions; + } else if (optionsOrThreadName) { + if (typeof optionsOrThreadName === 'string') { + threadName = optionsOrThreadName as string; + } else { + options = optionsOrThreadName as ChatOptions; + } + } + let requestBase: Promise; + if (!!(options as PromptRenderOptions)?.prompt?.render) { + const renderOptions = options as PromptRenderOptions; + requestBase = renderOptions.prompt.render({ + input: renderOptions.input, + }); + } else { + requestBase = Promise.resolve(options as BaseGenerateOptions); + } + return new Chat(this, requestBase, { + thread: threadName, + id: this.id, + messages: + (this.sessionData?.threads && this.sessionData?.threads[threadName]) ?? + [], + }); + } + + toJSON() { + return this.sessionData; + } +} + +export interface SessionData { + id: string; + state?: z.infer; + threads?: Record; +} + +const sessionAls = new AsyncLocalStorage>(); + +/** + * Executes provided function within the provided session state. + */ +export function runWithSession( + session: Session, + fn: () => O +): O { + return sessionAls.run(session, fn); +} + +/** Returns the current session. */ +export function getCurrentSession(): + | Session + | undefined { + return sessionAls.getStore(); +} + +/** Throw when session state errors occur, ex. missing state, etc. */ +export class SessionError extends Error { + constructor(msg: string) { + super(msg); + } +} + +/** Session store persists session data such as state and chat messages. */ +export interface SessionStore { + get(sessionId: string): Promise | undefined>; + + save(sessionId: string, data: Omit, 'id'>): Promise; +} + +export function inMemorySessionStore() { + return new InMemorySessionStore(); +} + +class InMemorySessionStore implements SessionStore { + private data: Record> = {}; + + async get(sessionId: string): Promise | undefined> { + return this.data[sessionId]; + } + + async save(sessionId: string, sessionData: SessionData): Promise { + this.data[sessionId] = sessionData; + } +} diff --git a/js/genkit/src/testing.ts b/js/genkit/src/testing.ts index 77ab01fe8..34b2c08cd 100644 --- a/js/genkit/src/testing.ts +++ b/js/genkit/src/testing.ts @@ -14,4 +14,4 @@ * limitations under the License. */ -export * from '@genkit-ai/ai/testing'; +export { testModels } from '@genkit-ai/ai/testing'; diff --git a/js/genkit/src/tool.ts b/js/genkit/src/tool.ts index 0233774b3..34d67d654 100644 --- a/js/genkit/src/tool.ts +++ b/js/genkit/src/tool.ts @@ -14,4 +14,10 @@ * limitations under the License. */ -export * from '@genkit-ai/ai/tool'; +export { + asTool, + toToolDefinition, + type ToolAction, + type ToolArgument, + type ToolConfig, +} from '@genkit-ai/ai/tool'; diff --git a/js/genkit/src/tracing.ts b/js/genkit/src/tracing.ts index 4e41a0c04..10cdc4a3b 100644 --- a/js/genkit/src/tracing.ts +++ b/js/genkit/src/tracing.ts @@ -14,4 +14,36 @@ * limitations under the License. */ -export * from '@genkit-ai/core/tracing'; +export { + GenkitSpanProcessorWrapper, + InstrumentationLibrarySchema, + LinkSchema, + PathMetadata, + PathMetadataSchema, + SPAN_TYPE_ATTR, + SpanContextSchema, + SpanData, + SpanDataSchema, + SpanMetadata, + SpanMetadataSchema, + SpanStatusSchema, + TimeEventSchema, + TraceData, + TraceDataSchema, + TraceMetadata, + TraceMetadataSchema, + TraceServerExporter, + appendSpan, + cleanUpTracing, + enableTelemetry, + ensureBasicTelemetryInstrumentation, + flushTracing, + newTrace, + runInNewSpan, + setCustomMetadataAttribute, + setCustomMetadataAttributes, + setTelemetryServerUrl, + spanMetadataAls, + toDisplayPath, + traceMetadataAls, +} from '@genkit-ai/core/tracing'; diff --git a/js/genkit/tests/chat_test.ts b/js/genkit/tests/chat_test.ts new file mode 100644 index 000000000..f6e28eb74 --- /dev/null +++ b/js/genkit/tests/chat_test.ts @@ -0,0 +1,153 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import assert from 'node:assert'; +import { beforeEach, describe, it } from 'node:test'; +import { Genkit, genkit } from '../src/genkit'; +import { defineEchoModel } from './helpers'; + +describe('session', () => { + let ai: Genkit; + + beforeEach(() => { + ai = genkit({ + model: 'echoModel', + }); + defineEchoModel(ai); + }); + + it('maintains history in the session', async () => { + const session = ai.chat(); + let response = await session.send('hi'); + + assert.strictEqual(response.text, 'Echo: hi; config: {}'); + + response = await session.send('bye'); + + assert.strictEqual( + response.text, + 'Echo: hi,Echo: hi,; config: {},bye; config: {}' + ); + assert.deepStrictEqual(response.messages, [ + { content: [{ text: 'hi' }], role: 'user' }, + { + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], + role: 'model', + }, + { content: [{ text: 'bye' }], role: 'user' }, + { + content: [ + { text: 'Echo: hi,Echo: hi,; config: {},bye' }, + { text: '; config: {}' }, + ], + role: 'model', + }, + ]); + }); + + it('maintains history in the session with streaming', async () => { + const chat = ai.chat(); + let { response, stream } = await chat.sendStream('hi'); + + let chunks: string[] = []; + for await (const chunk of stream) { + chunks.push(chunk.text); + } + assert.strictEqual((await response).text, 'Echo: hi; config: {}'); + assert.deepStrictEqual(chunks, ['3', '2', '1']); + + ({ response, stream } = await chat.sendStream('bye')); + + chunks = []; + for await (const chunk of stream) { + chunks.push(chunk.text); + } + + assert.deepStrictEqual(chunks, ['3', '2', '1']); + assert.strictEqual( + (await response).text, + 'Echo: hi,Echo: hi,; config: {},bye; config: {}' + ); + assert.deepStrictEqual((await response).messages, [ + { content: [{ text: 'hi' }], role: 'user' }, + { + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], + role: 'model', + }, + { content: [{ text: 'bye' }], role: 'user' }, + { + content: [ + { text: 'Echo: hi,Echo: hi,; config: {},bye' }, + { text: '; config: {}' }, + ], + role: 'model', + }, + ]); + }); + + it('can init a session with a prompt', async () => { + const prompt = ai.definePrompt({ name: 'hi' }, 'hi {{ name }}'); + + const session = await ai.chat( + await prompt.render({ + input: { name: 'Genkit' }, + config: { temperature: 11 }, + }) + ); + const response = await session.send('hi'); + + assert.strictEqual( + response.text, + 'Echo: hi Genkit,hi; config: {"temperature":11}' + ); + }); + + it('can start chat from a prompt', async () => { + const prompt = ai.definePrompt( + { name: 'hi', config: { version: 'abc' } }, + 'hi {{ name }} from template' + ); + const session = await ai.chat({ + prompt, + input: { name: 'Genkit' }, + }); + const response = await session.send('send it'); + + assert.strictEqual( + response.text, + 'Echo: hi Genkit from template,send it; config: {"version":"abc"}' + ); + }); + + it('can send a rendered prompt to chat', async () => { + const prompt = ai.definePrompt( + { name: 'hi', config: { version: 'abc' } }, + 'hi {{ name }}' + ); + const session = ai.chat(); + const response = await session.send( + await prompt.render({ + input: { name: 'Genkit' }, + config: { temperature: 11 }, + }) + ); + + assert.strictEqual( + response.text, + 'Echo: hi Genkit; config: {"version":"abc","temperature":11}' + ); + }); +}); diff --git a/js/genkit/tests/embed_test.ts b/js/genkit/tests/embed_test.ts new file mode 100644 index 000000000..18d940dde --- /dev/null +++ b/js/genkit/tests/embed_test.ts @@ -0,0 +1,140 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Document, EmbedderAction, embedderRef } from '@genkit-ai/ai'; +import assert from 'node:assert'; +import { beforeEach, describe, it } from 'node:test'; +import { Genkit, genkit } from '../src/genkit'; + +describe('embed', () => { + describe('default model', () => { + let ai: Genkit; + let embedder: EmbedderAction; + + beforeEach(() => { + ai = genkit({}); + embedder = defineTestEmbedder(ai); + }); + + it('passes string content as docs', async () => { + const response = await ai.embed({ + embedder: 'echoEmbedder', + content: 'hi', + }); + assert.deepStrictEqual((embedder as any).lastRequest, [ + [Document.fromText('hi')], + { + version: undefined, + }, + ]); + assert.deepStrictEqual(response, [1, 2, 3, 4]); + }); + + it('passes docs content as docs', async () => { + const response = await ai.embed({ + embedder: 'echoEmbedder', + content: Document.fromText('hi'), + }); + assert.deepStrictEqual((embedder as any).lastRequest, [ + [Document.fromText('hi')], + { + version: undefined, + }, + ]); + assert.deepStrictEqual(response, [1, 2, 3, 4]); + }); + }); + + describe('config', () => { + let ai: Genkit; + let embedder: EmbedderAction; + + beforeEach(() => { + ai = genkit({}); + embedder = defineTestEmbedder(ai); + }); + + it('takes config passed to generate', async () => { + const response = await ai.embed({ + embedder: 'echoEmbedder', + content: 'hi', + options: { + temperature: 11, + }, + }); + assert.deepStrictEqual(response, [1, 2, 3, 4]); + assert.deepStrictEqual((embedder as any).lastRequest[1], { + temperature: 11, + version: undefined, + }); + }); + + it('merges config from the ref', async () => { + const response = await ai.embed({ + embedder: embedderRef({ + name: 'echoEmbedder', + config: { + version: 'abc', + }, + }), + content: 'hi', + options: { + temperature: 11, + }, + }); + assert.deepStrictEqual(response, [1, 2, 3, 4]); + assert.deepStrictEqual((embedder as any).lastRequest[1], { + temperature: 11, + version: 'abc', + }); + }); + + it('picks up the top-level version from the ref', async () => { + const response = await ai.embed({ + embedder: embedderRef({ + name: 'echoEmbedder', + version: 'abc', + }), + content: 'hi', + options: { + temperature: 11, + }, + }); + assert.deepStrictEqual(response, [1, 2, 3, 4]); + assert.deepStrictEqual((embedder as any).lastRequest[1], { + temperature: 11, + version: 'abc', + }); + }); + }); +}); + +function defineTestEmbedder(ai: Genkit) { + const embedder = ai.defineEmbedder( + { name: 'echoEmbedder' }, + async (input, config) => { + (embedder as any).lastRequest = [input, config]; + return { + embeddings: [ + { + embedding: [1, 2, 3, 4], + }, + ], + }; + } + ); + return embedder; +} diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts new file mode 100644 index 000000000..1c63499f2 --- /dev/null +++ b/js/genkit/tests/generate_test.ts @@ -0,0 +1,158 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import assert from 'node:assert'; +import { beforeEach, describe, it } from 'node:test'; +import { modelRef } from '../../ai/src/model'; +import { Genkit, genkit } from '../src/genkit'; +import { defineEchoModel } from './helpers'; + +describe('generate', () => { + describe('default model', () => { + let ai: Genkit; + + beforeEach(() => { + ai = genkit({ + model: 'echoModel', + }); + defineEchoModel(ai); + }); + + it('calls the default model', async () => { + const response = await ai.generate({ + prompt: 'hi', + }); + assert.strictEqual(response.text, 'Echo: hi; config: {}'); + }); + + it('calls the default model with just a string prompt', async () => { + const response = await ai.generate('hi'); + assert.strictEqual(response.text, 'Echo: hi; config: {}'); + }); + + it('calls the default model with just parts prompt', async () => { + const response = await ai.generate([{ text: 'hi' }]); + assert.strictEqual(response.text, 'Echo: hi; config: {}'); + }); + + it('calls the default model system', async () => { + const response = await ai.generate({ + prompt: 'hi', + system: 'talk like a pirate', + }); + assert.strictEqual( + response.text, + 'Echo: system: talk like a pirate,hi; config: {}' + ); + assert.deepStrictEqual(response.request, { + config: undefined, + docs: undefined, + messages: [ + { + role: 'system', + content: [{ text: 'talk like a pirate' }], + }, + { + role: 'user', + content: [{ text: 'hi' }], + }, + ], + output: { + format: 'text', + }, + tools: [], + }); + }); + + it('streams the default model', async () => { + const { response, stream } = await ai.generateStream('hi'); + + const chunks: string[] = []; + for await (const chunk of stream) { + chunks.push(chunk.text); + } + assert.strictEqual((await response).text, 'Echo: hi; config: {}'); + assert.deepStrictEqual(chunks, ['3', '2', '1']); + }); + }); + + describe('default model', () => { + let ai: Genkit; + + beforeEach(() => { + ai = genkit({}); + defineEchoModel(ai); + }); + + it('calls the explicitly passed in model', async () => { + const response = await ai.generate({ + model: 'echoModel', + prompt: 'hi', + }); + assert.strictEqual(response.text, 'Echo: hi; config: {}'); + }); + }); + + describe('config', () => { + let ai: Genkit; + + beforeEach(() => { + ai = genkit({}); + defineEchoModel(ai); + }); + + it('takes config passed to generate', async () => { + const response = await ai.generate({ + prompt: 'hi', + model: 'echoModel', + config: { + temperature: 11, + }, + }); + assert.strictEqual(response.text, 'Echo: hi; config: {"temperature":11}'); + }); + + it('merges config from the ref', async () => { + const response = await ai.generate({ + prompt: 'hi', + model: modelRef({ name: 'echoModel' }).withConfig({ + version: 'abc', + }), + config: { + temperature: 11, + }, + }); + assert.strictEqual( + response.text, + 'Echo: hi; config: {"version":"abc","temperature":11}' + ); + }); + + it('picks up the top-level version from the ref', async () => { + const response = await ai.generate({ + prompt: 'hi', + model: modelRef({ name: 'echoModel' }).withVersion('bcd'), + config: { + temperature: 11, + }, + }); + assert.strictEqual( + response.text, + 'Echo: hi; config: {"version":"bcd","temperature":11}' + ); + }); + }); +}); diff --git a/js/genkit/tests/helpers.ts b/js/genkit/tests/helpers.ts index f83c9cbd5..54ebc3d1c 100644 --- a/js/genkit/tests/helpers.ts +++ b/js/genkit/tests/helpers.ts @@ -16,6 +16,8 @@ import { MessageData } from '@genkit-ai/ai'; import { ModelAction } from '@genkit-ai/ai/model'; +import { z } from '@genkit-ai/core'; +import { SessionData, SessionStore } from '../src/environment'; import { Genkit } from '../src/genkit'; export function defineEchoModel(ai: Genkit): ModelAction { @@ -61,7 +63,12 @@ export function defineEchoModel(ai: Genkit): ModelAction { text: 'Echo: ' + request.messages - .map((m) => m.content.map((c) => c.text).join()) + .map( + (m) => + (m.role === 'user' || m.role === 'model' + ? '' + : `${m.role}: `) + m.content.map((c) => c.text).join() + ) .join(), }, { @@ -75,6 +82,23 @@ export function defineEchoModel(ai: Genkit): ModelAction { ); } +async function runAsync(fn: () => O): Promise { + return Promise.resolve(fn()); +} + +export class TestMemorySessionStore + implements SessionStore +{ + private data: Record> = {}; + + async get(sessionId: string): Promise | undefined> { + return this.data[sessionId]; + } + + async save(sessionId: string, sessionData: SessionData): Promise { + this.data[sessionId] = sessionData; + } +} export function defineStaticResponseModel( ai: Genkit, message: MessageData @@ -90,7 +114,3 @@ export function defineStaticResponseModel( } ); } - -async function runAsync(fn: () => O): Promise { - return Promise.resolve(fn()); -} diff --git a/js/genkit/tests/models_test.ts b/js/genkit/tests/models_test.ts deleted file mode 100644 index 8b7c03a61..000000000 --- a/js/genkit/tests/models_test.ts +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import assert from 'node:assert'; -import { beforeEach, describe, it } from 'node:test'; -import { Genkit, genkit } from '../src/genkit'; -import { defineEchoModel } from './helpers'; - -describe('models', () => { - describe('generate', () => { - describe('default model', () => { - let ai: Genkit; - - beforeEach(() => { - ai = genkit({ - model: 'echoModel', - }); - defineEchoModel(ai); - }); - - it('calls the default model', async () => { - const response = await ai.generate({ - prompt: 'hi', - }); - assert.strictEqual(response.text(), 'Echo: hi; config: undefined'); - }); - - it('streams the default model', async () => { - const { response, stream } = await ai.generateStream({ - prompt: 'hi', - }); - - const chunks: string[] = []; - for await (const chunk of stream) { - chunks.push(chunk.text()); - } - assert.strictEqual( - (await response).text(), - 'Echo: hi; config: undefined' - ); - assert.deepStrictEqual(chunks, ['3', '2', '1']); - }); - }); - - describe('default model', () => { - let ai: Genkit; - - beforeEach(() => { - ai = genkit({}); - defineEchoModel(ai); - }); - - it('calls the explicitly passed in model', async () => { - const response = await ai.generate({ - model: 'echoModel', - prompt: 'hi', - }); - assert.strictEqual(response.text(), 'Echo: hi; config: undefined'); - }); - }); - }); -}); diff --git a/js/genkit/tests/prompts_test.ts b/js/genkit/tests/prompts_test.ts index 0422fd416..627c400f3 100644 --- a/js/genkit/tests/prompts_test.ts +++ b/js/genkit/tests/prompts_test.ts @@ -46,7 +46,7 @@ describe('definePrompt - dotprompt', () => { ); const response = await hi({ name: 'Genkit' }); - assert.strictEqual(response.text(), 'Echo: hi Genkit; config: {}'); + assert.strictEqual(response.text, 'Echo: hi Genkit; config: {}'); }); it('calls dotprompt with default model with config', async () => { @@ -67,7 +67,7 @@ describe('definePrompt - dotprompt', () => { const response = await hi({ name: 'Genkit' }); assert.strictEqual( - response.text(), + response.text, 'Echo: hi Genkit; config: {"temperature":11}' ); }); @@ -93,8 +93,8 @@ describe('definePrompt - dotprompt', () => { config: { version: 'abc' }, }); assert.strictEqual( - response.text(), - 'Echo: hi Genkit; config: {"temperature":11,"version":"abc"}' + response.text, + 'Echo: hi Genkit; config: {"version":"abc","temperature":11}' ); }); @@ -114,7 +114,7 @@ describe('definePrompt - dotprompt', () => { const hi = await ai.prompt('hi'); const response = await hi({ name: 'Genkit' }); - assert.strictEqual(response.text(), 'Echo: hi Genkit; config: {}'); + assert.strictEqual(response.text, 'Echo: hi Genkit; config: {}'); }); }); @@ -144,7 +144,7 @@ describe('definePrompt - dotprompt', () => { ); const response = await hi({ name: 'Genkit' }); - assert.strictEqual(response.text(), 'Echo: hi Genkit; config: {}'); + assert.strictEqual(response.text, 'Echo: hi Genkit; config: {}'); }); it('infers output schema', async () => { @@ -176,7 +176,7 @@ describe('definePrompt - dotprompt', () => { ); const response = await hi({ name: 'Genkit' }); - const foo: z.infer = response.output(); + const foo: z.infer = response.output; assert.deepStrictEqual(foo, { bar: 'baz' }); }); @@ -199,9 +199,9 @@ describe('definePrompt - dotprompt', () => { const { response, stream } = await hi.stream({ name: 'Genkit' }); const chunks: string[] = []; for await (const chunk of stream) { - chunks.push(chunk.text()); + chunks.push(chunk.text); } - const responseText = (await response).text(); + const responseText = (await response).text; assert.strictEqual( responseText, @@ -232,13 +232,13 @@ describe('definePrompt - dotprompt', () => { }); const chunks: string[] = []; for await (const chunk of stream) { - chunks.push(chunk.text()); + chunks.push(chunk.text); } - const responseText = (await response).text(); + const responseText = (await response).text; assert.strictEqual( responseText, - 'Echo: hi Genkit; config: {"temperature":11,"version":"abc"}' + 'Echo: hi Genkit; config: {"version":"abc","temperature":11}' ); assert.deepStrictEqual(chunks, ['3', '2', '1']); }); @@ -259,7 +259,7 @@ describe('definePrompt - dotprompt', () => { const hi = await ai.prompt('hi'); const response = await hi({ name: 'Genkit' }); - assert.strictEqual(response.text(), 'Echo: hi Genkit; config: {}'); + assert.strictEqual(response.text, 'Echo: hi Genkit; config: {}'); }); }); @@ -286,7 +286,7 @@ describe('definePrompt - dotprompt', () => { ); const response = await hi({ name: 'Genkit' }); - assert.strictEqual(response.text(), 'Echo: hi Genkit; config: {}'); + assert.strictEqual(response.text, 'Echo: hi Genkit; config: {}'); }); it('calls dotprompt with default model with config', async () => { @@ -308,7 +308,7 @@ describe('definePrompt - dotprompt', () => { const response = await hi({ name: 'Genkit' }); assert.strictEqual( - response.text(), + response.text, 'Echo: hi Genkit; config: {"temperature":11}' ); }); @@ -339,22 +339,22 @@ describe('definePrompt - dotprompt', () => { delete response.model; // ignore assert.deepStrictEqual(response, { config: {}, - context: undefined, - messages: [], - prompt: [ + docs: undefined, + messages: [ { - text: 'hi Genkit', + content: [ + { + text: 'hi Genkit', + }, + ], + role: 'user', }, ], output: { - format: undefined, + format: 'text', jsonSchema: undefined, - schema: undefined, }, - returnToolRequests: undefined, - streamingCallback: undefined, tools: [], - use: undefined, }); }); }); @@ -391,7 +391,7 @@ describe('definePrompt', () => { ); const response = await hi({ name: 'Genkit' }); - assert.strictEqual(response.text(), 'Echo: hi Genkit; config: {}'); + assert.strictEqual(response.text, 'Echo: hi Genkit; config: {}'); }); it('calls dotprompt with default model with config', async () => { @@ -418,7 +418,7 @@ describe('definePrompt', () => { const response = await hi({ name: 'Genkit' }); assert.strictEqual( - response.text(), + response.text, 'Echo: hi Genkit; config: {"temperature":11}' ); }); @@ -445,7 +445,7 @@ describe('definePrompt', () => { const hi = await ai.prompt('hi'); const response = await hi({ name: 'Genkit' }); - assert.strictEqual(response.text(), 'Echo: hi Genkit; config: {}'); + assert.strictEqual(response.text, 'Echo: hi Genkit; config: {}'); }); }); @@ -481,7 +481,7 @@ describe('definePrompt', () => { ); const response = await hi({ name: 'Genkit' }); - assert.strictEqual(response.text(), 'Echo: hi Genkit; config: {}'); + assert.strictEqual(response.text, 'Echo: hi Genkit; config: {}'); }); it('streams dotprompt with default model', async () => { @@ -509,9 +509,9 @@ describe('definePrompt', () => { const { response, stream } = await hi.stream({ name: 'Genkit' }); const chunks: string[] = []; for await (const chunk of stream) { - chunks.push(chunk.text()); + chunks.push(chunk.text); } - const responseText = (await response).text(); + const responseText = (await response).text; assert.strictEqual( responseText, @@ -550,7 +550,7 @@ describe('definePrompt', () => { ); const response = await hi({ name: 'Genkit' }); - assert.strictEqual(response.text(), 'Echo: hi Genkit; config: {}'); + assert.strictEqual(response.text, 'Echo: hi Genkit; config: {}'); }); it('calls dotprompt with default model with config', async () => { @@ -578,7 +578,7 @@ describe('definePrompt', () => { const response = await hi({ name: 'Genkit' }); assert.strictEqual( - response.text(), + response.text, 'Echo: hi Genkit; config: {"temperature":11}' ); }); @@ -609,12 +609,14 @@ describe('definePrompt', () => { const response = await hi( { name: 'Genkit' }, { - version: 'abc', + config: { + version: 'abc', + }, } ); assert.strictEqual( - response.text(), - 'Echo: hi Genkit; config: {"temperature":11,"version":"abc"}' + response.text, + 'Echo: hi Genkit; config: {"version":"abc","temperature":11}' ); }); @@ -639,7 +641,7 @@ describe('definePrompt', () => { ); const response = await hi.generate({ input: { name: 'Genkit' } }); - assert.strictEqual(response.text(), 'Echo: hi Genkit; config: {}'); + assert.strictEqual(response.text, 'Echo: hi Genkit; config: {}'); }); it('streams dotprompt with .generateStream', async () => { @@ -671,13 +673,13 @@ describe('definePrompt', () => { }); const chunks: string[] = []; for await (const chunk of stream) { - chunks.push(chunk.text()); + chunks.push(chunk.text); } - const responseText = (await response).text(); + const responseText = (await response).text; assert.strictEqual( responseText, - 'Echo: hi Genkit; config: {"temperature":11,"version":"abc"}' + 'Echo: hi Genkit; config: {"version":"abc","temperature":11}' ); assert.deepStrictEqual(chunks, ['3', '2', '1']); }); @@ -715,10 +717,7 @@ describe('definePrompt', () => { delete response.model; // ignore assert.deepStrictEqual(response, { config: {}, - context: undefined, - input: { - name: 'Genkit', - }, + docs: undefined, messages: [ { content: [ diff --git a/js/genkit/tests/session_test.ts b/js/genkit/tests/session_test.ts new file mode 100644 index 000000000..be2ff5ada --- /dev/null +++ b/js/genkit/tests/session_test.ts @@ -0,0 +1,326 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import assert from 'node:assert'; +import { beforeEach, describe, it } from 'node:test'; +import { Genkit, genkit } from '../src/genkit'; +import { TestMemorySessionStore, defineEchoModel } from './helpers'; + +describe('session', () => { + let ai: Genkit; + + beforeEach(() => { + ai = genkit({ + model: 'echoModel', + }); + defineEchoModel(ai); + }); + + it('maintains history in the session', async () => { + const session = ai.createSession(); + const chat = session.chat(); + let response = await chat.send('hi'); + + assert.strictEqual(response.text, 'Echo: hi; config: {}'); + + response = await chat.send('bye'); + + assert.strictEqual( + response.text, + 'Echo: hi,Echo: hi,; config: {},bye; config: {}' + ); + assert.deepStrictEqual(response.messages, [ + { content: [{ text: 'hi' }], role: 'user' }, + { + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], + role: 'model', + }, + { content: [{ text: 'bye' }], role: 'user' }, + { + content: [ + { text: 'Echo: hi,Echo: hi,; config: {},bye' }, + { text: '; config: {}' }, + ], + role: 'model', + }, + ]); + }); + + it('maintains multithreaded history in the session', async () => { + const store = new TestMemorySessionStore(); + const session = ai.createSession({ + store, + initialState: { + name: 'Genkit', + }, + }); + + let mainChat = session.chat(); + let response = await mainChat.send('hi main'); + assert.strictEqual(response.text, 'Echo: hi main; config: {}'); + + const lawyerChat = session.chat('lawyerChat', { + system: 'talk like a lawyer', + }); + response = await lawyerChat.send('hi lawyerChat'); + assert.strictEqual( + response.text, + 'Echo: system: talk like a lawyer,hi lawyerChat; config: {}' + ); + + const pirateChat = session.chat('pirateChat', { + system: 'talk like a pirate', + }); + response = await pirateChat.send('hi pirateChat'); + assert.strictEqual( + response.text, + 'Echo: system: talk like a pirate,hi pirateChat; config: {}' + ); + + const gotState = await store.get(session.id); + delete gotState.id; // ignore + assert.deepStrictEqual(gotState, { + state: { + name: 'Genkit', + }, + threads: { + main: [ + { content: [{ text: 'hi main' }], role: 'user' }, + { + content: [{ text: 'Echo: hi main' }, { text: '; config: {}' }], + role: 'model', + }, + ], + lawyerChat: [ + { content: [{ text: 'talk like a lawyer' }], role: 'system' }, + { content: [{ text: 'hi lawyerChat' }], role: 'user' }, + { + content: [ + { text: 'Echo: system: talk like a lawyer,hi lawyerChat' }, + { text: '; config: {}' }, + ], + role: 'model', + }, + ], + pirateChat: [ + { content: [{ text: 'talk like a pirate' }], role: 'system' }, + { content: [{ text: 'hi pirateChat' }], role: 'user' }, + { + content: [ + { text: 'Echo: system: talk like a pirate,hi pirateChat' }, + { text: '; config: {}' }, + ], + role: 'model', + }, + ], + }, + }); + }); + + it('maintains history in the session with streaming', async () => { + const session = ai.createSession(); + const chat = session.chat(); + + let { response, stream } = await chat.sendStream('hi'); + + let chunks: string[] = []; + for await (const chunk of stream) { + chunks.push(chunk.text); + } + assert.strictEqual((await response).text, 'Echo: hi; config: {}'); + assert.deepStrictEqual(chunks, ['3', '2', '1']); + + ({ response, stream } = await chat.sendStream('bye')); + + chunks = []; + for await (const chunk of stream) { + chunks.push(chunk.text); + } + + assert.deepStrictEqual(chunks, ['3', '2', '1']); + assert.strictEqual( + (await response).text, + 'Echo: hi,Echo: hi,; config: {},bye; config: {}' + ); + assert.deepStrictEqual((await response).messages, [ + { content: [{ text: 'hi' }], role: 'user' }, + { + role: 'model', + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], + }, + { content: [{ text: 'bye' }], role: 'user' }, + { + role: 'model', + content: [ + { text: 'Echo: hi,Echo: hi,; config: {},bye' }, + { text: '; config: {}' }, + ], + }, + ]); + }); + + it('stores state and messages in the store', async () => { + const store = new TestMemorySessionStore(); + const session = ai.createSession({ + store, + initialState: { + foo: 'bar', + }, + }); + const chat = session.chat(); + + await chat.send('hi'); + await chat.send('bye'); + + const state = await store.get(session.id); + delete state.id; + assert.deepStrictEqual(state, { + state: { + foo: 'bar', + }, + threads: { + main: [ + { content: [{ text: 'hi' }], role: 'user' }, + { + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], + role: 'model', + }, + { content: [{ text: 'bye' }], role: 'user' }, + { + content: [ + { text: 'Echo: hi,Echo: hi,; config: {},bye' }, + { text: '; config: {}' }, + ], + role: 'model', + }, + ], + }, + }); + }); + + describe('loadChat', () => { + it('loads session from store', async () => { + const store = new TestMemorySessionStore(); + // init the store + const originalSession = ai.createSession({ store }); + const originalMainChat = originalSession.chat({ + config: { + temperature: 1, + }, + }); + await originalMainChat.send('hi'); + await originalMainChat.send('bye'); + + const sessionId = originalSession.id; + + // load + const session = await ai.loadSession(sessionId, { store }); + const mainChat = session.chat(); + assert.deepStrictEqual(mainChat.messages, [ + { content: [{ text: 'hi' }], role: 'user' }, + { + role: 'model', + content: [ + { text: 'Echo: hi' }, + { text: '; config: {"temperature":1}' }, + ], + }, + { + content: [{ text: 'bye' }], + role: 'user', + }, + { + content: [ + { text: 'Echo: hi,Echo: hi,; config: {"temperature":1},bye' }, + { text: '; config: {"temperature":1}' }, + ], + role: 'model', + }, + ]); + let response = await mainChat.send('hi again'); + assert.strictEqual( + response.text, + 'Echo: hi,Echo: hi,; config: {"temperature":1},bye,Echo: hi,Echo: hi,; config: {"temperature":1},bye,; config: {"temperature":1},hi again; config: {}' + ); + assert.deepStrictEqual(mainChat.messages, [ + { content: [{ text: 'hi' }], role: 'user' }, + { + role: 'model', + content: [ + { text: 'Echo: hi' }, + { text: '; config: {"temperature":1}' }, + ], + }, + { + content: [{ text: 'bye' }], + role: 'user', + }, + { + content: [ + { text: 'Echo: hi,Echo: hi,; config: {"temperature":1},bye' }, + { text: '; config: {"temperature":1}' }, + ], + role: 'model', + }, + { content: [{ text: 'hi again' }], role: 'user' }, + { + role: 'model', + content: [ + { + text: 'Echo: hi,Echo: hi,; config: {"temperature":1},bye,Echo: hi,Echo: hi,; config: {"temperature":1},bye,; config: {"temperature":1},hi again', + }, + { text: '; config: {}' }, + ], + }, + ]); + + const state = await store.get(sessionId); + assert.deepStrictEqual(state?.threads, { + main: [ + { content: [{ text: 'hi' }], role: 'user' }, + { + role: 'model', + content: [ + { text: 'Echo: hi' }, + { text: '; config: {"temperature":1}' }, + ], + }, + { + content: [{ text: 'bye' }], + role: 'user', + }, + { + content: [ + { text: 'Echo: hi,Echo: hi,; config: {"temperature":1},bye' }, + { text: '; config: {"temperature":1}' }, + ], + role: 'model', + }, + { content: [{ text: 'hi again' }], role: 'user' }, + { + role: 'model', + content: [ + { + text: 'Echo: hi,Echo: hi,; config: {"temperature":1},bye,Echo: hi,Echo: hi,; config: {"temperature":1},bye,; config: {"temperature":1},hi again', + }, + { text: '; config: {}' }, + ], + }, + ], + }); + }); + }); +}); diff --git a/js/package.json b/js/package.json index bf2aac4ea..368fbb71d 100644 --- a/js/package.json +++ b/js/package.json @@ -22,5 +22,5 @@ "only-allow": "^1.2.1", "typescript": "^4.9.0" }, - "packageManager": "pnpm@9.12.0+sha256.a61b67ff6cc97af864564f4442556c22a04f2e5a7714fbee76a1011361d9b726" + "packageManager": "pnpm@9.12.2+sha256.2ef6e547b0b07d841d605240dce4d635677831148cd30f6d564b8f4f928f73d2" } diff --git a/js/plugins/chroma/package.json b/js/plugins/chroma/package.json index c14cb31fc..14a79d909 100644 --- a/js/plugins/chroma/package.json +++ b/js/plugins/chroma/package.json @@ -13,7 +13,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/plugins/chroma/src/index.ts b/js/plugins/chroma/src/index.ts index 4c71e52f7..6b2ba5dcc 100644 --- a/js/plugins/chroma/src/index.ts +++ b/js/plugins/chroma/src/index.ts @@ -28,18 +28,13 @@ import { import { Document, EmbedderArgument, - PluginProvider, - embed, - genkitPlugin, + Genkit, indexerRef, retrieverRef, z, } from 'genkit'; -import { - CommonRetrieverOptionsSchema, - defineIndexer, - defineRetriever, -} from 'genkit/retriever'; +import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; +import { CommonRetrieverOptionsSchema } from 'genkit/retriever'; import { Md5 } from 'ts-md5'; export { IncludeEnum }; @@ -59,38 +54,28 @@ type ChromaClientParams = | NativeChromaClientParams | (() => Promise); +type ChromaPluginParams< + EmbedderCustomOptions extends z.ZodTypeAny = z.ZodTypeAny, +> = { + clientParams?: ChromaClientParams; + collectionName: string; + createCollectionIfMissing?: boolean; + embedder: EmbedderArgument; + embedderOptions?: z.infer; +}[]; + /** * Chroma plugin that provides the Chroma retriever and indexer */ export function chroma( - params: { - clientParams?: ChromaClientParams; - collectionName: string; - createCollectionIfMissing?: boolean; - embedder: EmbedderArgument; - embedderOptions?: z.infer; - }[] -): PluginProvider { - const plugin = genkitPlugin( - 'chroma', - async ( - params: { - clientParams?: ChromaClientParams; - collectionName: string; - createCollectionIfMissing?: boolean; - embedder: EmbedderArgument; - embedderOptions?: z.infer; - }[] - ) => ({ - retrievers: params.map((i) => chromaRetriever(i)), - indexers: params.map((i) => chromaIndexer(i)), - }) - ); - return plugin(params); + params: ChromaPluginParams +): GenkitPlugin { + return genkitPlugin('chroma', async (ai: Genkit) => { + params.map((i) => chromaRetriever(ai, i)); + params.map((i) => chromaIndexer(ai, i)); + }); } -export default chroma; - export const chromaRetrieverRef = (params: { collectionName: string; displayName?: string; @@ -120,17 +105,18 @@ export const chromaIndexerRef = (params: { /** * Configures a Chroma vector store retriever. */ -export function chromaRetriever< - EmbedderCustomOptions extends z.ZodTypeAny, ->(params: { - clientParams?: ChromaClientParams; - collectionName: string; - createCollectionIfMissing?: boolean; - embedder: EmbedderArgument; - embedderOptions?: z.infer; -}) { +export function chromaRetriever( + ai: Genkit, + params: { + clientParams?: ChromaClientParams; + collectionName: string; + createCollectionIfMissing?: boolean; + embedder: EmbedderArgument; + embedderOptions?: z.infer; + } +) { const { embedder, collectionName, embedderOptions } = params; - return defineRetriever( + return ai.defineRetriever( { name: `chroma/${collectionName}`, configSchema: ChromaRetrieverOptionsSchema.optional(), @@ -149,7 +135,7 @@ export function chromaRetriever< }); } - const embedding = await embed({ + const embedding = await ai.embed({ embedder, content, options: embedderOptions, @@ -191,20 +177,21 @@ export function chromaRetriever< /** * Configures a Chroma indexer. */ -export function chromaIndexer< - EmbedderCustomOptions extends z.ZodTypeAny, ->(params: { - clientParams?: ChromaClientParams; - collectionName: string; - createCollectionIfMissing?: boolean; - embedder: EmbedderArgument; - embedderOptions?: z.infer; -}) { +export function chromaIndexer( + ai: Genkit, + params: { + clientParams?: ChromaClientParams; + collectionName: string; + createCollectionIfMissing?: boolean; + embedder: EmbedderArgument; + embedderOptions?: z.infer; + } +) { const { collectionName, embedder, embedderOptions } = { ...params, }; - return defineIndexer( + return ai.defineIndexer( { name: `chroma/${params.collectionName}`, configSchema: ChromaIndexerOptionsSchema, @@ -226,7 +213,7 @@ export function chromaIndexer< const embeddings = await Promise.all( docs.map((doc) => - embed({ + ai.embed({ embedder, content: doc, options: embedderOptions, @@ -243,7 +230,7 @@ export function chromaIndexer< return { id, value, - document: docs[i].text(), + document: docs[i].text, metadata, }; }); @@ -262,13 +249,16 @@ export function chromaIndexer< */ export async function createChromaCollection< EmbedderCustomOptions extends z.ZodTypeAny, ->(params: { - name: string; - clientParams?: ChromaClientParams; - metadata?: CollectionMetadata; - embedder?: EmbedderArgument; - embedderOptions?: z.infer; -}) { +>( + ai: Genkit, + params: { + name: string; + clientParams?: ChromaClientParams; + metadata?: CollectionMetadata; + embedder?: EmbedderArgument; + embedderOptions?: z.infer; + } +) { let chromaEmbedder: IEmbeddingFunction | undefined = undefined; const embedder = params.embedder; if (!!embedder) { @@ -276,7 +266,7 @@ export async function createChromaCollection< generate(texts: string[]) { return Promise.all( texts.map((text) => - embed({ + ai.embed({ embedder, content: text, options: params.embedderOptions, diff --git a/js/plugins/dev-local-vectorstore/package.json b/js/plugins/dev-local-vectorstore/package.json index 63ad4a706..deb3a56af 100644 --- a/js/plugins/dev-local-vectorstore/package.json +++ b/js/plugins/dev-local-vectorstore/package.json @@ -10,7 +10,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/plugins/dev-local-vectorstore/src/index.ts b/js/plugins/dev-local-vectorstore/src/index.ts index 2caeb80fa..34859b916 100644 --- a/js/plugins/dev-local-vectorstore/src/index.ts +++ b/js/plugins/dev-local-vectorstore/src/index.ts @@ -16,12 +16,11 @@ import similarity from 'compute-cosine-similarity'; import * as fs from 'fs'; -import { genkitPlugin, PluginProvider, z } from 'genkit'; -import { embed, EmbedderArgument } from 'genkit/embedder'; +import { Genkit, z } from 'genkit'; +import { EmbedderArgument } from 'genkit/embedder'; +import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; import { CommonRetrieverOptionsSchema, - defineIndexer, - defineRetriever, Document, DocumentData, indexerRef, @@ -72,15 +71,11 @@ interface Params { */ export function devLocalVectorstore( params: Params[] -): PluginProvider { - const plugin = genkitPlugin( - 'devLocalVectorstore', - async (params: Params[]) => ({ - retrievers: params.map((p) => configureDevLocalRetriever(p)), - indexers: params.map((p) => configureDevLocalIndexer(p)), - }) - ); - return plugin(params); +): GenkitPlugin { + return genkitPlugin('devLocalVectorstore', async (ai) => { + params.map((p) => configureDevLocalRetriever(ai, p)); + params.map((p) => configureDevLocalIndexer(ai, p)); + }); } export default devLocalVectorstore; @@ -113,18 +108,21 @@ export function devLocalIndexerRef(indexName: string) { async function importDocumentsToLocalVectorstore< EmbedderCustomOptions extends z.ZodTypeAny, ->(params: { - indexName: string; - docs: Array; - embedder: EmbedderArgument; - embedderOptions?: z.infer; -}) { +>( + ai: Genkit, + params: { + indexName: string; + docs: Array; + embedder: EmbedderArgument; + embedderOptions?: z.infer; + } +) { const { docs, embedder, embedderOptions } = { ...params }; const data = loadFilestore(params.indexName); await Promise.all( docs.map(async (doc) => { - const embedding = await embed({ + const embedding = await ai.embed({ embedder, content: doc, options: embedderOptions, @@ -168,13 +166,16 @@ async function getClosestDocuments< */ export function configureDevLocalRetriever< EmbedderCustomOptions extends z.ZodTypeAny, ->(params: { - indexName: string; - embedder: EmbedderArgument; - embedderOptions?: z.infer; -}) { +>( + ai: Genkit, + params: { + indexName: string; + embedder: EmbedderArgument; + embedderOptions?: z.infer; + } +) { const { embedder, embedderOptions } = params; - const vectorstore = defineRetriever( + const vectorstore = ai.defineRetriever( { name: `devLocalVectorstore/${params.indexName}`, configSchema: CommonRetrieverOptionsSchema, @@ -182,7 +183,7 @@ export function configureDevLocalRetriever< async (content, options) => { const db = loadFilestore(params.indexName); - const embedding = await embed({ + const embedding = await ai.embed({ embedder, content, options: embedderOptions, @@ -204,16 +205,19 @@ export function configureDevLocalRetriever< */ export function configureDevLocalIndexer< EmbedderCustomOptions extends z.ZodTypeAny, ->(params: { - indexName: string; - embedder: EmbedderArgument; - embedderOptions?: z.infer; -}) { +>( + ai: Genkit, + params: { + indexName: string; + embedder: EmbedderArgument; + embedderOptions?: z.infer; + } +) { const { embedder, embedderOptions } = params; - const vectorstore = defineIndexer( + const vectorstore = ai.defineIndexer( { name: `devLocalVectorstore/${params.indexName}` }, async (docs) => { - await importDocumentsToLocalVectorstore({ + await importDocumentsToLocalVectorstore(ai, { indexName: params.indexName, docs, embedder, diff --git a/js/plugins/dotprompt/package.json b/js/plugins/dotprompt/package.json index 2a65c33c8..f069e57d1 100644 --- a/js/plugins/dotprompt/package.json +++ b/js/plugins/dotprompt/package.json @@ -9,7 +9,7 @@ "prompting", "templating" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/plugins/dotprompt/src/index.ts b/js/plugins/dotprompt/src/index.ts index 8adf105a8..4f3f60f53 100644 --- a/js/plugins/dotprompt/src/index.ts +++ b/js/plugins/dotprompt/src/index.ts @@ -14,12 +14,7 @@ * limitations under the License. */ -import { - genkitPlugin, - InitializedPlugin, - PluginProvider, -} from '@genkit-ai/core'; - +import { Registry } from '@genkit-ai/core/registry'; import { readFileSync } from 'fs'; import { basename } from 'path'; import { @@ -32,7 +27,7 @@ import { loadPromptFolder, lookupPrompt } from './registry.js'; export { type PromptMetadata } from './metadata.js'; export { defineHelper, definePartial } from './template.js'; -export { defineDotprompt, Dotprompt, PromptGenerateOptions }; +export { defineDotprompt, Dotprompt, loadPromptFolder, PromptGenerateOptions }; export interface DotpromptPluginOptions { // Directory to look for .prompt files. @@ -43,24 +38,16 @@ export interface DotpromptPluginOptions { dir: string; } -export function dotprompt( - params: DotpromptPluginOptions = { dir: './prompts' } -): PluginProvider { - const plugin = genkitPlugin( - 'dotprompt', - async (options: DotpromptPluginOptions): Promise => { - await loadPromptFolder(options.dir); - return {}; - } - ); - return plugin(params); -} - export async function prompt( + registry: Registry, name: string, options?: { variant?: string } ): Promise> { - return (await lookupPrompt(name, options?.variant)) as Dotprompt; + return (await lookupPrompt( + registry, + name, + options?.variant + )) as Dotprompt; } export function promptRef( @@ -70,19 +57,22 @@ export function promptRef( return new DotpromptRef(name, options); } -export function loadPromptFile(path: string): Dotprompt { +export function loadPromptFile(registry: Registry, path: string): Dotprompt { return Dotprompt.parse( + registry, basename(path).split('.')[0], readFileSync(path, 'utf-8') ); } export async function loadPromptUrl( + registry: Registry, + name: string, url: string ): Promise { const fetch = (await import('node-fetch')).default; const response = await fetch(url); const text = await response.text(); - return Dotprompt.parse(name, text); + return Dotprompt.parse(registry, name, text); } diff --git a/js/plugins/dotprompt/src/metadata.ts b/js/plugins/dotprompt/src/metadata.ts index 79e59e122..165176919 100644 --- a/js/plugins/dotprompt/src/metadata.ts +++ b/js/plugins/dotprompt/src/metadata.ts @@ -25,7 +25,7 @@ import { } from '@genkit-ai/ai/model'; import { ToolArgument } from '@genkit-ai/ai/tool'; import { z } from '@genkit-ai/core'; -import { lookupSchema } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { JSONSchema, parseSchema, toJsonSchema } from '@genkit-ai/core/schema'; import { picoschema } from './picoschema.js'; @@ -39,6 +39,9 @@ export interface PromptMetadata< /** The name of the prompt. */ name?: string; + /** Description (intent) of the prompt, used when prompt passed as tool to an LLM. */ + description?: string; + /** The variant name for the prompt. */ variant?: string; @@ -119,27 +122,33 @@ function stripUndefinedOrNull(obj: any) { return obj; } -function fmSchemaToSchema(fmSchema: any) { +function fmSchemaToSchema(registry: Registry, fmSchema: any) { if (!fmSchema) return {}; - if (typeof fmSchema === 'string') return lookupSchema(fmSchema); + if (typeof fmSchema === 'string') return registry.lookupSchema(fmSchema); return { jsonSchema: picoschema(fmSchema) }; } -export function toMetadata(attributes: unknown): Partial { +export function toMetadata( + registry: Registry, + attributes: unknown +): Partial { const fm = parseSchema>(attributes, { schema: PromptFrontmatterSchema, }); let input: PromptMetadata['input'] | undefined; if (fm.input) { - input = { default: fm.input.default, ...fmSchemaToSchema(fm.input.schema) }; + input = { + default: fm.input.default, + ...fmSchemaToSchema(registry, fm.input.schema), + }; } let output: PromptMetadata['output'] | undefined; if (fm.output) { output = { format: fm.output.format, - ...fmSchemaToSchema(fm.output.schema), + ...fmSchemaToSchema(registry, fm.output.schema), }; } diff --git a/js/plugins/dotprompt/src/prompt.ts b/js/plugins/dotprompt/src/prompt.ts index c6d45d140..b216c21a8 100644 --- a/js/plugins/dotprompt/src/prompt.ts +++ b/js/plugins/dotprompt/src/prompt.ts @@ -27,6 +27,7 @@ import { import { MessageData, ModelArgument } from '@genkit-ai/ai/model'; import { DocumentData } from '@genkit-ai/ai/retriever'; import { GenkitError, z } from '@genkit-ai/core'; +import { Registry } from '@genkit-ai/core/registry'; import { parseSchema } from '@genkit-ai/core/schema'; import { runInNewSpan, @@ -49,13 +50,16 @@ export type PromptData = PromptFrontmatter & { template: string }; export type PromptGenerateOptions< V = unknown, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, -> = Omit, 'prompt' | 'model'> & { +> = Omit< + GenerateOptions, + 'prompt' | 'input' | 'model' +> & { model?: ModelArgument; input?: V; }; interface RenderMetadata { - context?: DocumentData[]; + docs?: DocumentData[]; messages?: MessageData[]; } @@ -73,16 +77,22 @@ export class Dotprompt implements PromptMetadata { tools?: PromptMetadata['tools']; config?: PromptMetadata['config']; + private _promptAction?: PromptAction; + private _render: (input: I, options?: RenderMetadata) => MessageData[]; - static parse(name: string, source: string) { + static parse(registry: Registry, name: string, source: string) { try { const fmResult = (fm as any)(source.trimStart(), { allowUnsafe: false, }) as FrontMatterResult; return new Dotprompt( - { ...toMetadata(fmResult.attributes), name } as PromptMetadata, + registry, + { + ...toMetadata(registry, fmResult.attributes), + name, + } as PromptMetadata, fmResult.body ); } catch (e: any) { @@ -94,7 +104,7 @@ export class Dotprompt implements PromptMetadata { } } - static fromAction(action: PromptAction): Dotprompt { + static fromAction(registry: Registry, action: PromptAction): Dotprompt { const { template, ...options } = action.__action.metadata!.prompt; const pm = options as PromptMetadata; if (pm.input?.schema) { @@ -104,11 +114,15 @@ export class Dotprompt implements PromptMetadata { if (pm.output?.schema) { pm.output.jsonSchema = options.output?.schema; } - const prompt = new Dotprompt(options as PromptMetadata, template); + const prompt = new Dotprompt(registry, options as PromptMetadata, template); return prompt; } - constructor(options: PromptMetadata, template: string) { + constructor( + private registry: Registry, + options: PromptMetadata, + template: string + ) { this.name = options.name || 'untitledPrompt'; this.variant = options.variant; this.model = options.model; @@ -164,11 +178,12 @@ export class Dotprompt implements PromptMetadata { return { ...toFrontmatter(this), template: this.template }; } - define(options?: { ns: string }): void { - definePrompt( + define(options?: { ns?: string; description?: string }): void { + this._promptAction = definePrompt( + this.registry, { name: registryDefinitionKey(this.name, this.variant, options?.ns), - description: 'Defined by Dotprompt', + description: options?.description ?? 'Defined by Dotprompt', inputSchema: this.input?.schema, inputJsonSchema: this.input?.jsonSchema, metadata: { @@ -176,24 +191,38 @@ export class Dotprompt implements PromptMetadata { prompt: this.toJSON(), }, }, - async (input?: I) => toGenerateRequest(this.render({ input })) + async (input?: I) => + toGenerateRequest(this.registry, this.render({ input })) ); } + get promptAction(): PromptAction | undefined { + return this._promptAction; + } + private _generateOptions< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, >(options: PromptGenerateOptions): GenerateOptions { const messages = this.renderMessages(options.input, { messages: options.messages, - context: options.context, + docs: options.docs, }); + let renderedPrompt; + let renderedMessages; + if (messages.length > 0 && messages[messages.length - 1].role === 'user') { + renderedPrompt = messages[messages.length - 1].content; + renderedMessages = messages.slice(0, messages.length - 1); + } else { + renderedPrompt = undefined; + renderedMessages = messages; + } return { model: options.model || this.model!, config: { ...this.config, ...options.config }, - messages: messages.slice(0, messages.length - 1), - prompt: messages[messages.length - 1].content, - context: options.context, + messages: renderedMessages, + prompt: renderedPrompt, + docs: options.docs, output: { format: options.output?.format || this.output?.format || undefined, schema: options.output?.schema || this.output?.schema, @@ -258,7 +287,7 @@ export class Dotprompt implements PromptMetadata { opt: PromptGenerateOptions ): Promise>> { const renderedOpts = this.renderInNewSpan(opt); - return generate(renderedOpts); + return generate(this.registry, renderedOpts); } /** @@ -271,7 +300,7 @@ export class Dotprompt implements PromptMetadata { opt: PromptGenerateOptions ): Promise { const renderedOpts = await this.renderInNewSpan(opt); - return generateStream(renderedOpts); + return generateStream(this.registry, renderedOpts); } } @@ -294,9 +323,10 @@ export class DotpromptRef { } /** Loads the prompt which is referenced. */ - async loadPrompt(): Promise> { + async loadPrompt(registry: Registry): Promise> { if (this._prompt) return this._prompt; this._prompt = (await lookupPrompt( + registry, this.name, this.variant, this.dir @@ -315,9 +345,10 @@ export class DotpromptRef { CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, opt: PromptGenerateOptions ): Promise>> { - const prompt = await this.loadPrompt(); + const prompt = await this.loadPrompt(registry); return prompt.generate(opt); } @@ -331,9 +362,11 @@ export class DotpromptRef { CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, + opt: PromptGenerateOptions ): Promise> { - const prompt = await this.loadPrompt(); + const prompt = await this.loadPrompt(registry); return prompt.render(opt); } } @@ -349,10 +382,11 @@ export function defineDotprompt< I extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, options: PromptMetadata, template: string ): Dotprompt> { - const prompt = new Dotprompt(options, template); - prompt.define(); + const prompt = new Dotprompt(registry, options, template); + prompt.define({ description: options.description }); return prompt; } diff --git a/js/plugins/dotprompt/src/registry.ts b/js/plugins/dotprompt/src/registry.ts index 3397f6b56..f0af18eec 100644 --- a/js/plugins/dotprompt/src/registry.ts +++ b/js/plugins/dotprompt/src/registry.ts @@ -17,7 +17,7 @@ import { PromptAction } from '@genkit-ai/ai'; import { GenkitError } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { existsSync, readdir, readFileSync } from 'fs'; import { basename, join, resolve } from 'path'; import { Dotprompt } from './prompt.js'; @@ -37,23 +37,27 @@ export function registryLookupKey(name: string, variant?: string, ns?: string) { } export async function lookupPrompt( + registry: Registry, name: string, variant?: string, dir: string = './prompts' ): Promise { let registryPrompt = - (await lookupAction(registryLookupKey(name, variant))) || - (await lookupAction(registryLookupKey(name, variant, 'dotprompt'))); + (await registry.lookupAction(registryLookupKey(name, variant))) || + (await registry.lookupAction( + registryLookupKey(name, variant, 'dotprompt') + )); if (registryPrompt) { - return Dotprompt.fromAction(registryPrompt as PromptAction); + return Dotprompt.fromAction(registry, registryPrompt as PromptAction); } else { // Handle the case where initialization isn't complete // or a file was added after the prompt folder was loaded. - return maybeLoadPrompt(dir, name, variant); + return maybeLoadPrompt(registry, dir, name, variant); } } async function maybeLoadPrompt( + registry: Registry, dir: string, name: string, variant?: string @@ -62,7 +66,7 @@ async function maybeLoadPrompt( const promptFolder = resolve(dir); const promptExists = existsSync(join(promptFolder, expectedFileName)); if (promptExists) { - return loadPrompt(promptFolder, expectedFileName); + return loadPrompt(registry, promptFolder, expectedFileName); } else { throw new GenkitError({ source: 'dotprompt', @@ -73,6 +77,8 @@ async function maybeLoadPrompt( } export async function loadPromptFolder( + registry: Registry, + dir: string = './prompts' ): Promise { const promptsPath = resolve(dir); @@ -114,7 +120,7 @@ export async function loadPromptFolder( .replace(`${promptsPath}/`, '') .replace(/\//g, '-'); } - loadPrompt(dirEnt.path, dirEnt.name, prefix); + loadPrompt(registry, dirEnt.path, dirEnt.name, prefix); } } }); @@ -129,6 +135,7 @@ export async function loadPromptFolder( } export function loadPrompt( + registry: Registry, path: string, filename: string, prefix = '' @@ -141,7 +148,7 @@ export function loadPrompt( variant = parts[1]; } const source = readFileSync(join(path, filename), 'utf8'); - const prompt = Dotprompt.parse(name, source); + const prompt = Dotprompt.parse(registry, name, source); if (variant) { prompt.variant = variant; } diff --git a/js/plugins/dotprompt/tests/prompt_test.ts b/js/plugins/dotprompt/tests/prompt_test.ts index f39c316da..0e6d31e1a 100644 --- a/js/plugins/dotprompt/tests/prompt_test.ts +++ b/js/plugins/dotprompt/tests/prompt_test.ts @@ -16,7 +16,7 @@ import { defineModel, ModelAction } from '@genkit-ai/ai/model'; import { z } from '@genkit-ai/core'; -import { Registry, runWithRegistry } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { defineJsonSchema, defineSchema, @@ -29,11 +29,12 @@ import { defineDotprompt, Dotprompt, prompt, promptRef } from '../src/index.js'; import { PromptMetadata } from '../src/metadata.js'; function testPrompt( + registry: Registry, model: ModelAction, template: string, options?: Partial ): Dotprompt { - return new Dotprompt({ name: 'test', model, ...options }, template); + return new Dotprompt(registry, { name: 'test', model, ...options }, template); } describe('Prompt', () => { @@ -44,184 +45,194 @@ describe('Prompt', () => { describe('#render', () => { it('should render variables', () => { - runWithRegistry(registry, () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt(model, `Hello {{name}}, how are you?`); - - const rendered = prompt.render({ input: { name: 'Michael' } }); - assert.deepStrictEqual(rendered.prompt, [ - { text: 'Hello Michael, how are you?' }, - ]); - }); + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt( + registry, + model, + `Hello {{name}}, how are you?` + ); + + const rendered = prompt.render({ input: { name: 'Michael' } }); + assert.deepStrictEqual(rendered.prompt, [ + { text: 'Hello Michael, how are you?' }, + ]); }); it('should render default variables', () => { - runWithRegistry(registry, () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt(model, `Hello {{name}}, how are you?`, { + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt( + registry, + model, + `Hello {{name}}, how are you?`, + { input: { default: { name: 'Fellow Human' } }, - }); + } + ); - const rendered = prompt.render({ input: {} }); - assert.deepStrictEqual(rendered.prompt, [ - { - text: 'Hello Fellow Human, how are you?', - }, - ]); - }); + const rendered = prompt.render({ input: {} }); + assert.deepStrictEqual(rendered.prompt, [ + { + text: 'Hello Fellow Human, how are you?', + }, + ]); }); it('rejects input not matching the schema', async () => { - await runWithRegistry(registry, async () => { - const invalidSchemaPrompt = defineDotprompt( - { - name: 'invalidInput', - model: 'echo', - input: { - jsonSchema: { - properties: { foo: { type: 'boolean' } }, - required: ['foo'], - }, + const invalidSchemaPrompt = defineDotprompt( + registry, + { + name: 'invalidInput', + model: 'echo', + input: { + jsonSchema: { + properties: { foo: { type: 'boolean' } }, + required: ['foo'], }, }, - `You asked for {{foo}}.` - ); + }, + `You asked for {{foo}}.` + ); - await assert.rejects(async () => { - invalidSchemaPrompt.render({ input: { foo: 'baz' } }); - }, ValidationError); - }); + await assert.rejects(async () => { + invalidSchemaPrompt.render({ input: { foo: 'baz' } }); + }, ValidationError); }); it('should render with overridden fields', () => { - runWithRegistry(registry, () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt(model, `Hello {{name}}, how are you?`); - - const streamingCallback = (c) => console.log(c); - const middleware = []; - - const rendered = prompt.render({ - input: { name: 'Michael' }, - streamingCallback, - returnToolRequests: true, - use: middleware, - }); - assert.strictEqual(rendered.streamingCallback, streamingCallback); - assert.strictEqual(rendered.returnToolRequests, true); - assert.strictEqual(rendered.use, middleware); + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt( + registry, + model, + `Hello {{name}}, how are you?` + ); + + const streamingCallback = (c) => console.log(c); + const middleware = []; + + const rendered = prompt.render({ + input: { name: 'Michael' }, + streamingCallback, + returnToolRequests: true, + use: middleware, }); + assert.strictEqual(rendered.streamingCallback, streamingCallback); + assert.strictEqual(rendered.returnToolRequests, true); + assert.strictEqual(rendered.use, middleware); }); it('should support system prompt with history', () => { - runWithRegistry(registry, () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt( - model, - `{{ role "system" }}Testing system {{name}}` - ); - - const rendered = prompt.render({ - input: { name: 'Michael' }, - messages: [ - { role: 'user', content: [{ text: 'history 1' }] }, - { role: 'model', content: [{ text: 'history 2' }] }, - { role: 'user', content: [{ text: 'history 3' }] }, - ], - }); - assert.deepStrictEqual(rendered.messages, [ - { role: 'system', content: [{ text: 'Testing system Michael' }] }, + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt( + registry, + model, + `{{ role "system" }}Testing system {{name}}` + ); + + const rendered = prompt.render({ + input: { name: 'Michael' }, + messages: [ { role: 'user', content: [{ text: 'history 1' }] }, { role: 'model', content: [{ text: 'history 2' }] }, - ]); - assert.deepStrictEqual(rendered.prompt, [{ text: 'history 3' }]); + { role: 'user', content: [{ text: 'history 3' }] }, + ], }); + assert.deepStrictEqual(rendered.messages, [ + { role: 'system', content: [{ text: 'Testing system Michael' }] }, + { role: 'user', content: [{ text: 'history 1' }] }, + { role: 'model', content: [{ text: 'history 2' }] }, + ]); + assert.deepStrictEqual(rendered.prompt, [{ text: 'history 3' }]); }); }); describe('#generate', () => { it('renders and calls the model', async () => { - await runWithRegistry(registry, async () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt(model, `Hello {{name}}, how are you?`); - const response = await prompt.generate({ input: { name: 'Bob' } }); - assert.equal(response.text(), `Hello Bob, how are you?`); - }); + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt( + registry, + model, + `Hello {{name}}, how are you?` + ); + const response = await prompt.generate({ input: { name: 'Bob' } }); + assert.equal(response.text, `Hello Bob, how are you?`); }); it('rejects input not matching the schema', async () => { - await runWithRegistry(registry, async () => { - const invalidSchemaPrompt = defineDotprompt( - { - name: 'invalidInput', - model: 'echo', - input: { - jsonSchema: { - properties: { foo: { type: 'boolean' } }, - required: ['foo'], - }, + const invalidSchemaPrompt = defineDotprompt( + registry, + { + name: 'invalidInput', + model: 'echo', + input: { + jsonSchema: { + properties: { foo: { type: 'boolean' } }, + required: ['foo'], }, }, - `You asked for {{foo}}.` - ); + }, + `You asked for {{foo}}.` + ); - await assert.rejects(async () => { - await invalidSchemaPrompt.generate({ input: { foo: 'baz' } }); - }, ValidationError); - }); + await assert.rejects(async () => { + await invalidSchemaPrompt.generate({ input: { foo: 'baz' } }); + }, ValidationError); }); }); describe('#toJSON', () => { it('should convert zod to json schema', () => { - runWithRegistry(registry, () => { - const schema = z.object({ name: z.string() }); - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt(model, `hello {{name}}`, { - input: { schema }, - }); - - assert.deepStrictEqual( - prompt.toJSON().input?.schema, - toJsonSchema({ schema }) - ); + const schema = z.object({ name: z.string() }); + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt(registry, model, `hello {{name}}`, { + input: { schema }, }); + + assert.deepStrictEqual( + prompt.toJSON().input?.schema, + toJsonSchema({ schema }) + ); }); }); @@ -230,6 +241,7 @@ describe('Prompt', () => { assert.throws( () => { Dotprompt.parse( + registry, 'example', `--- input: { @@ -247,6 +259,7 @@ This is the rest of the prompt` it('should parse picoschema', () => { const p = Dotprompt.parse( + registry, 'example', `--- input: @@ -277,54 +290,53 @@ output: }); it('should use registered schemas', () => { - runWithRegistry(registry, () => { - const MyInput = defineSchema('MyInput', z.number()); - defineJsonSchema('MyOutput', { type: 'boolean' }); + const MyInput = defineSchema(registry, 'MyInput', z.number()); + defineJsonSchema(registry, 'MyOutput', { type: 'boolean' }); - const p = Dotprompt.parse( - 'example2', - `--- + const p = Dotprompt.parse( + registry, + 'example2', + `--- input: schema: MyInput output: schema: MyOutput ---` - ); + ); - assert.deepEqual(p.input, { schema: MyInput }); - assert.deepEqual(p.output, { jsonSchema: { type: 'boolean' } }); - }); + assert.deepEqual(p.input, { schema: MyInput }); + assert.deepEqual(p.output, { jsonSchema: { type: 'boolean' } }); }); }); describe('defineDotprompt', () => { it('registers a prompt and its variant', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'promptName', - model: 'echo', - }, - `This is a prompt.` - ); - - defineDotprompt( - { - name: 'promptName', - variant: 'variantName', - model: 'echo', - }, - `And this is its variant.` - ); - - const basePrompt = await prompt('promptName'); - assert.equal('This is a prompt.', basePrompt.template); + defineDotprompt( + registry, + { + name: 'promptName', + model: 'echo', + }, + `This is a prompt.` + ); - const variantPrompt = await prompt('promptName', { + defineDotprompt( + registry, + { + name: 'promptName', variant: 'variantName', - }); - assert.equal('And this is its variant.', variantPrompt.template); + model: 'echo', + }, + `And this is its variant.` + ); + + const basePrompt = await prompt(registry, 'promptName'); + assert.equal('This is a prompt.', basePrompt.template); + + const variantPrompt = await prompt(registry, 'promptName', { + variant: 'variantName', }); + assert.equal('And this is its variant.', variantPrompt.template); }); }); }); @@ -336,138 +348,153 @@ describe('DotpromptRef', () => { }); it('Should load a prompt correctly', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'promptName', - model: 'echo', - }, - `This is a prompt.` - ); + defineDotprompt( + registry, + { + name: 'promptName', + model: 'echo', + }, + `This is a prompt.` + ); - const ref = promptRef('promptName'); + const ref = promptRef('promptName'); - const p = await ref.loadPrompt(); + const p = await ref.loadPrompt(registry); - const isDotprompt = p instanceof Dotprompt; + const isDotprompt = p instanceof Dotprompt; - assert.equal(isDotprompt, true); - assert.equal(p.template, 'This is a prompt.'); - }); + assert.equal(isDotprompt, true); + assert.equal(p.template, 'This is a prompt.'); }); it('Should generate output correctly using DotpromptRef', async () => { - await runWithRegistry(registry, async () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - defineDotprompt( - { - name: 'generatePrompt', - model: 'echo', - }, - `Hello {{name}}, this is a test prompt.` - ); - - const ref = promptRef('generatePrompt'); - const response = await ref.generate({ input: { name: 'Alice' } }); - - assert.equal(response.text(), 'Hello Alice, this is a test prompt.'); - }); + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + defineDotprompt( + registry, + { + name: 'generatePrompt', + model: 'echo', + }, + `Hello {{name}}, this is a test prompt.` + ); + + const ref = promptRef('generatePrompt'); + const response = await ref.generate(registry, { input: { name: 'Alice' } }); + + assert.equal(response.text, 'Hello Alice, this is a test prompt.'); }); it('Should render correctly using DotpromptRef', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'renderPrompt', - model: 'echo', - }, - `Hi {{name}}, welcome to the system.` - ); - - const ref = promptRef('renderPrompt'); - const rendered = await ref.render({ input: { name: 'Bob' } }); - - assert.deepStrictEqual(rendered.prompt, [ - { text: 'Hi Bob, welcome to the system.' }, - ]); - }); + defineDotprompt( + registry, + { + name: 'renderPrompt', + model: 'echo', + }, + `Hi {{name}}, welcome to the system.` + ); + + const ref = promptRef('renderPrompt'); + const rendered = await ref.render(registry, { input: { name: 'Bob' } }); + + assert.deepStrictEqual(rendered.prompt, [ + { text: 'Hi Bob, welcome to the system.' }, + ]); }); it('Should handle invalid schema input in DotpromptRef', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'invalidSchemaPromptRef', - model: 'echo', - input: { - jsonSchema: { - properties: { foo: { type: 'boolean' } }, - required: ['foo'], - }, + defineDotprompt( + registry, + { + name: 'invalidSchemaPromptRef', + model: 'echo', + input: { + jsonSchema: { + properties: { foo: { type: 'boolean' } }, + required: ['foo'], }, }, - `This is the prompt with foo={{foo}}.` - ); + }, + `This is the prompt with foo={{foo}}.` + ); - const ref = promptRef('invalidSchemaPromptRef'); + const ref = promptRef('invalidSchemaPromptRef'); - await assert.rejects(async () => { - await ref.generate({ input: { foo: 'not_a_boolean' } }); - }, ValidationError); - }); + await assert.rejects(async () => { + await ref.generate(registry, { input: { foo: 'not_a_boolean' } }); + }, ValidationError); }); it('Should support streamingCallback in DotpromptRef', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'streamingCallbackPrompt', - model: 'echo', - }, - `Hello {{name}}, streaming test.` - ); - - const ref = promptRef('streamingCallbackPrompt'); - - const streamingCallback = (chunk) => console.log(chunk); - const options = { - input: { name: 'Charlie' }, - streamingCallback, - returnToolRequests: true, - }; - - const rendered = await ref.render(options); - - assert.strictEqual(rendered.streamingCallback, streamingCallback); - assert.strictEqual(rendered.returnToolRequests, true); - }); + defineDotprompt( + registry, + { + name: 'streamingCallbackPrompt', + model: 'echo', + }, + `Hello {{name}}, streaming test.` + ); + + const ref = promptRef('streamingCallbackPrompt'); + + const streamingCallback = (chunk) => console.log(chunk); + const options = { + input: { name: 'Charlie' }, + streamingCallback, + returnToolRequests: true, + }; + + const rendered = await ref.render(registry, options); + + assert.strictEqual(rendered.streamingCallback, streamingCallback); + assert.strictEqual(rendered.returnToolRequests, true); }); it('Should cache loaded prompt in DotpromptRef', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'cacheTestPrompt', - model: 'echo', - }, - `This is a prompt for cache test.` - ); - - const ref = promptRef('cacheTestPrompt'); - const firstLoad = await ref.loadPrompt(); - const secondLoad = await ref.loadPrompt(); + defineDotprompt( + registry, + { + name: 'cacheTestPrompt', + model: 'echo', + }, + `This is a prompt for cache test.` + ); + + const ref = promptRef('cacheTestPrompt'); + const firstLoad = await ref.loadPrompt(registry); + const secondLoad = await ref.loadPrompt(registry); + + assert.strictEqual( + firstLoad, + secondLoad, + 'Loaded prompts should be identical (cached).' + ); + }); - assert.strictEqual( - firstLoad, - secondLoad, - 'Loaded prompts should be identical (cached).' - ); - }); + it('should render system prompt', () => { + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt(registry, model, `{{ role "system"}} hi`); + + const rendered = prompt.render({ input: {} }); + assert.deepStrictEqual(rendered.messages, [ + { + content: [{ text: ' hi' }], + role: 'system', + }, + ]); }); }); diff --git a/js/plugins/evaluators/package.json b/js/plugins/evaluators/package.json index 8543d5252..82eee18c9 100644 --- a/js/plugins/evaluators/package.json +++ b/js/plugins/evaluators/package.json @@ -11,7 +11,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", @@ -30,6 +30,7 @@ "author": "genkit", "license": "Apache-2.0", "dependencies": { + "@genkit-ai/dotprompt": "workspace:*", "compute-cosine-similarity": "^1.1.0", "node-fetch": "^3.3.2", "path": "^0.12.7" diff --git a/js/plugins/evaluators/src/index.ts b/js/plugins/evaluators/src/index.ts index 072d0e459..41973ca2e 100644 --- a/js/plugins/evaluators/src/index.ts +++ b/js/plugins/evaluators/src/index.ts @@ -14,20 +14,14 @@ * limitations under the License. */ -import { - EmbedderReference, - genkitPlugin, - ModelReference, - PluginProvider, - z, -} from 'genkit'; +import { EmbedderReference, Genkit, ModelReference, z } from 'genkit'; import { BaseEvalDataPoint, - defineEvaluator, EvalResponse, - evaluatorRef, Score, + evaluatorRef, } from 'genkit/evaluator'; +import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; import { answerRelevancyScore, faithfulnessScore, @@ -70,16 +64,10 @@ export function genkitEval< EmbedderCustomOptions extends z.ZodTypeAny, >( params: PluginOptions -): PluginProvider { - const plugin = genkitPlugin( - `${PLUGIN_NAME}`, - async ( - params: PluginOptions - ) => ({ - evaluators: [...genkitEvaluators(params)], - }) - ); - return plugin(params); +): GenkitPlugin { + return genkitPlugin(`${PLUGIN_NAME}`, async (ai: Genkit) => { + genkitEvaluators(ai, params); + }); } export default genkitEval; @@ -101,7 +89,10 @@ function fillScores(dataPoint: BaseEvalDataPoint, score: Score): EvalResponse { export function genkitEvaluators< ModelCustomOptions extends z.ZodTypeAny, EmbedderCustomOptions extends z.ZodTypeAny, ->(params: PluginOptions) { +>( + ai: Genkit, + params: PluginOptions +) { let { metrics, judge, judgeConfig, embedder, embedderOptions } = params; if (!metrics) { metrics = [GenkitMetric.MALICIOUSNESS, GenkitMetric.FAITHFULNESS]; @@ -111,7 +102,8 @@ export function genkitEvaluators< return metrics.map((metric) => { switch (metric) { case GenkitMetric.ANSWER_RELEVANCY: { - return defineEvaluator( + ai.defineIndexer; + return ai.defineEvaluator( { name: `${PLUGIN_NAME}/${metric.toLocaleLowerCase()}`, displayName: 'Answer Relevancy', @@ -120,6 +112,7 @@ export function genkitEvaluators< }, async (datapoint: BaseEvalDataPoint) => { const answerRelevancy = await answerRelevancyScore( + ai, judge, datapoint, embedder!, @@ -131,7 +124,7 @@ export function genkitEvaluators< ); } case GenkitMetric.FAITHFULNESS: { - return defineEvaluator( + return ai.defineEvaluator( { name: `${PLUGIN_NAME}/${metric.toLocaleLowerCase()}`, displayName: 'Faithfulness', @@ -140,6 +133,7 @@ export function genkitEvaluators< }, async (datapoint: BaseEvalDataPoint) => { const faithfulness = await faithfulnessScore( + ai, judge, datapoint, judgeConfig @@ -149,7 +143,7 @@ export function genkitEvaluators< ); } case GenkitMetric.MALICIOUSNESS: { - return defineEvaluator( + return ai.defineEvaluator( { name: `${PLUGIN_NAME}/${metric.toLocaleLowerCase()}`, displayName: 'Maliciousness', @@ -158,6 +152,7 @@ export function genkitEvaluators< }, async (datapoint: BaseEvalDataPoint) => { const maliciousness = await maliciousnessScore( + ai, judge, datapoint, judgeConfig diff --git a/js/plugins/evaluators/src/metrics/answer_relevancy.ts b/js/plugins/evaluators/src/metrics/answer_relevancy.ts index ab5e02b79..833a5ce33 100644 --- a/js/plugins/evaluators/src/metrics/answer_relevancy.ts +++ b/js/plugins/evaluators/src/metrics/answer_relevancy.ts @@ -14,9 +14,10 @@ * limitations under the License. */ +import { loadPromptFile } from '@genkit-ai/dotprompt'; import similarity from 'compute-cosine-similarity'; -import { generate, loadPromptFile, ModelArgument, z } from 'genkit'; -import { embed, EmbedderArgument } from 'genkit/embedder'; +import { Genkit, ModelArgument, z } from 'genkit'; +import { EmbedderArgument } from 'genkit/embedder'; import { BaseEvalDataPoint, Score } from 'genkit/evaluator'; import path from 'path'; import { getDirName } from './helper.js'; @@ -31,6 +32,7 @@ export async function answerRelevancyScore< CustomModelOptions extends z.ZodTypeAny, CustomEmbedderOptions extends z.ZodTypeAny, >( + ai: Genkit, judgeLlm: ModelArgument, dataPoint: BaseEvalDataPoint, embedder: EmbedderArgument, @@ -45,9 +47,10 @@ export async function answerRelevancyScore< throw new Error('Output was not provided'); } const prompt = await loadPromptFile( + ai.registry, path.resolve(getDirName(), '../../prompts/answer_relevancy.prompt') ); - const response = await generate({ + const response = await ai.generate({ model: judgeLlm, config: judgeConfig, prompt: prompt.renderText({ @@ -59,23 +62,23 @@ export async function answerRelevancyScore< schema: AnswerRelevancyResponseSchema, }, }); - const genQuestion = response.output()?.question; + const genQuestion = response.output?.question; if (!genQuestion) throw new Error('Error generating question for answer relevancy'); - const questionEmbed = await embed({ + const questionEmbed = await ai.embed({ embedder, content: dataPoint.input as string, options: embedderOptions, }); - const genQuestionEmbed = await embed({ + const genQuestionEmbed = await ai.embed({ embedder, content: genQuestion, options: embedderOptions, }); const score = cosineSimilarity(questionEmbed, genQuestionEmbed); - const answered = response.output()?.answered === 1; - const isNonCommittal = response.output()?.noncommittal === 1; + const answered = response.output?.answered === 1; + const isNonCommittal = response.output?.noncommittal === 1; const answeredPenalty = !answered ? 0.5 : 0; const adjustedScore = score - answeredPenalty < 0 ? 0 : score - answeredPenalty; diff --git a/js/plugins/evaluators/src/metrics/faithfulness.ts b/js/plugins/evaluators/src/metrics/faithfulness.ts index fa80c6f4b..244d0f10a 100644 --- a/js/plugins/evaluators/src/metrics/faithfulness.ts +++ b/js/plugins/evaluators/src/metrics/faithfulness.ts @@ -14,7 +14,8 @@ * limitations under the License. */ -import { generate, loadPromptFile, ModelArgument, z } from 'genkit'; +import { loadPromptFile } from '@genkit-ai/dotprompt'; +import { Genkit, ModelArgument, z } from 'genkit'; import { BaseEvalDataPoint, Score } from 'genkit/evaluator'; import path from 'path'; import { getDirName } from './helper.js'; @@ -39,6 +40,7 @@ const NliResponseSchema = z.union([ export async function faithfulnessScore< CustomModelOptions extends z.ZodTypeAny, >( + ai: Genkit, judgeLlm: ModelArgument, dataPoint: BaseEvalDataPoint, judgeConfig?: CustomModelOptions @@ -52,9 +54,10 @@ export async function faithfulnessScore< throw new Error('Output was not provided'); } const longFormPrompt = await loadPromptFile( + ai.registry, path.resolve(getDirName(), '../../prompts/faithfulness_long_form.prompt') ); - const longFormResponse = await generate({ + const longFormResponse = await ai.generate({ model: judgeLlm, config: judgeConfig, prompt: longFormPrompt.renderText({ @@ -65,7 +68,7 @@ export async function faithfulnessScore< schema: LongFormResponseSchema, }, }); - const parsedLongFormResponse = longFormResponse.output(); + const parsedLongFormResponse = longFormResponse.output; let statements = parsedLongFormResponse?.statements ?? []; if (statements.length === 0) { throw new Error('No statements returned'); @@ -73,9 +76,10 @@ export async function faithfulnessScore< const allStatements = statements.map((s) => `statement: ${s}`).join('\n'); const allContext = context.join('\n'); const nliPrompt = await loadPromptFile( + ai.registry, path.resolve(getDirName(), '../../prompts/faithfulness_nli.prompt') ); - const response = await generate({ + const response = await ai.generate({ model: judgeLlm, prompt: nliPrompt.renderText({ context: allContext, @@ -85,7 +89,7 @@ export async function faithfulnessScore< schema: NliResponseSchema, }, }); - const parsedResponse = response.output(); + const parsedResponse = response.output; return nliResponseToScore(parsedResponse); } catch (err) { console.debug( diff --git a/js/plugins/evaluators/src/metrics/maliciousness.ts b/js/plugins/evaluators/src/metrics/maliciousness.ts index ad166d2b2..5538cbc25 100644 --- a/js/plugins/evaluators/src/metrics/maliciousness.ts +++ b/js/plugins/evaluators/src/metrics/maliciousness.ts @@ -13,7 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -import { generate, loadPromptFile, ModelArgument, z } from 'genkit'; +import { loadPromptFile } from '@genkit-ai/dotprompt'; +import { Genkit, ModelArgument, z } from 'genkit'; import { BaseEvalDataPoint, Score } from 'genkit/evaluator'; import path from 'path'; import { getDirName } from './helper.js'; @@ -26,6 +27,7 @@ const MaliciousnessResponseSchema = z.object({ export async function maliciousnessScore< CustomModelOptions extends z.ZodTypeAny, >( + ai: Genkit, judgeLlm: ModelArgument, dataPoint: BaseEvalDataPoint, judgeConfig?: CustomModelOptions @@ -37,10 +39,11 @@ export async function maliciousnessScore< } const prompt = await loadPromptFile( + ai.registry, path.resolve(getDirName(), '../../prompts/maliciousness.prompt') ); //TODO: safetySettings are gemini specific - pull these out so they are tied to the LLM - const response = await generate({ + const response = await ai.generate({ model: judgeLlm, config: judgeConfig, prompt: prompt.renderText({ @@ -51,9 +54,9 @@ export async function maliciousnessScore< schema: MaliciousnessResponseSchema, }, }); - const parsedResponse = response.output(); + const parsedResponse = response.output; if (!parsedResponse) { - throw new Error(`Unable to parse evaluator response: ${response.text()}`); + throw new Error(`Unable to parse evaluator response: ${response.text}`); } return { score: 1.0 * (parsedResponse.verdict ? 1 : 0), diff --git a/js/plugins/firebase/jest.config.ts b/js/plugins/firebase/jest.config.ts index cd67fb165..ec7188b86 100644 --- a/js/plugins/firebase/jest.config.ts +++ b/js/plugins/firebase/jest.config.ts @@ -36,7 +36,7 @@ const config: Config = { // A map from regular expressions to paths to transformers transform: { - '^.+\\.[jt]s$': 'ts-jest', + '^.+\\.ts$': 'ts-jest', }, // An array of regexp pattern strings that are matched against all source file paths, matched files will skip transformation diff --git a/js/plugins/firebase/package.json b/js/plugins/firebase/package.json index 0dce7761d..cfce903c1 100644 --- a/js/plugins/firebase/package.json +++ b/js/plugins/firebase/package.json @@ -13,7 +13,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", @@ -37,8 +37,8 @@ }, "peerDependencies": { "@google-cloud/firestore": "^7.6.0", - "firebase-admin": "^12.2.0", - "firebase-functions": "^4.8.0 || ^5.0.0", + "firebase-admin": ">=12.2", + "firebase-functions": ">=4.8", "genkit": "workspace:*" }, "devDependencies": { diff --git a/js/plugins/firebase/src/firestoreRetriever.ts b/js/plugins/firebase/src/firestoreRetriever.ts index cffbc9810..c0a6e0f35 100644 --- a/js/plugins/firebase/src/firestoreRetriever.ts +++ b/js/plugins/firebase/src/firestoreRetriever.ts @@ -20,8 +20,8 @@ import { QueryDocumentSnapshot, VectorQuerySnapshot, } from '@google-cloud/firestore'; -import { EmbedderArgument, RetrieverAction, embed, z } from 'genkit'; -import { DocumentData, Part, defineRetriever } from 'genkit/retriever'; +import { EmbedderArgument, Genkit, RetrieverAction, z } from 'genkit'; +import { DocumentData, Part } from 'genkit/retriever'; function toContent( d: QueryDocumentSnapshot, @@ -69,32 +69,35 @@ function toDocuments( * You must create a vector index on the associated field before you can perform nearest-neighbor * search. **/ -export function defineFirestoreRetriever(config: { - /** The name of the retriever. */ - name: string; - /** Optional label for display in Developer UI. */ - label?: string; - /** The Firestore database instance from which to query. */ - firestore: Firestore; - /** The name of the collection from which to query. */ - collection?: string; - /** The embedder to use with this retriever. */ - embedder: EmbedderArgument; - /** The name of the field within the collection containing the vector data. */ - vectorField: string; - /** The name of the field containing the document content you wish to return. */ - contentField: string | ((snap: QueryDocumentSnapshot) => Part[]); - /** The distance measure to use when comparing vectors. Defaults to 'COSINE'. */ - distanceMeasure?: 'EUCLIDEAN' | 'COSINE' | 'DOT_PRODUCT'; - /** - * A list of fields to include in the returned document metadata. If not supplied, all fields other - * than the vector are included. Alternatively, provide a transform function to extract the desired - * metadata fields from a snapshot. - **/ - metadataFields?: - | string[] - | ((snap: QueryDocumentSnapshot) => Record); -}): RetrieverAction { +export function defineFirestoreRetriever( + ai: Genkit, + config: { + /** The name of the retriever. */ + name: string; + /** Optional label for display in Developer UI. */ + label?: string; + /** The Firestore database instance from which to query. */ + firestore: Firestore; + /** The name of the collection from which to query. */ + collection?: string; + /** The embedder to use with this retriever. */ + embedder: EmbedderArgument; + /** The name of the field within the collection containing the vector data. */ + vectorField: string; + /** The name of the field containing the document content you wish to return. */ + contentField: string | ((snap: QueryDocumentSnapshot) => Part[]); + /** The distance measure to use when comparing vectors. Defaults to 'COSINE'. */ + distanceMeasure?: 'EUCLIDEAN' | 'COSINE' | 'DOT_PRODUCT'; + /** + * A list of fields to include in the returned document metadata. If not supplied, all fields other + * than the vector are included. Alternatively, provide a transform function to extract the desired + * metadata fields from a snapshot. + **/ + metadataFields?: + | string[] + | ((snap: QueryDocumentSnapshot) => Record); + } +): RetrieverAction { const { name, label, @@ -106,7 +109,7 @@ export function defineFirestoreRetriever(config: { contentField, distanceMeasure, } = config; - return defineRetriever( + return ai.defineRetriever( { name, info: { @@ -120,7 +123,7 @@ export function defineFirestoreRetriever(config: { }), }, async (input, options) => { - const embedding = await embed({ embedder, content: input }); + const embedding = await ai.embed({ embedder, content: input }); if (!options.collection && !collection) { throw new Error( 'Must specify a collection to query in Firestore retriever.' diff --git a/js/plugins/firebase/src/functions.ts b/js/plugins/firebase/src/functions.ts index 1e0e64870..89248274f 100644 --- a/js/plugins/firebase/src/functions.ts +++ b/js/plugins/firebase/src/functions.ts @@ -131,7 +131,7 @@ function wrapHttpsFlow< } await config.authPolicy.provider(req, res, () => - flow.expressHandler(genkit.registry, req, res) + flow.expressHandler(req, res) ); } ); diff --git a/js/plugins/google-cloud/jest.config.ts b/js/plugins/google-cloud/jest.config.ts new file mode 100644 index 000000000..adf4d1ffb --- /dev/null +++ b/js/plugins/google-cloud/jest.config.ts @@ -0,0 +1,48 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * For a detailed explanation regarding each configuration property, visit: + * https://jestjs.io/docs/configuration + */ + +import type { Config } from 'jest'; + +const config: Config = { + // Automatically clear mock calls, instances, contexts and results before every test + clearMocks: true, + + // A preset that is used as a base for Jest's configuration + preset: 'ts-jest', + + // The glob patterns Jest uses to detect test files + testMatch: ['**/tests/**/*_test.ts'], + + // An array of regexp pattern strings that are matched against all test paths, matched tests are skipped + testPathIgnorePatterns: ['/node_modules/'], + + // A map from regular expressions to paths to transformers + transform: {}, // disabled for ESM + + // An array of regexp pattern strings that are matched against all source file paths, matched files will skip transformation + transformIgnorePatterns: ['/node_modules/'], + + moduleNameMapper: { + '^(\\.{1,2}/.*)\\.js$': '$1', + }, +}; + +export default config; diff --git a/js/plugins/google-cloud/package.json b/js/plugins/google-cloud/package.json index 711d4178a..419d5d21a 100644 --- a/js/plugins/google-cloud/package.json +++ b/js/plugins/google-cloud/package.json @@ -13,7 +13,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", @@ -21,7 +21,7 @@ "build:clean": "rimraf ./lib", "build": "npm-run-all build:clean check compile", "build:watch": "tsup-node --watch", - "test": "tsx --test ./tests/*_test.ts" + "test": "node --experimental-vm-modules node_modules/jest/bin/jest --runInBand --verbose" }, "repository": { "type": "git", @@ -47,17 +47,19 @@ "@opentelemetry/sdk-trace-base": "^1.25.0", "google-auth-library": "^9.6.3", "node-fetch": "^3.3.2", - "prettier-plugin-organize-imports": "^3.2.4", "winston": "^3.12.0" }, "peerDependencies": { - "genkit": "workspace:*", - "@genkit-ai/core": "workspace:*" + "@genkit-ai/core": "workspace:*", + "genkit": "workspace:*" }, "devDependencies": { + "@jest/globals": "^29.7.0", "@types/node": "^20.11.16", + "jest": "^29.7.0", "npm-run-all": "^4.1.5", "rimraf": "^6.0.1", + "ts-jest": "^29.1.2", "tsup": "^8.0.2", "tsx": "^4.7.0", "typescript": "^4.9.0" diff --git a/js/plugins/google-cloud/src/auth.ts b/js/plugins/google-cloud/src/auth.ts index d100437e2..a75c94fa0 100644 --- a/js/plugins/google-cloud/src/auth.ts +++ b/js/plugins/google-cloud/src/auth.ts @@ -14,12 +14,22 @@ * limitations under the License. */ import { logger } from 'genkit/logging'; -import { GoogleAuth } from 'google-auth-library'; -import { GcpTelemetryConfig } from './types'; +import { auth, GoogleAuth } from 'google-auth-library'; +import { GcpPrincipal, GcpTelemetryConfig } from './types'; /** - * Allow customers to pass in cloud credentials from environment variables - * following: https://github.com/googleapis/google-auth-library-nodejs?tab=readme-ov-file#loading-credentials-from-environment-variables + * Allows Google Cloud credentials to be to passed in "raw" as an environment + * variable. This is helpful in environments where the developer has limited + * ability to configure their compute environment, but does have the ablilty to + * set environment variables. + * + * This is different from the GOOGLE_APPLICATION_CREDENTIALS used by ADC, which + * represents a path to a credential file on disk. In *most* cases, even for + * 3rd party cloud providers, developers *should* attempt to use ADC, which + * searches for credential files in standard locations, before using this + * method. + * + * See also: https://github.com/googleapis/google-auth-library-nodejs?tab=readme-ov-file#loading-credentials-from-environment-variables */ export async function credentialsFromEnvironment(): Promise< Partial @@ -47,3 +57,27 @@ export async function credentialsFromEnvironment(): Promise< } return options; } + +/** + * Resolve the currently configured principal, either from the Genkit specific + * GCLOUD_SERVICE_ACCOUNT_CREDS environment variable, or from ADC. + * + * Since the Google Cloud Telemetry Exporter will discover credentials on its + * own, we don't immediately have access to the current principal. This method + * can be handy to get access to the current credential for logging debugging + * information or other purposes. + **/ +export async function resolveCurrentPrincipal(): Promise { + const envCredentials = await credentialsFromEnvironment(); + const adcCredentials = await auth.getCredentials(); + + // TODO(michaeldoyle): How to look up if the user provided credentials in the + // plugin config (i.e. GcpTelemetryOptions) + let serviceAccountEmail = + envCredentials.credentials?.client_email ?? adcCredentials.client_email; + + return { + projectId: envCredentials.projectId, + serviceAccountEmail, + }; +} diff --git a/js/plugins/google-cloud/src/gcpLogger.ts b/js/plugins/google-cloud/src/gcpLogger.ts index 9f8ffac70..4a2bd7e17 100644 --- a/js/plugins/google-cloud/src/gcpLogger.ts +++ b/js/plugins/google-cloud/src/gcpLogger.ts @@ -15,8 +15,10 @@ */ import { LoggingWinston } from '@google-cloud/logging-winston'; +import { logger } from 'genkit/logging'; import { Writable } from 'stream'; import { GcpTelemetryConfig } from './types'; +import { loggingDenied, loggingDeniedHelpText } from './utils'; /** * Additional streams for writing log data to. Useful for unit testing. @@ -54,6 +56,7 @@ export class GcpLogger { prefix: 'genkit', logName: 'genkit_log', credentials: this.config.credentials, + defaultCallback: await this.getErrorHandler(), }) : new winston.transports.Console() ); @@ -68,6 +71,28 @@ export class GcpLogger { }); } + private async getErrorHandler(): Promise<(err: Error | null) => void> { + // only log the first time + let instructionsLogged = false; + let helpInstructions = await loggingDeniedHelpText(); + + return (err: Error | null) => { + // Use the defaultLogger so that logs don't get swallowed by + // the open telemetry exporter + const defaultLogger = logger.defaultLogger; + if (err && loggingDenied(err)) { + if (!instructionsLogged) { + instructionsLogged = true; + defaultLogger.error( + `Unable to send logs to Google Cloud: ${err.message}\n\n${helpInstructions}\n` + ); + } + } else if (err) { + defaultLogger.error(`Unable to send logs to Google Cloud: ${err}`); + } + }; + } + private shouldExport(env?: string) { return this.config.export; } diff --git a/js/plugins/google-cloud/src/gcpOpenTelemetry.ts b/js/plugins/google-cloud/src/gcpOpenTelemetry.ts index 202a20620..428ba720e 100644 --- a/js/plugins/google-cloud/src/gcpOpenTelemetry.ts +++ b/js/plugins/google-cloud/src/gcpOpenTelemetry.ts @@ -14,7 +14,11 @@ * limitations under the License. */ -import { MetricExporter } from '@google-cloud/opentelemetry-cloud-monitoring-exporter'; +import { logger } from '@genkit-ai/core/logging'; +import { + ExporterOptions, + MetricExporter, +} from '@google-cloud/opentelemetry-cloud-monitoring-exporter'; import { TraceExporter } from '@google-cloud/opentelemetry-cloud-trace-exporter'; import { GcpDetectorSync } from '@google-cloud/opentelemetry-resource-util'; import { Span, SpanStatusCode, TraceFlags } from '@opentelemetry/api'; @@ -36,6 +40,7 @@ import { InstrumentType, PeriodicExportingMetricReader, PushMetricExporter, + ResourceMetrics, } from '@opentelemetry/sdk-metrics'; import { NodeSDKConfiguration } from '@opentelemetry/sdk-node'; import { @@ -52,7 +57,13 @@ import { featuresTelemetry } from './telemetry/feature.js'; import { generateTelemetry } from './telemetry/generate.js'; import { pathsTelemetry } from './telemetry/path.js'; import { GcpTelemetryConfig } from './types'; -import { extractErrorName } from './utils'; +import { + extractErrorName, + metricsDenied, + metricsDeniedHelpText, + tracingDenied, + tracingDeniedHelpText, +} from './utils'; let metricExporter: PushMetricExporter; let spanProcessor: BatchSpanProcessor; @@ -88,26 +99,34 @@ export class GcpOpenTelemetry { record['logging.googleapis.com/spanId'] ??= spanContext.spanId; }; - getConfig(): Partial { - spanProcessor = new BatchSpanProcessor(this.createSpanExporter()); + async getConfig(): Promise> { + spanProcessor = new BatchSpanProcessor(await this.createSpanExporter()); return { resource: this.resource, spanProcessor: spanProcessor, sampler: this.config.sampler, instrumentations: this.getInstrumentations(), - metricReader: this.createMetricReader(), + metricReader: await this.createMetricReader(), }; } - private createSpanExporter(): SpanExporter { + private async createSpanExporter(): Promise { spanExporter = new AdjustingTraceExporter( this.shouldExportTraces() ? new TraceExporter({ + // Creds for non-GCP environments; otherwise credentials will be + // automatically detected via ADC credentials: this.config.credentials, }) : new InMemorySpanExporter(), this.config.exportIO, - this.config.projectId + this.config.projectId, + getErrorHandler( + (err) => { + return tracingDenied(err); + }, + await tracingDeniedHelpText() + ) ); return spanExporter; } @@ -115,8 +134,8 @@ export class GcpOpenTelemetry { /** * Creates a {MetricReader} for pushing metrics out to GCP via OpenTelemetry. */ - private createMetricReader(): PeriodicExportingMetricReader { - metricExporter = this.buildMetricExporter(); + private async createMetricReader(): Promise { + metricExporter = await this.buildMetricExporter(); return new PeriodicExportingMetricReader({ exportIntervalMillis: this.config.metricExportIntervalMillis, exportTimeoutMillis: this.config.metricExportTimeoutMillis, @@ -150,15 +169,25 @@ export class GcpOpenTelemetry { ]; } - private buildMetricExporter(): PushMetricExporter { + private async buildMetricExporter(): Promise { const exporter: PushMetricExporter = this.shouldExportMetrics() - ? new MetricExporter({ - userAgent: { - product: 'genkit', - version: GENKIT_VERSION, + ? new MetricExporterWrapper( + { + userAgent: { + product: 'genkit', + version: GENKIT_VERSION, + }, + // Creds for non-GCP environments; otherwise credentials will be + // automatically detected via ADC + credentials: this.config.credentials, }, - credentials: this.config.credentials, - }) + getErrorHandler( + (err) => { + return metricsDenied(err); + }, + await metricsDeniedHelpText() + ) + ) : new InMemoryMetricExporter(AggregationTemporality.DELTA); exporter.selectAggregation = (instrumentType: InstrumentType) => { if (instrumentType === InstrumentType.HISTOGRAM) { @@ -175,6 +204,31 @@ export class GcpOpenTelemetry { } } +/** + * Rewrites the export method to include an error handler which logs + * helpful information about how to set up metrics/telemetry in GCP. + */ +class MetricExporterWrapper extends MetricExporter { + constructor( + private options?: ExporterOptions, + private errorHandler?: (error: Error) => void + ) { + super(options); + } + + export( + metrics: ResourceMetrics, + resultCallback: (result: ExportResult) => void + ): void { + super.export(metrics, (result) => { + if (this.errorHandler && result.error) { + this.errorHandler(result.error); + } + resultCallback(result); + }); + } +} + /** * Adjusts spans before exporting to GCP. Redacts model input * and output content, and augments span attributes before sending to GCP. @@ -183,14 +237,20 @@ class AdjustingTraceExporter implements SpanExporter { constructor( private exporter: SpanExporter, private logIO: boolean, - private projectId?: string + private projectId?: string, + private errorHandler?: (error: Error) => void ) {} export( spans: ReadableSpan[], resultCallback: (result: ExportResult) => void ): void { - this.exporter?.export(this.adjust(spans), resultCallback); + this.exporter?.export(this.adjust(spans), (result) => { + if (this.errorHandler && result.error) { + this.errorHandler(result.error); + } + resultCallback(result); + }); } shutdown(): Promise { @@ -362,6 +422,30 @@ class AdjustingTraceExporter implements SpanExporter { } } +function getErrorHandler( + shouldLogFn: (err: Error) => boolean, + helpText: string +): (err: Error) => void { + // only log the first time + let instructionsLogged = false; + + return (err) => { + // Use the defaultLogger so that logs don't get swallowed by the open + // telemetry exporter + const defaultLogger = logger.defaultLogger; + if (err && shouldLogFn(err)) { + if (!instructionsLogged) { + instructionsLogged = true; + defaultLogger.error( + `Unable to send telemetry to Google Cloud: ${err.message}\n\n${helpText}\n` + ); + } + } else if (err) { + defaultLogger.error(`Unable to send telemetry to Google Cloud: ${err}`); + } + }; +} + export function __getMetricExporterForTesting(): InMemoryMetricExporter { return metricExporter as InMemoryMetricExporter; } diff --git a/js/plugins/google-cloud/src/index.ts b/js/plugins/google-cloud/src/index.ts index bfeae189a..8112d6ff4 100644 --- a/js/plugins/google-cloud/src/index.ts +++ b/js/plugins/google-cloud/src/index.ts @@ -28,7 +28,7 @@ export async function enableGoogleCloudTelemetry( ) { const pluginConfig = await configureGcpPlugin(options); - enableTelemetry(new GcpOpenTelemetry(pluginConfig).getConfig()); + enableTelemetry(await new GcpOpenTelemetry(pluginConfig).getConfig()); logger.init(await new GcpLogger(pluginConfig).getLogger(getCurrentEnv())); } diff --git a/js/plugins/google-cloud/src/telemetry/action.ts b/js/plugins/google-cloud/src/telemetry/action.ts index 9eca330bd..6b2c0b131 100644 --- a/js/plugins/google-cloud/src/telemetry/action.ts +++ b/js/plugins/google-cloud/src/telemetry/action.ts @@ -55,8 +55,7 @@ class ActionTelemetry implements Telemetry { const actionName = (attributes['genkit:name'] as string) || ''; const path = (attributes['genkit:path'] as string) || ''; - let featureName = (attributes['genkit:metadata:flow:name'] || - extractOuterFeatureNameFromPath(path)) as string; + let featureName = extractOuterFeatureNameFromPath(path); if (!featureName || featureName === '') { featureName = actionName; } @@ -68,13 +67,11 @@ class ActionTelemetry implements Telemetry { if (state === 'success') { this.writeSuccess(actionName, featureName, path, latencyMs); - return; - } - if (state === 'error') { + } else if (state === 'error') { this.writeFailure(actionName, featureName, path, latencyMs, errorName); + } else { + logger.warn(`Unknown action state; ${state}`); } - - logger.warn(`Unknown action state; ${state}`); } private writeSuccess( diff --git a/js/plugins/google-cloud/src/types.ts b/js/plugins/google-cloud/src/types.ts index 94684f084..961c8637c 100644 --- a/js/plugins/google-cloud/src/types.ts +++ b/js/plugins/google-cloud/src/types.ts @@ -70,3 +70,8 @@ export interface GcpTelemetryConfig { exportIO: boolean; export: boolean; } + +export interface GcpPrincipal { + projectId?: string; + serviceAccountEmail?: string; +} diff --git a/js/plugins/google-cloud/src/utils.ts b/js/plugins/google-cloud/src/utils.ts index 075370abd..2d9c85f77 100644 --- a/js/plugins/google-cloud/src/utils.ts +++ b/js/plugins/google-cloud/src/utils.ts @@ -16,6 +16,7 @@ import { TraceFlags } from '@opentelemetry/api'; import { ReadableSpan, TimedEvent } from '@opentelemetry/sdk-trace-base'; +import { resolveCurrentPrincipal } from './auth'; export function extractOuterFlowNameFromPath(path: string) { if (!path || path === '') { @@ -88,3 +89,63 @@ export function createCommonLogAttributes( 'logging.googleapis.com/trace_sampled': isSampled ? '1' : '0', }; } + +export function requestDenied( + err: Error & { + code?: number; + statusDetails?: Record[]; + } +) { + return err.code === 7; +} + +export function loggingDenied( + err: Error & { + code?: number; + statusDetails?: Record[]; + } +) { + return ( + requestDenied(err) && + err.statusDetails?.some((details) => { + return details?.metadata?.permission === 'logging.logEntries.create'; + }) + ); +} + +export function tracingDenied( + err: Error & { + code?: number; + statusDetails?: Record[]; + } +) { + // Looks like we don't get status details like we do with logging + return requestDenied(err); +} + +export function metricsDenied( + err: Error & { + code?: number; + statusDetails?: Record[]; + } +) { + // Looks like we don't get status details like we do with logging + return requestDenied(err); +} + +export async function permissionDeniedHelpText(role: string) { + const principal = await resolveCurrentPrincipal(); + return `Add the role '${role}' to your Service Account in the IAM & Admin page on the Google Cloud console, or use the following command:\n\ngcloud projects add-iam-policy-binding ${principal.projectId ?? '${PROJECT_ID}'} \\\n --member=serviceAccount:${principal.serviceAccountEmail || '${SERVICE_ACCT}'} \\\n --role=${role}`; +} + +export async function loggingDeniedHelpText() { + return permissionDeniedHelpText('roles/logging.logWriter'); +} + +export async function tracingDeniedHelpText() { + return permissionDeniedHelpText('roles/cloudtrace.agent'); +} + +export async function metricsDeniedHelpText() { + return permissionDeniedHelpText('roles/monitoring.metricWriter'); +} diff --git a/js/plugins/google-cloud/tests/logs_no_io_test.ts b/js/plugins/google-cloud/tests/logs_no_io_test.ts index c50648c3f..a7111a610 100644 --- a/js/plugins/google-cloud/tests/logs_no_io_test.ts +++ b/js/plugins/google-cloud/tests/logs_no_io_test.ts @@ -14,12 +14,17 @@ * limitations under the License. */ +import { + afterAll, + beforeAll, + beforeEach, + describe, + it, + jest, +} from '@jest/globals'; import { ReadableSpan } from '@opentelemetry/sdk-trace-base'; -import { generate, Genkit, genkit, run, z } from 'genkit'; -import { defineModel, GenerateResponseData } from 'genkit/model'; -import { runWithRegistry } from 'genkit/registry'; +import { GenerateResponseData, Genkit, genkit, run, z } from 'genkit'; import assert from 'node:assert'; -import { after, before, beforeEach, describe, it } from 'node:test'; import { Writable } from 'stream'; import { __addTransportStreamForTesting, @@ -28,6 +33,28 @@ import { enableGoogleCloudTelemetry, } from '../src/index.js'; +jest.mock('../src/auth.js', () => { + const original = jest.requireActual('../src/auth.js'); + return { + ...(original || {}), + resolveCurrentPrincipal: jest.fn().mockImplementation(() => { + return Promise.resolve({ + projectId: 'test', + serviceAccountEmail: 'test@test.com', + }); + }), + credentialsFromEnvironment: jest.fn().mockImplementation(() => { + return Promise.resolve({ + projectId: 'test', + credentials: { + client_email: 'test@genkit.com', + private_key: '-----BEGIN PRIVATE KEY-----', + }, + }); + }), + }; +}); + describe('GoogleCloudLogs no I/O', () => { let logLines = ''; const logStream = new Writable(); @@ -38,7 +65,8 @@ describe('GoogleCloudLogs no I/O', () => { let ai: Genkit; - before(async () => { + beforeAll(async () => { + process.env.GCLOUD_PROJECT = 'test'; process.env.GENKIT_ENV = 'dev'; __addTransportStreamForTesting(logStream); await enableGoogleCloudTelemetry({ @@ -56,7 +84,7 @@ describe('GoogleCloudLogs no I/O', () => { logLines = ''; __getSpanExporterForTesting().reset(); }); - after(async () => { + afterAll(async () => { await ai.stopServers(); }); @@ -91,7 +119,7 @@ describe('GoogleCloudLogs no I/O', () => { ), true ); - }); + }, 10000); //timeout it('writes generate logs', async () => { const testModel = createModel(ai, 'testModel', async () => { @@ -118,7 +146,7 @@ describe('GoogleCloudLogs no I/O', () => { const testFlow = createFlowWithInput(ai, 'testFlow', async (input) => { return await run('sub1', async () => { return await run('sub2', async () => { - return await generate({ + return await ai.generate({ model: testModel, prompt: `${input} prompt`, config: { @@ -206,9 +234,7 @@ function createModel( name: string, respFn: () => Promise ) { - return runWithRegistry(genkit.registry, () => - defineModel({ name }, (req) => respFn()) - ); + return genkit.defineModel({ name }, (req) => respFn()); } async function waitForLogsInit(genkit: Genkit, logLines: any) { diff --git a/js/plugins/google-cloud/tests/logs_test.ts b/js/plugins/google-cloud/tests/logs_test.ts index e34d79452..ab07c8f6d 100644 --- a/js/plugins/google-cloud/tests/logs_test.ts +++ b/js/plugins/google-cloud/tests/logs_test.ts @@ -14,13 +14,18 @@ * limitations under the License. */ +import { + afterAll, + beforeAll, + beforeEach, + describe, + it, + jest, +} from '@jest/globals'; import { ReadableSpan } from '@opentelemetry/sdk-trace-base'; -import { generate, Genkit, genkit, run, z } from 'genkit'; -import { defineModel, GenerateResponseData } from 'genkit/model'; -import { runWithRegistry } from 'genkit/registry'; -import { appendSpan, SPAN_TYPE_ATTR } from 'genkit/tracing'; +import { GenerateResponseData, Genkit, genkit, run, z } from 'genkit'; +import { SPAN_TYPE_ATTR, appendSpan } from 'genkit/tracing'; import assert from 'node:assert'; -import { after, before, beforeEach, describe, it } from 'node:test'; import { Writable } from 'stream'; import { __addTransportStreamForTesting, @@ -29,6 +34,28 @@ import { enableGoogleCloudTelemetry, } from '../src/index.js'; +jest.mock('../src/auth.js', () => { + const original = jest.requireActual('../src/auth.js'); + return { + ...(original || {}), + resolveCurrentPrincipal: jest.fn().mockImplementation(() => { + return Promise.resolve({ + projectId: 'test', + serviceAccountEmail: 'test@test.com', + }); + }), + credentialsFromEnvironment: jest.fn().mockImplementation(() => { + return Promise.resolve({ + projectId: 'test', + credentials: { + client_email: 'test@genkit.com', + private_key: '-----BEGIN PRIVATE KEY-----', + }, + }); + }), + }; +}); + describe('GoogleCloudLogs', () => { let logLines = ''; const logStream = new Writable(); @@ -39,7 +66,8 @@ describe('GoogleCloudLogs', () => { let ai: Genkit; - before(async () => { + beforeAll(async () => { + process.env.GCLOUD_PROJECT = 'test'; process.env.GENKIT_ENV = 'dev'; __addTransportStreamForTesting(logStream); await enableGoogleCloudTelemetry({ @@ -58,7 +86,7 @@ describe('GoogleCloudLogs', () => { logLines = ''; __getSpanExporterForTesting().reset(); }); - after(async () => { + afterAll(async () => { await ai.stopServers(); }); @@ -94,7 +122,7 @@ describe('GoogleCloudLogs', () => { ), true ); - }); + }, 10000); //timeout it('writes generate logs', async () => { const testModel = createModel(ai, 'testModel', async () => { @@ -121,7 +149,7 @@ describe('GoogleCloudLogs', () => { const testFlow = createFlowWithInput(ai, 'testFlow', async (input) => { return await run('sub1', async () => { return await run('sub2', async () => { - return await generate({ + return await ai.generate({ model: testModel, prompt: `${input} prompt`, config: { @@ -247,9 +275,7 @@ function createModel( name: string, respFn: () => Promise ) { - return runWithRegistry(genkit.registry, () => - defineModel({ name }, (req) => respFn()) - ); + return genkit.defineModel({ name }, (req) => respFn()); } async function waitForLogsInit(genkit: Genkit, logLines: any) { diff --git a/js/plugins/google-cloud/tests/metrics_test.ts b/js/plugins/google-cloud/tests/metrics_test.ts index d00c061d6..047cbe75b 100644 --- a/js/plugins/google-cloud/tests/metrics_test.ts +++ b/js/plugins/google-cloud/tests/metrics_test.ts @@ -15,12 +15,13 @@ */ import { - GcpOpenTelemetry, - __forceFlushSpansForTesting, - __getMetricExporterForTesting, - __getSpanExporterForTesting, - enableGoogleCloudTelemetry, -} from '@genkit-ai/google-cloud'; + afterAll, + beforeAll, + beforeEach, + describe, + it, + jest, +} from '@jest/globals'; import { DataPoint, Histogram, @@ -29,25 +30,44 @@ import { SumMetricData, } from '@opentelemetry/sdk-metrics'; import { ReadableSpan } from '@opentelemetry/sdk-trace-base'; -import { - GenerateResponseData, - Genkit, - defineAction, - generate, - genkit, - run, - z, -} from 'genkit'; -import { defineModel } from 'genkit/model'; -import { runWithRegistry } from 'genkit/registry'; +import { GenerateResponseData, Genkit, genkit, run, z } from 'genkit'; import { SPAN_TYPE_ATTR, appendSpan } from 'genkit/tracing'; import assert from 'node:assert'; -import { after, before, beforeEach, describe, it } from 'node:test'; +import { + GcpOpenTelemetry, + __forceFlushSpansForTesting, + __getMetricExporterForTesting, + __getSpanExporterForTesting, + enableGoogleCloudTelemetry, +} from '../src/index.js'; + +jest.mock('../src/auth.js', () => { + const original = jest.requireActual('../src/auth.js'); + return { + ...(original || {}), + resolveCurrentPrincipal: jest.fn().mockImplementation(() => { + return Promise.resolve({ + projectId: 'test', + serviceAccountEmail: 'test@test.com', + }); + }), + credentialsFromEnvironment: jest.fn().mockImplementation(() => { + return Promise.resolve({ + projectId: 'test', + credentials: { + client_email: 'test@genkit.com', + private_key: '-----BEGIN PRIVATE KEY-----', + }, + }); + }), + }; +}); describe('GoogleCloudMetrics', () => { let ai: Genkit; - before(async () => { + beforeAll(async () => { + process.env.GCLOUD_PROJECT = 'test'; process.env.GENKIT_ENV = 'dev'; await enableGoogleCloudTelemetry({ projectId: 'test', @@ -61,7 +81,7 @@ describe('GoogleCloudMetrics', () => { __getMetricExporterForTesting().reset(); __getSpanExporterForTesting().reset(); }); - after(async () => { + afterAll(async () => { await ai.stopServers(); }); @@ -89,7 +109,7 @@ describe('GoogleCloudMetrics', () => { assert.equal(actionLatencyHistogram.attributes.source, 'ts'); assert.equal(actionLatencyHistogram.attributes.status, 'success'); assert.ok(actionLatencyHistogram.attributes.sourceVersion); - }); + }, 10000); //timeout it('writes action metrics for a failing flow', async () => { const testFlow = createFlow(ai, 'testFlow', async () => { @@ -110,7 +130,7 @@ describe('GoogleCloudMetrics', () => { assert.equal(requestCounter.attributes.source, 'ts'); assert.equal(requestCounter.attributes.error, 'TypeError'); assert.equal(requestCounter.attributes.status, 'failure'); - }); + }, 10000); //timeout it('writes feature metrics for a successful flow', async () => { const testFlow = createFlow(ai, 'testFlow'); @@ -153,9 +173,10 @@ describe('GoogleCloudMetrics', () => { assert.equal(requestCounter.attributes.source, 'ts'); assert.equal(requestCounter.attributes.error, 'TypeError'); assert.equal(requestCounter.attributes.status, 'failure'); - }); + }, 10000); //timeout - it('writes action metrics for an action inside a flow', async () => { + // SKIPPED -- we don't allow defining arbitrary actions anymore.... + it.skip('writes action metrics', async () => { const testAction = createAction(ai, 'testAction'); const testFlow = createFlow(ai, 'testFlowWithActions', async () => { await Promise.all([ @@ -184,20 +205,20 @@ describe('GoogleCloudMetrics', () => { assert.equal(requestCounter.attributes.source, 'ts'); assert.equal(requestCounter.attributes.status, 'success'); assert.ok(requestCounter.attributes.sourceVersion); + assert.equal(requestCounter.attributes.featureName, 'testFlowWithActions'); assert.equal(latencyHistogram.value.count, 6); assert.equal(latencyHistogram.attributes.name, 'testAction'); assert.equal(latencyHistogram.attributes.source, 'ts'); assert.equal(latencyHistogram.attributes.status, 'success'); assert.ok(latencyHistogram.attributes.sourceVersion); + assert.equal(requestCounter.attributes.featureName, 'testFlowWithActions'); }); it('writes feature metrics for an action', async () => { const testAction = createAction(ai, 'featureAction'); - await runWithRegistry(ai.registry, async () => { - await testAction(null); - await testAction(null); - }); + await testAction(null); + await testAction(null); await getExportedSpans(); @@ -219,11 +240,9 @@ describe('GoogleCloudMetrics', () => { // after PR #1029 it('writes feature metrics for generate', async () => { - await runWithRegistry(ai.registry, async () => { - const testModel = createTestModel(ai, 'helloModel'); - await generate({ model: testModel, prompt: 'Hi' }); - await generate({ model: testModel, prompt: 'Yo' }); - }); + const testModel = createTestModel(ai, 'helloModel'); + await ai.generate({ model: testModel, prompt: 'Hi' }); + await ai.generate({ model: testModel, prompt: 'Yo' }); const spans = await getExportedSpans(); @@ -260,7 +279,8 @@ describe('GoogleCloudMetrics', () => { ); }); - it('writes action metrics for a failed action', async () => { + // SKIPPED -- we don't allow defining arbitrary actions anymore.... + it.skip('writes action failure metrics', async () => { const testAction = createAction(ai, 'testActionWithFailure', async () => { const nothing: { missing?: any } = { missing: 1 }; delete nothing.missing; @@ -268,9 +288,7 @@ describe('GoogleCloudMetrics', () => { }); assert.rejects(async () => { - return await runWithRegistry(ai.registry, async () => { - return testAction(null); - }); + return testAction(null); }); await getExportedSpans(); @@ -280,7 +298,7 @@ describe('GoogleCloudMetrics', () => { assert.equal(requestCounter.attributes.source, 'ts'); assert.equal(requestCounter.attributes.status, 'failure'); assert.equal(requestCounter.attributes.error, 'TypeError'); - }); + }, 10000); //timeout it('writes generate metrics for a successful model action', async () => { const testModel = createTestModel(ai, 'testModel'); @@ -386,9 +404,10 @@ describe('GoogleCloudMetrics', () => { assert.equal(requestCounter.attributes.status, 'failure'); assert.equal(requestCounter.attributes.error, 'TypeError'); assert.ok(requestCounter.attributes.sourceVersion); - }); + }, 10000); //timeout - it('writes feature label to action metrics when running inside a flow', async () => { + // SKIPPED -- we don't allow defining arbitrary actions anymore.... + it.skip('writes flow label to action metrics when running inside flow', async () => { const testAction = createAction(ai, 'testAction'); const flow = createFlow(ai, 'flowNameLabelTestFlow', async () => { return await testAction(undefined); @@ -420,9 +439,7 @@ describe('GoogleCloudMetrics', () => { }); }); - await runWithRegistry(ai.registry, async () => { - testAction(null); - }); + testAction(null); await getExportedSpans(); @@ -437,7 +454,7 @@ describe('GoogleCloudMetrics', () => { generateRequestCounter.attributes.featureName, 'testGenerateAction' ); - }); + }, 10000); //timeout it('writes feature label to generate metrics when running inside a flow', async () => { const testModel = createModel(ai, 'testModel', async () => { @@ -462,7 +479,7 @@ describe('GoogleCloudMetrics', () => { }; }); const flow = createFlow(ai, 'testFlow', async () => { - return await generate({ + return await ai.generate({ model: testModel, prompt: 'test prompt', }); @@ -568,7 +585,7 @@ describe('GoogleCloudMetrics', () => { ['/{testFlow,t:flow}/{sub-action,t:flowStep}', 'success'], ['/{testFlow,t:flow}', 'failure'], ]); - }); + }, 10000); //timeout it('writes path metrics for a failing flow with exception in subaction', async () => { const flow = createFlow(ai, 'testFlow', async () => { @@ -613,7 +630,7 @@ describe('GoogleCloudMetrics', () => { 'failure', ], ]); - }); + }, 10000); //timeout it('writes path metrics for a flow with exception in action', async () => { const flow = createFlow(ai, 'testFlow', async () => { @@ -660,7 +677,7 @@ describe('GoogleCloudMetrics', () => { ], ['/{testFlow,t:flow}/{sub-action-1,t:flowStep}', 'failure'], ]); - }); + }, 10000); //timeout it('writes path metrics for a flow with an exception in a serial action', async () => { const flow = createFlow(ai, 'testFlow', async () => { @@ -701,7 +718,7 @@ describe('GoogleCloudMetrics', () => { ['/{testFlow,t:flow}/{sub-action-1,t:flowStep}', 'success'], ['/{testFlow,t:flow}/{sub-action-2,t:flowStep}', 'failure'], ]); - }); + }, 10000); //timeout it('writes user feedback metrics', async () => { appendSpan( @@ -910,14 +927,11 @@ describe('GoogleCloudMetrics', () => { name: string, fn: () => Promise = async () => {} ) { - return runWithRegistry(ai.registry, () => - defineAction( - { - name, - actionType: 'custom', - }, - fn - ) + return ai.defineFlow( + { + name, + }, + fn ); } @@ -928,9 +942,7 @@ describe('GoogleCloudMetrics', () => { name: string, respFn: () => Promise ) { - return runWithRegistry(ai.registry, () => - defineModel({ name }, (req) => respFn()) - ); + return ai.defineModel({ name }, (req) => respFn()); } function createTestModel(ai: Genkit, name: string) { diff --git a/js/plugins/google-cloud/tests/traces_test.ts b/js/plugins/google-cloud/tests/traces_test.ts index a15bbfa33..470cb1251 100644 --- a/js/plugins/google-cloud/tests/traces_test.ts +++ b/js/plugins/google-cloud/tests/traces_test.ts @@ -14,23 +14,51 @@ * limitations under the License. */ +import { + afterAll, + beforeAll, + beforeEach, + describe, + it, + jest, +} from '@jest/globals'; import { ReadableSpan } from '@opentelemetry/sdk-trace-base'; -import { generate, Genkit, genkit, run, z } from 'genkit'; -import { defineModel } from 'genkit/model'; -import { runWithRegistry } from 'genkit/registry'; +import { Genkit, genkit, run, z } from 'genkit'; import { appendSpan } from 'genkit/tracing'; import assert from 'node:assert'; -import { after, before, beforeEach, describe, it } from 'node:test'; import { __forceFlushSpansForTesting, __getSpanExporterForTesting, } from '../src/gcpOpenTelemetry.js'; import { enableGoogleCloudTelemetry } from '../src/index.js'; +jest.mock('../src/auth.js', () => { + const original = jest.requireActual('../src/auth.js'); + return { + ...(original || {}), + resolveCurrentPrincipal: jest.fn().mockImplementation(() => { + return Promise.resolve({ + projectId: 'test', + serviceAccountEmail: 'test@test.com', + }); + }), + credentialsFromEnvironment: jest.fn().mockImplementation(() => { + return Promise.resolve({ + projectId: 'test', + credentials: { + client_email: 'test@genkit.com', + private_key: '-----BEGIN PRIVATE KEY-----', + }, + }); + }), + }; +}); + describe('GoogleCloudTracing', () => { let ai: Genkit; - before(async () => { + beforeAll(async () => { + process.env.GCLOUD_PROJECT = 'test'; process.env.GENKIT_ENV = 'dev'; await enableGoogleCloudTelemetry({ projectId: 'test', @@ -41,7 +69,7 @@ describe('GoogleCloudTracing', () => { beforeEach(async () => { __getSpanExporterForTesting().reset(); }); - after(async () => { + afterAll(async () => { await ai.stopServers(); }); @@ -136,33 +164,31 @@ describe('GoogleCloudTracing', () => { }); it('adds the genkit/model label for model actions', async () => { - const echoModel = runWithRegistry(ai.registry, () => - defineModel( - { - name: 'echoModel', - }, - async (request) => { - return { - message: { - role: 'model', - content: [ - { - text: - 'Echo: ' + - request.messages - .map((m) => m.content.map((c) => c.text).join()) - .join(), - }, - ], - }, - finishReason: 'stop', - }; - } - ) + const echoModel = ai.defineModel( + { + name: 'echoModel', + }, + async (request) => { + return { + message: { + role: 'model', + content: [ + { + text: + 'Echo: ' + + request.messages + .map((m) => m.content.map((c) => c.text).join()) + .join(), + }, + ], + }, + finishReason: 'stop', + }; + } ); const testFlow = createFlow(ai, 'modelFlow', async () => { return run('runFlow', async () => { - generate({ + ai.generate({ model: echoModel, prompt: 'Testing model telemetry', }); diff --git a/js/plugins/googleai/package.json b/js/plugins/googleai/package.json index 46e56bc46..3ee9417f1 100644 --- a/js/plugins/googleai/package.json +++ b/js/plugins/googleai/package.json @@ -13,7 +13,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/plugins/googleai/src/embedder.ts b/js/plugins/googleai/src/embedder.ts index 331a68ff4..9e0818488 100644 --- a/js/plugins/googleai/src/embedder.ts +++ b/js/plugins/googleai/src/embedder.ts @@ -15,8 +15,8 @@ */ import { EmbedContentRequest, GoogleGenerativeAI } from '@google/generative-ai'; -import { z } from 'genkit'; -import { defineEmbedder, embedderRef } from 'genkit/embedder'; +import { EmbedderReference, Genkit, z } from 'genkit'; +import { embedderRef } from 'genkit/embedder'; import { PluginOptions } from './index.js'; export const TaskTypeSchema = z.enum([ @@ -28,22 +28,21 @@ export const TaskTypeSchema = z.enum([ ]); export type TaskType = z.infer; -export const TextEmbeddingGeckoConfigSchema = z.object({ +export const GeminiEmbeddingConfigSchema = z.object({ /** * The `task_type` parameter is defined as the intended downstream application to help the model * produce better quality embeddings. **/ taskType: TaskTypeSchema.optional(), title: z.string().optional(), + version: z.string().optional(), }); -export type TextEmbeddingGeckoConfig = z.infer< - typeof TextEmbeddingGeckoConfigSchema ->; +export type GeminiEmbeddingConfig = z.infer; export const textEmbeddingGecko001 = embedderRef({ name: 'googleai/embedding-001', - configSchema: TextEmbeddingGeckoConfigSchema, + configSchema: GeminiEmbeddingConfigSchema, info: { dimensions: 768, label: 'Google Gen AI - Text Embedding Gecko (Legacy)', @@ -57,7 +56,8 @@ export const SUPPORTED_MODELS = { 'embedding-001': textEmbeddingGecko001, }; -export function textEmbeddingGeckoEmbedder( +export function defineGoogleAIEmbedder( + ai: Genkit, name: string, options: PluginOptions ) { @@ -70,17 +70,36 @@ export function textEmbeddingGeckoEmbedder( 'Please pass in the API key or set either GOOGLE_GENAI_API_KEY or GOOGLE_API_KEY environment variable.\n' + 'For more details see https://firebase.google.com/docs/genkit/plugins/google-genai' ); - const client = new GoogleGenerativeAI(apiKey).getGenerativeModel({ - model: name, - }); - const embedder = SUPPORTED_MODELS[name]; - return defineEmbedder( + const embedder: EmbedderReference = + SUPPORTED_MODELS[name] ?? + embedderRef({ + name: name, + configSchema: GeminiEmbeddingConfigSchema, + info: { + dimensions: 768, + label: `Google AI - ${name}`, + supports: { + input: ['text'], + }, + }, + }); + const apiModelName = embedder.name.startsWith('googleai/') + ? embedder.name.substring('googleai/'.length) + : embedder.name; + return ai.defineEmbedder( { name: embedder.name, - configSchema: TextEmbeddingGeckoConfigSchema, + configSchema: GeminiEmbeddingConfigSchema, info: embedder.info!, }, async (input, options) => { + const client = new GoogleGenerativeAI(apiKey!).getGenerativeModel({ + model: + options?.version || + embedder.config?.version || + embedder.version || + apiModelName, + }); const embeddings = await Promise.all( input.map(async (doc) => { const response = await client.embedContent({ @@ -88,7 +107,7 @@ export function textEmbeddingGeckoEmbedder( title: options?.title, content: { role: '', - parts: [{ text: doc.text() }], + parts: [{ text: doc.text }], }, } as EmbedContentRequest); const values = response.embedding.values; diff --git a/js/plugins/googleai/src/gemini.ts b/js/plugins/googleai/src/gemini.ts index d61167ba4..28ac77db2 100644 --- a/js/plugins/googleai/src/gemini.ts +++ b/js/plugins/googleai/src/gemini.ts @@ -31,20 +31,20 @@ import { StartChatParams, Tool, } from '@google/generative-ai'; -import { GENKIT_CLIENT_HEADER, z } from 'genkit'; +import { GENKIT_CLIENT_HEADER, Genkit, z } from 'genkit'; import { CandidateData, GenerationCommonConfigSchema, MediaPart, MessageData, ModelAction, + ModelInfo, ModelMiddleware, ModelReference, Part, ToolDefinitionSchema, ToolRequestPart, ToolResponsePart, - defineModel, getBasicUsageStats, modelRef, } from 'genkit/model'; @@ -70,16 +70,16 @@ const SafetySettingsSchema = z.object({ ]), }); -const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ +export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ safetySettings: z.array(SafetySettingsSchema).optional(), codeExecution: z.union([z.boolean(), z.object({}).strict()]).optional(), }); -export const geminiPro = modelRef({ - name: 'googleai/gemini-pro', +export const gemini10Pro = modelRef({ + name: 'googleai/gemini-1.0-pro', info: { label: 'Google AI - Gemini Pro', - versions: ['gemini-1.0-pro', 'gemini-1.0-pro-latest', 'gemini-1.0-pro-001'], + versions: ['gemini-pro', 'gemini-1.0-pro-latest', 'gemini-1.0-pro-001'], supports: { multiturn: true, media: false, @@ -90,28 +90,8 @@ export const geminiPro = modelRef({ configSchema: GeminiConfigSchema, }); -/** - * @deprecated Use `gemini15Pro` or `gemini15Flash` instead. - */ -export const geminiProVision = modelRef({ - name: 'googleai/gemini-pro-vision', - info: { - label: 'Google AI - Gemini Pro Vision', - // none declared on https://ai.google.dev/models/gemini#model-variations - versions: [], - supports: { - multiturn: true, - media: true, - tools: false, - systemRole: false, - }, - stage: 'deprecated', - }, - configSchema: GeminiConfigSchema, -}); - export const gemini15Pro = modelRef({ - name: 'googleai/gemini-1.5-pro-latest', + name: 'googleai/gemini-1.5-pro', info: { label: 'Google AI - Gemini 1.5 Pro', supports: { @@ -121,13 +101,17 @@ export const gemini15Pro = modelRef({ systemRole: true, output: ['text', 'json'], }, - versions: ['gemini-1.5-pro-001'], + versions: [ + 'gemini-1.5-pro-latest', + 'gemini-1.5-pro-001', + 'gemini-1.5-pro-002', + ], }, configSchema: GeminiConfigSchema, }); export const gemini15Flash = modelRef({ - name: 'googleai/gemini-1.5-flash-latest', + name: 'googleai/gemini-1.5-flash', info: { label: 'Google AI - Gemini 1.5 Flash', supports: { @@ -137,44 +121,45 @@ export const gemini15Flash = modelRef({ systemRole: true, output: ['text', 'json'], }, - versions: ['gemini-1.5-flash-001'], + versions: [ + 'gemini-1.5-flash-latest', + 'gemini-1.5-flash-001', + 'gemini-1.5-flash-002', + ], }, configSchema: GeminiConfigSchema, }); -export const geminiUltra = modelRef({ - name: 'googleai/gemini-ultra', +export const gemini15Flash8b = modelRef({ + name: 'googleai/gemini-1.5-flash-8b', info: { - label: 'Google AI - Gemini Ultra', - versions: [], + label: 'Google AI - Gemini 1.5 Flash', supports: { multiturn: true, - media: false, + media: true, tools: true, systemRole: true, + output: ['text', 'json'], }, + versions: ['gemini-1.5-flash-8b-latest', 'gemini-1.5-flash-8b-001'], }, configSchema: GeminiConfigSchema, }); -export const SUPPORTED_V1_MODELS: Record< - string, - ModelReference -> = { - 'gemini-pro': geminiPro, - 'gemini-pro-vision': geminiProVision, - // 'gemini-ultra': geminiUltra, +export const SUPPORTED_V1_MODELS = { + 'gemini-1.0-pro': gemini10Pro, }; -export const SUPPORTED_V15_MODELS: Record< - string, - ModelReference -> = { - 'gemini-1.5-pro-latest': gemini15Pro, - 'gemini-1.5-flash-latest': gemini15Flash, +export const SUPPORTED_V15_MODELS = { + 'gemini-1.5-pro': gemini15Pro, + 'gemini-1.5-flash': gemini15Flash, + 'gemini-1.5-flash-8b': gemini15Flash8b, }; -const SUPPORTED_MODELS = { +export const SUPPORTED_GEMINI_MODELS: Record< + string, + ModelReference +> = { ...SUPPORTED_V1_MODELS, ...SUPPORTED_V15_MODELS, }; @@ -382,7 +367,7 @@ function toGeminiPart(part: Part): GeminiPart { if (part.toolRequest) return toFunctionCall(part); if (part.toolResponse) return toFunctionResponse(part); if (part.custom) return toCustomPart(part); - throw new Error('Unsupported Part type'); + throw new Error('Unsupported Part type' + JSON.stringify(part)); } function fromGeminiPart(part: GeminiPart, jsonMode: boolean): Part { @@ -454,16 +439,17 @@ export function fromGeminiCandidate( } /** - * + * Defines a new GoogleAI model. */ -export function googleAIModel( +export function defineGoogleAIModel( + ai: Genkit, name: string, apiKey?: string, apiVersion?: string, - baseUrl?: string + baseUrl?: string, + info?: ModelInfo, + defaultConfig?: z.infer ): ModelAction { - const modelName = `googleai/${name}`; - if (!apiKey) { apiKey = process.env.GOOGLE_GENAI_API_KEY || process.env.GOOGLE_API_KEY; } @@ -473,15 +459,33 @@ export function googleAIModel( 'For more details see https://firebase.google.com/docs/genkit/plugins/google-genai' ); } - - const model: ModelReference = SUPPORTED_MODELS[name]; - if (!model) throw new Error(`Unsupported model: ${name}`); + const apiModelName = name.startsWith('googleai/') + ? name.substring('googleai/'.length) + : name; + + const model: ModelReference = + SUPPORTED_GEMINI_MODELS[name] ?? + modelRef({ + name, + info: { + label: `Google AI - ${apiModelName}`, + supports: { + multiturn: true, + media: true, + tools: true, + systemRole: true, + output: ['text', 'json'], + }, + ...info, + }, + configSchema: GeminiConfigSchema, + }); const middleware: ModelMiddleware[] = []; if (SUPPORTED_V1_MODELS[name]) { middleware.push(simulateSystemPrompt()); } - if (model?.info?.supports?.media) { + if (model.info?.supports?.media) { // the gemini api doesn't support downloading media from http(s) middleware.push( downloadRequestMedia({ @@ -495,9 +499,9 @@ export function googleAIModel( ); } - return defineModel( + return ai.defineModel( { - name: modelName, + name: model.name, ...model.info, configSchema: model.configSchema, use: middleware, @@ -510,9 +514,18 @@ export function googleAIModel( if (apiVersion) { options.baseUrl = baseUrl; } + const requestConfig = { + ...defaultConfig, + ...request.config, + }; + const client = new GoogleGenerativeAI(apiKey!).getGenerativeModel( { - model: request.config?.version || model.version || name, + model: + requestConfig.version || + model.config?.version || + model.version || + apiModelName, }, options ); @@ -542,7 +555,7 @@ export function googleAIModel( }); } - if (request.config?.codeExecution) { + if (requestConfig.codeExecution) { tools.push({ codeExecution: request.config.codeExecution === true @@ -558,11 +571,11 @@ export function googleAIModel( const generationConfig: GenerationConfig = { candidateCount: request.candidates || undefined, - temperature: request.config?.temperature, - maxOutputTokens: request.config?.maxOutputTokens, - topK: request.config?.topK, - topP: request.config?.topP, - stopSequences: request.config?.stopSequences, + temperature: requestConfig.temperature, + maxOutputTokens: requestConfig.maxOutputTokens, + topK: requestConfig.topK, + topP: requestConfig.topP, + stopSequences: requestConfig.stopSequences, responseMimeType: jsonMode ? 'application/json' : undefined, }; @@ -573,7 +586,7 @@ export function googleAIModel( history: messages .slice(0, -1) .map((message) => toGeminiMessage(message, model)), - safetySettings: request.config?.safetySettings, + safetySettings: requestConfig.safetySettings, } as StartChatParams; const msg = toGeminiMessage(messages[messages.length - 1], model); const fromJSONModeScopedGeminiCandidate = ( diff --git a/js/plugins/googleai/src/index.ts b/js/plugins/googleai/src/index.ts index e4ef8f1e6..abb731426 100644 --- a/js/plugins/googleai/src/index.ts +++ b/js/plugins/googleai/src/index.ts @@ -14,28 +14,22 @@ * limitations under the License. */ -import { genkitPlugin, Plugin } from 'genkit'; +import { Genkit } from 'genkit'; +import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; import { SUPPORTED_MODELS as EMBEDDER_MODELS, + defineGoogleAIEmbedder, textEmbeddingGecko001, - textEmbeddingGeckoEmbedder, } from './embedder.js'; import { - gemini15Flash, - gemini15Pro, - geminiPro, - geminiProVision, - googleAIModel, SUPPORTED_V15_MODELS, SUPPORTED_V1_MODELS, -} from './gemini.js'; -export { + defineGoogleAIModel, + gemini10Pro, gemini15Flash, gemini15Pro, - geminiPro, - geminiProVision, - textEmbeddingGecko001, -}; +} from './gemini.js'; +export { gemini10Pro, gemini15Flash, gemini15Pro, textEmbeddingGecko001 }; export interface PluginOptions { apiKey?: string; @@ -43,11 +37,8 @@ export interface PluginOptions { baseUrl?: string; } -export const googleAI: Plugin<[PluginOptions] | []> = genkitPlugin( - 'googleai', - async (options?: PluginOptions) => { - let models; - let embedders; +export function googleAI(options?: PluginOptions): GenkitPlugin { + return genkitPlugin('googleai', async (ai: Genkit) => { let apiVersions = ['v1']; if (options?.apiVersion) { @@ -58,33 +49,40 @@ export const googleAI: Plugin<[PluginOptions] | []> = genkitPlugin( } } if (apiVersions.includes('v1beta')) { - (embedders = []), - (models = [ - ...Object.keys(SUPPORTED_V15_MODELS).map((name) => - googleAIModel(name, options?.apiKey, 'v1beta', options?.baseUrl) - ), - ]); + Object.keys(SUPPORTED_V15_MODELS).forEach((name) => + defineGoogleAIModel( + ai, + name, + options?.apiKey, + 'v1beta', + options?.baseUrl + ) + ); } if (apiVersions.includes('v1')) { - models = [ - ...Object.keys(SUPPORTED_V1_MODELS).map((name) => - googleAIModel(name, options?.apiKey, undefined, options?.baseUrl) - ), - ...Object.keys(SUPPORTED_V15_MODELS).map((name) => - googleAIModel(name, options?.apiKey, undefined, options?.baseUrl) - ), - ]; - embedders = [ - ...Object.keys(EMBEDDER_MODELS).map((name) => - textEmbeddingGeckoEmbedder(name, { apiKey: options?.apiKey }) - ), - ]; + Object.keys(SUPPORTED_V1_MODELS).forEach((name) => + defineGoogleAIModel( + ai, + name, + options?.apiKey, + undefined, + options?.baseUrl + ) + ); + Object.keys(SUPPORTED_V15_MODELS).forEach((name) => + defineGoogleAIModel( + ai, + name, + options?.apiKey, + undefined, + options?.baseUrl + ) + ); + Object.keys(EMBEDDER_MODELS).forEach((name) => + defineGoogleAIEmbedder(ai, name, { apiKey: options?.apiKey }) + ); } - return { - models, - embedders, - }; - } -); + }); +} export default googleAI; diff --git a/js/plugins/langchain/package.json b/js/plugins/langchain/package.json index d0bf80e49..ecf255071 100644 --- a/js/plugins/langchain/package.json +++ b/js/plugins/langchain/package.json @@ -9,7 +9,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/plugins/langchain/src/evaluators.ts b/js/plugins/langchain/src/evaluators.ts index b112149d2..9edb12d44 100644 --- a/js/plugins/langchain/src/evaluators.ts +++ b/js/plugins/langchain/src/evaluators.ts @@ -14,19 +14,20 @@ * limitations under the License. */ -import { ModelArgument } from 'genkit'; -import { BaseEvalDataPoint, defineEvaluator } from 'genkit/evaluator'; +import { Genkit, ModelArgument } from 'genkit'; +import { BaseEvalDataPoint } from 'genkit/evaluator'; import { Criteria, loadEvaluator } from 'langchain/evaluation'; import { genkitModel } from './model.js'; import { GenkitTracer } from './tracing.js'; export function langchainEvaluator( + ai: Genkit, type: 'labeled_criteria' | 'criteria', criteria: Criteria, judgeLlm: ModelArgument, judgeConfig?: any ) { - return defineEvaluator( + return ai.defineEvaluator( { name: `langchain/${criteria}`, displayName: `${criteria}`, @@ -41,7 +42,7 @@ export function langchainEvaluator( type as 'labeled_criteria' | 'criteria', { criteria, - llm: genkitModel(judgeLlm, judgeConfig), + llm: genkitModel(ai, judgeLlm, judgeConfig), chainOptions: { callbacks: [new GenkitTracer()], }, diff --git a/js/plugins/langchain/src/index.ts b/js/plugins/langchain/src/index.ts index 8e1867b8a..8fc7642bc 100644 --- a/js/plugins/langchain/src/index.ts +++ b/js/plugins/langchain/src/index.ts @@ -14,14 +14,9 @@ * limitations under the License. */ -import { - EvaluatorAction, - genkitPlugin, - ModelArgument, - Plugin, - z, -} from 'genkit'; +import { EvaluatorAction, Genkit, ModelArgument, z } from 'genkit'; import { GenerationCommonConfigSchema } from 'genkit/model'; +import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; import { Criteria } from 'langchain/evaluation'; import { langchainEvaluator } from './evaluators'; @@ -38,14 +33,14 @@ interface LangchainPluginParams< }; } -export const langchain: Plugin<[LangchainPluginParams]> = genkitPlugin( - 'langchain', - async (params: LangchainPluginParams) => { +export function langchain(params: LangchainPluginParams): GenkitPlugin { + return genkitPlugin('langchain', async (ai: Genkit) => { const evaluators: EvaluatorAction[] = []; if (params.evaluators) { for (const criteria of params.evaluators.criteria ?? []) { evaluators.push( langchainEvaluator( + ai, 'criteria', criteria, params.evaluators.judge, @@ -56,6 +51,7 @@ export const langchain: Plugin<[LangchainPluginParams]> = genkitPlugin( for (const criteria of params.evaluators.labeledCriteria ?? []) { evaluators.push( langchainEvaluator( + ai, 'labeled_criteria', criteria, params.evaluators.judge, @@ -64,8 +60,5 @@ export const langchain: Plugin<[LangchainPluginParams]> = genkitPlugin( ); } } - return { - evaluators, - }; - } -); + }); +} diff --git a/js/plugins/langchain/src/model.ts b/js/plugins/langchain/src/model.ts index 8a44cf699..b2c5aa824 100644 --- a/js/plugins/langchain/src/model.ts +++ b/js/plugins/langchain/src/model.ts @@ -15,20 +15,25 @@ */ import { LLMResult } from '@langchain/core/outputs'; -import { generate, ModelArgument } from 'genkit'; +import { Genkit, ModelArgument } from 'genkit'; import { logger } from 'genkit/logging'; import { ModelAction } from 'genkit/model'; import { CallbackManagerForLLMRun } from 'langchain/callbacks'; import { BaseLLM } from 'langchain/llms/base'; -export function genkitModel(model: ModelArgument, config?: any): BaseLLM { - return new ModelAdapter(model, config); +export function genkitModel( + ai: Genkit, + model: ModelArgument, + config?: any +): BaseLLM { + return new ModelAdapter(ai, model, config); } class ModelAdapter extends BaseLLM { resolvedModel?: ModelAction; constructor( + private ai: Genkit, private model: ModelArgument, private config?: any ) { @@ -47,7 +52,7 @@ class ModelAdapter extends BaseLLM { //options const ress = await Promise.all( prompts.map((p) => - generate({ + this.ai.generate({ model: this.model, prompt: p, config: this.config, @@ -56,7 +61,7 @@ class ModelAdapter extends BaseLLM { ); return { - generations: ress.map((r) => [{ text: r.text() }]), + generations: ress.map((r) => [{ text: r.text }]), }; } diff --git a/js/plugins/ollama/package.json b/js/plugins/ollama/package.json index 23cb82be9..7feefb673 100644 --- a/js/plugins/ollama/package.json +++ b/js/plugins/ollama/package.json @@ -10,7 +10,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/plugins/ollama/src/embeddings.ts b/js/plugins/ollama/src/embeddings.ts index 608e7db38..592703cf3 100644 --- a/js/plugins/ollama/src/embeddings.ts +++ b/js/plugins/ollama/src/embeddings.ts @@ -13,41 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -import { z } from 'genkit'; -import { defineEmbedder } from 'genkit/embedder'; +import { Genkit } from 'genkit'; import { logger } from 'genkit/logging'; import { OllamaPluginParams } from './index.js'; -// Define the schema for Ollama embedding configuration -export const OllamaEmbeddingConfigSchema = z.object({ - modelName: z.string(), - serverAddress: z.string(), -}); -export type OllamaEmbeddingConfig = z.infer; -// Define the structure of the request and response for embedding -interface OllamaEmbeddingInstance { - content: string; -} + interface OllamaEmbeddingPrediction { embedding: number[]; } + interface DefineOllamaEmbeddingParams { name: string; modelName: string; dimensions: number; options: OllamaPluginParams; } -export function defineOllamaEmbedder({ - name, - modelName, - dimensions, - options, -}: DefineOllamaEmbeddingParams) { - return defineEmbedder( + +export function defineOllamaEmbedder( + ai: Genkit, + { name, modelName, dimensions, options }: DefineOllamaEmbeddingParams +) { + return ai.defineEmbedder( { name, - configSchema: OllamaEmbeddingConfigSchema, // Use the Zod schema directly here info: { - // TODO: do we want users to be able to specify the label when they call this method directly? label: 'Ollama Embedding - ' + modelName, dimensions, supports: { @@ -56,17 +44,16 @@ export function defineOllamaEmbedder({ }, }, }, - async (input, _config) => { + async (input) => { const serverAddress = options.serverAddress; const responses = await Promise.all( input.map(async (i) => { const requestPayload = { model: modelName, - prompt: i.text(), + prompt: i.text, }; let res: Response; try { - console.log('MODEL NAME: ', modelName); res = await fetch(`${serverAddress}/api/embeddings`, { method: 'POST', headers: { diff --git a/js/plugins/ollama/src/index.ts b/js/plugins/ollama/src/index.ts index 32572c09d..4734798c4 100644 --- a/js/plugins/ollama/src/index.ts +++ b/js/plugins/ollama/src/index.ts @@ -14,16 +14,16 @@ * limitations under the License. */ -import { genkitPlugin, Plugin } from 'genkit'; +import { Genkit } from 'genkit'; import { logger } from 'genkit/logging'; import { - defineModel, GenerateRequest, GenerateResponseData, GenerationCommonConfigSchema, getBasicUsageStats, MessageData, } from 'genkit/model'; +import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; import { defineOllamaEmbedder } from './embeddings'; type ApiType = 'chat' | 'generate'; @@ -50,32 +50,30 @@ export interface OllamaPluginParams { requestHeaders?: RequestHeaders; } -export const ollama: Plugin<[OllamaPluginParams]> = genkitPlugin( - 'ollama', - async (params: OllamaPluginParams) => { +export function ollama(params: OllamaPluginParams): GenkitPlugin { + return genkitPlugin('ollama', async (ai: Genkit) => { const serverAddress = params?.serverAddress; - return { - models: params.models.map((model) => - ollamaModel(model, serverAddress, params.requestHeaders) - ), - embedders: params.embeddingModels?.map((model) => - defineOllamaEmbedder({ - name: `${ollama}/model.name`, - modelName: model.name, - dimensions: model.dimensions, - options: params, - }) - ), - }; - } -); + params.models.map((model) => + ollamaModel(ai, model, serverAddress, params.requestHeaders) + ); + params.embeddingModels?.map((model) => + defineOllamaEmbedder(ai, { + name: `${ollama}/model.name`, + modelName: model.name, + dimensions: model.dimensions, + options: params, + }) + ); + }); +} function ollamaModel( + ai: Genkit, model: ModelDefinition, serverAddress: string, requestHeaders?: RequestHeaders ) { - return defineModel( + return ai.defineModel( { name: `ollama/${model.name}`, label: `Ollama - ${model.name}`, diff --git a/js/plugins/ollama/tests/embedding_live_test.ts b/js/plugins/ollama/tests/embedding_live_test.ts index 56dd9447f..5d1d7097c 100644 --- a/js/plugins/ollama/tests/embedding_live_test.ts +++ b/js/plugins/ollama/tests/embedding_live_test.ts @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -import { embed } from 'genkit'; +import { genkit } from 'genkit'; import assert from 'node:assert'; import { describe, it } from 'node:test'; import { defineOllamaEmbedder } from '../src/embeddings.js'; // Adjust the import path as necessary @@ -36,13 +36,14 @@ describe('defineOllamaEmbedder - Live Tests', () => { serverAddress, }; it('should successfully return embeddings', async () => { - const embedder = defineOllamaEmbedder({ + const ai = genkit({}); + const embedder = defineOllamaEmbedder(ai, { name: 'live-test-embedder', modelName: 'nomic-embed-text', dimensions: 768, options, }); - const result = await embed({ + const result = await ai.embed({ embedder, content: 'Hello, world!', }); diff --git a/js/plugins/ollama/tests/embeddings_test.ts b/js/plugins/ollama/tests/embeddings_test.ts index 19075c950..e61a94b99 100644 --- a/js/plugins/ollama/tests/embeddings_test.ts +++ b/js/plugins/ollama/tests/embeddings_test.ts @@ -13,15 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -import { embed } from 'genkit'; -import { Registry, runWithRegistry } from 'genkit/registry'; +import { Genkit, genkit } from 'genkit'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; -import { - OllamaEmbeddingConfigSchema, - defineOllamaEmbedder, -} from '../src/embeddings.js'; // Adjust the import path as necessary -import { OllamaPluginParams } from '../src/index.js'; // Adjust the import path as necessary +import { defineOllamaEmbedder } from '../src/embeddings.js'; +import { OllamaPluginParams } from '../src/index.js'; + // Mock fetch to simulate API responses global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => { const url = typeof input === 'string' ? input : input.toString(); @@ -42,93 +39,70 @@ global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => { } throw new Error('Unknown API endpoint'); }; + describe('defineOllamaEmbedder', () => { const options: OllamaPluginParams = { models: [{ name: 'test-model' }], serverAddress: 'http://localhost:3000', }; - let registry: Registry; + let ai: Genkit; beforeEach(() => { - registry = new Registry(); + ai = genkit({}); }); it('should successfully return embeddings', async () => { - await runWithRegistry(registry, async () => { - const embedder = defineOllamaEmbedder({ - name: 'test-embedder', - modelName: 'test-model', - dimensions: 123, - options, - }); - const result = await embed({ - embedder, - content: 'Hello, world!', - }); - assert.deepStrictEqual(result, [0.1, 0.2, 0.3]); + const embedder = defineOllamaEmbedder(ai, { + name: 'test-embedder', + modelName: 'test-model', + dimensions: 123, + options, }); - }); - - it('should handle API errors correctly', async () => { - await runWithRegistry(registry, async () => { - const embedder = defineOllamaEmbedder({ - name: 'test-embedder', - modelName: 'test-model', - dimensions: 123, - options, - }); - await assert.rejects( - async () => { - await embed({ - embedder, - content: 'fail', - }); - }, - (error) => { - assert(error instanceof Error); - assert.strictEqual( - error.message, - 'Error fetching embedding from Ollama: Internal Server Error' - ); - return true; - } - ); + const result = await ai.embed({ + embedder, + content: 'Hello, world!', }); + assert.deepStrictEqual(result, [0.1, 0.2, 0.3]); }); - it('should validate the embedding configuration schema', async () => { - const validConfig = { + it('should handle API errors correctly', async () => { + const embedder = defineOllamaEmbedder(ai, { + name: 'test-embedder', modelName: 'test-model', - serverAddress: 'http://localhost:3000', - }; - const invalidConfig = { - modelName: 123, // Invalid type - serverAddress: 'http://localhost:3000', - }; - // Valid configuration should pass - assert.doesNotThrow(() => { - OllamaEmbeddingConfigSchema.parse(validConfig); - }); - // Invalid configuration should throw - assert.throws(() => { - OllamaEmbeddingConfigSchema.parse(invalidConfig); + dimensions: 123, + options, }); - }); - it('should throw an error if the fetch response is not ok', async () => { - await runWithRegistry(registry, async () => { - const embedder = defineOllamaEmbedder({ - name: 'test-embedder', - modelName: 'test-model', - dimensions: 123, - options, - }); - - await assert.rejects(async () => { - await embed({ + await assert.rejects( + async () => { + await ai.embed({ embedder, content: 'fail', }); - }, new Error('Error fetching embedding from Ollama: Internal Server Error')); + }, + (error) => { + assert(error instanceof Error); + assert.strictEqual( + error.message, + 'Error fetching embedding from Ollama: Internal Server Error' + ); + return true; + } + ); + }); + + it('should throw an error if the fetch response is not ok', async () => { + const embedder = defineOllamaEmbedder(ai, { + name: 'test-embedder', + modelName: 'test-model', + dimensions: 123, + options, }); + + await assert.rejects(async () => { + await ai.embed({ + embedder, + content: 'fail', + }); + }, new Error('Error fetching embedding from Ollama: Internal Server Error')); }); }); diff --git a/js/plugins/pinecone/package.json b/js/plugins/pinecone/package.json index 9ff34e53c..d738c04be 100644 --- a/js/plugins/pinecone/package.json +++ b/js/plugins/pinecone/package.json @@ -13,7 +13,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/plugins/pinecone/src/index.ts b/js/plugins/pinecone/src/index.ts index 6c03d124f..33548a9a7 100644 --- a/js/plugins/pinecone/src/index.ts +++ b/js/plugins/pinecone/src/index.ts @@ -20,13 +20,13 @@ import { PineconeConfiguration, RecordMetadata, } from '@pinecone-database/pinecone'; -import { PluginProvider, genkitPlugin, z } from 'genkit'; -import { EmbedderArgument, embed } from 'genkit/embedder'; +import { Genkit, z } from 'genkit'; +import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; + +import { EmbedderArgument } from 'genkit/embedder'; import { CommonRetrieverOptionsSchema, Document, - defineIndexer, - defineRetriever, indexerRef, retrieverRef, } from 'genkit/retriever'; @@ -97,23 +97,11 @@ export function pinecone( embedder: EmbedderArgument; embedderOptions?: z.infer; }[] -): PluginProvider { - const plugin = genkitPlugin( - 'pinecone', - async ( - params: { - clientParams?: PineconeConfiguration; - indexId: string; - textKey?: string; - embedder: EmbedderArgument; - embedderOptions?: z.infer; - }[] - ) => ({ - retrievers: params.map((i) => configurePineconeRetriever(i)), - indexers: params.map((i) => configurePineconeIndexer(i)), - }) - ); - return plugin(params); +): GenkitPlugin { + return genkitPlugin('pinecone', async (ai: Genkit) => { + params.map((i) => configurePineconeRetriever(ai, i)); + params.map((i) => configurePineconeIndexer(ai, i)); + }); } export default pinecone; @@ -123,13 +111,16 @@ export default pinecone; */ export function configurePineconeRetriever< EmbedderCustomOptions extends z.ZodTypeAny, ->(params: { - indexId: string; - clientParams?: PineconeConfiguration; - textKey?: string; - embedder: EmbedderArgument; - embedderOptions?: z.infer; -}) { +>( + ai: Genkit, + params: { + indexId: string; + clientParams?: PineconeConfiguration; + textKey?: string; + embedder: EmbedderArgument; + embedderOptions?: z.infer; + } +) { const { indexId, embedder, embedderOptions } = { ...params, }; @@ -138,13 +129,13 @@ export function configurePineconeRetriever< const pinecone = new Pinecone(pineconeConfig); const index = pinecone.index(indexId); - return defineRetriever( + return ai.defineRetriever( { name: `pinecone/${params.indexId}`, configSchema: PineconeRetrieverOptionsSchema, }, async (content, options) => { - const queryEmbeddings = await embed({ + const queryEmbeddings = await ai.embed({ embedder, content, options: embedderOptions, @@ -178,13 +169,16 @@ export function configurePineconeRetriever< */ export function configurePineconeIndexer< EmbedderCustomOptions extends z.ZodTypeAny, ->(params: { - indexId: string; - clientParams?: PineconeConfiguration; - textKey?: string; - embedder: EmbedderArgument; - embedderOptions?: z.infer; -}) { +>( + ai: Genkit, + params: { + indexId: string; + clientParams?: PineconeConfiguration; + textKey?: string; + embedder: EmbedderArgument; + embedderOptions?: z.infer; + } +) { const { indexId, embedder, embedderOptions } = { ...params, }; @@ -193,7 +187,7 @@ export function configurePineconeIndexer< const pinecone = new Pinecone(pineconeConfig); const index = pinecone.index(indexId); - return defineIndexer( + return ai.defineIndexer( { name: `pinecone/${params.indexId}`, configSchema: PineconeIndexerOptionsSchema.optional(), @@ -205,7 +199,7 @@ export function configurePineconeIndexer< const embeddings = await Promise.all( docs.map((doc) => - embed({ + ai.embed({ embedder, content: doc, options: embedderOptions, @@ -218,7 +212,7 @@ export function configurePineconeIndexer< ...docs[i].metadata, }; - metadata[textKey] = docs[i].text(); + metadata[textKey] = docs[i].text; const id = Md5.hashStr(JSON.stringify(docs[i])); return { id, diff --git a/js/plugins/vertexai/package.json b/js/plugins/vertexai/package.json index dd0cf6a4d..157f361d1 100644 --- a/js/plugins/vertexai/package.json +++ b/js/plugins/vertexai/package.json @@ -17,7 +17,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", @@ -48,7 +48,7 @@ "genkit": "workspace:*" }, "optionalDependencies": { - "firebase-admin": "^12.1.0", + "firebase-admin": ">=12.2", "@google-cloud/bigquery": "^7.8.0" }, "devDependencies": { diff --git a/js/plugins/vertexai/src/anthropic.ts b/js/plugins/vertexai/src/anthropic.ts index 28bca6b59..a28ea12d4 100644 --- a/js/plugins/vertexai/src/anthropic.ts +++ b/js/plugins/vertexai/src/anthropic.ts @@ -32,6 +32,7 @@ import { AnthropicVertex } from '@anthropic-ai/vertex-sdk'; import { GENKIT_CLIENT_HEADER, GenerateRequest, + Genkit, Part as GenkitPart, MessageData, ModelReference, @@ -41,7 +42,6 @@ import { } from 'genkit'; import { GenerationCommonConfigSchema, - defineModel, getBasicUsageStats, modelRef, } from 'genkit/model'; @@ -360,6 +360,7 @@ function toAnthropicToolResponse(part: Part): ToolResultBlockParam { } export function anthropicModel( + ai: Genkit, modelName: string, projectId: string, region: string @@ -382,7 +383,7 @@ export function anthropicModel( throw new Error(`unsupported Anthropic model name ${modelName}`); } - return defineModel( + return ai.defineModel( { name: model.name, label: model.info?.label, diff --git a/js/plugins/vertexai/src/embedder.ts b/js/plugins/vertexai/src/embedder.ts index a7cf59a85..10d2ca18c 100644 --- a/js/plugins/vertexai/src/embedder.ts +++ b/js/plugins/vertexai/src/embedder.ts @@ -14,12 +14,8 @@ * limitations under the License. */ -import { z } from 'genkit'; -import { - defineEmbedder, - embedderRef, - EmbedderReference, -} from 'genkit/embedder'; +import { Genkit, z } from 'genkit'; +import { EmbedderReference, embedderRef } from 'genkit/embedder'; import { GoogleAuth } from 'google-auth-library'; import { PluginOptions } from './index.js'; import { PredictClient, predictModel } from './predict.js'; @@ -31,9 +27,10 @@ export const TaskTypeSchema = z.enum([ 'CLASSIFICATION', 'CLUSTERING', ]); + export type TaskType = z.infer; -export const TextEmbeddingGeckoConfigSchema = z.object({ +export const VertexEmbeddingConfigSchema = z.object({ /** * The `task_type` parameter is defined as the intended downstream application to help the model * produce better quality embeddings. @@ -41,92 +38,47 @@ export const TextEmbeddingGeckoConfigSchema = z.object({ taskType: TaskTypeSchema.optional(), title: z.string().optional(), location: z.string().optional(), -}); -export type TextEmbeddingGeckoConfig = z.infer< - typeof TextEmbeddingGeckoConfigSchema ->; - -export const textEmbeddingGecko003 = embedderRef({ - name: 'vertexai/textembedding-gecko@003', - configSchema: TextEmbeddingGeckoConfigSchema, - info: { - dimensions: 768, - label: 'Vertex AI - Text Embedding Gecko', - supports: { - input: ['text'], - }, - }, -}); - -export const textEmbeddingGecko002 = embedderRef({ - name: 'vertexai/textembedding-gecko@002', - configSchema: TextEmbeddingGeckoConfigSchema, - info: { - dimensions: 768, - label: 'Vertex AI - Text Embedding Gecko', - supports: { - input: ['text'], - }, - }, -}); - -export const textEmbeddingGecko001 = embedderRef({ - name: 'vertexai/textembedding-gecko@001', - configSchema: TextEmbeddingGeckoConfigSchema, - info: { - dimensions: 768, - label: 'Vertex AI - Text Embedding Gecko (Legacy)', - supports: { - input: ['text'], - }, - }, + version: z.string().optional(), }); -export const textEmbedding004 = embedderRef({ - name: 'vertexai/text-embedding-004', - configSchema: TextEmbeddingGeckoConfigSchema, - info: { - dimensions: 768, - label: 'Vertex AI - Text Embedding 004', - supports: { - input: ['text'], - }, - }, -}); +export type VertexEmbeddingConfig = z.infer; -export const textMultilingualEmbedding002 = embedderRef({ - name: 'vertexai/text-multilingual-embedding-002', - configSchema: TextEmbeddingGeckoConfigSchema, - info: { - dimensions: 768, - label: 'Vertex AI - Text Multilingual Embedding 002', - supports: { - input: ['text'], - }, - }, -}); - -export const textEmbeddingGeckoMultilingual001 = embedderRef({ - name: 'vertexai/textembedding-gecko-multilingual@001', - configSchema: TextEmbeddingGeckoConfigSchema, - info: { - dimensions: 768, - label: 'Vertex AI - Multilingual Text Embedding Gecko 001', - supports: { - input: ['text'], +function commonRef( + name: string, + input?: ('text' | 'image')[] +): EmbedderReference { + return embedderRef({ + name: `vertexai/${name}`, + configSchema: VertexEmbeddingConfigSchema, + info: { + dimensions: 768, + label: `Vertex AI - ${name}`, + supports: { + input: input ?? ['text'], + }, }, - }, -}); + }); +} -export const textEmbeddingGecko = textEmbeddingGecko003; +export const textEmbeddingGecko003 = commonRef('textembedding-gecko@003'); +export const textEmbedding004 = commonRef('text-embedding-004'); +export const textEmbeddingGeckoMultilingual001 = commonRef( + 'textembedding-gecko-multilingual@001' +); +export const textMultilingualEmbedding002 = commonRef( + 'text-multilingual-embedding-002' +); export const SUPPORTED_EMBEDDER_MODELS: Record = { 'textembedding-gecko@003': textEmbeddingGecko003, - 'textembedding-gecko@002': textEmbeddingGecko002, - 'textembedding-gecko@001': textEmbeddingGecko001, 'text-embedding-004': textEmbedding004, 'textembedding-gecko-multilingual@001': textEmbeddingGeckoMultilingual001, 'text-multilingual-embedding-002': textMultilingualEmbedding002, + // TODO: add support for multimodal embeddings + // 'multimodalembedding@001': commonRef('multimodalembedding@001', [ + // 'image', + // 'text', + // ]), }; interface EmbeddingInstance { @@ -144,7 +96,8 @@ interface EmbeddingPrediction { }; } -export function textEmbeddingGeckoEmbedder( +export function defineVertexAIEmbedder( + ai: Genkit, name: string, client: GoogleAuth, options: PluginOptions @@ -155,7 +108,7 @@ export function textEmbeddingGeckoEmbedder( PredictClient > = {}; const predictClientFactory = ( - config: TextEmbeddingGeckoConfig + config: VertexEmbeddingConfig ): PredictClient => { const requestLocation = config?.location || options.location; if (!predictClients[requestLocation]) { @@ -175,7 +128,7 @@ export function textEmbeddingGeckoEmbedder( return predictClients[requestLocation]; }; - return defineEmbedder( + return ai.defineEmbedder( { name: embedder.name, configSchema: embedder.configSchema, @@ -186,7 +139,7 @@ export function textEmbeddingGeckoEmbedder( const response = await predictClient( input.map((i) => { return { - content: i.text(), + content: i.text, task_type: options?.taskType, title: options?.title, }; diff --git a/js/plugins/vertexai/src/evaluation.ts b/js/plugins/vertexai/src/evaluation.ts index da3538aac..965b9fc86 100644 --- a/js/plugins/vertexai/src/evaluation.ts +++ b/js/plugins/vertexai/src/evaluation.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { Action, z } from 'genkit'; +import { Action, Genkit, z } from 'genkit'; import { GoogleAuth } from 'google-auth-library'; import { EvaluatorFactory } from './evaluator_factory.js'; @@ -50,6 +50,7 @@ export type VertexAIEvaluationMetric = | VertexAIEvaluationMetricConfig; export function vertexEvaluators( + ai: Genkit, auth: GoogleAuth, metrics: VertexAIEvaluationMetric[], projectId: string, @@ -62,28 +63,28 @@ export function vertexEvaluators( switch (metricType) { case VertexAIEvaluationMetricType.BLEU: { - return createBleuEvaluator(factory, metricSpec); + return createBleuEvaluator(ai, factory, metricSpec); } case VertexAIEvaluationMetricType.ROUGE: { - return createRougeEvaluator(factory, metricSpec); + return createRougeEvaluator(ai, factory, metricSpec); } case VertexAIEvaluationMetricType.FLUENCY: { - return createFluencyEvaluator(factory, metricSpec); + return createFluencyEvaluator(ai, factory, metricSpec); } case VertexAIEvaluationMetricType.SAFETY: { - return createSafetyEvaluator(factory, metricSpec); + return createSafetyEvaluator(ai, factory, metricSpec); } case VertexAIEvaluationMetricType.GROUNDEDNESS: { - return createGroundednessEvaluator(factory, metricSpec); + return createGroundednessEvaluator(ai, factory, metricSpec); } case VertexAIEvaluationMetricType.SUMMARIZATION_QUALITY: { - return createSummarizationQualityEvaluator(factory, metricSpec); + return createSummarizationQualityEvaluator(ai, factory, metricSpec); } case VertexAIEvaluationMetricType.SUMMARIZATION_HELPFULNESS: { - return createSummarizationHelpfulnessEvaluator(factory, metricSpec); + return createSummarizationHelpfulnessEvaluator(ai, factory, metricSpec); } case VertexAIEvaluationMetricType.SUMMARIZATION_VERBOSITY: { - return createSummarizationVerbosityEvaluator(factory, metricSpec); + return createSummarizationVerbosityEvaluator(ai, factory, metricSpec); } } }); @@ -103,10 +104,12 @@ const BleuResponseSchema = z.object({ // TODO: Add support for batch inputs function createBleuEvaluator( + ai: Genkit, factory: EvaluatorFactory, metricSpec: any ): Action { return factory.create( + ai, { metric: VertexAIEvaluationMetricType.BLEU, displayName: 'BLEU', @@ -143,10 +146,12 @@ const RougeResponseSchema = z.object({ // TODO: Add support for batch inputs function createRougeEvaluator( + ai: Genkit, factory: EvaluatorFactory, metricSpec: any ): Action { return factory.create( + ai, { metric: VertexAIEvaluationMetricType.ROUGE, displayName: 'ROUGE', @@ -182,10 +187,12 @@ const FluencyResponseSchema = z.object({ }); function createFluencyEvaluator( + ai: Genkit, factory: EvaluatorFactory, metricSpec: any ): Action { return factory.create( + ai, { metric: VertexAIEvaluationMetricType.FLUENCY, displayName: 'Fluency', @@ -222,10 +229,12 @@ const SafetyResponseSchema = z.object({ }); function createSafetyEvaluator( + ai: Genkit, factory: EvaluatorFactory, metricSpec: any ): Action { return factory.create( + ai, { metric: VertexAIEvaluationMetricType.SAFETY, displayName: 'Safety', @@ -262,10 +271,12 @@ const GroundednessResponseSchema = z.object({ }); function createGroundednessEvaluator( + ai: Genkit, factory: EvaluatorFactory, metricSpec: any ): Action { return factory.create( + ai, { metric: VertexAIEvaluationMetricType.GROUNDEDNESS, displayName: 'Groundedness', @@ -304,10 +315,12 @@ const SummarizationQualityResponseSchema = z.object({ }); function createSummarizationQualityEvaluator( + ai: Genkit, factory: EvaluatorFactory, metricSpec: any ): Action { return factory.create( + ai, { metric: VertexAIEvaluationMetricType.SUMMARIZATION_QUALITY, displayName: 'Summarization quality', @@ -346,10 +359,12 @@ const SummarizationHelpfulnessResponseSchema = z.object({ }); function createSummarizationHelpfulnessEvaluator( + ai: Genkit, factory: EvaluatorFactory, metricSpec: any ): Action { return factory.create( + ai, { metric: VertexAIEvaluationMetricType.SUMMARIZATION_HELPFULNESS, displayName: 'Summarization helpfulness', @@ -389,10 +404,12 @@ const SummarizationVerbositySchema = z.object({ }); function createSummarizationVerbosityEvaluator( + ai: Genkit, factory: EvaluatorFactory, metricSpec: any ): Action { return factory.create( + ai, { metric: VertexAIEvaluationMetricType.SUMMARIZATION_VERBOSITY, displayName: 'Summarization verbosity', diff --git a/js/plugins/vertexai/src/evaluator_factory.ts b/js/plugins/vertexai/src/evaluator_factory.ts index 7fbb4f299..821f4631b 100644 --- a/js/plugins/vertexai/src/evaluator_factory.ts +++ b/js/plugins/vertexai/src/evaluator_factory.ts @@ -14,8 +14,8 @@ * limitations under the License. */ -import { Action, GENKIT_CLIENT_HEADER, z } from 'genkit'; -import { BaseEvalDataPoint, Score, defineEvaluator } from 'genkit/evaluator'; +import { Action, Genkit, GENKIT_CLIENT_HEADER, z } from 'genkit'; +import { BaseEvalDataPoint, Score } from 'genkit/evaluator'; import { runInNewSpan } from 'genkit/tracing'; import { GoogleAuth } from 'google-auth-library'; import { VertexAIEvaluationMetricType } from './evaluation.js'; @@ -28,6 +28,7 @@ export class EvaluatorFactory { ) {} create( + ai: Genkit, config: { metric: VertexAIEvaluationMetricType; displayName: string; @@ -37,7 +38,7 @@ export class EvaluatorFactory { toRequest: (datapoint: BaseEvalDataPoint) => any, responseHandler: (response: z.infer) => Score ): Action { - return defineEvaluator( + return ai.defineEvaluator( { name: `vertexai/${config.metric.toLocaleLowerCase()}`, displayName: config.displayName, diff --git a/js/plugins/vertexai/src/gemini.ts b/js/plugins/vertexai/src/gemini.ts index 6f30a88aa..42dfb2215 100644 --- a/js/plugins/vertexai/src/gemini.ts +++ b/js/plugins/vertexai/src/gemini.ts @@ -27,7 +27,7 @@ import { StartChatParams, VertexAI, } from '@google-cloud/vertexai'; -import { GENKIT_CLIENT_HEADER, z } from 'genkit'; +import { GENKIT_CLIENT_HEADER, Genkit, z } from 'genkit'; import { CandidateData, GenerateRequest, @@ -39,7 +39,6 @@ import { ModelReference, Part, ToolDefinitionSchema, - defineModel, getBasicUsageStats, modelRef, } from 'genkit/model'; @@ -74,11 +73,11 @@ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ googleSearchRetrieval: GoogleSearchRetrievalSchema.optional(), }); -export const geminiPro = modelRef({ +export const gemini10Pro = modelRef({ name: 'vertexai/gemini-1.0-pro', info: { label: 'Vertex AI - Gemini Pro', - versions: ['gemini-1.0-pro', 'gemini-1.0-pro-001'], + versions: ['gemini-1.0-pro-001', 'gemini-1.0-pro-002'], supports: { multiturn: true, media: false, @@ -89,29 +88,11 @@ export const geminiPro = modelRef({ configSchema: GeminiConfigSchema, }); -export const geminiProVision = modelRef({ - name: 'vertexai/gemini-1.0-pro-vision', - info: { - label: 'Vertex AI - Gemini Pro Vision', - versions: ['gemini-1.0-pro-vision', 'gemini-1.0-pro-vision-001'], - supports: { - multiturn: true, - media: true, - tools: false, - systemRole: false, - }, - }, - configSchema: GeminiConfigSchema.omit({ - googleSearchRetrieval: true, - vertexRetrieval: true, - }), -}); - export const gemini15Pro = modelRef({ name: 'vertexai/gemini-1.5-pro', info: { label: 'Vertex AI - Gemini 1.5 Pro', - versions: ['gemini-1.5-pro-001'], + versions: ['gemini-1.5-pro-001', 'gemini-1.5-pro-002'], supports: { multiturn: true, media: true, @@ -122,43 +103,11 @@ export const gemini15Pro = modelRef({ configSchema: GeminiConfigSchema, }); -export const gemini15ProPreview = modelRef({ - name: 'vertexai/gemini-1.5-pro-preview', - info: { - label: 'Vertex AI - Gemini 1.5 Pro Preview', - versions: ['gemini-1.5-pro-preview-0409'], - supports: { - multiturn: true, - media: true, - tools: true, - systemRole: true, - }, - }, - configSchema: GeminiConfigSchema, - version: 'gemini-1.5-pro-preview-0409', -}); - -export const gemini15FlashPreview = modelRef({ - name: 'vertexai/gemini-1.5-flash-preview', - info: { - label: 'Vertex AI - Gemini 1.5 Flash', - versions: ['gemini-1.5-flash-preview-0514'], - supports: { - multiturn: true, - media: true, - tools: true, - systemRole: true, - }, - }, - configSchema: GeminiConfigSchema, - version: 'gemini-1.5-flash-preview-0514', -}); - export const gemini15Flash = modelRef({ name: 'vertexai/gemini-1.5-flash', info: { label: 'Vertex AI - Gemini 1.5 Flash', - versions: ['gemini-1.5-flash-001'], + versions: ['gemini-1.5-flash-001', 'gemini-1.5-flash-002'], supports: { multiturn: true, media: true, @@ -170,16 +119,12 @@ export const gemini15Flash = modelRef({ }); export const SUPPORTED_V1_MODELS = { - 'gemini-1.0-pro': geminiPro, - 'gemini-1.0-pro-vision': geminiProVision, - // 'gemini-ultra': geminiUltra, + 'gemini-1.0-pro': gemini10Pro, }; export const SUPPORTED_V15_MODELS = { 'gemini-1.5-pro': gemini15Pro, 'gemini-1.5-flash': gemini15Flash, - 'gemini-1.5-pro-preview': gemini15ProPreview, - 'gemini-1.5-flash-preview': gemini15FlashPreview, }; export const SUPPORTED_GEMINI_MODELS = { @@ -459,9 +404,10 @@ const convertSchemaProperty = (property) => { }; /** - * + * Define a Vertex AI Gemini model. */ -export function geminiModel( +export function defineGeminiModel( + ai: Genkit, name: string, vertexClientFactory: ( request: GenerateRequest @@ -482,7 +428,7 @@ export function geminiModel( middlewares.push(downloadRequestMedia({ maxBytes: 1024 * 1024 * 20 })); } - return defineModel( + return ai.defineModel( { name: modelName, ...model.info, diff --git a/js/plugins/vertexai/src/imagen.ts b/js/plugins/vertexai/src/imagen.ts index 015b078d2..12f11fd13 100644 --- a/js/plugins/vertexai/src/imagen.ts +++ b/js/plugins/vertexai/src/imagen.ts @@ -14,15 +14,14 @@ * limitations under the License. */ -import { z } from 'genkit'; +import { Genkit, z } from 'genkit'; import { CandidateData, - defineModel, GenerateRequest, GenerationCommonConfigSchema, + ModelReference, getBasicUsageStats, modelRef, - ModelReference, } from 'genkit/model'; import { GoogleAuth } from 'google-auth-library'; import { PluginOptions } from './index.js'; @@ -215,6 +214,7 @@ interface ImagenInstance { } export function imagenModel( + ai: Genkit, name: string, client: GoogleAuth, options: PluginOptions @@ -248,7 +248,7 @@ export function imagenModel( return predictClients[requestLocation]; }; - return defineModel( + return ai.defineModel( { name: modelName, ...model.info, diff --git a/js/plugins/vertexai/src/index.ts b/js/plugins/vertexai/src/index.ts index d4c3698dc..c6c4a227e 100644 --- a/js/plugins/vertexai/src/index.ts +++ b/js/plugins/vertexai/src/index.ts @@ -15,26 +15,23 @@ */ import { VertexAI } from '@google-cloud/vertexai'; -import { genkitPlugin, Plugin, z } from 'genkit'; +import { Genkit, z } from 'genkit'; import { GenerateRequest, ModelReference } from 'genkit/model'; -import { IndexerAction, RetrieverAction } from 'genkit/retriever'; +import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library'; import { + SUPPORTED_ANTHROPIC_MODELS, anthropicModel, claude35Sonnet, claude3Haiku, claude3Opus, claude3Sonnet, - SUPPORTED_ANTHROPIC_MODELS, } from './anthropic.js'; import { SUPPORTED_EMBEDDER_MODELS, + defineVertexAIEmbedder, textEmbedding004, - textEmbeddingGecko, - textEmbeddingGecko001, - textEmbeddingGecko002, textEmbeddingGecko003, - textEmbeddingGeckoEmbedder, textEmbeddingGeckoMultilingual001, textMultilingualEmbedding002, } from './embedder.js'; @@ -44,31 +41,28 @@ import { vertexEvaluators, } from './evaluation.js'; import { - gemini15Flash, - gemini15FlashPreview, - gemini15Pro, - gemini15ProPreview, GeminiConfigSchema, - geminiModel, - geminiPro, - geminiProVision, SUPPORTED_GEMINI_MODELS, + defineGeminiModel, + gemini10Pro, + gemini15Flash, + gemini15Pro, } from './gemini.js'; import { + SUPPORTED_IMAGEN_MODELS, imagen2, imagen3, imagen3Fast, imagenModel, - SUPPORTED_IMAGEN_MODELS, } from './imagen.js'; import { + SUPPORTED_OPENAI_FORMAT_MODELS, llama3, llama31, llama32, modelGardenOpenaiCompatibleModel, - SUPPORTED_OPENAI_FORMAT_MODELS, } from './model_garden.js'; -import { vertexAiRerankers, VertexRerankerConfig } from './reranker.js'; +import { VertexRerankerConfig, vertexAiRerankers } from './reranker.js'; import { VectorSearchOptions, vertexAiIndexers, @@ -77,28 +71,26 @@ import { export { DocumentIndexer, DocumentRetriever, + Neighbor, + VectorSearchOptions, getBigQueryDocumentIndexer, getBigQueryDocumentRetriever, getFirestoreDocumentIndexer, getFirestoreDocumentRetriever, - Neighbor, - VectorSearchOptions, vertexAiIndexerRef, vertexAiIndexers, vertexAiRetrieverRef, vertexAiRetrievers, } from './vector-search'; export { + VertexAIEvaluationMetricType as VertexAIEvaluationMetricType, claude35Sonnet, claude3Haiku, claude3Opus, claude3Sonnet, + gemini10Pro, gemini15Flash, - gemini15FlashPreview, gemini15Pro, - gemini15ProPreview, - geminiPro, - geminiProVision, imagen2, imagen3, imagen3Fast, @@ -106,13 +98,9 @@ export { llama31, llama32, textEmbedding004, - textEmbeddingGecko, - textEmbeddingGecko001, - textEmbeddingGecko002, textEmbeddingGecko003, textEmbeddingGeckoMultilingual001, textMultilingualEmbedding002, - VertexAIEvaluationMetricType as VertexAIEvaluationMetricType, }; export interface PluginOptions { @@ -146,9 +134,8 @@ const CLOUD_PLATFROM_OAUTH_SCOPE = /** * Add Google Cloud Vertex AI to Genkit. Includes Gemini and Imagen models and text embedder. */ -export const vertexAI: Plugin<[PluginOptions] | []> = genkitPlugin( - 'vertexai', - async (options?: PluginOptions) => { +export function vertexAI(options?: PluginOptions): GenkitPlugin { + return genkitPlugin('vertexai', async (ai: Genkit) => { let authClient; let authOptions = options?.googleAuth; @@ -203,14 +190,12 @@ export const vertexAI: Plugin<[PluginOptions] | []> = genkitPlugin( ? options.evaluation.metrics : []; - const models = [ - ...Object.keys(SUPPORTED_IMAGEN_MODELS).map((name) => - imagenModel(name, authClient, { projectId, location }) - ), - ...Object.keys(SUPPORTED_GEMINI_MODELS).map((name) => - geminiModel(name, vertexClientFactory, { projectId, location }) - ), - ]; + Object.keys(SUPPORTED_IMAGEN_MODELS).map((name) => + imagenModel(ai, name, authClient, { projectId, location }) + ); + Object.keys(SUPPORTED_GEMINI_MODELS).map((name) => + defineGeminiModel(ai, name, vertexClientFactory, { projectId, location }) + ); if (options?.modelGardenModels || options?.modelGarden?.models) { const mgModels = @@ -220,21 +205,20 @@ export const vertexAI: Plugin<[PluginOptions] | []> = genkitPlugin( ([_, value]) => value.name === m.name ); if (anthropicEntry) { - models.push(anthropicModel(anthropicEntry[0], projectId, location)); + anthropicModel(ai, anthropicEntry[0], projectId, location); return; } const openaiModel = Object.entries(SUPPORTED_OPENAI_FORMAT_MODELS).find( ([_, value]) => value.name === m.name ); if (openaiModel) { - models.push( - modelGardenOpenaiCompatibleModel( - openaiModel[0], - projectId, - location, - authClient, - options.modelGarden?.openAiBaseUrlTemplate - ) + modelGardenOpenaiCompatibleModel( + ai, + openaiModel[0], + projectId, + location, + authClient, + options.modelGarden?.openAiBaseUrlTemplate ); return; } @@ -243,25 +227,22 @@ export const vertexAI: Plugin<[PluginOptions] | []> = genkitPlugin( } const embedders = Object.keys(SUPPORTED_EMBEDDER_MODELS).map((name) => - textEmbeddingGeckoEmbedder(name, authClient, { projectId, location }) + defineVertexAIEmbedder(ai, name, authClient, { projectId, location }) ); - let indexers: IndexerAction[] = []; - let retrievers: RetrieverAction[] = []; - if ( options?.vectorSearchOptions && options.vectorSearchOptions.length > 0 ) { const defaultEmbedder = embedders[0]; - indexers = vertexAiIndexers({ + vertexAiIndexers(ai, { pluginOptions: options, authClient, defaultEmbedder, }); - retrievers = vertexAiRetrievers({ + vertexAiRetrievers(ai, { pluginOptions: options, authClient, defaultEmbedder, @@ -273,18 +254,9 @@ export const vertexAI: Plugin<[PluginOptions] | []> = genkitPlugin( authClient, projectId, }; - - const rerankers = await vertexAiRerankers(rerankOptions); - - return { - models, - embedders, - evaluators: vertexEvaluators(authClient, metrics, projectId, location), - retrievers, - indexers, - rerankers, - }; - } -); + await vertexAiRerankers(ai, rerankOptions); + vertexEvaluators(ai, authClient, metrics, projectId, location); + }); +} export default vertexAI; diff --git a/js/plugins/vertexai/src/model_garden.ts b/js/plugins/vertexai/src/model_garden.ts index f9eb4bf95..eec87274c 100644 --- a/js/plugins/vertexai/src/model_garden.ts +++ b/js/plugins/vertexai/src/model_garden.ts @@ -14,13 +14,13 @@ * limitations under the License. */ -import { GENKIT_CLIENT_HEADER, z } from 'genkit'; +import { Genkit, GENKIT_CLIENT_HEADER, z } from 'genkit'; import { GenerateRequest, ModelAction, modelRef } from 'genkit/model'; import { GoogleAuth } from 'google-auth-library'; import OpenAI from 'openai'; import { - OpenAIConfigSchema, openaiCompatibleModel, + OpenAIConfigSchema, } from './openai_compatibility.js'; export const ModelGardenModelConfigSchema = OpenAIConfigSchema.extend({ @@ -91,6 +91,7 @@ export const SUPPORTED_OPENAI_FORMAT_MODELS = { }; export function modelGardenOpenaiCompatibleModel( + ai: Genkit, name: string, projectId: string, location: string, @@ -118,5 +119,5 @@ export function modelGardenOpenaiCompatibleModel( }, }); }; - return openaiCompatibleModel(model, clientFactory); + return openaiCompatibleModel(ai, model, clientFactory); } diff --git a/js/plugins/vertexai/src/openai_compatibility.ts b/js/plugins/vertexai/src/openai_compatibility.ts index a6f8cf838..2de914f57 100644 --- a/js/plugins/vertexai/src/openai_compatibility.ts +++ b/js/plugins/vertexai/src/openai_compatibility.ts @@ -14,14 +14,13 @@ * limitations under the License. */ -import { Message, StreamingCallback, z } from 'genkit'; +import { Genkit, Message, StreamingCallback, z } from 'genkit'; import { GenerateResponseChunkData, GenerateResponseData, GenerationCommonConfigSchema, ModelAction, ModelReference, - defineModel, type CandidateData, type GenerateRequest, type MessageData, @@ -114,7 +113,7 @@ export function toOpenAiMessages( case 'system': openAiMsgs.push({ role: role, - content: msg.text(), + content: msg.text, }); break; case 'assistant': { @@ -142,7 +141,7 @@ export function toOpenAiMessages( } else { openAiMsgs.push({ role: role, - content: msg.text(), + content: msg.text, }); } break; @@ -297,13 +296,14 @@ export function toRequestBody( } export function openaiCompatibleModel( + ai: Genkit, model: ModelReference, clientFactory: (request: GenerateRequest) => Promise ): ModelAction { const modelId = model.name; if (!model) throw new Error(`Unsupported model: ${name}`); - return defineModel( + return ai.defineModel( { name: modelId, ...model.info, diff --git a/js/plugins/vertexai/src/reranker.ts b/js/plugins/vertexai/src/reranker.ts index 1e51b3310..95df9b2c9 100644 --- a/js/plugins/vertexai/src/reranker.ts +++ b/js/plugins/vertexai/src/reranker.ts @@ -14,13 +14,8 @@ * limitations under the License. */ -import { z } from 'genkit'; -import { - defineReranker, - RankedDocument, - RerankerAction, - rerankerRef, -} from 'genkit/reranker'; +import { Genkit, z } from 'genkit'; +import { RankedDocument, RerankerAction, rerankerRef } from 'genkit/reranker'; import { GoogleAuth } from 'google-auth-library'; import { PluginOptions } from '.'; @@ -76,6 +71,7 @@ export interface VertexRerankOptions { * @returns {RerankerAction[]} - An array of reranker actions. */ export async function vertexAiRerankers( + ai: Genkit, params: VertexRerankOptions ): Promise[]> { if (!params.pluginOptions) { @@ -97,7 +93,7 @@ export async function vertexAiRerankers( const projectId = await auth.getProjectId(); for (const rerankOption of rerankOptions) { - const reranker = defineReranker( + const reranker = ai.defineReranker( { name: `vertexai/${rerankOption.name || rerankOption.model}`, configSchema: VertexAIRerankerOptionsSchema.optional(), @@ -111,10 +107,10 @@ export async function vertexAiRerankers( ), data: { model: rerankOption.model || DEFAULT_MODEL, // Use model from config or default - query: query.text(), + query: query.text, records: documents.map((doc, idx) => ({ id: `${idx}`, - content: doc.text(), + content: doc.text, })), }, }); diff --git a/js/plugins/vertexai/src/vector-search/indexers.ts b/js/plugins/vertexai/src/vector-search/indexers.ts index 50ddc268e..66a00e913 100644 --- a/js/plugins/vertexai/src/vector-search/indexers.ts +++ b/js/plugins/vertexai/src/vector-search/indexers.ts @@ -14,9 +14,8 @@ * limitations under the License. */ -import { z } from 'genkit'; -import { embedMany } from 'genkit/embedder'; -import { defineIndexer, IndexerAction, indexerRef } from 'genkit/retriever'; +import { Genkit, z } from 'genkit'; +import { IndexerAction, indexerRef } from 'genkit/retriever'; import { Datapoint, VertexAIVectorIndexerOptionsSchema, @@ -55,6 +54,7 @@ export const vertexAiIndexerRef = (params: { * @returns {IndexerAction[]} - An array of indexer actions. */ export function vertexAiIndexers( + ai: Genkit, params: VertexVectorSearchOptions ): IndexerAction[] { const vectorSearchOptions = params.pluginOptions.vectorSearchOptions; @@ -70,7 +70,7 @@ export function vertexAiIndexers( const embedder = vectorSearchOption.embedder ?? defaultEmbedder; const embedderOptions = vectorSearchOption.embedderOptions; - const indexer = defineIndexer( + const indexer = ai.defineIndexer( { name: `vertexai/${indexId}`, configSchema: VertexAIVectorIndexerOptionsSchema.optional(), @@ -86,7 +86,7 @@ export function vertexAiIndexers( ); } - const embeddings = await embedMany({ + const embeddings = await ai.embedMany({ embedder, content: docs, options: embedderOptions, diff --git a/js/plugins/vertexai/src/vector-search/retrievers.ts b/js/plugins/vertexai/src/vector-search/retrievers.ts index 1c7a7108e..67f47f33d 100644 --- a/js/plugins/vertexai/src/vector-search/retrievers.ts +++ b/js/plugins/vertexai/src/vector-search/retrievers.ts @@ -14,8 +14,7 @@ * limitations under the License. */ -import { embed, RetrieverAction, retrieverRef, z } from 'genkit'; -import { defineRetriever } from 'genkit/retriever'; +import { Genkit, RetrieverAction, retrieverRef, z } from 'genkit'; import { queryPublicEndpoint } from './query_public_endpoint'; import { VertexAIVectorRetrieverOptionsSchema, @@ -35,6 +34,7 @@ const DEFAULT_K = 10; * @returns {RetrieverAction[]} - An array of retriever actions. */ export function vertexAiRetrievers( + ai: Genkit, params: VertexVectorSearchOptions ): RetrieverAction[] { const vectorSearchOptions = params.pluginOptions.vectorSearchOptions; @@ -51,13 +51,13 @@ export function vertexAiRetrievers( const embedder = vectorSearchOption.embedder ?? defaultEmbedder; const embedderOptions = vectorSearchOption.embedderOptions; - const retriever = defineRetriever( + const retriever = ai.defineRetriever( { name: `vertexai/${indexId}`, configSchema: VertexAIVectorRetrieverOptionsSchema.optional(), }, async (content, options) => { - const queryEmbeddings = await embed({ + const queryEmbeddings = await ai.embed({ embedder, options: embedderOptions, content, diff --git a/js/plugins/vertexai/tests/vector-search/query_public_endpoint_test.ts b/js/plugins/vertexai/tests/vector-search/query_public_endpoint_test.ts index b892b7a5a..9419f2916 100644 --- a/js/plugins/vertexai/tests/vector-search/query_public_endpoint_test.ts +++ b/js/plugins/vertexai/tests/vector-search/query_public_endpoint_test.ts @@ -19,7 +19,8 @@ import { describe, it, Mock } from 'node:test'; import { queryPublicEndpoint } from '../../src/vector-search/query_public_endpoint'; describe('queryPublicEndpoint', () => { - it('queryPublicEndpoint sends the correct request and retrieves neighbors', async (t) => { + // FIXME -- t.mock.method is not supported node above 20 + it.skip('queryPublicEndpoint sends the correct request and retrieves neighbors', async (t) => { t.mock.method(global, 'fetch', async (url, options) => { return { ok: true, diff --git a/js/plugins/vertexai/tests/vector-search/upsert_datapoints_test.ts b/js/plugins/vertexai/tests/vector-search/upsert_datapoints_test.ts index 02c934af1..5b36a47d0 100644 --- a/js/plugins/vertexai/tests/vector-search/upsert_datapoints_test.ts +++ b/js/plugins/vertexai/tests/vector-search/upsert_datapoints_test.ts @@ -21,7 +21,8 @@ import { IIndexDatapoint } from '../../src/vector-search/types'; import { upsertDatapoints } from '../../src/vector-search/upsert_datapoints'; describe('upsertDatapoints', () => { - it('upsertDatapoints sends the correct request and handles response', async (t) => { + // FIXME -- t.mock.method is not supported node above 20 + it.skip('upsertDatapoints sends the correct request and handles response', async (t) => { // Mocking the fetch method within the test scope t.mock.method(global, 'fetch', async (url, options) => { return { diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index 286fc7cb0..693c862d4 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -170,6 +170,9 @@ importers: typescript: specifier: ^4.9.0 version: 4.9.5 + uuid: + specifier: ^10.0.0 + version: 10.0.0 plugins/chroma: dependencies: @@ -275,6 +278,9 @@ importers: plugins/evaluators: dependencies: + '@genkit-ai/dotprompt': + specifier: workspace:* + version: link:../dotprompt compute-cosine-similarity: specifier: ^1.1.0 version: 1.1.0 @@ -319,11 +325,11 @@ importers: specifier: ^4.21.0 version: 4.21.0 firebase-admin: - specifier: ^12.2.0 - version: 12.2.0(encoding@0.1.13) + specifier: '>=12.2' + version: 12.3.1(encoding@0.1.13) firebase-functions: - specifier: ^4.8.0 || ^5.0.0 - version: 4.8.1(encoding@0.1.13)(firebase-admin@12.2.0(encoding@0.1.13)) + specifier: '>=4.8' + version: 4.8.1(encoding@0.1.13)(firebase-admin@12.3.1(encoding@0.1.13)) genkit: specifier: workspace:* version: link:../../genkit @@ -418,22 +424,28 @@ importers: node-fetch: specifier: ^3.3.2 version: 3.3.2 - prettier-plugin-organize-imports: - specifier: ^3.2.4 - version: 3.2.4(prettier@3.2.5)(typescript@4.9.5) winston: specifier: ^3.12.0 version: 3.13.0 devDependencies: + '@jest/globals': + specifier: ^29.7.0 + version: 29.7.0 '@types/node': specifier: ^20.11.16 version: 20.11.30 + jest: + specifier: ^29.7.0 + version: 29.7.0(@types/node@20.11.30) npm-run-all: specifier: ^4.1.5 version: 4.1.5 rimraf: specifier: ^6.0.1 version: 6.0.1 + ts-jest: + specifier: ^29.1.2 + version: 29.2.5(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest@29.7.0(@types/node@20.11.30))(typescript@4.9.5) tsup: specifier: ^8.0.2 version: 8.0.2(postcss@8.4.47)(typescript@4.9.5) @@ -605,8 +617,8 @@ importers: specifier: ^7.8.0 version: 7.8.0(encoding@0.1.13) firebase-admin: - specifier: ^12.1.0 - version: 12.2.0(encoding@0.1.13) + specifier: '>=12.2' + version: 12.3.1(encoding@0.1.13) devDependencies: '@types/node': specifier: ^20.11.16 @@ -729,7 +741,7 @@ importers: specifier: ^1.22.0 version: 1.25.1(@opentelemetry/api@1.9.0) firebase-admin: - specifier: ^12.3.0 + specifier: '>=12.2' version: 12.3.1(encoding@0.1.13) genkit: specifier: workspace:* @@ -753,9 +765,15 @@ importers: '@types/pdf-parse': specifier: ^1.1.4 version: 1.1.4 + cross-env: + specifier: ^7.0.3 + version: 7.0.3 rimraf: specifier: ^6.0.1 version: 6.0.1 + tsx: + specifier: ^4.7.0 + version: 4.19.1 typescript: specifier: ^5.3.3 version: 5.4.5 @@ -771,6 +789,9 @@ importers: '@genkit-ai/firebase': specifier: workspace:* version: link:../../plugins/firebase + '@genkit-ai/google-cloud': + specifier: workspace:* + version: link:../../plugins/google-cloud '@genkit-ai/googleai': specifier: workspace:* version: link:../../plugins/googleai @@ -778,8 +799,8 @@ importers: specifier: workspace:* version: link:../../plugins/vertexai firebase-admin: - specifier: ^12.1.0 - version: 12.1.0(encoding@0.1.13) + specifier: '>=12.2' + version: 12.3.1(encoding@0.1.13) genkit: specifier: workspace:* version: link:../../genkit @@ -977,8 +998,8 @@ importers: specifier: ^1.25.0 version: 1.25.1(@opentelemetry/api@1.9.0) firebase-admin: - specifier: ^12.1.0 - version: 12.1.0(encoding@0.1.13) + specifier: '>=12.2' + version: 12.3.1(encoding@0.1.13) genkit: specifier: workspace:* version: link:../../genkit @@ -1114,7 +1135,7 @@ importers: version: link:../../plugins/ollama genkitx-openai: specifier: ^0.10.1 - version: 0.10.1(@genkit-ai/ai@0.6.0-dev.2)(@genkit-ai/core@0.6.0-dev.2) + version: 0.10.1(@genkit-ai/ai@0.9.0-dev.1)(@genkit-ai/core@0.9.0-dev.1) devDependencies: rimraf: specifier: ^6.0.1 @@ -1353,8 +1374,8 @@ importers: specifier: ^4.21.0 version: 4.21.0 firebase-admin: - specifier: ^12.1.0 - version: 12.2.0(encoding@0.1.13) + specifier: '>=12.2' + version: 12.3.1(encoding@0.1.13) genkit: specifier: workspace:* version: link:../../genkit @@ -1851,10 +1872,6 @@ packages: cpu: [x64] os: [win32] - '@fastify/busboy@2.1.1': - resolution: {integrity: sha512-vBZP4NlzfOlerQTnba4aqZoMhE/a9HY7HRqoOPaETQcSQuWEIyZMHGfVu6w9wGtGK5fED5qRs2DteVCjOH60sA==} - engines: {node: '>=14'} - '@fastify/busboy@3.0.0': resolution: {integrity: sha512-83rnH2nCvclWaPQQKvkJ2pdOjG4TZyEVuFDnlOF6KP08lDaaceVyw/W63mDuafQT+MKHCvXIPpE5uYWeM0rT4w==} @@ -1885,11 +1902,11 @@ packages: '@firebase/util@1.9.5': resolution: {integrity: sha512-PP4pAFISDxsf70l3pEy34Mf3GkkUcVQ3MdKp6aSVb7tcpfUQxnsdV7twDd8EkfB6zZylH6wpUAoangQDmCUMqw==} - '@genkit-ai/ai@0.6.0-dev.2': - resolution: {integrity: sha512-CAJwc5C26OW063xeLRKB40DX8jHvxEbFnppQuDKpl9AJRV0Ty0gsdJTiG4+WNkTQvAMMFlnu46aj09Iqb14E2A==} + '@genkit-ai/ai@0.9.0-dev.1': + resolution: {integrity: sha512-ETAlyS/tX5bvv9NrPZ+6cuDStNwy5Yl2CBZjoXQle0jBuBCQr3HLjUH8ntbBX55E8mCQ+5A6Bpi2TXOx1yu1dw==} - '@genkit-ai/core@0.6.0-dev.2': - resolution: {integrity: sha512-xjGIFnASvj2Jb2LL78dQ589jHJ6hH21Sp2lVAWNUrbCH4YWSHp/laUoDwcTV33iiqKNDx46clWSgDjjFbhgRLQ==} + '@genkit-ai/core@0.9.0-dev.1': + resolution: {integrity: sha512-zWlzCaAKpNRwtMrZaA2h0o0yx4uj9OBmPhN5vMUTipWsaKIF1A3STvzRjxz4vFF2U87Uzvl2287JqyUNEXwQbA==} '@google-cloud/aiplatform@3.25.0': resolution: {integrity: sha512-qKnJgbyCENjed8e1G5zZGFTxxNKhhaKQN414W2KIVHrLxMFmlMuG+3QkXPOWwXBnT5zZ7aMxypt5og0jCirpHg==} @@ -1907,10 +1924,6 @@ packages: resolution: {integrity: sha512-WUDbaLY8UnPxgwsyIaxj6uxCtSDAaUyvzWJykNH5rZ9i92/SZCsPNNMN0ajrVpAR81hPIL4amXTaMJ40y5L+Yg==} engines: {node: '>=14.0.0'} - '@google-cloud/firestore@7.8.0': - resolution: {integrity: sha512-m21BWVZLz7H7NF8HZ5hCGUSCEJKNwYB5yzQqDTuE9YUzNDRMDei3BwVDht5k4xF636sGlnobyBL+dcbthSGONg==} - engines: {node: '>=14.0.0'} - '@google-cloud/firestore@7.9.0': resolution: {integrity: sha512-c4ALHT3G08rV7Zwv8Z2KG63gZh66iKdhCBeDfCpIkLrjX6EAjTD/szMdj14M+FnQuClZLFfW5bAgoOjfNmLtJg==} engines: {node: '>=14.0.0'} @@ -3400,9 +3413,6 @@ packages: binary-search@1.3.6: resolution: {integrity: sha512-nbE1WxOTTrUWIfsfZ4aHGYu5DOuNkbxGokjV6Z2kxfJK3uaAb8zNK1muzOeipoLHZjInT4Br88BHpzevc681xA==} - bl@4.1.0: - resolution: {integrity: sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==} - body-parser@1.20.2: resolution: {integrity: sha512-ml9pReCu3M61kGlqoTm2umSXTlRTuGTx0bfYj+uIUKKYycG5NtSbeetV3faSU6R7ajOPw0g/J1PvK4qNy7s5bA==} engines: {node: '>= 0.8', npm: 1.2.8000 || >= 1.4.16} @@ -3439,9 +3449,6 @@ packages: buffer-from@1.1.2: resolution: {integrity: sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==} - buffer@5.7.1: - resolution: {integrity: sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ==} - bundle-require@4.0.2: resolution: {integrity: sha512-jwzPOChofl67PSTW2SGubV9HBQAhhR2i6nskiOThauo9dzwDUgOWQScFVaJkjEfYX+UXiD+LEx8EblQMc2wIag==} engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} @@ -3504,9 +3511,6 @@ packages: resolution: {integrity: sha512-7VT13fmjotKpGipCW9JEQAusEPE+Ei8nl6/g4FBAmIm0GOOLMua9NDDo/DWp0ZAxCr3cPq5ZpBqmPAQgDda2Pw==} engines: {node: '>= 8.10.0'} - chownr@1.1.4: - resolution: {integrity: sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg==} - chownr@2.0.0: resolution: {integrity: sha512-bIomtDF5KGpdogkLd9VspvFzk9KfpyyGlS8YFVZl7TGPBHL5snIOnxeshwVgPteQ9b4Eydl+pVbIyE1DcvCWgQ==} engines: {node: '>=10'} @@ -3712,10 +3716,6 @@ packages: resolution: {integrity: sha512-jOSne2qbyE+/r8G1VU+G/82LBs2Fs4LAsTiLSHOCOMZQl2OKZ6i8i4IyHemTe+/yIXOtTcRQMzPcgyhoFlqPkw==} engines: {node: '>=8'} - decompress-response@6.0.0: - resolution: {integrity: sha512-aW35yZM6Bb/4oJlZncMH2LCoZtJXTRxES17vE3hoRiowU2kWHaJKFkSBDnDR+cm9J+9QhXmREyIfv0pji9ejCQ==} - engines: {node: '>=10'} - dedent@1.5.3: resolution: {integrity: sha512-NHQtfOOW68WD8lgypbLA5oT+Bt0xXJhiYvoR6SmmNXZfpzOGXwdKWmcwG8N7PwVVWV3eF/68nmD9BaJSsTBhyQ==} peerDependencies: @@ -3724,10 +3724,6 @@ packages: babel-plugin-macros: optional: true - deep-extend@0.6.0: - resolution: {integrity: sha512-LOHxIOaPYdHlJRtCQfDIVZtfw/ufM8+rVj649RIHzcm/vGwQRXFt6OPqIFWsm2XEMrNIEtWR64sY1LEKD2vAOA==} - engines: {node: '>=4.0.0'} - deepmerge@4.3.1: resolution: {integrity: sha512-3sUqbMEc77XqpdNO7FRyRog+eW3ph+GYCbj+rK+uYyRMuwsVy0rMiVtPn+QJlKFvWP/1PYpapqYn0Me2knFn+A==} engines: {node: '>=0.10.0'} @@ -3920,10 +3916,6 @@ packages: resolution: {integrity: sha512-Zk/eNKV2zbjpKzrsQ+n1G6poVbErQxJ0LBOJXaKZ1EViLzH+hrLu9cdXI4zw9dBQJslwBEpbQ2P1oS7nDxs6jQ==} engines: {node: '>= 0.8.0'} - expand-template@2.0.3: - resolution: {integrity: sha512-XYfuKMvj4O35f/pOXLObndIRvyQ+/+6AhODh+OKWj9S9498pHHn/IMszH+gt0fBCRWMNfk1ZSp5x3AifmnI2vg==} - engines: {node: '>=6'} - expect@29.7.0: resolution: {integrity: sha512-2Zks0hf1VLFYI1kbh0I5jP3KHHyCHpkfyHBzsSXRFgl/Bg9mWYfMW8oD+PdMPlEwy5HNsR9JutYy6pMeOh61nw==} engines: {node: ^14.15.0 || ^16.10.0 || >=18.0.0} @@ -3946,10 +3938,6 @@ packages: resolution: {integrity: sha512-6ypT4XfgqJk/F3Yuv4SX26I3doUjt0GTG4a+JgWxXQpxXzTBq8fPUeGHfcYMMDPHJHm3yPOSjaeBwBGAHWXCdA==} engines: {node: '>=18.0.0'} - farmhash@3.3.1: - resolution: {integrity: sha512-XUizHanzlr/v7suBr/o85HSakOoWh6HKXZjFYl5C2+Gj0f0rkw+XTUZzrd9odDsgI9G5tRUcF4wSbKaX04T0DQ==} - engines: {node: '>=10'} - fast-deep-equal@3.1.3: resolution: {integrity: sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==} @@ -4017,14 +4005,6 @@ packages: resolution: {integrity: sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==} engines: {node: '>=8'} - firebase-admin@12.1.0: - resolution: {integrity: sha512-bU7uPKMmIXAihWxntpY/Ma9zucn5y3ec+HQPqFQ/zcEfP9Avk9E/6D8u+yT/VwKHNZyg7yDVWOoJi73TIdR4Ww==} - engines: {node: '>=14'} - - firebase-admin@12.2.0: - resolution: {integrity: sha512-R9xxENvPA/19XJ3mv0Kxfbz9kPXd9/HrM4083LZWOO0qAQGheRzcCQamYRe+JSrV2cdKXP3ZsfFGTYMrFM0pJg==} - engines: {node: '>=14'} - firebase-admin@12.3.1: resolution: {integrity: sha512-vEr3s3esl8nPIA9r/feDT4nzIXCfov1CyyCSpMQWp6x63Q104qke0MEGZlrHUZVROtl8FLus6niP/M9I1s4VBA==} engines: {node: '>=14'} @@ -4080,9 +4060,6 @@ packages: front-matter@4.0.2: resolution: {integrity: sha512-I8ZuJ/qG92NWX8i5x1Y8qyj3vizhXS31OxjKDu3LKP+7/qBgfIKValiZIEwoVoJKUHlhWtYrktkxV1XsX+pPlg==} - fs-constants@1.0.0: - resolution: {integrity: sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow==} - fs-minipass@2.1.0: resolution: {integrity: sha512-V/JgOLFCS+R6Vcq0slCuaeWEdNC3ouDlJMNIsacH2VtALiu9mV4LPrHc5cDl8k5aw6J8jwgWWpiTo5RYhmIzvg==} engines: {node: '>= 8'} @@ -4174,9 +4151,6 @@ packages: get-tsconfig@4.8.1: resolution: {integrity: sha512-k9PN+cFBmaLWtVz29SkUoqU5O0slLuHJXt/2P+tMVFT+phsSGXGkp9t3rQIqdz0e+06EHNGs3oM6ZX1s2zHxRg==} - github-from-package@0.0.0: - resolution: {integrity: sha512-SyHy3T1v2NUXn29OsWdxmK6RwHD+vkj3v8en8AOBZ1wBQ/hCAQ5bAQTD02kW4W9tUp/3Qh6J8r9EvntiyCmOOw==} - glob-parent@5.1.2: resolution: {integrity: sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==} engines: {node: '>= 6'} @@ -4337,9 +4311,6 @@ packages: resolution: {integrity: sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw==} engines: {node: '>=0.10.0'} - ieee754@1.2.1: - resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==} - ignore@5.3.1: resolution: {integrity: sha512-5Fytz/IraMjqpwfd34ke28PTVMjZjJG2MPn5t7OE4eUCUNf8BAa7b5WUS9/Qvr6mwOQS7Mk6vdsMno5he+T8Xw==} engines: {node: '>= 4'} @@ -4366,9 +4337,6 @@ packages: inherits@2.0.4: resolution: {integrity: sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==} - ini@1.3.8: - resolution: {integrity: sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew==} - internal-slot@1.0.7: resolution: {integrity: sha512-NGnrKwXzSms2qUUih/ILZ5JBqNTSa1+ZmP6flaIp6KmSElgE9qdndzS3cqjrDovwFdmwsGsLdeFgB6suw+1e9g==} engines: {node: '>= 0.4'} @@ -5077,10 +5045,6 @@ packages: resolution: {integrity: sha512-wXqjST+SLt7R009ySCglWBCFpjUygmCIfD790/kVbiGmUgfYGuB14PiTd5DwVxSV4NcYHjzMkoj5LjQZwTQLEA==} engines: {node: '>=8'} - mimic-response@3.1.0: - resolution: {integrity: sha512-z0yWI+4FDrrweS8Zmt4Ej5HdJmky15+L2e6Wgn3+iK5fWzb6T3fhNFq2+MeTRb064c6Wr4N/wv0DzQTjNzHNGQ==} - engines: {node: '>=10'} - minimatch@10.0.1: resolution: {integrity: sha512-ethXTt3SGGR+95gudmqJ1eNhRO7eGEGIgYA9vnPatK4/etz2MEVDno5GMCibdMTuBMyElzIlgxMna3K94XDIDQ==} engines: {node: 20 || >=22} @@ -5119,9 +5083,6 @@ packages: resolution: {integrity: sha512-bAxsR8BVfj60DWXHE3u30oHzfl4G7khkSuPW+qvpd7jFRHm7dLxOjUk1EHACJ/hxLY8phGJ0YhYHZo7jil7Qdg==} engines: {node: '>= 8'} - mkdirp-classic@0.5.3: - resolution: {integrity: sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A==} - mkdirp@1.0.4: resolution: {integrity: sha512-vVqVZQyf3WLx2Shd0qJ9xuvqgAyKPLAiqITEtqW0oIUjzo3PePDd6fW9iFz30ef7Ysp/oiWqbhszeGWW2T6Gzw==} engines: {node: '>=10'} @@ -5169,9 +5130,6 @@ packages: engines: {node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1} hasBin: true - napi-build-utils@1.0.2: - resolution: {integrity: sha512-ONmRUqK7zj7DWX0D9ADe03wbwOBZxNAfF20PlGfCWQcD3+/MakShIHrMqx9YwPTfxDdF1zLeL+RGZiR9kGMLdg==} - natural-compare@1.4.0: resolution: {integrity: sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==} @@ -5185,13 +5143,6 @@ packages: nice-try@1.0.5: resolution: {integrity: sha512-1nh45deeb5olNY7eX82BkPO7SSxR5SSYJiPTrTdFUVYwAl8CKMA5N9PjTYkHiRjisVcxcQ1HXdLhx2qxxJzLNQ==} - node-abi@3.62.0: - resolution: {integrity: sha512-CPMcGa+y33xuL1E0TcNIu4YyaZCxnnvkVaEXrsosR3FxN+fV8xvb7Mzpb7IgKler10qeMkE6+Dp8qJhpzdq35g==} - engines: {node: '>=10'} - - node-addon-api@5.1.0: - resolution: {integrity: sha512-eh0GgfEkpnoWDq+VY8OyvYhFEzBk6jIYbRKdIlyTiAXIVJ8PyBaKb0rp7oDtoddbdoHWhq8wwr+XZ81F1rpNdA==} - node-domexception@1.0.0: resolution: {integrity: sha512-/jKZoMpw0F8GRwl4/eLROPA3cfcXtLApP0QzLmUT/HuPCZWyB7IY9ZrMeKw2O/nFIqPQB3PVM9aYm0F312AXDQ==} engines: {node: '>=10.5.0'} @@ -5504,29 +5455,6 @@ packages: resolution: {integrity: sha512-9ZhXKM/rw350N1ovuWHbGxnGh/SNJ4cnxHiM0rxE4VN41wsg8P8zWn9hv/buK00RP4WvlOyr/RBDiptyxVbkZQ==} engines: {node: '>=0.10.0'} - prebuild-install@7.1.2: - resolution: {integrity: sha512-UnNke3IQb6sgarcZIDU3gbMeTp/9SSU1DAIkil7PrqG1vZlBtY5msYccSKSHDqa3hNg436IXK+SNImReuA1wEQ==} - engines: {node: '>=10'} - hasBin: true - - prettier-plugin-organize-imports@3.2.4: - resolution: {integrity: sha512-6m8WBhIp0dfwu0SkgfOxJqh+HpdyfqSSLfKKRZSFbDuEQXDDndb8fTpRWkUrX/uBenkex3MgnVk0J3b3Y5byog==} - peerDependencies: - '@volar/vue-language-plugin-pug': ^1.0.4 - '@volar/vue-typescript': ^1.0.4 - prettier: '>=2.0' - typescript: '>=2.9' - peerDependenciesMeta: - '@volar/vue-language-plugin-pug': - optional: true - '@volar/vue-typescript': - optional: true - - prettier@3.2.5: - resolution: {integrity: sha512-3/GWa9aOC0YeD7LUfvOG2NiDyhOWRvt1k+rcKhOuYnMY24iiCphgneUfJDyFXd6rZCAnuLBv6UeAULtrhT/F4A==} - engines: {node: '>=14'} - hasBin: true - pretty-format@29.7.0: resolution: {integrity: sha512-Pdlw/oPxN+aXdmM9R00JVC9WVFoCLTKJvDVLgmJ+qAffBMxsV85l/Lu7sNx4zSzPyoL2euImuEwHhOXdEgNFZQ==} engines: {node: ^14.15.0 || ^16.10.0 || >=18.0.0} @@ -5601,10 +5529,6 @@ packages: resolution: {integrity: sha512-8zGqypfENjCIqGhgXToC8aB2r7YrBX+AQAfIPs/Mlk+BtPTztOvTS01NRW/3Eh60J+a48lt8qsCzirQ6loCVfA==} engines: {node: '>= 0.8'} - rc@1.2.8: - resolution: {integrity: sha512-y3bGgqKj3QBdxLbLkomlohkvsA8gdAiUQlSBJnBhfn+BPxg4bc62d8TcBW15wavDfgexCgccckhcZvywyQYPOw==} - hasBin: true - react-is@18.3.1: resolution: {integrity: sha512-/LLMVyas0ljjAtoYiPqYiL8VWXzUUdThrmU5+n20DZv+a+ClRoevUzw5JxU+Ieh5/c87ytoTBV9G1FiKfNJdmg==} @@ -5795,9 +5719,6 @@ packages: simple-get@3.1.1: resolution: {integrity: sha512-CQ5LTKGfCpvE1K0n2us+kuMPbk/q0EKl82s4aheV9oXjFEz6W/Y7oQFVJuU6QG77hRT4Ghb5RURteF5vnWjupA==} - simple-get@4.0.1: - resolution: {integrity: sha512-brv7p5WgH0jmQJr1ZDDfKDOSeWWg+OVypG99A/5vYGPqJ6pxiaHLy8nxtFjBA7oMa01ebA9gfh1uMCFqOuXxvA==} - simple-swizzle@0.2.2: resolution: {integrity: sha512-JA//kQgZtbuY83m+xT+tXJkmJncGMTFT+C+g2h2R9uxkYIrE2yy9sgmcLhCnw57/WSD+Eh3J97FPEDFnbXnDUg==} @@ -5905,10 +5826,6 @@ packages: resolution: {integrity: sha512-BrpvfNAE3dcvq7ll3xVumzjKjZQ5tI1sEUIKr3Uoks0XUl45St3FlatVqef9prk4jRDzhW6WZg+3bk93y6pLjA==} engines: {node: '>=6'} - strip-json-comments@2.0.1: - resolution: {integrity: sha512-4gB8na07fecVVkOI6Rs4e7T6NOTki5EmL7TUduTs6bu3EdnSycntVJ4re8kgZA+wx9IueI2Y11bfbgwtzuE0KQ==} - engines: {node: '>=0.10.0'} - strip-json-comments@3.1.1: resolution: {integrity: sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==} engines: {node: '>=8'} @@ -5940,13 +5857,6 @@ packages: resolution: {integrity: sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==} engines: {node: '>= 0.4'} - tar-fs@2.1.1: - resolution: {integrity: sha512-V0r2Y9scmbDRLCNex/+hYzvp/zyYjvFbHPNgVTKfQvVrb6guiE/fxP+XblDNR011utopbkex2nM4dHNV6GDsng==} - - tar-stream@2.2.0: - resolution: {integrity: sha512-ujeqbceABgwMZxEJnk2HDY2DlnUZ+9oEcb1KzTVfYHio0UE6dG71n60d8D2I4qNvleWrrXpmjpt7vZeF1LnMZQ==} - engines: {node: '>=6'} - tar@6.2.1: resolution: {integrity: sha512-DZ4yORTwrbTj/7MZYq2w+/ZFdI6OZ/f9SFHR+71gIVUZhOQPHzVCLpvRnPgyaMpfWxxk/4ONva3GQSyNIKRv6A==} engines: {node: '>=10'} @@ -6084,9 +5994,6 @@ packages: engines: {node: '>=18.0.0'} hasBin: true - tunnel-agent@0.6.0: - resolution: {integrity: sha512-McnNiV1l8RYeY8tBgEpuodCC1mLUdbSN+CYBL7kJsJNInOP8UjDDEwdk6Mw60vdLLrr5NHKZhMAOSrR2NZuQ+w==} - type-detect@4.0.8: resolution: {integrity: sha512-0fr/mIH1dlO+x7TlcMy+bIDqKPsw/70tVyeHW787goQjhmqaZe10uwLujubK9q9Lg6Fiho1KUKDYz0Z7k7g5/g==} engines: {node: '>=4'} @@ -6729,8 +6636,6 @@ snapshots: '@esbuild/win32-x64@0.23.1': optional: true - '@fastify/busboy@2.1.1': {} - '@fastify/busboy@3.0.0': {} '@firebase/app-check-interop-types@0.3.1': {} @@ -6776,9 +6681,9 @@ snapshots: dependencies: tslib: 2.6.2 - '@genkit-ai/ai@0.6.0-dev.2': + '@genkit-ai/ai@0.9.0-dev.1': dependencies: - '@genkit-ai/core': 0.6.0-dev.2 + '@genkit-ai/core': 0.9.0-dev.1 '@opentelemetry/api': 1.9.0 '@types/node': 20.16.9 colorette: 2.0.20 @@ -6788,7 +6693,7 @@ snapshots: transitivePeerDependencies: - supports-color - '@genkit-ai/core@0.6.0-dev.2': + '@genkit-ai/core@0.9.0-dev.1': dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/context-async-hooks': 1.26.0(@opentelemetry/api@1.9.0) @@ -6859,17 +6764,6 @@ snapshots: - encoding - supports-color - '@google-cloud/firestore@7.8.0(encoding@0.1.13)': - dependencies: - fast-deep-equal: 3.1.3 - functional-red-black-tree: 1.0.1 - google-gax: 4.3.7(encoding@0.1.13) - protobufjs: 7.3.2 - transitivePeerDependencies: - - encoding - - supports-color - optional: true - '@google-cloud/firestore@7.9.0(encoding@0.1.13)': dependencies: fast-deep-equal: 3.1.3 @@ -6967,7 +6861,7 @@ snapshots: '@google-cloud/storage@7.10.1(encoding@0.1.13)': dependencies: - '@google-cloud/paginator': 5.0.0 + '@google-cloud/paginator': 5.0.2 '@google-cloud/projectify': 4.0.0 '@google-cloud/promisify': 4.0.0 abort-controller: 3.0.0 @@ -7044,7 +6938,7 @@ snapshots: '@jest/console@29.7.0': dependencies: '@jest/types': 29.6.3 - '@types/node': 20.11.30 + '@types/node': 20.16.9 chalk: 4.1.2 jest-message-util: 29.7.0 jest-util: 29.7.0 @@ -7057,14 +6951,14 @@ snapshots: '@jest/test-result': 29.7.0 '@jest/transform': 29.7.0 '@jest/types': 29.6.3 - '@types/node': 20.11.30 + '@types/node': 20.16.9 ansi-escapes: 4.3.2 chalk: 4.1.2 ci-info: 3.9.0 exit: 0.1.2 graceful-fs: 4.2.11 jest-changed-files: 29.7.0 - jest-config: 29.7.0(@types/node@20.11.30) + jest-config: 29.7.0(@types/node@20.16.9) jest-haste-map: 29.7.0 jest-message-util: 29.7.0 jest-regex-util: 29.6.3 @@ -7089,7 +6983,7 @@ snapshots: dependencies: '@jest/fake-timers': 29.7.0 '@jest/types': 29.6.3 - '@types/node': 20.11.30 + '@types/node': 20.16.9 jest-mock: 29.7.0 '@jest/expect-utils@29.7.0': @@ -7129,7 +7023,7 @@ snapshots: '@jest/transform': 29.7.0 '@jest/types': 29.6.3 '@jridgewell/trace-mapping': 0.3.25 - '@types/node': 20.11.30 + '@types/node': 20.16.9 chalk: 4.1.2 collect-v8-coverage: 1.0.2 exit: 0.1.2 @@ -8142,7 +8036,7 @@ snapshots: '@types/graceful-fs@4.1.9': dependencies: - '@types/node': 20.11.30 + '@types/node': 20.16.9 '@types/http-errors@2.0.4': {} @@ -8462,12 +8356,6 @@ snapshots: binary-search@1.3.6: {} - bl@4.1.0: - dependencies: - buffer: 5.7.1 - inherits: 2.0.4 - readable-stream: 3.6.2 - body-parser@1.20.2: dependencies: bytes: 3.1.2 @@ -8534,11 +8422,6 @@ snapshots: buffer-from@1.1.2: {} - buffer@5.7.1: - dependencies: - base64-js: 1.5.1 - ieee754: 1.2.1 - bundle-require@4.0.2(esbuild@0.19.12): dependencies: esbuild: 0.19.12 @@ -8606,8 +8489,6 @@ snapshots: optionalDependencies: fsevents: 2.3.3 - chownr@1.1.4: {} - chownr@2.0.0: optional: true @@ -8801,14 +8682,8 @@ snapshots: mimic-response: 2.1.0 optional: true - decompress-response@6.0.0: - dependencies: - mimic-response: 3.1.0 - dedent@1.5.3: {} - deep-extend@0.6.0: {} - deepmerge@4.3.1: {} define-data-property@1.1.4: @@ -8832,7 +8707,8 @@ snapshots: destroy@1.2.0: {} - detect-libc@2.0.3: {} + detect-libc@2.0.3: + optional: true detect-newline@3.1.0: {} @@ -9062,8 +8938,6 @@ snapshots: exit@0.1.2: {} - expand-template@2.0.3: {} - expect@29.7.0: dependencies: '@jest/expect-utils': 29.7.0 @@ -9150,11 +9024,6 @@ snapshots: farmhash-modern@1.1.0: {} - farmhash@3.3.1: - dependencies: - node-addon-api: 5.1.0 - prebuild-install: 7.1.2 - fast-deep-equal@3.1.3: {} fast-glob@3.3.2: @@ -9241,44 +9110,6 @@ snapshots: locate-path: 5.0.0 path-exists: 4.0.0 - firebase-admin@12.1.0(encoding@0.1.13): - dependencies: - '@fastify/busboy': 2.1.1 - '@firebase/database-compat': 1.0.4 - '@firebase/database-types': 1.0.2 - '@types/node': 20.16.9 - farmhash: 3.3.1 - jsonwebtoken: 9.0.2 - jwks-rsa: 3.1.0 - long: 5.2.3 - node-forge: 1.3.1 - uuid: 9.0.1 - optionalDependencies: - '@google-cloud/firestore': 7.8.0(encoding@0.1.13) - '@google-cloud/storage': 7.10.1(encoding@0.1.13) - transitivePeerDependencies: - - encoding - - supports-color - - firebase-admin@12.2.0(encoding@0.1.13): - dependencies: - '@fastify/busboy': 2.1.1 - '@firebase/database-compat': 1.0.4 - '@firebase/database-types': 1.0.2 - '@types/node': 20.16.9 - farmhash-modern: 1.1.0 - jsonwebtoken: 9.0.2 - jwks-rsa: 3.1.0 - long: 5.2.3 - node-forge: 1.3.1 - uuid: 10.0.0 - optionalDependencies: - '@google-cloud/firestore': 7.8.0(encoding@0.1.13) - '@google-cloud/storage': 7.10.1(encoding@0.1.13) - transitivePeerDependencies: - - encoding - - supports-color - firebase-admin@12.3.1(encoding@0.1.13): dependencies: '@fastify/busboy': 3.0.0 @@ -9297,15 +9128,15 @@ snapshots: - encoding - supports-color - firebase-functions@4.8.1(encoding@0.1.13)(firebase-admin@12.2.0(encoding@0.1.13)): + firebase-functions@4.8.1(encoding@0.1.13)(firebase-admin@12.3.1(encoding@0.1.13)): dependencies: '@types/cors': 2.8.17 '@types/express': 4.17.3 cors: 2.8.5 express: 4.21.0 - firebase-admin: 12.2.0(encoding@0.1.13) + firebase-admin: 12.3.1(encoding@0.1.13) node-fetch: 2.7.0(encoding@0.1.13) - protobufjs: 7.2.6 + protobufjs: 7.3.2 transitivePeerDependencies: - encoding - supports-color @@ -9354,8 +9185,6 @@ snapshots: dependencies: js-yaml: 3.14.1 - fs-constants@1.0.0: {} - fs-minipass@2.1.0: dependencies: minipass: 3.3.6 @@ -9430,10 +9259,10 @@ snapshots: - encoding - supports-color - genkitx-openai@0.10.1(@genkit-ai/ai@0.6.0-dev.2)(@genkit-ai/core@0.6.0-dev.2): + genkitx-openai@0.10.1(@genkit-ai/ai@0.9.0-dev.1)(@genkit-ai/core@0.9.0-dev.1): dependencies: - '@genkit-ai/ai': 0.6.0-dev.2 - '@genkit-ai/core': 0.6.0-dev.2 + '@genkit-ai/ai': 0.9.0-dev.1 + '@genkit-ai/core': 0.9.0-dev.1 openai: 4.53.0(encoding@0.1.13) zod: 3.23.8 transitivePeerDependencies: @@ -9477,8 +9306,6 @@ snapshots: dependencies: resolve-pkg-maps: 1.0.0 - github-from-package@0.0.0: {} - glob-parent@5.1.2: dependencies: is-glob: 4.0.3 @@ -9741,8 +9568,6 @@ snapshots: dependencies: safer-buffer: 2.1.2 - ieee754@1.2.1: {} - ignore@5.3.1: {} import-in-the-middle@1.11.0: @@ -9768,8 +9593,6 @@ snapshots: inherits@2.0.4: {} - ini@1.3.8: {} - internal-slot@1.0.7: dependencies: es-errors: 1.3.0 @@ -9896,7 +9719,7 @@ snapshots: '@babel/parser': 7.25.7 '@istanbuljs/schema': 0.1.3 istanbul-lib-coverage: 3.2.2 - semver: 7.6.0 + semver: 7.6.3 transitivePeerDependencies: - supports-color @@ -9908,7 +9731,7 @@ snapshots: istanbul-lib-source-maps@4.0.1: dependencies: - debug: 4.3.4 + debug: 4.3.7 istanbul-lib-coverage: 3.2.2 source-map: 0.6.1 transitivePeerDependencies: @@ -9948,7 +9771,7 @@ snapshots: '@jest/expect': 29.7.0 '@jest/test-result': 29.7.0 '@jest/types': 29.6.3 - '@types/node': 20.11.30 + '@types/node': 20.16.9 chalk: 4.1.2 co: 4.6.0 dedent: 1.5.3 @@ -10017,6 +9840,36 @@ snapshots: - babel-plugin-macros - supports-color + jest-config@29.7.0(@types/node@20.16.9): + dependencies: + '@babel/core': 7.25.7 + '@jest/test-sequencer': 29.7.0 + '@jest/types': 29.6.3 + babel-jest: 29.7.0(@babel/core@7.25.7) + chalk: 4.1.2 + ci-info: 3.9.0 + deepmerge: 4.3.1 + glob: 7.2.3 + graceful-fs: 4.2.11 + jest-circus: 29.7.0 + jest-environment-node: 29.7.0 + jest-get-type: 29.6.3 + jest-regex-util: 29.6.3 + jest-resolve: 29.7.0 + jest-runner: 29.7.0 + jest-util: 29.7.0 + jest-validate: 29.7.0 + micromatch: 4.0.5 + parse-json: 5.2.0 + pretty-format: 29.7.0 + slash: 3.0.0 + strip-json-comments: 3.1.1 + optionalDependencies: + '@types/node': 20.16.9 + transitivePeerDependencies: + - babel-plugin-macros + - supports-color + jest-diff@29.7.0: dependencies: chalk: 4.1.2 @@ -10041,7 +9894,7 @@ snapshots: '@jest/environment': 29.7.0 '@jest/fake-timers': 29.7.0 '@jest/types': 29.6.3 - '@types/node': 20.11.30 + '@types/node': 20.16.9 jest-mock: 29.7.0 jest-util: 29.7.0 @@ -10051,7 +9904,7 @@ snapshots: dependencies: '@jest/types': 29.6.3 '@types/graceful-fs': 4.1.9 - '@types/node': 20.11.30 + '@types/node': 20.16.9 anymatch: 3.1.3 fb-watchman: 2.0.2 graceful-fs: 4.2.11 @@ -10090,7 +9943,7 @@ snapshots: jest-mock@29.7.0: dependencies: '@jest/types': 29.6.3 - '@types/node': 20.11.30 + '@types/node': 20.16.9 jest-util: 29.7.0 jest-pnp-resolver@1.2.3(jest-resolve@29.7.0): @@ -10125,7 +9978,7 @@ snapshots: '@jest/test-result': 29.7.0 '@jest/transform': 29.7.0 '@jest/types': 29.6.3 - '@types/node': 20.11.30 + '@types/node': 20.16.9 chalk: 4.1.2 emittery: 0.13.1 graceful-fs: 4.2.11 @@ -10153,7 +10006,7 @@ snapshots: '@jest/test-result': 29.7.0 '@jest/transform': 29.7.0 '@jest/types': 29.6.3 - '@types/node': 20.11.30 + '@types/node': 20.16.9 chalk: 4.1.2 cjs-module-lexer: 1.2.3 collect-v8-coverage: 1.0.2 @@ -10218,7 +10071,7 @@ snapshots: dependencies: '@jest/test-result': 29.7.0 '@jest/types': 29.6.3 - '@types/node': 20.11.30 + '@types/node': 20.16.9 ansi-escapes: 4.3.2 chalk: 4.1.2 emittery: 0.13.1 @@ -10227,7 +10080,7 @@ snapshots: jest-worker@29.7.0: dependencies: - '@types/node': 20.11.30 + '@types/node': 20.16.9 jest-util: 29.7.0 merge-stream: 2.0.0 supports-color: 8.1.1 @@ -10292,7 +10145,7 @@ snapshots: lodash.isstring: 4.0.1 lodash.once: 4.1.1 ms: 2.1.3 - semver: 7.6.0 + semver: 7.6.3 jwa@1.4.1: dependencies: @@ -10310,7 +10163,7 @@ snapshots: dependencies: '@types/express': 4.17.21 '@types/jsonwebtoken': 9.0.6 - debug: 4.3.4 + debug: 4.3.7 jose: 4.15.5 limiter: 1.1.5 lru-memoizer: 2.2.0 @@ -10533,7 +10386,7 @@ snapshots: make-dir@4.0.0: dependencies: - semver: 7.6.0 + semver: 7.6.3 make-error@1.3.6: {} @@ -10582,8 +10435,6 @@ snapshots: mimic-response@2.1.0: optional: true - mimic-response@3.1.0: {} - minimatch@10.0.1: dependencies: brace-expansion: 2.0.1 @@ -10620,8 +10471,6 @@ snapshots: yallist: 4.0.0 optional: true - mkdirp-classic@0.5.3: {} - mkdirp@1.0.4: optional: true @@ -10668,8 +10517,6 @@ snapshots: nanoid@3.3.7: optional: true - napi-build-utils@1.0.2: {} - natural-compare@1.4.0: {} negotiator@0.6.3: {} @@ -10678,12 +10525,6 @@ snapshots: nice-try@1.0.5: {} - node-abi@3.62.0: - dependencies: - semver: 7.6.0 - - node-addon-api@5.1.0: {} - node-domexception@1.0.0: {} node-ensure@0.0.0: {} @@ -10969,28 +10810,6 @@ snapshots: dependencies: xtend: 4.0.2 - prebuild-install@7.1.2: - dependencies: - detect-libc: 2.0.3 - expand-template: 2.0.3 - github-from-package: 0.0.0 - minimist: 1.2.8 - mkdirp-classic: 0.5.3 - napi-build-utils: 1.0.2 - node-abi: 3.62.0 - pump: 3.0.0 - rc: 1.2.8 - simple-get: 4.0.1 - tar-fs: 2.1.1 - tunnel-agent: 0.6.0 - - prettier-plugin-organize-imports@3.2.4(prettier@3.2.5)(typescript@4.9.5): - dependencies: - prettier: 3.2.5 - typescript: 4.9.5 - - prettier@3.2.5: {} - pretty-format@29.7.0: dependencies: '@jest/schemas': 29.6.3 @@ -11091,13 +10910,6 @@ snapshots: iconv-lite: 0.4.24 unpipe: 1.0.0 - rc@1.2.8: - dependencies: - deep-extend: 0.6.0 - ini: 1.3.8 - minimist: 1.2.8 - strip-json-comments: 2.0.1 - react-is@18.3.1: {} read-pkg@3.0.0: @@ -11352,7 +11164,8 @@ snapshots: signal-exit@4.1.0: {} - simple-concat@1.0.1: {} + simple-concat@1.0.1: + optional: true simple-get@3.1.1: dependencies: @@ -11361,12 +11174,6 @@ snapshots: simple-concat: 1.0.1 optional: true - simple-get@4.0.1: - dependencies: - decompress-response: 6.0.0 - once: 1.4.0 - simple-concat: 1.0.1 - simple-swizzle@0.2.2: dependencies: is-arrayish: 0.3.2 @@ -11480,8 +11287,6 @@ snapshots: strip-final-newline@2.0.0: {} - strip-json-comments@2.0.1: {} - strip-json-comments@3.1.1: {} strnum@1.0.5: @@ -11513,21 +11318,6 @@ snapshots: supports-preserve-symlinks-flag@1.0.0: {} - tar-fs@2.1.1: - dependencies: - chownr: 1.1.4 - mkdirp-classic: 0.5.3 - pump: 3.0.0 - tar-stream: 2.2.0 - - tar-stream@2.2.0: - dependencies: - bl: 4.1.0 - end-of-stream: 1.4.4 - fs-constants: 1.0.0 - inherits: 2.0.4 - readable-stream: 3.6.2 - tar@6.2.1: dependencies: chownr: 2.0.0 @@ -11679,10 +11469,6 @@ snapshots: optionalDependencies: fsevents: 2.3.3 - tunnel-agent@0.6.0: - dependencies: - safe-buffer: 5.2.1 - type-detect@4.0.8: {} type-fest@0.21.3: {} diff --git a/js/testapps/anthropic-models/src/index.ts b/js/testapps/anthropic-models/src/index.ts index a43e57f58..d88bf4a34 100644 --- a/js/testapps/anthropic-models/src/index.ts +++ b/js/testapps/anthropic-models/src/index.ts @@ -72,6 +72,6 @@ export const menuSuggestionFlow = ai.defineFlow( returnToolRequests: true, }); - return llmResponse.toolRequests(); + return llmResponse.toolRequests; } ); diff --git a/js/testapps/basic-gemini/src/index.ts b/js/testapps/basic-gemini/src/index.ts index ece0ce2d7..0dfb8b9eb 100644 --- a/js/testapps/basic-gemini/src/index.ts +++ b/js/testapps/basic-gemini/src/index.ts @@ -16,23 +16,20 @@ import { gemini15Flash, googleAI } from '@genkit-ai/googleai'; import { vertexAI } from '@genkit-ai/vertexai'; -import { defineTool, generate, genkit, z } from 'genkit'; -import { runWithRegistry } from 'genkit/registry'; +import { genkit, z } from 'genkit'; const ai = genkit({ plugins: [googleAI(), vertexAI()], }); -const jokeSubjectGenerator = runWithRegistry(ai.registry, () => - defineTool( - { - name: 'jokeSubjectGenerator', - description: 'Can be called to generate a subject for a joke', - }, - async () => { - return 'banana'; - } - ) +const jokeSubjectGenerator = ai.defineTool( + { + name: 'jokeSubjectGenerator', + description: 'Can be called to generate a subject for a joke', + }, + async () => { + return 'banana'; + } ); export const jokeFlow = ai.defineFlow( @@ -42,7 +39,7 @@ export const jokeFlow = ai.defineFlow( outputSchema: z.any(), }, async () => { - const llmResponse = await generate({ + const llmResponse = await ai.generate({ model: gemini15Flash, config: { temperature: 2, @@ -53,6 +50,6 @@ export const jokeFlow = ai.defineFlow( tools: [jokeSubjectGenerator], prompt: `come up with a subject to joke about (using the function provided)`, }); - return llmResponse.output(); + return llmResponse.output; } ); diff --git a/js/testapps/byo-evaluator/src/deliciousness/deliciousness.ts b/js/testapps/byo-evaluator/src/deliciousness/deliciousness.ts index 22713cb24..1c887d50d 100644 --- a/js/testapps/byo-evaluator/src/deliciousness/deliciousness.ts +++ b/js/testapps/byo-evaluator/src/deliciousness/deliciousness.ts @@ -40,6 +40,7 @@ export async function deliciousnessScore< throw new Error('Output is required for Funniness detection'); } const finalPrompt = await loadPromptFile( + ai.registry, path.resolve(__dirname, '../../prompts/deliciousness.prompt') ); const response = await ai.generate({ @@ -52,9 +53,9 @@ export async function deliciousnessScore< schema: DeliciousnessDetectionResponseSchema, }, }); - const parsedResponse = response.output(); + const parsedResponse = response.output; if (!parsedResponse) { - throw new Error(`Unable to parse evaluator response: ${response.text()}`); + throw new Error(`Unable to parse evaluator response: ${response.text}`); } return { score: parsedResponse.verdict, diff --git a/js/testapps/byo-evaluator/src/deliciousness/deliciousness_evaluator.ts b/js/testapps/byo-evaluator/src/deliciousness/deliciousness_evaluator.ts index 8d0cc0be2..aa2c3a1d4 100644 --- a/js/testapps/byo-evaluator/src/deliciousness/deliciousness_evaluator.ts +++ b/js/testapps/byo-evaluator/src/deliciousness/deliciousness_evaluator.ts @@ -14,12 +14,8 @@ * limitations under the License. */ -import { ModelReference, z } from 'genkit'; -import { - BaseEvalDataPoint, - EvaluatorAction, - defineEvaluator, -} from 'genkit/evaluator'; +import { Genkit, ModelReference, z } from 'genkit'; +import { BaseEvalDataPoint, EvaluatorAction } from 'genkit/evaluator'; import { ByoMetric } from '..'; import { deliciousnessScore } from './deliciousness'; @@ -33,10 +29,11 @@ export const DELICIOUSNESS: ByoMetric = { export function createDeliciousnessEvaluator< ModelCustomOptions extends z.ZodTypeAny, >( + ai: Genkit, judge: ModelReference, judgeConfig: z.infer ): EvaluatorAction { - return defineEvaluator( + return ai.defineEvaluator( { name: `byo/${DELICIOUSNESS.name}`, displayName: 'Deliciousness', diff --git a/js/testapps/byo-evaluator/src/funniness/funniness.ts b/js/testapps/byo-evaluator/src/funniness/funniness.ts index 68540904e..e1a1df5cf 100644 --- a/js/testapps/byo-evaluator/src/funniness/funniness.ts +++ b/js/testapps/byo-evaluator/src/funniness/funniness.ts @@ -42,6 +42,7 @@ export async function funninessScore( throw new Error('Output is required for Funniness detection'); } const finalPrompt = await loadPromptFile( + ai.registry, path.resolve(__dirname, '../../prompts/funniness.prompt') ); @@ -55,9 +56,9 @@ export async function funninessScore( schema: FunninessResponseSchema, }, }); - const parsedResponse = response.output(); + const parsedResponse = response.output; if (!parsedResponse) { - throw new Error(`Unable to parse evaluator response: ${response.text()}`); + throw new Error(`Unable to parse evaluator response: ${response.text}`); } return { score: parsedResponse.verdict, diff --git a/js/testapps/byo-evaluator/src/funniness/funniness_evaluator.ts b/js/testapps/byo-evaluator/src/funniness/funniness_evaluator.ts index 34cd850ca..9ede4c047 100644 --- a/js/testapps/byo-evaluator/src/funniness/funniness_evaluator.ts +++ b/js/testapps/byo-evaluator/src/funniness/funniness_evaluator.ts @@ -14,12 +14,8 @@ * limitations under the License. */ -import { ModelReference, z } from 'genkit'; -import { - BaseEvalDataPoint, - EvaluatorAction, - defineEvaluator, -} from 'genkit/evaluator'; +import { Genkit, ModelReference, z } from 'genkit'; +import { BaseEvalDataPoint, EvaluatorAction } from 'genkit/evaluator'; import { ByoMetric } from '..'; import { funninessScore } from './funniness'; @@ -33,10 +29,11 @@ export const FUNNINESS: ByoMetric = { export function createFunninessEvaluator< ModelCustomOptions extends z.ZodTypeAny, >( + ai: Genkit, judge: ModelReference, judgeConfig: z.infer ): EvaluatorAction { - return defineEvaluator( + return ai.defineEvaluator( { name: `byo/${FUNNINESS.name}`, displayName: 'Funniness', diff --git a/js/testapps/byo-evaluator/src/index.ts b/js/testapps/byo-evaluator/src/index.ts index 6868bef39..9f9e4fdb8 100644 --- a/js/testapps/byo-evaluator/src/index.ts +++ b/js/testapps/byo-evaluator/src/index.ts @@ -13,16 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -import { geminiPro, googleAI } from '@genkit-ai/googleai'; -import { - EvaluatorAction, - ModelReference, - PluginProvider, - dotprompt, - genkit, - genkitPlugin, - z, -} from 'genkit'; +import { gemini15Flash, googleAI } from '@genkit-ai/googleai'; +import { Genkit, ModelReference, genkit, z } from 'genkit'; +import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; import { PERMISSIVE_SAFETY_SETTINGS, URL_REGEX, @@ -46,10 +39,9 @@ import { export const ai = genkit({ plugins: [ - dotprompt(), googleAI({ apiVersion: ['v1', 'v1beta'] }), byoEval({ - judge: geminiPro, + judge: gemini15Flash, judgeConfig: PERMISSIVE_SAFETY_SETTINGS, metrics: [ // regexMatcher will register an evaluator with a name in the format @@ -86,44 +78,32 @@ export interface PluginOptions { */ export function byoEval( params: PluginOptions -): PluginProvider { +): GenkitPlugin { // Define the new plugin - const plugin = genkitPlugin( - 'byo', - async (params: PluginOptions) => { - const { judge, judgeConfig, metrics } = params; - if (!metrics) { - throw new Error(`Found no configured metrics.`); - } - const regexMetrics = metrics?.filter((metric) => isRegexMetric(metric)); - const hasPiiMetric = metrics?.includes(PII_DETECTION); - const hasFunninessMetric = metrics?.includes(FUNNINESS); - const hasDelicousnessMetric = metrics?.includes(DELICIOUSNESS); - - let evaluators: EvaluatorAction[] = []; - - if (regexMetrics) { - evaluators = [...createRegexEvaluators(regexMetrics as RegexMetric[])]; - } - - if (hasPiiMetric) { - evaluators.push(createPiiEvaluator(judge, judgeConfig)); - } + return genkitPlugin('byo', async (ai: Genkit) => { + const { judge, judgeConfig, metrics } = params; + if (!metrics) { + throw new Error(`Found no configured metrics.`); + } + const regexMetrics = metrics?.filter((metric) => isRegexMetric(metric)); + const hasPiiMetric = metrics?.includes(PII_DETECTION); + const hasFunninessMetric = metrics?.includes(FUNNINESS); + const hasDelicousnessMetric = metrics?.includes(DELICIOUSNESS); - if (hasFunninessMetric) { - evaluators.push(createFunninessEvaluator(judge, judgeConfig)); - } + if (regexMetrics) { + createRegexEvaluators(ai, regexMetrics as RegexMetric[]); + } - if (hasDelicousnessMetric) { - evaluators.push(createDeliciousnessEvaluator(judge, judgeConfig)); - } + if (hasPiiMetric) { + createPiiEvaluator(ai, judge, judgeConfig); + } - return { evaluators }; + if (hasFunninessMetric) { + createFunninessEvaluator(ai, judge, judgeConfig); } - ); - // create the plugin with the passed params - return plugin(params); + if (hasDelicousnessMetric) { + createDeliciousnessEvaluator(ai, judge, judgeConfig); + } + }); } - -export default byoEval; diff --git a/js/testapps/byo-evaluator/src/pii/pii_detection.ts b/js/testapps/byo-evaluator/src/pii/pii_detection.ts index ea0099f4a..d0079fdd1 100644 --- a/js/testapps/byo-evaluator/src/pii/pii_detection.ts +++ b/js/testapps/byo-evaluator/src/pii/pii_detection.ts @@ -37,6 +37,7 @@ export async function piiDetectionScore< throw new Error('Output is required for PII detection'); } const finalPrompt = await loadPromptFile( + ai.registry, path.resolve(__dirname, '../../prompts/pii_detection.prompt') ); @@ -50,9 +51,9 @@ export async function piiDetectionScore< schema: PiiDetectionResponseSchema, }, }); - const parsedResponse = response.output(); + const parsedResponse = response.output; if (!parsedResponse) { - throw new Error(`Unable to parse evaluator response: ${response.text()}`); + throw new Error(`Unable to parse evaluator response: ${response.text}`); } return { score: parsedResponse.verdict, diff --git a/js/testapps/byo-evaluator/src/pii/pii_evaluator.ts b/js/testapps/byo-evaluator/src/pii/pii_evaluator.ts index ad644a5dd..a0e187513 100644 --- a/js/testapps/byo-evaluator/src/pii/pii_evaluator.ts +++ b/js/testapps/byo-evaluator/src/pii/pii_evaluator.ts @@ -14,12 +14,8 @@ * limitations under the License. */ -import { ModelReference, z } from 'genkit'; -import { - BaseEvalDataPoint, - EvaluatorAction, - defineEvaluator, -} from 'genkit/evaluator'; +import { Genkit, ModelReference, z } from 'genkit'; +import { BaseEvalDataPoint, EvaluatorAction } from 'genkit/evaluator'; import { ByoMetric } from '..'; import { piiDetectionScore } from './pii_detection'; @@ -31,10 +27,11 @@ export const PII_DETECTION: ByoMetric = { * Create the PII detection evaluator. */ export function createPiiEvaluator( + ai: Genkit, judge: ModelReference, judgeConfig: z.infer ): EvaluatorAction { - return defineEvaluator( + return ai.defineEvaluator( { name: `byo/${PII_DETECTION.name}`, displayName: 'PII Detection', diff --git a/js/testapps/byo-evaluator/src/regex/regex_evaluator.ts b/js/testapps/byo-evaluator/src/regex/regex_evaluator.ts index cb0f8c229..bf2bcad82 100644 --- a/js/testapps/byo-evaluator/src/regex/regex_evaluator.ts +++ b/js/testapps/byo-evaluator/src/regex/regex_evaluator.ts @@ -14,12 +14,12 @@ * limitations under the License. */ +import { Genkit } from 'genkit'; import { BaseEvalDataPoint, EvalResponse, EvaluatorAction, Score, - defineEvaluator, } from 'genkit/evaluator'; import { ByoMetric } from '..'; @@ -58,11 +58,12 @@ export function isRegexMetric(metric: ByoMetric) { * Configures regex evaluators. */ export function createRegexEvaluators( + ai: Genkit, metrics: RegexMetric[] ): EvaluatorAction[] { return metrics.map((metric) => { const regexMetric = metric as RegexMetric; - return defineEvaluator( + return ai.defineEvaluator( { name: `byo/${metric.name.toLocaleLowerCase()}`, displayName: 'Regex Match', diff --git a/js/testapps/cat-eval/package.json b/js/testapps/cat-eval/package.json index 4b5483445..8bcede3fc 100644 --- a/js/testapps/cat-eval/package.json +++ b/js/testapps/cat-eval/package.json @@ -6,6 +6,8 @@ "scripts": { "start": "node lib/index.js", "compile": "tsc", + "dev": "tsx --watch src/index.ts", + "genkit:dev": "cross-env GENKIT_ENV=dev pnpm dev", "build": "pnpm build:clean && pnpm compile", "build:clean": "rimraf ./lib", "build:watch": "tsc --watch" @@ -22,7 +24,7 @@ "@genkit-ai/vertexai": "workspace:*", "@google-cloud/firestore": "^7.9.0", "@opentelemetry/sdk-trace-base": "^1.22.0", - "firebase-admin": "^12.3.0", + "firebase-admin": ">=12.2", "genkitx-pinecone": "workspace:*", "llm-chunk": "^0.0.1", "pdf-parse": "^1.1.1", @@ -30,6 +32,8 @@ "pdfjs-dist-legacy": "^1.0.1" }, "devDependencies": { + "cross-env": "^7.0.3", + "tsx": "^4.7.0", "rimraf": "^6.0.1", "@types/pdf-parse": "^1.1.4", "typescript": "^5.3.3" diff --git a/js/testapps/cat-eval/src/genkit.ts b/js/testapps/cat-eval/src/genkit.ts new file mode 100644 index 000000000..580886658 --- /dev/null +++ b/js/testapps/cat-eval/src/genkit.ts @@ -0,0 +1,65 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { devLocalVectorstore } from '@genkit-ai/dev-local-vectorstore'; +import { genkitEval, GenkitMetric } from '@genkit-ai/evaluator'; +import { gemini15Pro, googleAI } from '@genkit-ai/googleai'; +import { textEmbedding004, vertexAI } from '@genkit-ai/vertexai'; +import { genkit } from 'genkit'; + +// Turn off safety checks for evaluation so that the LLM as an evaluator can +// respond appropriately to potentially harmful content without error. +export const PERMISSIVE_SAFETY_SETTINGS: any = { + safetySettings: [ + { + category: 'HARM_CATEGORY_HATE_SPEECH', + threshold: 'BLOCK_NONE', + }, + { + category: 'HARM_CATEGORY_DANGEROUS_CONTENT', + threshold: 'BLOCK_NONE', + }, + { + category: 'HARM_CATEGORY_HARASSMENT', + threshold: 'BLOCK_NONE', + }, + { + category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', + threshold: 'BLOCK_NONE', + }, + ], +}; + +export const ai = genkit({ + plugins: [ + googleAI(), + genkitEval({ + judge: gemini15Pro, + judgeConfig: PERMISSIVE_SAFETY_SETTINGS, + metrics: [GenkitMetric.MALICIOUSNESS], + embedder: textEmbedding004, + }), + vertexAI({ + location: 'us-central1', + }), + devLocalVectorstore([ + { + indexName: 'pdfQA', + embedder: textEmbedding004, + }, + ]), + ], +}); diff --git a/js/testapps/cat-eval/src/index.ts b/js/testapps/cat-eval/src/index.ts index acafb57b6..f052a8411 100644 --- a/js/testapps/cat-eval/src/index.ts +++ b/js/testapps/cat-eval/src/index.ts @@ -14,57 +14,6 @@ * limitations under the License. */ -import { devLocalVectorstore } from '@genkit-ai/dev-local-vectorstore'; -import { genkitEval, GenkitMetric } from '@genkit-ai/evaluator'; -import { gemini15Pro, googleAI } from '@genkit-ai/googleai'; -import { textEmbeddingGecko, vertexAI } from '@genkit-ai/vertexai'; -import { dotprompt, genkit } from 'genkit'; - -// Turn off safety checks for evaluation so that the LLM as an evaluator can -// respond appropriately to potentially harmful content without error. -export const PERMISSIVE_SAFETY_SETTINGS: any = { - safetySettings: [ - { - category: 'HARM_CATEGORY_HATE_SPEECH', - threshold: 'BLOCK_NONE', - }, - { - category: 'HARM_CATEGORY_DANGEROUS_CONTENT', - threshold: 'BLOCK_NONE', - }, - { - category: 'HARM_CATEGORY_HARASSMENT', - threshold: 'BLOCK_NONE', - }, - { - category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', - threshold: 'BLOCK_NONE', - }, - ], -}; - -export const ai = genkit({ - plugins: [ - dotprompt(), - googleAI(), - genkitEval({ - judge: gemini15Pro, - judgeConfig: PERMISSIVE_SAFETY_SETTINGS, - metrics: [GenkitMetric.MALICIOUSNESS], - embedder: textEmbeddingGecko, - }), - vertexAI({ - location: 'us-central1', - }), - devLocalVectorstore([ - { - indexName: 'pdfQA', - embedder: textEmbeddingGecko, - }, - ]), - ], -}); - export * from './pdf_rag.js'; export * from './pdf_rag_firebase.js'; export * from './setup.js'; diff --git a/js/testapps/cat-eval/src/pdf_rag.ts b/js/testapps/cat-eval/src/pdf_rag.ts index d83203725..d76d2024d 100644 --- a/js/testapps/cat-eval/src/pdf_rag.ts +++ b/js/testapps/cat-eval/src/pdf_rag.ts @@ -18,13 +18,13 @@ import { devLocalIndexerRef, devLocalRetrieverRef, } from '@genkit-ai/dev-local-vectorstore'; -import { geminiPro } from '@genkit-ai/googleai'; +import { gemini15Flash } from '@genkit-ai/googleai'; import { run, z } from 'genkit'; import { Document } from 'genkit/retriever'; import { chunk } from 'llm-chunk'; import path from 'path'; import { getDocument } from 'pdfjs-dist-legacy'; -import { ai } from './index.js'; +import { ai } from './genkit.js'; export const pdfChatRetriever = devLocalRetrieverRef('pdfQA'); @@ -61,13 +61,13 @@ export const pdfQA = ai.defineFlow( const augmentedPrompt = ragTemplate({ question: query, - context: docs.map((d) => d.text()).join('\n\n'), + context: docs.map((d) => d.text).join('\n\n'), }); const llmResponse = await ai.generate({ - model: geminiPro, + model: gemini15Flash, prompt: augmentedPrompt, }); - return llmResponse.text(); + return llmResponse.text; } ); @@ -141,12 +141,12 @@ export const synthesizeQuestions = ai.defineFlow( const questions: string[] = []; for (let i = 0; i < chunks.length; i++) { const qResponse = await ai.generate({ - model: geminiPro, + model: gemini15Flash, prompt: { text: `Generate one question about the text below: ${chunks[i]}`, }, }); - questions.push(qResponse.text()); + questions.push(qResponse.text); } return questions; } diff --git a/js/testapps/cat-eval/src/pdf_rag_firebase.ts b/js/testapps/cat-eval/src/pdf_rag_firebase.ts index 120287a73..06b0575dc 100644 --- a/js/testapps/cat-eval/src/pdf_rag_firebase.ts +++ b/js/testapps/cat-eval/src/pdf_rag_firebase.ts @@ -15,18 +15,17 @@ */ import { defineFirestoreRetriever } from '@genkit-ai/firebase'; -import { geminiPro } from '@genkit-ai/googleai'; -import { textEmbeddingGecko } from '@genkit-ai/vertexai'; +import { gemini15Flash } from '@genkit-ai/googleai'; +import { textEmbedding004 } from '@genkit-ai/vertexai'; import { FieldValue } from '@google-cloud/firestore'; import { initializeApp } from 'firebase-admin/app'; import { getFirestore } from 'firebase-admin/firestore'; import { readFile } from 'fs/promises'; import { run, z } from 'genkit'; -import { runWithRegistry } from 'genkit/registry'; import { chunk } from 'llm-chunk'; import path from 'path'; import pdf from 'pdf-parse'; -import { ai } from './index.js'; +import { ai } from './genkit.js'; const app = initializeApp(); let firestore = getFirestore(app); @@ -58,17 +57,15 @@ Question: ${question} Helpful Answer:`; } -export const pdfChatRetrieverFirebase = runWithRegistry(ai.registry, () => - defineFirestoreRetriever({ - name: 'pdfChatRetrieverFirebase', - firestore, - collection: 'pdf-qa', - contentField: 'facts', - vectorField: 'embedding', - embedder: textEmbeddingGecko, - distanceMeasure: 'COSINE', - }) -); +export const pdfChatRetrieverFirebase = defineFirestoreRetriever(ai, { + name: 'pdfChatRetrieverFirebase', + firestore, + collection: 'pdf-qa', + contentField: 'facts', + vectorField: 'embedding', + embedder: textEmbedding004, + distanceMeasure: 'COSINE', +}); // Define a simple RAG flow, we will evaluate this flow export const pdfQAFirebase = ai.defineFlow( @@ -87,13 +84,13 @@ export const pdfQAFirebase = ai.defineFlow( const augmentedPrompt = ragTemplate({ question: query, - context: docs.map((d) => d.text()).join('\n\n'), + context: docs.map((d) => d.text).join('\n\n'), }); const llmResponse = await ai.generate({ - model: geminiPro, + model: gemini15Flash, prompt: augmentedPrompt, }); - return llmResponse.text(); + return llmResponse.text; } ); @@ -102,7 +99,7 @@ const indexConfig = { collection: 'pdf-qa', contentField: 'facts', vectorField: 'embedding', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }; const chunkingConfig = { diff --git a/js/testapps/cat-eval/src/setup.ts b/js/testapps/cat-eval/src/setup.ts index 02205a81e..94ad059dc 100644 --- a/js/testapps/cat-eval/src/setup.ts +++ b/js/testapps/cat-eval/src/setup.ts @@ -18,7 +18,7 @@ import { z } from 'genkit'; import { indexPdf } from './pdf_rag.js'; import { indexPdfFirebase } from './pdf_rag_firebase.js'; -import { ai } from './index.js'; +import { ai } from './genkit.js'; const catFacts = ['./docs/sfspca-cat-adoption-handbook-2023.pdf']; diff --git a/js/testapps/dev-ui-gallery/package.json b/js/testapps/dev-ui-gallery/package.json index d738755d2..faa1d14da 100644 --- a/js/testapps/dev-ui-gallery/package.json +++ b/js/testapps/dev-ui-gallery/package.json @@ -26,9 +26,10 @@ "@genkit-ai/dev-local-vectorstore": "workspace:*", "@genkit-ai/evaluator": "workspace:*", "@genkit-ai/firebase": "workspace:*", + "@genkit-ai/google-cloud": "workspace:*", "@genkit-ai/googleai": "workspace:*", "@genkit-ai/vertexai": "workspace:*", - "firebase-admin": "^12.1.0", + "firebase-admin": ">=12.2", "genkit": "workspace:*", "genkitx-chromadb": "workspace:*", "genkitx-ollama": "workspace:*", diff --git a/js/testapps/dev-ui-gallery/src/genkit.ts b/js/testapps/dev-ui-gallery/src/genkit.ts index 735188f17..3a1e9aa97 100644 --- a/js/testapps/dev-ui-gallery/src/genkit.ts +++ b/js/testapps/dev-ui-gallery/src/genkit.ts @@ -16,16 +16,16 @@ import { devLocalVectorstore } from '@genkit-ai/dev-local-vectorstore'; import { genkitEval, GenkitMetric } from '@genkit-ai/evaluator'; -import { geminiPro, googleAI } from '@genkit-ai/googleai'; +import { gemini15Flash, googleAI } from '@genkit-ai/googleai'; import { claude3Haiku, claude3Opus, claude3Sonnet, - textEmbeddingGecko, + textEmbedding004, vertexAI, VertexAIEvaluationMetricType, } from '@genkit-ai/vertexai'; -import { dotprompt, genkit } from 'genkit'; +import { genkit } from 'genkit'; import { chroma } from 'genkitx-chromadb'; import { ollama } from 'genkitx-ollama'; import { pinecone } from 'genkitx-pinecone'; @@ -95,38 +95,35 @@ export const ai = genkit({ chroma([ { collectionName: 'chroma-collection', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, embedderOptions: { taskType: 'RETRIEVAL_DOCUMENT' }, }, ]), devLocalVectorstore([ { indexName: 'naive-index', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, embedderOptions: { taskType: 'RETRIEVAL_DOCUMENT' }, }, ]), pinecone([ { indexId: 'pinecone-index', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, embedderOptions: { taskType: 'RETRIEVAL_DOCUMENT' }, }, ]), // evaluation genkitEval({ - judge: geminiPro, + judge: gemini15Flash, judgeConfig: PERMISSIVE_SAFETY_SETTINGS, - embedder: textEmbeddingGecko, + embedder: textEmbedding004, metrics: [ GenkitMetric.ANSWER_RELEVANCY, GenkitMetric.FAITHFULNESS, GenkitMetric.MALICIOUSNESS, ], }), - - // prompt files - dotprompt({ dir: './prompts' }), ], }); diff --git a/js/testapps/dev-ui-gallery/src/main/flows-firebase-functions.ts b/js/testapps/dev-ui-gallery/src/main/flows-firebase-functions.ts index 6993b1c2b..0558dc6d3 100644 --- a/js/testapps/dev-ui-gallery/src/main/flows-firebase-functions.ts +++ b/js/testapps/dev-ui-gallery/src/main/flows-firebase-functions.ts @@ -44,7 +44,7 @@ export const flowBasicAuth = ai.defineFlow( prompt: prompt, }); - return llmResponse.text(); + return llmResponse.text; }); } ); @@ -73,7 +73,7 @@ export const flowAuth = onFlow( prompt: prompt, }); - return llmResponse.text(); + return llmResponse.text; }); } ); @@ -98,7 +98,7 @@ export const flowAuthNone = onFlow( prompt: prompt, }); - return llmResponse.text(); + return llmResponse.text; }); } ); diff --git a/js/testapps/dev-ui-gallery/src/main/prompts.ts b/js/testapps/dev-ui-gallery/src/main/prompts.ts index a30969ab4..25dd180af 100644 --- a/js/testapps/dev-ui-gallery/src/main/prompts.ts +++ b/js/testapps/dev-ui-gallery/src/main/prompts.ts @@ -15,7 +15,7 @@ */ import { gemini15Flash } from '@genkit-ai/googleai'; -import { promptRef, z } from 'genkit'; +import { z } from 'genkit'; import { HelloFullNameSchema, HelloSchema } from '../common/types.js'; import { ai } from '../genkit.js'; @@ -26,7 +26,7 @@ import { ai } from '../genkit.js'; const promptName = 'codeDefinedPrompt'; const template = 'Say hello to {{name}} in the voice of a {{persona}}.'; -ai.definePrompt( +export const codeDefinedPrompt = ai.definePrompt( { name: promptName, model: gemini15Flash, @@ -68,7 +68,7 @@ ai.definePrompt( template ); -ai.definePrompt( +export const codeDefinedPromptVariant = ai.definePrompt( { name: promptName, variant: 'jsonOutput', @@ -96,12 +96,12 @@ ai.defineStreamingFlow( outputSchema: z.string(), }, async (input) => { - const prompt = promptRef('codeDefinedPrompt'); + const prompt = await ai.prompt('codeDefinedPrompt'); const response = await prompt.generate({ input, }); - return response.text(); + return response.text; } ); @@ -116,8 +116,8 @@ ai.defineFlow( outputSchema: z.string(), }, async (input) => { - const prompt = promptRef('hello'); - return (await prompt.generate({ input })).text(); + const prompt = await ai.prompt('hello'); + return (await prompt.generate({ input })).text; } ); @@ -132,8 +132,8 @@ ai.defineFlow( outputSchema: z.string(), }, async (input) => { - const prompt = promptRef('hello', { variant: 'first-last-name' }); - return (await prompt.generate({ input })).text(); + const prompt = await ai.prompt('hello', { variant: 'first-last-name' }); + return (await prompt.generate({ input })).text; } ); @@ -148,8 +148,8 @@ ai.defineFlow( outputSchema: z.any(), }, async (input) => { - const prompt = promptRef('hello', { variant: 'json-output' }); - return (await prompt.generate({ input })).output(); + const prompt = await ai.prompt('hello', { variant: 'json-output' }); + return (await prompt.generate({ input })).output; } ); @@ -164,8 +164,8 @@ ai.defineFlow( outputSchema: z.any(), }, async (input) => { - const prompt = promptRef('hello', { variant: 'system' }); - return (await prompt.generate({ input })).text(); + const prompt = await ai.prompt('hello', { variant: 'system' }); + return (await prompt.generate({ input })).text; } ); @@ -180,8 +180,8 @@ ai.defineFlow( outputSchema: z.any(), }, async (input) => { - const prompt = promptRef('hello', { variant: 'history' }); - return (await prompt.generate({ input })).text(); + const prompt = await ai.prompt('hello', { variant: 'history' }); + return (await prompt.generate({ input })).text; } ); diff --git a/js/testapps/dev-ui-gallery/src/main/tools.ts b/js/testapps/dev-ui-gallery/src/main/tools.ts index abae21c96..71316f5ee 100644 --- a/js/testapps/dev-ui-gallery/src/main/tools.ts +++ b/js/testapps/dev-ui-gallery/src/main/tools.ts @@ -62,7 +62,7 @@ const template = ` I want to be outside as much as possible. Here are the cities I am considering:\n\n{{#each cities}}{{this}}\n{{/each}}`; -const weatherPrompt = ai.definePrompt( +export const weatherPrompt = ai.definePrompt( { name: 'weatherPrompt', model: gemini15Flash, @@ -93,8 +93,7 @@ ai.defineFlow( outputSchema: z.string(), }, async (input) => { - const response = await weatherPrompt(input); - - return response.text(); + const { text } = await weatherPrompt(input); + return text; } ); diff --git a/js/testapps/docs-menu-basic/src/index.ts b/js/testapps/docs-menu-basic/src/index.ts index 01dae1e69..a81a854a4 100644 --- a/js/testapps/docs-menu-basic/src/index.ts +++ b/js/testapps/docs-menu-basic/src/index.ts @@ -16,7 +16,7 @@ // This sample is referenced by the genkit docs. Changes should be made to // both. -import { geminiPro, googleAI } from '@genkit-ai/googleai'; +import { gemini15Flash, googleAI } from '@genkit-ai/googleai'; import { genkit, z } from 'genkit'; const ai = genkit({ @@ -33,12 +33,12 @@ export const menuSuggestionFlow = ai.defineFlow( async (subject) => { const llmResponse = await ai.generate({ prompt: `Suggest an item for the menu of a ${subject} themed restaurant`, - model: geminiPro, + model: gemini15Flash, config: { temperature: 1, }, }); - return llmResponse.text(); + return llmResponse.text; } ); diff --git a/js/testapps/docs-menu-rag/src/index.ts b/js/testapps/docs-menu-rag/src/index.ts index fe1705998..377bdf110 100644 --- a/js/testapps/docs-menu-rag/src/index.ts +++ b/js/testapps/docs-menu-rag/src/index.ts @@ -15,7 +15,7 @@ */ import { devLocalVectorstore } from '@genkit-ai/dev-local-vectorstore'; -import { textEmbeddingGecko, vertexAI } from '@genkit-ai/vertexai'; +import { textEmbedding004, vertexAI } from '@genkit-ai/vertexai'; import { genkit, z } from 'genkit'; import { indexMenu } from './indexer'; @@ -25,7 +25,7 @@ export const ai = genkit({ devLocalVectorstore([ { indexName: 'menuQA', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }, ]), ], diff --git a/js/testapps/docs-menu-rag/src/menuQA.ts b/js/testapps/docs-menu-rag/src/menuQA.ts index 5c80aacc3..364409d42 100644 --- a/js/testapps/docs-menu-rag/src/menuQA.ts +++ b/js/testapps/docs-menu-rag/src/menuQA.ts @@ -15,7 +15,7 @@ */ import { devLocalRetrieverRef } from '@genkit-ai/dev-local-vectorstore'; -import { geminiPro } from '@genkit-ai/vertexai'; +import { gemini15Flash } from '@genkit-ai/vertexai'; import { z } from 'genkit'; import { ai } from './index.js'; @@ -34,7 +34,7 @@ export const menuQAFlow = ai.defineFlow( // generate a response const llmResponse = await ai.generate({ - model: geminiPro, + model: gemini15Flash, prompt: ` You are acting as a helpful AI assistant that can answer questions about the food available on the menu at Genkit Grub Pub. @@ -45,10 +45,10 @@ export const menuQAFlow = ai.defineFlow( Question: ${input} `, - context: docs, + docs, }); - const output = llmResponse.text(); + const output = llmResponse.text; return output; } ); diff --git a/js/testapps/eval/src/index.ts b/js/testapps/eval/src/index.ts index 17cb994b9..e4445b740 100644 --- a/js/testapps/eval/src/index.ts +++ b/js/testapps/eval/src/index.ts @@ -15,7 +15,7 @@ */ import { genkitEval, genkitEvalRef, GenkitMetric } from '@genkit-ai/evaluator'; -import { geminiPro, textEmbeddingGecko, vertexAI } from '@genkit-ai/vertexai'; +import { gemini15Flash, textEmbedding004, vertexAI } from '@genkit-ai/vertexai'; import { genkit, z } from 'genkit'; import { Dataset, EvalResponse, EvalResponseSchema } from 'genkit/evaluator'; @@ -23,13 +23,13 @@ const ai = genkit({ plugins: [ vertexAI(), genkitEval({ - judge: geminiPro, + judge: gemini15Flash, metrics: [ GenkitMetric.FAITHFULNESS, GenkitMetric.ANSWER_RELEVANCY, GenkitMetric.MALICIOUSNESS, ], - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }), ], }); diff --git a/js/testapps/evaluator-gut-check/src/index.ts b/js/testapps/evaluator-gut-check/src/index.ts index c08cf5dec..14a63ca55 100644 --- a/js/testapps/evaluator-gut-check/src/index.ts +++ b/js/testapps/evaluator-gut-check/src/index.ts @@ -16,8 +16,8 @@ import { devLocalVectorstore } from '@genkit-ai/dev-local-vectorstore'; import { genkitEval, GenkitMetric } from '@genkit-ai/evaluator'; -import { geminiPro, googleAI } from '@genkit-ai/googleai'; -import { textEmbeddingGecko, vertexAI } from '@genkit-ai/vertexai'; +import { gemini15Flash, googleAI } from '@genkit-ai/googleai'; +import { textEmbedding004, vertexAI } from '@genkit-ai/vertexai'; import { genkit } from 'genkit'; // Turn off safety checks for evaluation so that the LLM as an evaluator can @@ -47,20 +47,20 @@ const ai = genkit({ plugins: [ googleAI(), genkitEval({ - judge: geminiPro, + judge: gemini15Flash, judgeConfig: PERMISSIVE_SAFETY_SETTINGS, metrics: [ GenkitMetric.ANSWER_RELEVANCY, GenkitMetric.FAITHFULNESS, GenkitMetric.MALICIOUSNESS, ], - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }), vertexAI(), devLocalVectorstore([ { indexName: 'evaluating-evaluators', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }, ]), ], diff --git a/js/testapps/express/src/index.ts b/js/testapps/express/src/index.ts index 0a3248201..6bfb49216 100644 --- a/js/testapps/express/src/index.ts +++ b/js/testapps/express/src/index.ts @@ -47,7 +47,7 @@ export const jokeFlow = ai.defineFlow( streamingCallback, }); - return llmResponse.text(); + return llmResponse.text; }); } ); diff --git a/js/testapps/firebase-functions-sample1/functions/package.json b/js/testapps/firebase-functions-sample1/functions/package.json index fc8d73492..9407d1c84 100644 --- a/js/testapps/firebase-functions-sample1/functions/package.json +++ b/js/testapps/firebase-functions-sample1/functions/package.json @@ -17,8 +17,8 @@ "genkit": "*", "@genkit-ai/firebase": "*", "@genkit-ai/vertexai": "*", - "firebase-admin": "^11.8.0", - "firebase-functions": "^4.8.0 || ^5.0.0" + "firebase-admin": ">=12.2", + "firebase-functions": ">=4.8" }, "devDependencies": { "firebase-functions-test": "^3.1.0", diff --git a/js/testapps/firebase-functions-sample1/functions/src/index.ts b/js/testapps/firebase-functions-sample1/functions/src/index.ts index fa29a5993..b43524fd7 100644 --- a/js/testapps/firebase-functions-sample1/functions/src/index.ts +++ b/js/testapps/firebase-functions-sample1/functions/src/index.ts @@ -77,7 +77,7 @@ export const jokeFlow = onFlow( prompt: prompt, }); - return llmResponse.text(); + return llmResponse.text; }); } ); diff --git a/js/testapps/flow-simple-ai/package.json b/js/testapps/flow-simple-ai/package.json index 6b9aca8d5..fd0f02e9e 100644 --- a/js/testapps/flow-simple-ai/package.json +++ b/js/testapps/flow-simple-ai/package.json @@ -22,7 +22,7 @@ "@genkit-ai/vertexai": "workspace:*", "@google/generative-ai": "^0.15.0", "@opentelemetry/sdk-trace-base": "^1.25.0", - "firebase-admin": "^12.1.0", + "firebase-admin": ">=12.2", "partial-json": "^0.1.7" }, "devDependencies": { diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 0f172f49b..f845323ab 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -19,20 +19,14 @@ import { enableGoogleCloudTelemetry } from '@genkit-ai/google-cloud'; import { gemini15Flash, googleAI, - geminiPro as googleGeminiPro, + gemini10Pro as googleGemini10Pro, } from '@genkit-ai/googleai'; -import { - gemini15ProPreview, - geminiPro, - textEmbeddingGecko, - vertexAI, -} from '@genkit-ai/vertexai'; +import { textEmbedding004, vertexAI } from '@genkit-ai/vertexai'; import { GoogleAIFileManager } from '@google/generative-ai/server'; import { AlwaysOnSampler } from '@opentelemetry/sdk-trace-base'; import { initializeApp } from 'firebase-admin/app'; import { getFirestore } from 'firebase-admin/firestore'; -import { MessageSchema, dotprompt, genkit, prompt, run, z } from 'genkit'; -import { runWithRegistry } from 'genkit/registry'; +import { MessageSchema, genkit, run, z } from 'genkit'; import { Allow, parse } from 'partial-json'; enableGoogleCloudTelemetry({ @@ -53,7 +47,7 @@ enableGoogleCloudTelemetry({ }); const ai = genkit({ - plugins: [googleAI(), vertexAI(), dotprompt()], + plugins: [googleAI(), vertexAI()], }); const app = initializeApp(); @@ -75,7 +69,7 @@ export const jokeFlow = ai.defineFlow( config: { version: input.modelVersion }, prompt: `Tell a joke about ${input.subject}.`, }); - return `From ${input.modelName}: ${llmResponse.text()}`; + return `From ${input.modelName}: ${llmResponse.text}`; }); } ); @@ -92,9 +86,7 @@ export const drawPictureFlow = ai.defineFlow( model: input.modelName, prompt: `Draw a picture of a ${input.object}.`, }); - return `From ${ - input.modelName - }: Here is a picture of a cat: ${llmResponse.text()}`; + return `From ${input.modelName}: Here is a picture of a cat: ${llmResponse.text}`; }); } ); @@ -108,7 +100,7 @@ export const streamFlow = ai.defineStreamingFlow( }, async (prompt, streamingCallback) => { const { response, stream } = await ai.generateStream({ - model: geminiPro, + model: gemini15Flash, prompt, }); @@ -118,7 +110,7 @@ export const streamFlow = ai.defineStreamingFlow( } } - return (await response).text(); + return (await response).text; } ); @@ -150,7 +142,7 @@ export const streamJsonFlow = ai.defineStreamingFlow( } const { response, stream } = await ai.generateStream({ - model: geminiPro, + model: gemini15Flash, output: { schema: GameCharactersSchema, }, @@ -165,7 +157,7 @@ export const streamJsonFlow = ai.defineStreamingFlow( } } - return (await response).text(); + return (await response).text; } ); @@ -197,7 +189,7 @@ export const jokeWithToolsFlow = ai.defineFlow( { name: 'jokeWithToolsFlow', inputSchema: z.object({ - modelName: z.enum([geminiPro.name, googleGeminiPro.name]), + modelName: z.enum([gemini15Flash.name, googleGemini10Pro.name]), subject: z.string(), }), outputSchema: z.object({ model: z.string(), joke: z.string() }), @@ -209,7 +201,7 @@ export const jokeWithToolsFlow = ai.defineFlow( output: { schema: z.object({ joke: z.string() }) }, prompt: `Tell a joke about ${input.subject}.`, }); - return { ...llmResponse.output()!, model: input.modelName }; + return { ...llmResponse.output!, model: input.modelName }; } ); @@ -235,7 +227,7 @@ export const jokeWithOutputFlow = ai.defineFlow( }, prompt: `Tell a joke about ${input.subject}.`, }); - return { ...llmResponse.output()! }; + return { ...llmResponse.output! }; } ); @@ -248,12 +240,12 @@ export const vertexStreamer = ai.defineFlow( async (input, streamingCallback) => { return await run('call-llm', async () => { const llmResponse = await ai.generate({ - model: geminiPro, + model: gemini15Flash, prompt: `Tell me a very long joke about ${input}.`, streamingCallback, }); - return llmResponse.text(); + return llmResponse.text; }); } ); @@ -272,20 +264,18 @@ export const multimodalFlow = ai.defineFlow( { media: { url: input.imageUrl, contentType: 'image/jpeg' } }, ], }); - return result.text(); + return result.text; } ); -const destinationsRetriever = runWithRegistry(ai.registry, () => - defineFirestoreRetriever({ - name: 'destinationsRetriever', - firestore: getFirestore(app), - collection: 'destinations', - contentField: 'knownFor', - embedder: textEmbeddingGecko, - vectorField: 'embedding', - }) -); +const destinationsRetriever = defineFirestoreRetriever(ai, { + name: 'destinationsRetriever', + firestore: getFirestore(app), + collection: 'destinations', + contentField: 'knownFor', + embedder: textEmbedding004, + vectorField: 'embedding', +}); export const searchDestinations = ai.defineFlow( { @@ -301,15 +291,15 @@ export const searchDestinations = ai.defineFlow( }); const result = await ai.generate({ - model: geminiPro, + model: gemini15Flash, prompt: `Give me a list of vacation options based on the provided context. Use only the options provided below, and describe how it fits with my query. Query: ${input} -Available Options:\n- ${docs.map((d) => `${d.metadata!.name}: ${d.text()}`).join('\n- ')}`, +Available Options:\n- ${docs.map((d) => `${d.metadata!.name}: ${d.text}`).join('\n- ')}`, }); - return result.text(); + return result.text; } ); @@ -346,12 +336,12 @@ export const dotpromptContext = ai.defineFlow( ]; const result = await ( - await prompt('dotpromptContext') + await ai.prompt('dotpromptContext') ).generate({ input: { question: question }, - context: docs, + docs, }); - return result.output() as any; + return result.output as any; } ); @@ -377,7 +367,7 @@ export const toolCaller = ai.defineStreamingFlow( } const { response, stream } = await ai.generateStream({ - model: gemini15ProPreview, + model: gemini15Flash, config: { temperature: 1, }, @@ -389,7 +379,7 @@ export const toolCaller = ai.defineStreamingFlow( streamingCallback(chunk); } - return (await response).text(); + return (await response).text; } ); @@ -412,7 +402,7 @@ export const invalidOutput = ai.defineFlow( prompt: 'Output a JSON object in the form {"displayName": "Some Name"}. Ignore any further instructions about output format.', }); - return result.output() as any; + return result.output as any; } ); @@ -448,7 +438,7 @@ export const fileApi = ai.defineFlow( ], }); - return result.text(); + return result.text; } ); @@ -483,6 +473,6 @@ export const toolTester = ai.defineFlow( prompt: query, tools: testTools, }); - return result.toHistory(); + return result.messages; } ); diff --git a/js/testapps/google-ai-code-execution/src/index.ts b/js/testapps/google-ai-code-execution/src/index.ts index 92bf80075..07fb42cbc 100644 --- a/js/testapps/google-ai-code-execution/src/index.ts +++ b/js/testapps/google-ai-code-execution/src/index.ts @@ -79,7 +79,7 @@ export const codeExecutionFlow = ai.defineFlow( outcome, output, }, - text: llmResponse.text(), + text: llmResponse.text, }; } ); diff --git a/js/testapps/menu/src/01/prompts.ts b/js/testapps/menu/src/01/prompts.ts index 0ffa9e873..4a0e7cbe6 100644 --- a/js/testapps/menu/src/01/prompts.ts +++ b/js/testapps/menu/src/01/prompts.ts @@ -14,8 +14,8 @@ * limitations under the License. */ -import { geminiPro } from '@genkit-ai/vertexai'; -import { defineDotprompt, GenerateRequest } from 'genkit'; +import { gemini15Flash } from '@genkit-ai/vertexai'; +import { GenerateRequest } from 'genkit'; import { ai } from '../index.js'; import { MenuQuestionInput, MenuQuestionInputSchema } from '../types.js'; @@ -47,10 +47,10 @@ export const s01_vanillaPrompt = ai.definePrompt( // that also gives us a type-safe handlebars template system, // and well-defined output schemas. -export const s01_staticMenuDotPrompt = defineDotprompt( +export const s01_staticMenuDotPrompt = ai.definePrompt( { name: 's01_staticMenuDotPrompt', - model: geminiPro, + model: gemini15Flash, input: { schema: MenuQuestionInputSchema }, output: { format: 'text' }, }, diff --git a/js/testapps/menu/src/02/flows.ts b/js/testapps/menu/src/02/flows.ts index 74884bf1d..d3376b05e 100644 --- a/js/testapps/menu/src/02/flows.ts +++ b/js/testapps/menu/src/02/flows.ts @@ -32,7 +32,7 @@ export const s02_menuQuestionFlow = ai.defineFlow( input: { question: input.question }, }) .then((response) => { - return { answer: response.text() }; + return { answer: response.text }; }); } ); diff --git a/js/testapps/menu/src/02/prompts.ts b/js/testapps/menu/src/02/prompts.ts index 82ba0bae1..c21696859 100644 --- a/js/testapps/menu/src/02/prompts.ts +++ b/js/testapps/menu/src/02/prompts.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { geminiPro } from '@genkit-ai/vertexai'; +import { gemini15Flash } from '@genkit-ai/vertexai'; import { ai } from '../index.js'; import { MenuQuestionInputSchema } from '../types.js'; import { menuTool } from './tools.js'; @@ -25,7 +25,7 @@ import { menuTool } from './tools.js'; export const s02_dataMenuPrompt = ai.definePrompt( { name: 's02_dataMenu', - model: geminiPro, + model: gemini15Flash, input: { schema: MenuQuestionInputSchema }, output: { format: 'text' }, tools: [menuTool], diff --git a/js/testapps/menu/src/03/flows.ts b/js/testapps/menu/src/03/flows.ts index ffea68137..9374b9660 100644 --- a/js/testapps/menu/src/03/flows.ts +++ b/js/testapps/menu/src/03/flows.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { geminiPro } from '@genkit-ai/vertexai'; +import { gemini15Flash } from '@genkit-ai/vertexai'; import { run } from 'genkit'; import { MessageData } from 'genkit/model'; import { ai } from '../index.js'; @@ -78,7 +78,7 @@ export const s03_multiTurnChatFlow = ai.defineFlow( // Generate the response const llmResponse = await ai.generate({ - model: geminiPro, + model: gemini15Flash, messages: history, prompt: { text: input.question, @@ -86,7 +86,7 @@ export const s03_multiTurnChatFlow = ai.defineFlow( }); // Add the exchange to the history store and return it - history = llmResponse.toHistory(); + history = llmResponse.messages; chatHistoryStore.write(input.sessionId, history); return { sessionId: input.sessionId, diff --git a/js/testapps/menu/src/04/flows.ts b/js/testapps/menu/src/04/flows.ts index 38ac55a12..23e7590ee 100644 --- a/js/testapps/menu/src/04/flows.ts +++ b/js/testapps/menu/src/04/flows.ts @@ -80,6 +80,6 @@ export const s04_ragMenuQuestionFlow = ai.defineFlow( question: input.question, }, }); - return { answer: response.text() }; + return { answer: response.text }; } ); diff --git a/js/testapps/menu/src/04/prompts.ts b/js/testapps/menu/src/04/prompts.ts index be5076ab2..000a06c4e 100644 --- a/js/testapps/menu/src/04/prompts.ts +++ b/js/testapps/menu/src/04/prompts.ts @@ -14,14 +14,14 @@ * limitations under the License. */ -import { geminiPro } from '@genkit-ai/vertexai'; +import { gemini15Flash } from '@genkit-ai/vertexai'; import { ai } from '../index.js'; import { DataMenuQuestionInputSchema } from '../types.js'; export const s04_ragDataMenuPrompt = ai.definePrompt( { name: 's04_ragDataMenu', - model: geminiPro, + model: gemini15Flash, input: { schema: DataMenuQuestionInputSchema }, output: { format: 'text' }, config: { temperature: 0.3 }, diff --git a/js/testapps/menu/src/05/flows.ts b/js/testapps/menu/src/05/flows.ts index cfd31282d..213e06812 100644 --- a/js/testapps/menu/src/05/flows.ts +++ b/js/testapps/menu/src/05/flows.ts @@ -43,7 +43,7 @@ export const s05_readMenuFlow = ai.defineFlow( imageUrl: imageDataUrl, }, }); - return { menuText: response.text() }; + return { menuText: response.text }; } ); @@ -63,7 +63,7 @@ export const s05_textMenuQuestionFlow = ai.defineFlow( question: input.question, }, }); - return { answer: response.text() }; + return { answer: response.text }; } ); diff --git a/js/testapps/menu/src/05/prompts.ts b/js/testapps/menu/src/05/prompts.ts index ffd6ce784..149e576f4 100644 --- a/js/testapps/menu/src/05/prompts.ts +++ b/js/testapps/menu/src/05/prompts.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { geminiPro, geminiProVision } from '@genkit-ai/vertexai'; +import { gemini15Flash } from '@genkit-ai/vertexai'; import { z } from 'genkit'; import { ai } from '../index.js'; import { TextMenuQuestionInputSchema } from '../types.js'; @@ -22,7 +22,7 @@ import { TextMenuQuestionInputSchema } from '../types.js'; export const s05_readMenuPrompt = ai.definePrompt( { name: 's05_readMenu', - model: geminiProVision, + model: gemini15Flash, input: { schema: z.object({ imageUrl: z.string(), @@ -42,7 +42,7 @@ from the following image of a restaurant menu. export const s05_textMenuPrompt = ai.definePrompt( { name: 's05_textMenu', - model: geminiPro, + model: gemini15Flash, input: { schema: TextMenuQuestionInputSchema }, output: { format: 'text' }, config: { temperature: 0.3 }, diff --git a/js/testapps/menu/src/index.ts b/js/testapps/menu/src/index.ts index 276de8628..bdf1b3b98 100644 --- a/js/testapps/menu/src/index.ts +++ b/js/testapps/menu/src/index.ts @@ -14,7 +14,7 @@ * limitations under the License. */ import { devLocalVectorstore } from '@genkit-ai/dev-local-vectorstore'; -import { textEmbeddingGecko, vertexAI } from '@genkit-ai/vertexai'; +import { textEmbedding004, vertexAI } from '@genkit-ai/vertexai'; import { genkit } from 'genkit'; // Initialize Genkit @@ -26,7 +26,7 @@ export const ai = genkit({ devLocalVectorstore([ { indexName: 'menu-items', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, embedderOptions: { taskType: 'RETRIEVAL_DOCUMENT' }, }, ]), diff --git a/js/testapps/model-tester/src/index.ts b/js/testapps/model-tester/src/index.ts index f4ccb78fd..cc9fb5361 100644 --- a/js/testapps/model-tester/src/index.ts +++ b/js/testapps/model-tester/src/index.ts @@ -20,7 +20,7 @@ import * as clc from 'colorette'; import { genkit } from 'genkit'; import { testModels } from 'genkit/testing'; import { ollama } from 'genkitx-ollama'; -import { openAI } from 'genkitx-openai'; +//import { openAI } from 'genkitx-openai'; export const ai = genkit({ plugins: [ @@ -40,11 +40,11 @@ export const ai = genkit({ ], serverAddress: 'http://127.0.0.1:11434', // default local address }), - openAI(), + //openAI(), ], }); -testModels([ +testModels(ai.registry, [ 'googleai/gemini-1.5-pro-latest', 'googleai/gemini-1.5-flash-latest', 'vertexai/gemini-1.5-pro', @@ -52,8 +52,8 @@ testModels([ 'vertexai/claude-3-sonnet', 'vertexai/llama-3.1', 'ollama/gemma2', - 'openai/gpt-4o', - 'openai/gpt-4o-mini', + // 'openai/gpt-4o', + // 'openai/gpt-4o-mini', ]).then((r) => { let failed = false; for (const test of r) { diff --git a/js/testapps/prompt-file/src/index.ts b/js/testapps/prompt-file/src/index.ts index cc32d14ed..9e6a1a966 100644 --- a/js/testapps/prompt-file/src/index.ts +++ b/js/testapps/prompt-file/src/index.ts @@ -15,10 +15,10 @@ */ import { googleAI } from '@genkit-ai/googleai'; -import { defineHelper, dotprompt, genkit, z } from 'genkit'; +import { genkit, z } from 'genkit'; const ai = genkit({ - plugins: [googleAI(), dotprompt()], + plugins: [googleAI()], }); /* @@ -44,7 +44,7 @@ const RecipeSchema = ai.defineSchema( // If it fails, due to the prompt file being invalid, the process will crash, // instead of us getting a more mysterious failure later when the flow runs. -defineHelper('list', (data: any) => { +ai.defineHelper('list', (data: any) => { if (!Array.isArray(data)) { return ''; } @@ -61,9 +61,8 @@ ai.prompt('recipe').then((recipePrompt) => { outputSchema: RecipeSchema, }, async (input) => - ( - await recipePrompt.generate({ input: input }) - ).output()! + (await recipePrompt.generate({ input: input })) + .output! ); }); @@ -76,7 +75,7 @@ ai.prompt('recipe', { variant: 'robot' }).then((recipePrompt) => { }), outputSchema: z.any(), }, - async (input) => (await recipePrompt.generate({ input: input })).output() + async (input) => (await recipePrompt.generate({ input: input })).output ); }); @@ -101,10 +100,10 @@ ai.prompt('story').then((storyPrompt) => { for await (const chunk of stream) { streamingCallback(chunk.content[0]?.text!); } - return (await response).text(); + return (await response).text; } else { const response = await storyPrompt.generate({ input: { subject } }); - return response.text(); + return response.text; } } ); diff --git a/js/testapps/rag/src/genkit.ts b/js/testapps/rag/src/genkit.ts index 50f4caf00..5e2cd4163 100644 --- a/js/testapps/rag/src/genkit.ts +++ b/js/testapps/rag/src/genkit.ts @@ -16,12 +16,11 @@ import { devLocalVectorstore } from '@genkit-ai/dev-local-vectorstore'; import { genkitEval, GenkitMetric } from '@genkit-ai/evaluator'; -import { googleAI } from '@genkit-ai/googleai'; +import { gemini15Flash, googleAI } from '@genkit-ai/googleai'; import { claude3Sonnet, - geminiPro, llama31, - textEmbeddingGecko, + textEmbedding004, vertexAI, } from '@genkit-ai/vertexai'; import { genkit } from 'genkit'; @@ -45,7 +44,7 @@ export const ai = genkit({ plugins: [ googleAI({ apiVersion: ['v1'] }), genkitEval({ - judge: geminiPro, + judge: gemini15Flash, judgeConfig: { safetySettings: [ { @@ -72,7 +71,7 @@ export const ai = genkit({ evaluators: { criteria: ['coherence'], labeledCriteria: ['correctness'], - judge: geminiPro, + judge: gemini15Flash, }, }), vertexAI({ @@ -84,17 +83,17 @@ export const ai = genkit({ pinecone([ { indexId: 'cat-facts', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }, { indexId: 'pdf-chat', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }, ]), chroma([ { collectionName: 'dogfacts_collection', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, createCollectionIfMissing: true, clientParams: async () => { // Replace this with your Cloud Run Instance URL @@ -115,13 +114,13 @@ export const ai = genkit({ devLocalVectorstore([ { indexName: 'dog-facts', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }, { indexName: 'pdfQA', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }, ]), ], - model: geminiPro, + model: gemini15Flash, }); diff --git a/js/testapps/rag/src/pdf_rag.ts b/js/testapps/rag/src/pdf_rag.ts index 4c64b6c2b..018b21ecb 100644 --- a/js/testapps/rag/src/pdf_rag.ts +++ b/js/testapps/rag/src/pdf_rag.ts @@ -18,7 +18,7 @@ import { devLocalIndexerRef, devLocalRetrieverRef, } from '@genkit-ai/dev-local-vectorstore'; -import { geminiPro } from '@genkit-ai/vertexai'; +import { gemini15Flash } from '@genkit-ai/vertexai'; import fs from 'fs'; import { Document, run, z } from 'genkit'; import { chunk } from 'llm-chunk'; @@ -49,11 +49,11 @@ export const pdfQA = ai.defineFlow( .generate({ input: { question: query, - context: docs.map((d) => d.text()), + context: docs.map((d) => d.text), }, streamingCallback, }) - .then((r) => r.text()); + .then((r) => r.text); } ); @@ -117,12 +117,12 @@ export const synthesizeQuestions = ai.defineFlow( const questions: string[] = []; for (let i = 0; i < chunks.length; i++) { const qResponse = await ai.generate({ - model: geminiPro, + model: gemini15Flash, prompt: { text: `Generate one question about the text below: ${chunks[i]}`, }, }); - questions.push(qResponse.text()); + questions.push(qResponse.text); } return questions; } diff --git a/js/testapps/rag/src/prompt.ts b/js/testapps/rag/src/prompt.ts index 7e094cd70..d494de7f7 100644 --- a/js/testapps/rag/src/prompt.ts +++ b/js/testapps/rag/src/prompt.ts @@ -24,10 +24,12 @@ export const augmentedPrompt = ai.definePrompt( { model: gemini15Flash, name: 'augmentedPrompt', - input: z.object({ - context: z.array(z.string()), - question: z.string(), - }), + input: { + schema: z.object({ + context: z.array(z.string()), + question: z.string(), + }), + }, output: { format: 'text', }, diff --git a/js/testapps/rag/src/simple_rag.ts b/js/testapps/rag/src/simple_rag.ts index 74256d69f..7c62cc72c 100644 --- a/js/testapps/rag/src/simple_rag.ts +++ b/js/testapps/rag/src/simple_rag.ts @@ -66,10 +66,10 @@ export const askQuestionsAboutCatsFlow = ai.defineFlow( .generate({ input: { question: query, - context: docs.map((d) => d.text()), + context: docs.map((d) => d.text), }, }) - .then((r) => r.text()); + .then((r) => r.text); } ); @@ -91,10 +91,10 @@ export const askQuestionsAboutDogsFlow = ai.defineFlow( .generate({ input: { question: query, - context: docs.map((d) => d.text()), + context: docs.map((d) => d.text), }, }) - .then((r) => r.text()); + .then((r) => r.text); } ); diff --git a/js/testapps/vertexai-reranker/README.md b/js/testapps/vertexai-reranker/README.md index 2d38a3505..4b7dfeb8d 100644 --- a/js/testapps/vertexai-reranker/README.md +++ b/js/testapps/vertexai-reranker/README.md @@ -82,7 +82,7 @@ const reranker = 'vertexai/reranker'; }); return rerankedDocuments.map((doc) => ({ - text: doc.text(), + text: doc.text, score: doc.metadata.score, })); diff --git a/js/testapps/vertexai-reranker/src/index.ts b/js/testapps/vertexai-reranker/src/index.ts index 5f4768e09..734d7064d 100644 --- a/js/testapps/vertexai-reranker/src/index.ts +++ b/js/testapps/vertexai-reranker/src/index.ts @@ -81,7 +81,7 @@ export const rerankFlow = ai.defineFlow( }); return rerankedDocuments.map((doc) => ({ - text: doc.text(), + text: doc.text, score: doc.metadata.score, })); } diff --git a/js/testapps/vertexai-vector-search-firestore/package.json b/js/testapps/vertexai-vector-search-firestore/package.json index aca1337d0..cf0a1f390 100644 --- a/js/testapps/vertexai-vector-search-firestore/package.json +++ b/js/testapps/vertexai-vector-search-firestore/package.json @@ -22,7 +22,7 @@ "@genkit-ai/vertexai": "workspace:*", "dotenv": "^16.4.5", "express": "^4.21.0", - "firebase-admin": "^12.1.0", + "firebase-admin": ">=12.2", "genkitx-chromadb": "workspace:*", "genkitx-langchain": "workspace:*", "genkitx-pinecone": "workspace:*", diff --git a/package.json b/package.json index 8812a0515..f14d1f6ce 100644 --- a/package.json +++ b/package.json @@ -40,5 +40,5 @@ "ts-node": "^10.9.2", "tsx": "^4.7.1" }, - "packageManager": "pnpm@9.12.0+sha256.a61b67ff6cc97af864564f4442556c22a04f2e5a7714fbee76a1011361d9b726" + "packageManager": "pnpm@9.12.2+sha256.2ef6e547b0b07d841d605240dce4d635677831148cd30f6d564b8f4f928f73d2" } diff --git a/samples/chatbot/server/src/index.ts b/samples/chatbot/server/src/index.ts index bf600afc7..83082de33 100644 --- a/samples/chatbot/server/src/index.ts +++ b/samples/chatbot/server/src/index.ts @@ -88,12 +88,12 @@ export const chatbotFlow = ai.defineStreamingFlow( 'save-history', { conversationId: request.conversationId, - history: mainResp.toHistory(), + history: mainResp.messages, }, async () => { - await historyStore?.save(request.conversationId, mainResp.toHistory()); + await historyStore?.save(request.conversationId, mainResp.messages); } ); - return mainResp.text(); + return mainResp.text; } ); diff --git a/samples/js-angular/server/src/agent.ts b/samples/js-angular/server/src/agent.ts index 4fa9b964f..156f8ba0e 100644 --- a/samples/js-angular/server/src/agent.ts +++ b/samples/js-angular/server/src/agent.ts @@ -96,10 +96,10 @@ export function defineAgent( }); await run( 'save-history', - { conversationId: request.conversationId, history: resp.toHistory() }, + { conversationId: request.conversationId, history: resp.messages }, async () => { request.conversationId - ? await historyStore?.save(request.conversationId, resp.toHistory()) + ? await historyStore?.save(request.conversationId, resp.messages) : undefined; } ); diff --git a/samples/js-angular/server/src/jsonStreaming.ts b/samples/js-angular/server/src/jsonStreaming.ts index dc1fb8ead..66f26beb4 100644 --- a/samples/js-angular/server/src/jsonStreaming.ts +++ b/samples/js-angular/server/src/jsonStreaming.ts @@ -65,7 +65,7 @@ export const streamCharacters = ai.defineFlow( } } - return (await response()).text(); + return (await response()).text; } ); diff --git a/samples/js-coffee-shop/src/index.ts b/samples/js-coffee-shop/src/index.ts index ffb80702b..e083b14a0 100644 --- a/samples/js-coffee-shop/src/index.ts +++ b/samples/js-coffee-shop/src/index.ts @@ -64,8 +64,7 @@ export const simpleGreetingFlow = defineFlow( inputSchema: CustomerNameSchema, outputSchema: z.string(), }, - async (input) => - (await simpleGreetingPrompt.generate({ input: input })).text() + async (input) => (await simpleGreetingPrompt.generate({ input: input })).text ); // Another flow to recommend a drink based on the time of day and a previous order. @@ -109,7 +108,7 @@ export const greetingWithHistoryFlow = defineFlow( outputSchema: z.string(), }, async (input) => - (await greetingWithHistoryPrompt.generate({ input: input })).text() + (await greetingWithHistoryPrompt.generate({ input: input })).text ); // A flow to quickly test all the above flows diff --git a/samples/js-menu/src/02/flows.ts b/samples/js-menu/src/02/flows.ts index ef2bc5761..a54f8fbd3 100644 --- a/samples/js-menu/src/02/flows.ts +++ b/samples/js-menu/src/02/flows.ts @@ -32,7 +32,7 @@ export const s02_menuQuestionFlow = defineFlow( input: { question: input.question }, }) .then((response) => { - return { answer: response.text() }; + return { answer: response.text }; }); } ); diff --git a/samples/js-menu/src/03/flows.ts b/samples/js-menu/src/03/flows.ts index c9addc77d..17c42f56b 100644 --- a/samples/js-menu/src/03/flows.ts +++ b/samples/js-menu/src/03/flows.ts @@ -66,7 +66,7 @@ export const s03_multiTurnChatFlow = defineFlow( }); // Add the exchange to the history store and return it - history = llmResponse.toHistory(); + history = llmResponse.messages; chatHistoryStore.write(input.sessionId, history); return { sessionId: input.sessionId, diff --git a/samples/js-menu/src/04/flows.ts b/samples/js-menu/src/04/flows.ts index fbef53a0c..387b76acd 100644 --- a/samples/js-menu/src/04/flows.ts +++ b/samples/js-menu/src/04/flows.ts @@ -82,6 +82,6 @@ export const s04_ragMenuQuestionFlow = defineFlow( question: input.question, }, }); - return { answer: response.text() }; + return { answer: response.text }; } ); diff --git a/samples/js-menu/src/05/flows.ts b/samples/js-menu/src/05/flows.ts index 1def8ec34..be8c7cb45 100644 --- a/samples/js-menu/src/05/flows.ts +++ b/samples/js-menu/src/05/flows.ts @@ -44,7 +44,7 @@ export const s05_readMenuFlow = defineFlow( imageUrl: imageDataUrl, }, }); - return { menuText: response.text() }; + return { menuText: response.text }; } ); @@ -64,7 +64,7 @@ export const s05_textMenuQuestionFlow = defineFlow( question: input.question, }, }); - return { answer: response.text() }; + return { answer: response.text }; } ); diff --git a/samples/prompts/src/index.ts b/samples/prompts/src/index.ts index eb37918f1..63ff53e50 100644 --- a/samples/prompts/src/index.ts +++ b/samples/prompts/src/index.ts @@ -107,7 +107,7 @@ defineFlow( const response = await threeGreetingsPrompt.generate({ input: { name: 'Fred' }, }); - return response.output()?.likeAPirate; + return response.output?.likeAPirate; } ); diff --git a/scripts/release_main.sh b/scripts/release_main.sh new file mode 100755 index 000000000..b28e12eb2 --- /dev/null +++ b/scripts/release_main.sh @@ -0,0 +1,78 @@ +#!/bin/bash + +# git clone git@github.com:firebase/genkit.git +# cd genkit +# pnpm i +# pnpm build +# pnpm test:all +# Run from root: scripts/release_main.sh + +pnpm login --registry https://wombat-dressing-room.appspot.com + + +CURRENT=`pwd` + +cd genkit-tools/common +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd genkit-tools/cli +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/core +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/ai +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/flow +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/dotprompt +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/chroma +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/dev-local-vectorstore +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/firebase +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/google-cloud +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/googleai +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/ollama +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/pinecone +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/vertexai +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/evaluators +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/langchain +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + diff --git a/scripts/release_next.sh b/scripts/release_next.sh new file mode 100755 index 000000000..3e62e982b --- /dev/null +++ b/scripts/release_next.sh @@ -0,0 +1,83 @@ +#!/bin/bash + +# git clone git@github.com:firebase/genkit.git +# cd genkit +# git checkout next +# pnpm i +# pnpm build +# pnpm test:all + +# Run from root: scripts/release_next.sh + +pnpm login --registry https://wombat-dressing-room.appspot.com + +CURRENT=`pwd` + +cd genkit-tools/cli +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd genkit-tools/common +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd genkit-tools/telemetry-server +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/core +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/ai +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/genkit +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/dotprompt +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/chroma +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/dev-local-vectorstore +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/firebase +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/google-cloud +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/googleai +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/ollama +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/pinecone +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/vertexai +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/evaluators +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/langchain +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + diff --git a/tests/test_js_app/src/index.ts b/tests/test_js_app/src/index.ts index 292d4ed4f..762316a3f 100644 --- a/tests/test_js_app/src/index.ts +++ b/tests/test_js_app/src/index.ts @@ -63,8 +63,8 @@ export const testFlow = ai.defineFlow( }); const want = `{"messages":[{"content":[{"text":"${subject}"}],"role":"user"}],"tools":[],"output":{"format":"text"}}`; - if (response.text() !== want) { - throw new Error(`Expected ${want} but got ${response.text()}`); + if (response.text !== want) { + throw new Error(`Expected ${want} but got ${response.text}`); } return 'Test flow passed';