diff --git a/examples/basic/src/model-provider/llamacpp/llamacpp-run-tool-mistral-example.ts b/examples/basic/src/model-provider/llamacpp/llamacpp-run-tool-mistral-example.ts new file mode 100644 index 00000000..2798a17b --- /dev/null +++ b/examples/basic/src/model-provider/llamacpp/llamacpp-run-tool-mistral-example.ts @@ -0,0 +1,28 @@ +import { jsonToolCallPrompt, llamacpp, runTool } from "modelfusion"; +import { calculator } from "../../tool/tools/calculator-tool"; + +async function main() { + const { tool, toolCall, args, ok, result } = await runTool({ + model: llamacpp + .CompletionTextGenerator({ + // run https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF with llama.cpp + promptTemplate: llamacpp.prompt.ChatML, + temperature: 0, + }) + .withInstructionPrompt() + .asToolCallGenerationModel(jsonToolCallPrompt.text()), + + tool: calculator, + prompt: "What's fourteen times twelve?", + + logging: "detailed-object", + }); + + console.log(`Tool call:`, toolCall); + console.log(`Tool:`, tool); + console.log(`Arguments:`, args); + console.log(`Ok:`, ok); + console.log(`Result or Error:`, result); +} + +main().catch(console.error); diff --git a/packages/modelfusion/src/model-function/generate-object/jsonObjectPrompt.ts b/packages/modelfusion/src/model-function/generate-object/jsonObjectPrompt.ts index a5707c40..144421e1 100644 --- a/packages/modelfusion/src/model-function/generate-object/jsonObjectPrompt.ts +++ b/packages/modelfusion/src/model-function/generate-object/jsonObjectPrompt.ts @@ -9,7 +9,7 @@ import { const DEFAULT_SCHEMA_PREFIX = "JSON schema:"; const DEFAULT_SCHEMA_SUFFIX = - "\nYou MUST answer with a JSON object matches the above schema."; + "\nYou MUST answer with a JSON object that matches the JSON schema above."; export const jsonObjectPrompt = { custom( @@ -85,6 +85,7 @@ function createSystemPrompt({ }) { return [ originalSystemPrompt, + originalSystemPrompt != null ? "" : null, schemaPrefix, JSON.stringify(schema.getJsonSchema()), schemaSuffix, diff --git a/packages/modelfusion/src/model-function/generate-text/PromptTemplateTextGenerationModel.ts b/packages/modelfusion/src/model-function/generate-text/PromptTemplateTextGenerationModel.ts index 554df53f..f1c46c78 100644 --- a/packages/modelfusion/src/model-function/generate-text/PromptTemplateTextGenerationModel.ts +++ b/packages/modelfusion/src/model-function/generate-text/PromptTemplateTextGenerationModel.ts @@ -1,10 +1,8 @@ import { FunctionCallOptions } from "../../core/FunctionOptions"; import { JsonSchemaProducer } from "../../core/schema/JsonSchemaProducer"; import { Schema } from "../../core/schema/Schema"; -import { - TextGenerationToolCallModel, - ToolCallPromptTemplate, -} from "../../tool/generate-tool-call/TextGenerationToolCallModel"; +import { TextGenerationToolCallModel } from "../../tool/generate-tool-call/TextGenerationToolCallModel"; +import { ToolCallPromptTemplate } from "../../tool/generate-tool-call/ToolCallPromptTemplate"; import { TextGenerationToolCallsModel } from "../../tool/generate-tool-calls/TextGenerationToolCallsModel"; import { ToolCallsPromptTemplate } from "../../tool/generate-tool-calls/ToolCallsPromptTemplate"; import { ObjectFromTextGenerationModel } from "../generate-object/ObjectFromTextGenerationModel"; @@ -91,7 +89,7 @@ export class PromptTemplateTextGenerationModel< ) { return new TextGenerationToolCallModel({ model: this, - format: promptTemplate, + template: promptTemplate, }); } diff --git a/packages/modelfusion/src/model-provider/ollama/OllamaChatModel.ts b/packages/modelfusion/src/model-provider/ollama/OllamaChatModel.ts index b592d960..8277703e 100644 --- a/packages/modelfusion/src/model-provider/ollama/OllamaChatModel.ts +++ b/packages/modelfusion/src/model-provider/ollama/OllamaChatModel.ts @@ -20,10 +20,8 @@ import { textGenerationModelProperties, } from "../../model-function/generate-text/TextGenerationModel"; import { TextGenerationPromptTemplate } from "../../model-function/generate-text/TextGenerationPromptTemplate"; -import { - TextGenerationToolCallModel, - ToolCallPromptTemplate, -} from "../../tool/generate-tool-call/TextGenerationToolCallModel"; +import { TextGenerationToolCallModel } from "../../tool/generate-tool-call/TextGenerationToolCallModel"; +import { ToolCallPromptTemplate } from "../../tool/generate-tool-call/ToolCallPromptTemplate"; import { TextGenerationToolCallsModel } from "../../tool/generate-tool-calls/TextGenerationToolCallsModel"; import { ToolCallsPromptTemplate } from "../../tool/generate-tool-calls/ToolCallsPromptTemplate"; import { createJsonStreamResponseHandler } from "../../util/streaming/createJsonStreamResponseHandler"; @@ -198,7 +196,7 @@ export class OllamaChatModel ) { return new TextGenerationToolCallModel({ model: this, - format: promptTemplate, + template: promptTemplate, }); } diff --git a/packages/modelfusion/src/model-provider/ollama/OllamaCompletionModel.ts b/packages/modelfusion/src/model-provider/ollama/OllamaCompletionModel.ts index a5b469f0..79dfde94 100644 --- a/packages/modelfusion/src/model-provider/ollama/OllamaCompletionModel.ts +++ b/packages/modelfusion/src/model-provider/ollama/OllamaCompletionModel.ts @@ -23,10 +23,8 @@ import { TextGenerationPromptTemplate } from "../../model-function/generate-text import { ChatPrompt } from "../../model-function/generate-text/prompt-template/ChatPrompt"; import { InstructionPrompt } from "../../model-function/generate-text/prompt-template/InstructionPrompt"; import { TextGenerationPromptTemplateProvider } from "../../model-function/generate-text/prompt-template/PromptTemplateProvider"; -import { - TextGenerationToolCallModel, - ToolCallPromptTemplate, -} from "../../tool/generate-tool-call/TextGenerationToolCallModel"; +import { TextGenerationToolCallModel } from "../../tool/generate-tool-call/TextGenerationToolCallModel"; +import { ToolCallPromptTemplate } from "../../tool/generate-tool-call/ToolCallPromptTemplate"; import { TextGenerationToolCallsModel } from "../../tool/generate-tool-calls/TextGenerationToolCallsModel"; import { ToolCallsPromptTemplate } from "../../tool/generate-tool-calls/ToolCallsPromptTemplate"; import { createJsonStreamResponseHandler } from "../../util/streaming/createJsonStreamResponseHandler"; @@ -262,7 +260,7 @@ export class OllamaCompletionModel< ) { return new TextGenerationToolCallModel({ model: this, - format: promptTemplate, + template: promptTemplate, }); } diff --git a/packages/modelfusion/src/tool/generate-tool-call/TextGenerationToolCallModel.ts b/packages/modelfusion/src/tool/generate-tool-call/TextGenerationToolCallModel.ts index 6a7e9bd1..cede023a 100644 --- a/packages/modelfusion/src/tool/generate-tool-call/TextGenerationToolCallModel.ts +++ b/packages/modelfusion/src/tool/generate-tool-call/TextGenerationToolCallModel.ts @@ -1,23 +1,15 @@ import { FunctionOptions } from "../../core/FunctionOptions"; +import { JsonSchemaProducer } from "../../core/schema/JsonSchemaProducer"; +import { Schema } from "../../core/schema/Schema"; import { TextGenerationModel, TextGenerationModelSettings, } from "../../model-function/generate-text/TextGenerationModel"; import { generateText } from "../../model-function/generate-text/generateText"; -import { ToolCallParseError } from "./ToolCallParseError"; import { ToolDefinition } from "../ToolDefinition"; import { ToolCallGenerationModel } from "./ToolCallGenerationModel"; - -export interface ToolCallPromptTemplate { - createPrompt: ( - prompt: SOURCE_PROMPT, - tool: ToolDefinition - ) => TARGET_PROMPT; - extractToolCall: ( - response: string, - tool: ToolDefinition - ) => { id: string; args: unknown } | null; -} +import { ToolCallParseError } from "./ToolCallParseError"; +import { ToolCallPromptTemplate } from "./ToolCallPromptTemplate"; export class TextGenerationToolCallModel< SOURCE_PROMPT, @@ -26,17 +18,20 @@ export class TextGenerationToolCallModel< > implements ToolCallGenerationModel { private readonly model: MODEL; - private readonly format: ToolCallPromptTemplate; + private readonly template: ToolCallPromptTemplate< + SOURCE_PROMPT, + TARGET_PROMPT + >; constructor({ model, - format, + template, }: { model: MODEL; - format: ToolCallPromptTemplate; + template: ToolCallPromptTemplate; }) { this.model = model; - this.format = format; + this.template = template; } get modelInformation() { @@ -51,14 +46,25 @@ export class TextGenerationToolCallModel< return this.model.settingsForEvent; } + getModelWithJsonOutput(schema: Schema & JsonSchemaProducer) { + if (this.template.withJsonOutput != null) { + return this.template.withJsonOutput({ + model: this.model, + schema, + }) as MODEL; + } + + return this.model; + } + async doGenerateToolCall( tool: ToolDefinition, prompt: SOURCE_PROMPT, options?: FunctionOptions ) { const { rawResponse, text, metadata } = await generateText({ - model: this.model, - prompt: this.format.createPrompt(prompt, tool), + model: this.getModelWithJsonOutput(tool.parameters), + prompt: this.template.createPrompt(prompt, tool), fullResponse: true, ...options, }); @@ -66,7 +72,7 @@ export class TextGenerationToolCallModel< try { return { rawResponse, - toolCall: this.format.extractToolCall(text, tool), + toolCall: this.template.extractToolCall(text, tool), usage: metadata?.usage as | { promptTokens: number; @@ -87,7 +93,7 @@ export class TextGenerationToolCallModel< withSettings(additionalSettings: Partial): this { return new TextGenerationToolCallModel({ model: this.model.withSettings(additionalSettings), - format: this.format, + template: this.template, }) as this; } } diff --git a/packages/modelfusion/src/tool/generate-tool-call/ToolCallPromptTemplate.ts b/packages/modelfusion/src/tool/generate-tool-call/ToolCallPromptTemplate.ts new file mode 100644 index 00000000..8736f233 --- /dev/null +++ b/packages/modelfusion/src/tool/generate-tool-call/ToolCallPromptTemplate.ts @@ -0,0 +1,27 @@ +import { JsonSchemaProducer } from "../../core/schema/JsonSchemaProducer"; +import { Schema } from "../../core/schema/Schema"; +import { ToolDefinition } from "../ToolDefinition"; + +export interface ToolCallPromptTemplate { + createPrompt( + prompt: SOURCE_PROMPT, + tool: ToolDefinition + ): TARGET_PROMPT; + + extractToolCall( + response: string, + tool: ToolDefinition + ): { id: string; args: unknown } | null; + + withJsonOutput?({ + model, + schema, + }: { + model: { + withJsonOutput( + schema: Schema & JsonSchemaProducer + ): typeof model; + }; + schema: Schema & JsonSchemaProducer; + }): typeof model; +} diff --git a/packages/modelfusion/src/tool/generate-tool-call/index.ts b/packages/modelfusion/src/tool/generate-tool-call/index.ts index 52fc5ab4..977bbea3 100644 --- a/packages/modelfusion/src/tool/generate-tool-call/index.ts +++ b/packages/modelfusion/src/tool/generate-tool-call/index.ts @@ -2,5 +2,6 @@ export * from "./TextGenerationToolCallModel"; export * from "./ToolCallGenerationEvent"; export * from "./ToolCallGenerationModel"; export * from "./ToolCallParseError"; +export * from "./ToolCallPromptTemplate"; export * from "./generateToolCall"; export * from "./jsonToolCallPrompt"; diff --git a/packages/modelfusion/src/tool/generate-tool-call/jsonToolCallPrompt.ts b/packages/modelfusion/src/tool/generate-tool-call/jsonToolCallPrompt.ts index 6a131f68..6638c6e5 100644 --- a/packages/modelfusion/src/tool/generate-tool-call/jsonToolCallPrompt.ts +++ b/packages/modelfusion/src/tool/generate-tool-call/jsonToolCallPrompt.ts @@ -2,33 +2,84 @@ import { nanoid } from "nanoid"; import { parseJSON } from "../../core/schema/parseJSON"; import { InstructionPrompt } from "../../model-function/generate-text/prompt-template/InstructionPrompt"; import { ToolDefinition } from "../../tool/ToolDefinition"; -import { ToolCallPromptTemplate } from "./TextGenerationToolCallModel"; +import { ToolCallPromptTemplate } from "./ToolCallPromptTemplate"; + +const DEFAULT_TOOL_PROMPT = (tool: ToolDefinition) => + [ + `You are calling the function "${tool.name}".`, + tool.description != null + ? `Function description: ${tool.description}` + : null, + `Function parameters JSON schema: ${JSON.stringify( + tool.parameters.getJsonSchema() + )}`, + ``, + `You MUST answer with a JSON object that matches the JSON schema above.`, + ] + .filter(Boolean) + .join("\n"); export const jsonToolCallPrompt = { - text(): ToolCallPromptTemplate { + text({ + toolPrompt, + }: { + toolPrompt?: (tool: ToolDefinition) => string; + } = {}): ToolCallPromptTemplate { return { - createPrompt(instruction: string, tool: ToolDefinition) { + createPrompt(prompt: string, tool: ToolDefinition) { return { - system: [ - `You are calling a function "${tool.name}".`, - tool.description != null - ? ` Function description: ${tool.description}` - : null, - ` Function parameters JSON schema: ${JSON.stringify( - tool.parameters.getJsonSchema() - )}`, - ``, - `You MUST answer with a JSON object matches the above schema for the arguments.`, - ] - .filter(Boolean) - .join("\n"), - instruction, + system: createSystemPrompt({ tool, toolPrompt }), + instruction: prompt, }; }, + extractToolCall, + withJsonOutput: ({ model, schema }) => model.withJsonOutput(schema), + }; + }, - extractToolCall(response) { - return { id: nanoid(), args: parseJSON({ text: response }) }; + instruction({ + toolPrompt, + }: { + toolPrompt?: (tool: ToolDefinition) => string; + } = {}): ToolCallPromptTemplate { + return { + createPrompt( + prompt: InstructionPrompt, + tool: ToolDefinition + ): InstructionPrompt { + return { + system: createSystemPrompt({ + originalSystemPrompt: prompt.system, + tool, + toolPrompt, + }), + instruction: prompt.instruction, + }; }, + extractToolCall, + withJsonOutput: ({ model, schema }) => model.withJsonOutput(schema), }; }, }; + +function createSystemPrompt({ + originalSystemPrompt, + toolPrompt = DEFAULT_TOOL_PROMPT, + tool, +}: { + originalSystemPrompt?: string; + toolPrompt?: (tool: ToolDefinition) => string; + tool: ToolDefinition; +}) { + return [ + originalSystemPrompt, + originalSystemPrompt != null ? "" : null, + toolPrompt(tool), + ] + .filter(Boolean) + .join("\n"); +} + +function extractToolCall(response: string) { + return { id: nanoid(), args: parseJSON({ text: response }) }; +}