Skip to content

Commit

Permalink
fix(code-interpreter): prevent agent from misusing IDs (#98)
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Pokorný <[email protected]>
  • Loading branch information
JanPokorny authored Oct 18, 2024
1 parent 241d427 commit b366007
Showing 1 changed file with 13 additions and 52 deletions.
65 changes: 13 additions & 52 deletions src/tools/python/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,19 @@
* limitations under the License.
*/

import {
BaseToolOptions,
BaseToolRunOptions,
Tool,
ToolInput,
ToolInputValidationError,
} from "@/tools/base.js";
import { BaseToolOptions, BaseToolRunOptions, Tool, ToolInput } from "@/tools/base.js";
import { createGrpcTransport } from "@connectrpc/connect-node";
import { PromiseClient, createPromiseClient } from "@connectrpc/connect";
import { CodeInterpreterService } from "bee-proto/code_interpreter/v1/code_interpreter_service_connect";
import { z } from "zod";
import { BaseLLMOutput } from "@/llms/base.js";
import { LLM } from "@/llms/llm.js";
import { PromptTemplate } from "@/template.js";
import { differenceWith, isShallowEqual, isTruthy, map, unique } from "remeda";
import { filter, isIncludedIn, isTruthy, map, pipe, unique, uniqueBy } from "remeda";
import { PythonFile, PythonStorage } from "@/tools/python/storage.js";
import { PythonToolOutput } from "@/tools/python/output.js";
import { ValidationError } from "ajv";
import { ConnectionOptions } from "node:tls";
import { AnySchemaLike } from "@/internals/helpers/schema.js";
import { RunContext } from "@/context.js";
import { hasMinLength } from "@/internals/helpers/array.js";

Expand Down Expand Up @@ -74,34 +67,18 @@ export class PythonTool extends Tool<PythonToolOutput, PythonToolOptions> {
].join(" ");

public readonly storage: PythonStorage;
protected files: PythonFile[] = [];

async inputSchema() {
const files = await this.storage.list();
const fileIds = unique(map(files, ({ id }) => id));

const zodFileId = hasMinLength(files, 2)
? z.union(map(files, (file) => z.literal(file.id).describe(file.filename)))
: hasMinLength(files, 1)
? z.literal(files[0].id).describe(files[0].filename)
: z.undefined();

this.files = await this.storage.list();
const fileNames = unique(map(this.files, ({ filename }) => filename));
return z.object({
language: z.enum(["python", "shell"]).describe("Use shell for ffmpeg, pandoc, yt-dlp"),
code: z.string().describe("full source code file that will be executed"),
...(hasMinLength(fileIds, 1)
...(hasMinLength(fileNames, 1)
? {
inputFiles: z
.array(
z.object({
id: zodFileId,
filename: z
.string()
.min(1)
.describe(
"name under which the file will be available to the Python code in the working directory",
),
}),
)
.array(z.enum(fileNames))
.describe(
"To access an existing file, you must specify it; otherwise, the file will not be accessible. IMPORTANT: If the file is not provided in the input, it will not be accessible.",
),
Expand All @@ -110,27 +87,6 @@ export class PythonTool extends Tool<PythonToolOutput, PythonToolOptions> {
});
}

protected validateInput(
schema: AnySchemaLike,
rawInput: unknown,
): asserts rawInput is ToolInput<this> {
super.validateInput(schema, rawInput);

const fileNames: string[] =
(rawInput.inputFiles as { filename: string }[])
?.map(({ filename }) => filename)
.filter(Boolean) ?? [];
const diff = differenceWith(fileNames, unique(fileNames), isShallowEqual);
if (diff.length > 0) {
throw new ToolInputValidationError(
[
`All 'inputFiles' must have a unique filenames.`,
`Duplicated filenames: ${diff.join(",")}`,
].join("\n"),
);
}
}

protected readonly client: PromiseClient<typeof CodeInterpreterService>;
protected readonly preprocess;

Expand Down Expand Up @@ -170,7 +126,12 @@ export class PythonTool extends Tool<PythonToolOutput, PythonToolOptions> {
_options: BaseToolRunOptions | undefined,
run: RunContext<this>,
) {
const inputFiles = await this.storage.upload((input.inputFiles as PythonFile[]) ?? []);
const inputFiles = await pipe(
this.files ?? (await this.storage.list()),
uniqueBy((f) => f.filename),
filter((file) => isIncludedIn(file.filename, (input.inputFiles ?? []) as string[])),
(files) => this.storage.upload(files),
);

// replace relative paths in "files" with absolute paths by prepending "/workspace"
const getSourceCode = async () => {
Expand Down

0 comments on commit b366007

Please sign in to comment.