From 3556474116db2fc7dbfbb34bfc351490360f4d85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Cruz?= Date: Thu, 21 Sep 2023 16:50:03 +0100 Subject: [PATCH] wrangler: Add AI binding (#3992) * wrangler: Add AI binding Added binding for the AI project. * Workers AI: added example --- .changeset/brave-stingrays-shout.md | 36 +++++++++++ .../src/__tests__/configuration.test.ts | 59 +++++++++++++++++++ .../wrangler/src/__tests__/deploy.test.ts | 37 ++++++++++++ .../pages/create-worker-bundle-contents.ts | 1 + packages/wrangler/src/config/environment.ts | 9 +++ packages/wrangler/src/config/index.ts | 8 +++ .../wrangler/src/config/validation-helpers.ts | 6 ++ packages/wrangler/src/config/validation.ts | 37 ++++++++++++ packages/wrangler/src/deploy/deploy.ts | 1 + .../create-worker-upload-form.ts | 8 +++ .../wrangler/src/deployment-bundle/worker.ts | 9 +++ packages/wrangler/src/dev.tsx | 1 + packages/wrangler/src/init.ts | 7 +++ packages/wrangler/src/secret/index.ts | 1 + 14 files changed, 220 insertions(+) create mode 100644 .changeset/brave-stingrays-shout.md diff --git a/.changeset/brave-stingrays-shout.md b/.changeset/brave-stingrays-shout.md new file mode 100644 index 000000000000..56864f118e44 --- /dev/null +++ b/.changeset/brave-stingrays-shout.md @@ -0,0 +1,36 @@ +--- +"wrangler": patch +--- + +Add AI binding that will be used to interact with the AI project. + +Example `wrangler.toml` + + name = "ai-worker" + main = "src/index.ts" + + [ai] + binding = "AI" + +Example script: + + import Ai from "@cloudflare/ai" + + export default { + async fetch(request: Request, env: Env): Promise { + const ai = new Ai(env.AI); + + const story = await ai.run({ + model: 'llama-2', + input: { + prompt: 'Tell me a story about the future of the Cloudflare dev platform' + } + }); + + return new Response(JSON.stringify(story)); + }, + }; + + export interface Env { + AI: any; + } diff --git a/packages/wrangler/src/__tests__/configuration.test.ts b/packages/wrangler/src/__tests__/configuration.test.ts index 0584e0b34ac9..8b9c29115d66 100644 --- a/packages/wrangler/src/__tests__/configuration.test.ts +++ b/packages/wrangler/src/__tests__/configuration.test.ts @@ -64,6 +64,7 @@ describe("normalizeAndValidateConfig()", () => { site: undefined, text_blobs: undefined, browser: undefined, + ai: undefined, triggers: { crons: [], }, @@ -1620,6 +1621,64 @@ describe("normalizeAndValidateConfig()", () => { }); }); + describe("[ai]", () => { + it("should error if ai is an array", () => { + const { diagnostics } = normalizeAndValidateConfig( + { ai: [] } as unknown as RawConfig, + undefined, + { env: undefined } + ); + + expect(diagnostics.hasWarnings()).toBe(false); + expect(diagnostics.renderErrors()).toMatchInlineSnapshot(` + "Processing wrangler configuration: + - The field \\"ai\\" should be an object but got []." + `); + }); + + it("should error if ai is a string", () => { + const { diagnostics } = normalizeAndValidateConfig( + { ai: "BAD" } as unknown as RawConfig, + undefined, + { env: undefined } + ); + + expect(diagnostics.hasWarnings()).toBe(false); + expect(diagnostics.renderErrors()).toMatchInlineSnapshot(` + "Processing wrangler configuration: + - The field \\"ai\\" should be an object but got \\"BAD\\"." + `); + }); + + it("should error if ai is a number", () => { + const { diagnostics } = normalizeAndValidateConfig( + { ai: 999 } as unknown as RawConfig, + undefined, + { env: undefined } + ); + + expect(diagnostics.hasWarnings()).toBe(false); + expect(diagnostics.renderErrors()).toMatchInlineSnapshot(` + "Processing wrangler configuration: + - The field \\"ai\\" should be an object but got 999." + `); + }); + + it("should error if ai is null", () => { + const { diagnostics } = normalizeAndValidateConfig( + { ai: null } as unknown as RawConfig, + undefined, + { env: undefined } + ); + + expect(diagnostics.hasWarnings()).toBe(false); + expect(diagnostics.renderErrors()).toMatchInlineSnapshot(` + "Processing wrangler configuration: + - The field \\"ai\\" should be an object but got null." + `); + }); + }); + describe("[kv_namespaces]", () => { it("should error if kv_namespaces is an object", () => { const { diagnostics } = normalizeAndValidateConfig( diff --git a/packages/wrangler/src/__tests__/deploy.test.ts b/packages/wrangler/src/__tests__/deploy.test.ts index ad8dd9cb6b45..23710bc55f43 100644 --- a/packages/wrangler/src/__tests__/deploy.test.ts +++ b/packages/wrangler/src/__tests__/deploy.test.ts @@ -8367,6 +8367,43 @@ export default{ }); }); + describe("ai", () => { + it("should upload ai bindings", async () => { + writeWranglerToml({ + ai: { binding: "AI_BIND" }, + browser: { binding: "MYBROWSER" }, + }); + await fs.promises.writeFile("index.js", `export default {};`); + mockSubDomainRequest(); + mockUploadWorkerRequest({ + expectedBindings: [ + { + type: "browser", + name: "MYBROWSER", + }, + { + type: "ai", + name: "AI_BIND", + }, + ], + }); + + await runWrangler("deploy index.js"); + expect(std.out).toMatchInlineSnapshot(` + "Total Upload: xx KiB / gzip: xx KiB + Your worker has access to the following bindings: + - Browser: + - Name: MYBROWSER + - AI: + - Name: AI_BIND + Uploaded test-name (TIMINGS) + Published test-name (TIMINGS) + https://test-name.test-sub-domain.workers.dev + Current Deployment ID: Galaxy-Class" + `); + }); + }); + describe("mtls_certificates", () => { it("should upload mtls_certificate bindings", async () => { writeWranglerToml({ diff --git a/packages/wrangler/src/api/pages/create-worker-bundle-contents.ts b/packages/wrangler/src/api/pages/create-worker-bundle-contents.ts index abfd8c076bff..e343aa2ef6fa 100644 --- a/packages/wrangler/src/api/pages/create-worker-bundle-contents.ts +++ b/packages/wrangler/src/api/pages/create-worker-bundle-contents.ts @@ -54,6 +54,7 @@ function createWorkerBundleFormData(workerBundle: BundleResult): FormData { wasm_modules: undefined, text_blobs: undefined, browser: undefined, + ai: undefined, data_blobs: undefined, durable_objects: undefined, queues: undefined, diff --git a/packages/wrangler/src/config/environment.ts b/packages/wrangler/src/config/environment.ts index 6e37bd594bb6..4fbc007d8602 100644 --- a/packages/wrangler/src/config/environment.ts +++ b/packages/wrangler/src/config/environment.ts @@ -519,6 +519,15 @@ interface EnvironmentNonInheritable { } | undefined; + /** + * Binding to the AI project. + */ + ai: + | { + binding: string; + } + | undefined; + /** * "Unsafe" tables for features that aren't directly supported by wrangler. * diff --git a/packages/wrangler/src/config/index.ts b/packages/wrangler/src/config/index.ts index c6e54ace34c1..df1af96ed3bc 100644 --- a/packages/wrangler/src/config/index.ts +++ b/packages/wrangler/src/config/index.ts @@ -106,6 +106,7 @@ export function printBindings(bindings: CfWorkerInit["bindings"]) { analytics_engine_datasets, text_blobs, browser, + ai, unsafe, vars, wasm_modules, @@ -296,6 +297,13 @@ export function printBindings(bindings: CfWorkerInit["bindings"]) { }); } + if (ai !== undefined) { + output.push({ + type: "AI", + entries: [{ key: "Name", value: ai.binding }], + }); + } + if (unsafe?.bindings !== undefined && unsafe.bindings.length > 0) { output.push({ type: "Unsafe", diff --git a/packages/wrangler/src/config/validation-helpers.ts b/packages/wrangler/src/config/validation-helpers.ts index 265079c89a2b..a2b72c7be0ca 100644 --- a/packages/wrangler/src/config/validation-helpers.ts +++ b/packages/wrangler/src/config/validation-helpers.ts @@ -571,6 +571,12 @@ export const getBindingNames = (value: unknown): string[] => { } else if (isNamespaceList(value)) { return value.map(({ binding }) => binding); } else if (isRecord(value)) { + // browser and AI bindings are single values with a similar shape + // { binding = "name" } + if (value["binding"] !== undefined) { + return [value["binding"] as string]; + } + return Object.keys(value).filter((k) => value[k] !== undefined); } else { return []; diff --git a/packages/wrangler/src/config/validation.ts b/packages/wrangler/src/config/validation.ts index 4183bc48a4c1..7f5b1bbcba2b 100644 --- a/packages/wrangler/src/config/validation.ts +++ b/packages/wrangler/src/config/validation.ts @@ -1314,6 +1314,16 @@ function normalizeAndValidateEnvironment( validateBrowserBinding(envName), undefined ), + ai: notInheritable( + diagnostics, + topLevelEnv, + rawConfig, + rawEnv, + envName, + "ai", + validateAIBinding(envName), + undefined + ), zone_id: rawEnv.zone_id, logfwdr: inheritable( diagnostics, @@ -1893,6 +1903,30 @@ const validateBrowserBinding = return isValid; }; +const validateAIBinding = + (envName: string): ValidatorFn => + (diagnostics, field, value, config) => { + const fieldPath = + config === undefined ? `${field}` : `env.${envName}.${field}`; + + if (typeof value !== "object" || value === null || Array.isArray(value)) { + diagnostics.errors.push( + `The field "${fieldPath}" should be an object but got ${JSON.stringify( + value + )}.` + ); + return false; + } + + let isValid = true; + if (!isRequiredProperty(value, "binding", "string")) { + diagnostics.errors.push(`binding should have a string "binding" field.`); + isValid = false; + } + + return isValid; + }; + /** * Check that the given field is a valid "unsafe" binding object. * @@ -1920,6 +1954,7 @@ const validateUnsafeBinding: ValidatorFn = (diagnostics, field, value) => { "data_blob", "text_blob", "browser", + "ai", "kv_namespace", "durable_object_namespace", "d1_database", @@ -2278,6 +2313,7 @@ const validateBindingsHaveUniqueNames = ( analytics_engine_datasets, text_blobs, browser, + ai, unsafe, vars, define, @@ -2294,6 +2330,7 @@ const validateBindingsHaveUniqueNames = ( "Analytics Engine Dataset": getBindingNames(analytics_engine_datasets), "Text Blob": getBindingNames(text_blobs), Browser: getBindingNames(browser), + AI: getBindingNames(ai), Unsafe: getBindingNames(unsafe), "Environment Variable": getBindingNames(vars), Definition: getBindingNames(define), diff --git a/packages/wrangler/src/deploy/deploy.ts b/packages/wrangler/src/deploy/deploy.ts index 8b8ab271ebd7..3e3518c3a1e7 100644 --- a/packages/wrangler/src/deploy/deploy.ts +++ b/packages/wrangler/src/deploy/deploy.ts @@ -529,6 +529,7 @@ See https://developers.cloudflare.com/workers/platform/compatibility-dates for m vars: { ...config.vars, ...props.vars }, wasm_modules: config.wasm_modules, browser: config.browser, + ai: config.ai, text_blobs: { ...config.text_blobs, ...(assets.manifest && diff --git a/packages/wrangler/src/deployment-bundle/create-worker-upload-form.ts b/packages/wrangler/src/deployment-bundle/create-worker-upload-form.ts index 48ff02e2ad78..1034eee8e9df 100644 --- a/packages/wrangler/src/deployment-bundle/create-worker-upload-form.ts +++ b/packages/wrangler/src/deployment-bundle/create-worker-upload-form.ts @@ -35,6 +35,7 @@ export type WorkerMetadataBinding = | { type: "wasm_module"; name: string; part: string } | { type: "text_blob"; name: string; part: string } | { type: "browser"; name: string } + | { type: "ai"; name: string } | { type: "data_blob"; name: string; part: string } | { type: "kv_namespace"; name: string; namespace_id: string } | { @@ -268,6 +269,13 @@ export function createWorkerUploadForm(worker: CfWorkerInit): FormData { }); } + if (bindings.ai !== undefined) { + metadataBindings.push({ + name: bindings.ai.binding, + type: "ai", + }); + } + for (const [name, filePath] of Object.entries(bindings.text_blobs || {})) { metadataBindings.push({ name, diff --git a/packages/wrangler/src/deployment-bundle/worker.ts b/packages/wrangler/src/deployment-bundle/worker.ts index febcf82ce06c..bef3ca376166 100644 --- a/packages/wrangler/src/deployment-bundle/worker.ts +++ b/packages/wrangler/src/deployment-bundle/worker.ts @@ -103,6 +103,14 @@ export interface CfBrowserBinding { binding: string; } +/** + * A binding to the AI project + */ + +export interface CfAIBinding { + binding: string; +} + /** * A binding to a data blob (in service-worker format) */ @@ -256,6 +264,7 @@ export interface CfWorkerInit { wasm_modules: CfWasmModuleBindings | undefined; text_blobs: CfTextBlobBindings | undefined; browser: CfBrowserBinding | undefined; + ai: CfAIBinding | undefined; data_blobs: CfDataBlobBindings | undefined; durable_objects: { bindings: CfDurableObject[] } | undefined; queues: CfQueue[] | undefined; diff --git a/packages/wrangler/src/dev.tsx b/packages/wrangler/src/dev.tsx index b15faf28a951..24c14660a172 100644 --- a/packages/wrangler/src/dev.tsx +++ b/packages/wrangler/src/dev.tsx @@ -876,6 +876,7 @@ function getBindings( wasm_modules: configParam.wasm_modules, text_blobs: configParam.text_blobs, browser: configParam.browser, + ai: configParam.ai, data_blobs: configParam.data_blobs, durable_objects: { bindings: [ diff --git a/packages/wrangler/src/init.ts b/packages/wrangler/src/init.ts index ef06cb2d2f6f..c7d17306f16b 100644 --- a/packages/wrangler/src/init.ts +++ b/packages/wrangler/src/init.ts @@ -1046,6 +1046,13 @@ export function mapBindings(bindings: WorkerMetadataBinding[]): RawConfig { }; } break; + case "ai": + { + configObj.ai = { + binding: binding.name, + }; + } + break; case "r2_bucket": { configObj.r2_buckets = [ diff --git a/packages/wrangler/src/secret/index.ts b/packages/wrangler/src/secret/index.ts index e5458636fd24..84ca30d59f2d 100644 --- a/packages/wrangler/src/secret/index.ts +++ b/packages/wrangler/src/secret/index.ts @@ -113,6 +113,7 @@ export const secret = (secretYargs: CommonYargsArgv) => { analytics_engine_datasets: [], wasm_modules: {}, browser: undefined, + ai: undefined, text_blobs: {}, data_blobs: {}, dispatch_namespaces: [],