From b3660071fd70696bcea21f6a10a9a53cfcb24b8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Pokorn=C3=BD?= Date: Fri, 18 Oct 2024 15:32:08 +0200 Subject: [PATCH] fix(code-interpreter): prevent agent from misusing IDs (#98) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jan Pokorný --- src/tools/python/python.ts | 65 ++++++++------------------------------ 1 file changed, 13 insertions(+), 52 deletions(-) diff --git a/src/tools/python/python.ts b/src/tools/python/python.ts index 4c2c8131..9ffed22f 100644 --- a/src/tools/python/python.ts +++ b/src/tools/python/python.ts @@ -14,13 +14,7 @@ * limitations under the License. */ -import { - BaseToolOptions, - BaseToolRunOptions, - Tool, - ToolInput, - ToolInputValidationError, -} from "@/tools/base.js"; +import { BaseToolOptions, BaseToolRunOptions, Tool, ToolInput } from "@/tools/base.js"; import { createGrpcTransport } from "@connectrpc/connect-node"; import { PromiseClient, createPromiseClient } from "@connectrpc/connect"; import { CodeInterpreterService } from "bee-proto/code_interpreter/v1/code_interpreter_service_connect"; @@ -28,12 +22,11 @@ import { z } from "zod"; import { BaseLLMOutput } from "@/llms/base.js"; import { LLM } from "@/llms/llm.js"; import { PromptTemplate } from "@/template.js"; -import { differenceWith, isShallowEqual, isTruthy, map, unique } from "remeda"; +import { filter, isIncludedIn, isTruthy, map, pipe, unique, uniqueBy } from "remeda"; import { PythonFile, PythonStorage } from "@/tools/python/storage.js"; import { PythonToolOutput } from "@/tools/python/output.js"; import { ValidationError } from "ajv"; import { ConnectionOptions } from "node:tls"; -import { AnySchemaLike } from "@/internals/helpers/schema.js"; import { RunContext } from "@/context.js"; import { hasMinLength } from "@/internals/helpers/array.js"; @@ -74,34 +67,18 @@ export class PythonTool extends Tool { ].join(" "); public readonly storage: PythonStorage; + protected files: PythonFile[] = []; async inputSchema() { - const files = await this.storage.list(); - const fileIds = unique(map(files, ({ id }) => id)); - - const zodFileId = hasMinLength(files, 2) - ? z.union(map(files, (file) => z.literal(file.id).describe(file.filename))) - : hasMinLength(files, 1) - ? z.literal(files[0].id).describe(files[0].filename) - : z.undefined(); - + this.files = await this.storage.list(); + const fileNames = unique(map(this.files, ({ filename }) => filename)); return z.object({ language: z.enum(["python", "shell"]).describe("Use shell for ffmpeg, pandoc, yt-dlp"), code: z.string().describe("full source code file that will be executed"), - ...(hasMinLength(fileIds, 1) + ...(hasMinLength(fileNames, 1) ? { inputFiles: z - .array( - z.object({ - id: zodFileId, - filename: z - .string() - .min(1) - .describe( - "name under which the file will be available to the Python code in the working directory", - ), - }), - ) + .array(z.enum(fileNames)) .describe( "To access an existing file, you must specify it; otherwise, the file will not be accessible. IMPORTANT: If the file is not provided in the input, it will not be accessible.", ), @@ -110,27 +87,6 @@ export class PythonTool extends Tool { }); } - protected validateInput( - schema: AnySchemaLike, - rawInput: unknown, - ): asserts rawInput is ToolInput { - super.validateInput(schema, rawInput); - - const fileNames: string[] = - (rawInput.inputFiles as { filename: string }[]) - ?.map(({ filename }) => filename) - .filter(Boolean) ?? []; - const diff = differenceWith(fileNames, unique(fileNames), isShallowEqual); - if (diff.length > 0) { - throw new ToolInputValidationError( - [ - `All 'inputFiles' must have a unique filenames.`, - `Duplicated filenames: ${diff.join(",")}`, - ].join("\n"), - ); - } - } - protected readonly client: PromiseClient; protected readonly preprocess; @@ -170,7 +126,12 @@ export class PythonTool extends Tool { _options: BaseToolRunOptions | undefined, run: RunContext, ) { - const inputFiles = await this.storage.upload((input.inputFiles as PythonFile[]) ?? []); + const inputFiles = await pipe( + this.files ?? (await this.storage.list()), + uniqueBy((f) => f.filename), + filter((file) => isIncludedIn(file.filename, (input.inputFiles ?? []) as string[])), + (files) => this.storage.upload(files), + ); // replace relative paths in "files" with absolute paths by prepending "/workspace" const getSourceCode = async () => {