Skip to content

Commit

Permalink
Add automatic JSON enforcement to jsonToolCallPrompt.
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Jan 28, 2024
1 parent fa2bb67 commit a000cc3
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 55 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import { jsonToolCallPrompt, llamacpp, runTool } from "modelfusion";
import { calculator } from "../../tool/tools/calculator-tool";

async function main() {
const { tool, toolCall, args, ok, result } = await runTool({
model: llamacpp
.CompletionTextGenerator({
// run https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF with llama.cpp
promptTemplate: llamacpp.prompt.ChatML,
temperature: 0,
})
.withInstructionPrompt()
.asToolCallGenerationModel(jsonToolCallPrompt.text()),

tool: calculator,
prompt: "What's fourteen times twelve?",

logging: "detailed-object",
});

console.log(`Tool call:`, toolCall);
console.log(`Tool:`, tool);
console.log(`Arguments:`, args);
console.log(`Ok:`, ok);
console.log(`Result or Error:`, result);
}

main().catch(console.error);
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {

const DEFAULT_SCHEMA_PREFIX = "JSON schema:";
const DEFAULT_SCHEMA_SUFFIX =
"\nYou MUST answer with a JSON object matches the above schema.";
"\nYou MUST answer with a JSON object that matches the JSON schema above.";

export const jsonObjectPrompt = {
custom<SOURCE_PROMPT, TARGET_PROMPT>(
Expand Down Expand Up @@ -85,6 +85,7 @@ function createSystemPrompt({
}) {
return [
originalSystemPrompt,
originalSystemPrompt != null ? "" : null,
schemaPrefix,
JSON.stringify(schema.getJsonSchema()),
schemaSuffix,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import { FunctionCallOptions } from "../../core/FunctionOptions";
import { JsonSchemaProducer } from "../../core/schema/JsonSchemaProducer";
import { Schema } from "../../core/schema/Schema";
import {
TextGenerationToolCallModel,
ToolCallPromptTemplate,
} from "../../tool/generate-tool-call/TextGenerationToolCallModel";
import { TextGenerationToolCallModel } from "../../tool/generate-tool-call/TextGenerationToolCallModel";
import { ToolCallPromptTemplate } from "../../tool/generate-tool-call/ToolCallPromptTemplate";
import { TextGenerationToolCallsModel } from "../../tool/generate-tool-calls/TextGenerationToolCallsModel";
import { ToolCallsPromptTemplate } from "../../tool/generate-tool-calls/ToolCallsPromptTemplate";
import { ObjectFromTextGenerationModel } from "../generate-object/ObjectFromTextGenerationModel";
Expand Down Expand Up @@ -91,7 +89,7 @@ export class PromptTemplateTextGenerationModel<
) {
return new TextGenerationToolCallModel({
model: this,
format: promptTemplate,
template: promptTemplate,
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@ import {
textGenerationModelProperties,
} from "../../model-function/generate-text/TextGenerationModel";
import { TextGenerationPromptTemplate } from "../../model-function/generate-text/TextGenerationPromptTemplate";
import {
TextGenerationToolCallModel,
ToolCallPromptTemplate,
} from "../../tool/generate-tool-call/TextGenerationToolCallModel";
import { TextGenerationToolCallModel } from "../../tool/generate-tool-call/TextGenerationToolCallModel";
import { ToolCallPromptTemplate } from "../../tool/generate-tool-call/ToolCallPromptTemplate";
import { TextGenerationToolCallsModel } from "../../tool/generate-tool-calls/TextGenerationToolCallsModel";
import { ToolCallsPromptTemplate } from "../../tool/generate-tool-calls/ToolCallsPromptTemplate";
import { createJsonStreamResponseHandler } from "../../util/streaming/createJsonStreamResponseHandler";
Expand Down Expand Up @@ -198,7 +196,7 @@ export class OllamaChatModel
) {
return new TextGenerationToolCallModel({
model: this,
format: promptTemplate,
template: promptTemplate,
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ import { TextGenerationPromptTemplate } from "../../model-function/generate-text
import { ChatPrompt } from "../../model-function/generate-text/prompt-template/ChatPrompt";
import { InstructionPrompt } from "../../model-function/generate-text/prompt-template/InstructionPrompt";
import { TextGenerationPromptTemplateProvider } from "../../model-function/generate-text/prompt-template/PromptTemplateProvider";
import {
TextGenerationToolCallModel,
ToolCallPromptTemplate,
} from "../../tool/generate-tool-call/TextGenerationToolCallModel";
import { TextGenerationToolCallModel } from "../../tool/generate-tool-call/TextGenerationToolCallModel";
import { ToolCallPromptTemplate } from "../../tool/generate-tool-call/ToolCallPromptTemplate";
import { TextGenerationToolCallsModel } from "../../tool/generate-tool-calls/TextGenerationToolCallsModel";
import { ToolCallsPromptTemplate } from "../../tool/generate-tool-calls/ToolCallsPromptTemplate";
import { createJsonStreamResponseHandler } from "../../util/streaming/createJsonStreamResponseHandler";
Expand Down Expand Up @@ -262,7 +260,7 @@ export class OllamaCompletionModel<
) {
return new TextGenerationToolCallModel({
model: this,
format: promptTemplate,
template: promptTemplate,
});
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
import { FunctionOptions } from "../../core/FunctionOptions";
import { JsonSchemaProducer } from "../../core/schema/JsonSchemaProducer";
import { Schema } from "../../core/schema/Schema";
import {
TextGenerationModel,
TextGenerationModelSettings,
} from "../../model-function/generate-text/TextGenerationModel";
import { generateText } from "../../model-function/generate-text/generateText";
import { ToolCallParseError } from "./ToolCallParseError";
import { ToolDefinition } from "../ToolDefinition";
import { ToolCallGenerationModel } from "./ToolCallGenerationModel";

export interface ToolCallPromptTemplate<SOURCE_PROMPT, TARGET_PROMPT> {
createPrompt: (
prompt: SOURCE_PROMPT,
tool: ToolDefinition<string, unknown>
) => TARGET_PROMPT;
extractToolCall: (
response: string,
tool: ToolDefinition<string, unknown>
) => { id: string; args: unknown } | null;
}
import { ToolCallParseError } from "./ToolCallParseError";
import { ToolCallPromptTemplate } from "./ToolCallPromptTemplate";

export class TextGenerationToolCallModel<
SOURCE_PROMPT,
Expand All @@ -26,17 +18,20 @@ export class TextGenerationToolCallModel<
> implements ToolCallGenerationModel<SOURCE_PROMPT, MODEL["settings"]>
{
private readonly model: MODEL;
private readonly format: ToolCallPromptTemplate<SOURCE_PROMPT, TARGET_PROMPT>;
private readonly template: ToolCallPromptTemplate<
SOURCE_PROMPT,
TARGET_PROMPT
>;

constructor({
model,
format,
template,
}: {
model: MODEL;
format: ToolCallPromptTemplate<SOURCE_PROMPT, TARGET_PROMPT>;
template: ToolCallPromptTemplate<SOURCE_PROMPT, TARGET_PROMPT>;
}) {
this.model = model;
this.format = format;
this.template = template;
}

get modelInformation() {
Expand All @@ -51,22 +46,33 @@ export class TextGenerationToolCallModel<
return this.model.settingsForEvent;
}

getModelWithJsonOutput(schema: Schema<unknown> & JsonSchemaProducer) {
if (this.template.withJsonOutput != null) {
return this.template.withJsonOutput({
model: this.model,
schema,
}) as MODEL;
}

return this.model;
}

async doGenerateToolCall(
tool: ToolDefinition<string, unknown>,
prompt: SOURCE_PROMPT,
options?: FunctionOptions
) {
const { rawResponse, text, metadata } = await generateText({
model: this.model,
prompt: this.format.createPrompt(prompt, tool),
model: this.getModelWithJsonOutput(tool.parameters),
prompt: this.template.createPrompt(prompt, tool),
fullResponse: true,
...options,
});

try {
return {
rawResponse,
toolCall: this.format.extractToolCall(text, tool),
toolCall: this.template.extractToolCall(text, tool),
usage: metadata?.usage as
| {
promptTokens: number;
Expand All @@ -87,7 +93,7 @@ export class TextGenerationToolCallModel<
withSettings(additionalSettings: Partial<MODEL["settings"]>): this {
return new TextGenerationToolCallModel({
model: this.model.withSettings(additionalSettings),
format: this.format,
template: this.template,
}) as this;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { JsonSchemaProducer } from "../../core/schema/JsonSchemaProducer";
import { Schema } from "../../core/schema/Schema";
import { ToolDefinition } from "../ToolDefinition";

export interface ToolCallPromptTemplate<SOURCE_PROMPT, TARGET_PROMPT> {
createPrompt(
prompt: SOURCE_PROMPT,
tool: ToolDefinition<string, unknown>
): TARGET_PROMPT;

extractToolCall(
response: string,
tool: ToolDefinition<string, unknown>
): { id: string; args: unknown } | null;

withJsonOutput?({
model,
schema,
}: {
model: {
withJsonOutput(
schema: Schema<unknown> & JsonSchemaProducer
): typeof model;
};
schema: Schema<unknown> & JsonSchemaProducer;
}): typeof model;
}
1 change: 1 addition & 0 deletions packages/modelfusion/src/tool/generate-tool-call/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ export * from "./TextGenerationToolCallModel";
export * from "./ToolCallGenerationEvent";
export * from "./ToolCallGenerationModel";
export * from "./ToolCallParseError";
export * from "./ToolCallPromptTemplate";
export * from "./generateToolCall";
export * from "./jsonToolCallPrompt";
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,84 @@ import { nanoid } from "nanoid";
import { parseJSON } from "../../core/schema/parseJSON";
import { InstructionPrompt } from "../../model-function/generate-text/prompt-template/InstructionPrompt";
import { ToolDefinition } from "../../tool/ToolDefinition";
import { ToolCallPromptTemplate } from "./TextGenerationToolCallModel";
import { ToolCallPromptTemplate } from "./ToolCallPromptTemplate";

const DEFAULT_TOOL_PROMPT = (tool: ToolDefinition<string, unknown>) =>
[
`You are calling the function "${tool.name}".`,
tool.description != null
? `Function description: ${tool.description}`
: null,
`Function parameters JSON schema: ${JSON.stringify(
tool.parameters.getJsonSchema()
)}`,
``,
`You MUST answer with a JSON object that matches the JSON schema above.`,
]
.filter(Boolean)
.join("\n");

export const jsonToolCallPrompt = {
text(): ToolCallPromptTemplate<string, InstructionPrompt> {
text({
toolPrompt,
}: {
toolPrompt?: (tool: ToolDefinition<string, unknown>) => string;
} = {}): ToolCallPromptTemplate<string, InstructionPrompt> {
return {
createPrompt(instruction: string, tool: ToolDefinition<string, unknown>) {
createPrompt(prompt: string, tool: ToolDefinition<string, unknown>) {
return {
system: [
`You are calling a function "${tool.name}".`,
tool.description != null
? ` Function description: ${tool.description}`
: null,
` Function parameters JSON schema: ${JSON.stringify(
tool.parameters.getJsonSchema()
)}`,
``,
`You MUST answer with a JSON object matches the above schema for the arguments.`,
]
.filter(Boolean)
.join("\n"),
instruction,
system: createSystemPrompt({ tool, toolPrompt }),
instruction: prompt,
};
},
extractToolCall,
withJsonOutput: ({ model, schema }) => model.withJsonOutput(schema),
};
},

extractToolCall(response) {
return { id: nanoid(), args: parseJSON({ text: response }) };
instruction({
toolPrompt,
}: {
toolPrompt?: (tool: ToolDefinition<string, unknown>) => string;
} = {}): ToolCallPromptTemplate<InstructionPrompt, InstructionPrompt> {
return {
createPrompt(
prompt: InstructionPrompt,
tool: ToolDefinition<string, unknown>
): InstructionPrompt {
return {
system: createSystemPrompt({
originalSystemPrompt: prompt.system,
tool,
toolPrompt,
}),
instruction: prompt.instruction,
};
},
extractToolCall,
withJsonOutput: ({ model, schema }) => model.withJsonOutput(schema),
};
},
};

function createSystemPrompt({
originalSystemPrompt,
toolPrompt = DEFAULT_TOOL_PROMPT,
tool,
}: {
originalSystemPrompt?: string;
toolPrompt?: (tool: ToolDefinition<string, unknown>) => string;
tool: ToolDefinition<string, unknown>;
}) {
return [
originalSystemPrompt,
originalSystemPrompt != null ? "" : null,
toolPrompt(tool),
]
.filter(Boolean)
.join("\n");
}

function extractToolCall(response: string) {
return { id: nanoid(), args: parseJSON({ text: response }) };
}

0 comments on commit a000cc3

Please sign in to comment.