From 5a9950b4e17470f2c17d5cf3208f6b52926fc49a Mon Sep 17 00:00:00 2001 From: Aurelien Franky Date: Thu, 23 Nov 2023 20:57:56 +0100 Subject: [PATCH] update checklist node definition --- .../adapters/document_check/document_check.ts | 51 ++++-- .../__snapshots__/document_check.test.ts.snap | 58 ++++--- .../document_check/document_check.test.ts | 159 ++++++++++++++++-- .../adapters/document_check/document_check.ts | 76 ++++++++- .../@pufflig/ps-types/src/types/params.ts | 5 +- 5 files changed, 288 insertions(+), 61 deletions(-) diff --git a/packages/@pufflig/ps-nodes-config/src/adapters/document_check/document_check.ts b/packages/@pufflig/ps-nodes-config/src/adapters/document_check/document_check.ts index 3a5ed39..8c02a1a 100644 --- a/packages/@pufflig/ps-nodes-config/src/adapters/document_check/document_check.ts +++ b/packages/@pufflig/ps-nodes-config/src/adapters/document_check/document_check.ts @@ -4,7 +4,7 @@ import { NodeConfig } from "@pufflig/ps-types"; export const documentCheckNodeType = "modifier/document_check" as const; export const documentCheck: NodeConfig = { - name: "Document Checklist", + name: "Checklist", description: "Run a checklist on a document.", tags: ["adapter", "document", "text"], status: "stable", @@ -23,9 +23,9 @@ export const documentCheck: NodeConfig = { }, outputs: [ { - id: "result", - name: "Result", - description: "A list, checklist or other information about the document", + id: "checklist", + name: "Checklist", + description: "A checklist of items to run on the document", type: "text", defaultValue: "", }, @@ -43,25 +43,44 @@ export const documentCheck: NodeConfig = { }, }, { - id: "prompt", - name: "Prompt", - description: "Prompt to check the document with", + id: "instructions", + name: "Instructions", + description: "Instructions for the AI", type: "text", - defaultValue: `Extract information in the document below and insert them in the csv table, don't overwrite existing values and keep things empty if you cannot find information in the document:\n\nTABLE EXAMPLE:\ncharacters, age\nmickey mouse, 10\ndonald duck, -\n\nTABLE:\n{{table}}\n\nDOCUMENT:\n{{document}}\n\nTABLE:\n`, - }, - { - id: "table", - name: "Table", - description: "The list, table or checklist to parse the document with.", - type: "text", - defaultValue: "", + defaultValue: `Run the checklist below on the document.`, }, { id: "document", name: "Document", - description: "Document to be processed", + description: "Document to be checked", type: "text", defaultValue: "", }, + { + id: "checklist", + name: "Checklist", + description: "The checklist to run on the document", + type: "object", + editableSchema: true, + defaultValue: [], + }, + { + id: "format", + name: "Format", + description: "The format in which to return the cheklist results", + type: "selection", + defaultValue: "markdown", + options: [ + { id: "csv", name: "CSV" }, + { id: "markdown", name: "Markdown" }, + ], + }, + { + id: "fields", + name: "Fields", + description: "Custom fields to include in the output for each checklist item", + type: "list", + defaultValue: ["ok"], + }, ], }; diff --git a/packages/@pufflig/ps-nodes/src/adapters/document_check/__snapshots__/document_check.test.ts.snap b/packages/@pufflig/ps-nodes/src/adapters/document_check/__snapshots__/document_check.test.ts.snap index 29f261b..e9519f5 100644 --- a/packages/@pufflig/ps-nodes/src/adapters/document_check/__snapshots__/document_check.test.ts.snap +++ b/packages/@pufflig/ps-nodes/src/adapters/document_check/__snapshots__/document_check.test.ts.snap @@ -4,7 +4,7 @@ exports[`documentCheck should extract variables correctly 1`] = ` [ { "defaultValue": { - "modelId": "test_model", + "modelId": "gpt-3.5-turbo-instruct", "parameters": {}, }, "definition": { @@ -213,32 +213,52 @@ exports[`documentCheck should extract variables correctly 1`] = ` "type": "model", }, { - "defaultValue": "Hello, {{myVariable}}!", - "description": "Prompt to check the document with", - "id": "prompt", - "name": "Prompt", + "defaultValue": "Run the checklist below on the document.", + "description": "Instructions for the AI", + "id": "instructions", + "name": "Instructions", "type": "text", }, { - "defaultValue": "test_table", - "description": "The list, table or checklist to parse the document with.", - "id": "table", - "name": "Table", - "type": "text", - }, - { - "defaultValue": "This is a test document.", - "description": "Document to be processed", + "defaultValue": "", + "description": "Document to be checked", "id": "document", "name": "Document", "type": "text", }, { - "defaultValue": "myValue", - "description": "", - "id": "myVariable", - "name": "myVariable", - "type": "text", + "defaultValue": [], + "description": "The checklist to run on the document", + "editableSchema": true, + "id": "checklist", + "name": "Checklist", + "type": "object", + }, + { + "defaultValue": "markdown", + "description": "The format in which to return the cheklist results", + "id": "format", + "name": "Format", + "options": [ + { + "id": "csv", + "name": "CSV", + }, + { + "id": "markdown", + "name": "Markdown", + }, + ], + "type": "selection", + }, + { + "defaultValue": [ + "ok", + ], + "description": "Custom fields to include in the output for each checklist item", + "id": "fields", + "name": "Fields", + "type": "list", }, ] `; diff --git a/packages/@pufflig/ps-nodes/src/adapters/document_check/document_check.test.ts b/packages/@pufflig/ps-nodes/src/adapters/document_check/document_check.test.ts index aeb34f2..8378d81 100644 --- a/packages/@pufflig/ps-nodes/src/adapters/document_check/document_check.test.ts +++ b/packages/@pufflig/ps-nodes/src/adapters/document_check/document_check.test.ts @@ -1,5 +1,5 @@ -import { execute, getInputDefinition, LLMCompletionInput } from "./document_check"; import axios from "axios"; +import { execute, getInputDefinition, LLMCompletionInput } from "./document_check"; jest.mock("axios"); @@ -8,15 +8,58 @@ describe("documentCheck", () => { jest.resetAllMocks(); }); - it("should return the completion string", async () => { + it("should return the resulting checklist", async () => { const input: LLMCompletionInput = { - prompt: "Hello, world!", + instructions: "Hello, world!", model: { modelId: "test_model", parameters: {}, }, document: "This is a test document.", - table: "test_table", + checklist: [ + { id: "is_greeting", defaultValue: "test_table", type: "text", name: "is_greeting", description: "" }, + ], + fields: ["ok"], + format: "csv", + }; + + const expectedOutput = { result: "mock checklist" }; + const mockedAxiosResponse = { data: expectedOutput }; + (axios.post as jest.MockedFunction).mockResolvedValueOnce(mockedAxiosResponse); + + const output = await execute(input); + + expect(output).toEqual({ checklist: expectedOutput.result }); + expect(axios.post).toHaveBeenCalledTimes(1); + }); + + it("should parse input variables", async () => { + const input: LLMCompletionInput = { + instructions: "Hello, {{world}}! Run a checklist on the following document: {{document}}", + model: { + modelId: "test_model", + parameters: {}, + }, + document: "This is a test document.", + checklist: [ + { + id: "is_greeting", + defaultValue: "is the text a greeting?", + type: "text", + name: "is_greeting", + description: "", + }, + { + id: "is_formal", + defaultValue: "is the greeting formal?", + type: "text", + name: "is_formal", + description: "", + }, + ], + fields: ["ok"], + format: "csv", + world: "test", }; const expectedOutput = { result: "This is a test completion." }; @@ -25,20 +68,70 @@ describe("documentCheck", () => { const output = await execute(input); - expect(output).toEqual(expectedOutput); + expect(output).toEqual({ checklist: expectedOutput.result }); expect(axios.post).toHaveBeenCalledTimes(1); + expect(axios.post).toHaveBeenCalledWith( + expect.any(String), + { + document: "This is a test document.", + format: `check,ok +is_greeting, +is_formal,`, + modelId: "test_model", + options: { + cache: true, + track: true, + }, + parameters: {}, + prompt: `Hello, test! Run a checklist on the following document: {{document}} + +CHECKLIST DESCRIPTION: +check,description +is_greeting, is the text a greeting? +is_formal, is the greeting formal? + +CHECKLIST FORMAT: +{{table}} + +CHECKLIST IN CSV FORMAT: +`, + }, + { + headers: { + Authorization: "Bearer undefined", + "Content-Type": "application/json", + }, + } + ); }); it("should parse input variables", async () => { const input: LLMCompletionInput = { - prompt: "Hello, {{myVariable}}!", + instructions: "Hello, {{world}}! Run a checklist on the following document: {{document}}", model: { modelId: "test_model", parameters: {}, }, document: "This is a test document.", - table: "test_table", - myVariable: "myValue", + checklist: [ + { + id: "is_greeting", + defaultValue: "is the text a greeting?", + type: "text", + name: "is_greeting", + description: "", + }, + { + id: "is_formal", + defaultValue: "is the greeting formal?", + type: "text", + name: "is_formal", + description: "", + }, + ], + fields: ["ok"], + format: "markdown", + world: "test", }; const expectedOutput = { result: "This is a test completion." }; @@ -47,20 +140,33 @@ describe("documentCheck", () => { const output = await execute(input); - expect(output).toEqual(expectedOutput); + expect(output).toEqual({ checklist: expectedOutput.result }); expect(axios.post).toHaveBeenCalledTimes(1); expect(axios.post).toHaveBeenCalledWith( expect.any(String), { document: "This is a test document.", - format: "test_table", + format: `|check|ok| +|is_greeting|| +|is_formal||`, modelId: "test_model", options: { cache: true, track: true, }, parameters: {}, - prompt: "Hello, myValue!", + prompt: `Hello, test! Run a checklist on the following document: {{document}} + +CHECKLIST DESCRIPTION: +|check|description| +|is_greeting|is the text a greeting?| +|is_formal|is the greeting formal?| + +CHECKLIST FORMAT: +{{table}} + +CHECKLIST IN MARKDOWN FORMAT: +`, }, { headers: { @@ -73,14 +179,31 @@ describe("documentCheck", () => { it("should extract variables correctly", async () => { const output = getInputDefinition({ - prompt: "Hello, {{myVariable}}!", + instructions: "Hello, {{world}}!", model: { modelId: "test_model", parameters: {}, }, document: "This is a test document.", - table: "test_table", - myVariable: "myValue", + checklist: [ + { + id: "is_greeting", + defaultValue: "is the text a greeting?", + type: "text", + name: "is_greeting", + description: "", + }, + { + id: "is_formal", + defaultValue: "is the greeting formal?", + type: "text", + name: "is_formal", + description: "", + }, + ], + fields: ["ok"], + format: "csv", + world: "test", }); expect(output).toMatchSnapshot(); @@ -88,13 +211,17 @@ describe("documentCheck", () => { it("should throw an error if the API call fails", async () => { const input: LLMCompletionInput = { - prompt: "Hello, world!", + instructions: "Hello, world!", model: { modelId: "test_model", parameters: {}, }, document: "This is a test document.", - table: "test_table", + checklist: [ + { id: "is_greeting", defaultValue: "test_table", type: "text", name: "is_greeting", description: "" }, + ], + fields: ["ok"], + format: "csv", }; const expectedError = new Error("API call failed."); diff --git a/packages/@pufflig/ps-nodes/src/adapters/document_check/document_check.ts b/packages/@pufflig/ps-nodes/src/adapters/document_check/document_check.ts index 63b6c7e..8fc1041 100644 --- a/packages/@pufflig/ps-nodes/src/adapters/document_check/document_check.ts +++ b/packages/@pufflig/ps-nodes/src/adapters/document_check/document_check.ts @@ -1,36 +1,96 @@ import { nodeTypes, nodes } from "@pufflig/ps-nodes-config"; import { refineCompletion } from "@pufflig/ps-sdk"; -import { Execute, GetInputDefinition, ModelValue, Node, Param } from "@pufflig/ps-types"; +import { Execute, GetInputDefinition, ModelValue, Node, ObjectDefinition, Param } from "@pufflig/ps-types"; import { getPromptStudioKey } from "../../utils/getPromptStudioKey"; import { extractVariables } from "../../utils/extractVariables"; import Mustache from "mustache"; export interface LLMCompletionInput { - prompt: string; + instructions: string; model: ModelValue; document: string; - table: string; + checklist: ObjectDefinition; + format: string; + fields: string[]; [key: string]: any; } export interface LLMCompletionOutput { - result: string; + checklist: string; } +const makeCSVChecklist = (checklist: ObjectDefinition, fields: string[]) => { + const header = "check," + fields.join(","); + const rows = checklist.map((row) => { + return row.id + "," + fields.map(() => "").join(","); + }); + + return `${header}\n${rows.join("\n")}`; +}; + +const makeMarkdownChecklist = (checklist: ObjectDefinition, fields: string[]) => { + const header = "|check|" + fields.join("|") + "|"; + const rows = checklist.map((row) => { + return "|" + row.id + "|" + fields.map(() => "").join("|") + "|"; + }); + + return `${header}\n${rows.join("\n")}`; +}; + +const makeCSVDescription = (checklist: ObjectDefinition) => { + const header = "check,description"; + const rows = checklist.map((item) => { + return `${item.id}, ${(item.defaultValue as string).replace(/,/g, "")}`; + }); + return `${header}\n${rows.join("\n")}`; +}; + +const makeMarkdownDescription = (checklist: ObjectDefinition) => { + const header = "|check|description|"; + const rows = checklist.map((item) => { + return `|${item.id}|${(item.defaultValue as string).replace(/,/g, "")}|`; + }); + return `${header}\n${rows.join("\n")}`; +}; + export const execute: Execute = async (input, options = {}) => { - const { prompt, model, document, table, ...variables } = input; + const { instructions, model, document, checklist, fields, format, ...variables } = input; const { modelId, parameters } = model; const { globals } = options; + const isCSV = format === "csv"; + + // checklist format + const checkListFormat = isCSV ? makeCSVChecklist(checklist, fields) : makeMarkdownChecklist(checklist, fields); + + // checklist description + const description = isCSV ? makeCSVDescription(checklist) : makeMarkdownDescription(checklist); + + // TODO: move into the API + const instructionsWithChecklist = `${instructions} + +CHECKLIST DESCRIPTION: +${description} + +CHECKLIST FORMAT: +{{table}} + +CHECKLIST IN ${format.toUpperCase()} FORMAT: +`; + // render the prompt without overwriting the document and table variables - const renderedPrompt = Mustache.render(prompt, { ...variables, document: "{{document}}", table: "{{table}}" }); + const renderedPrompt = Mustache.render(instructionsWithChecklist, { + ...variables, + document: "{{document}}", + table: "{{table}}", + }); const { result } = await refineCompletion({ apiKey: getPromptStudioKey(globals || {}), modelId, prompt: renderedPrompt, document: document, - format: table, + format: checkListFormat, parameters, config: globals, options: { @@ -40,7 +100,7 @@ export const execute: Execute = async ( }); return { - result: result || "", + checklist: result || "", }; }; diff --git a/packages/@pufflig/ps-types/src/types/params.ts b/packages/@pufflig/ps-types/src/types/params.ts index ae5c845..d687a5d 100644 --- a/packages/@pufflig/ps-types/src/types/params.ts +++ b/packages/@pufflig/ps-types/src/types/params.ts @@ -86,7 +86,8 @@ export interface ObjectParam extends BaseParam { export interface ListParam extends BaseParam { type: "list"; - defaultValue: []; + files?: boolean; + defaultValue: string[]; } export interface VectorParam extends BaseParam { @@ -102,7 +103,7 @@ export type ParamValue = | Chat | ChatMessage | null - | (NumberParam | TextParam)[] + | ObjectDefinition | Array | Array;