From d68cdc862b1f32a841faf9ad9f40745b0292a339 Mon Sep 17 00:00:00 2001 From: jeadie Date: Wed, 16 Oct 2024 10:50:56 +1000 Subject: [PATCH 1/2] add better methods for using tools in and update examples --- mistralrs/examples/lower_level/tools/main.rs | 24 +++++---- mistralrs/examples/tools/main.rs | 11 +--- mistralrs/src/messages.rs | 55 +++++++++++++++++++- 3 files changed, 71 insertions(+), 19 deletions(-) diff --git a/mistralrs/examples/lower_level/tools/main.rs b/mistralrs/examples/lower_level/tools/main.rs index 64197e6d2..4fc8b1b08 100644 --- a/mistralrs/examples/lower_level/tools/main.rs +++ b/mistralrs/examples/lower_level/tools/main.rs @@ -125,21 +125,27 @@ fn main() -> anyhow::Result<()> { let result = get_weather(input); // Add tool call message from assistant so it knows what it called messages.push(IndexMap::from([ - ("role".to_string(), Either::Left("asistant".to_string())), + ("role".to_string(), Either::Left("assistant".to_string())), + ("content".to_string(), Either::Left("".to_string())), ( - "content".to_string(), - Either::Right( - json!({ - "name": called.function.name, - "parameters": called.function.arguments, - }) - .to_string(), - ), + "tool_calls".to_string(), + Either::Right(Vec![IndexMap::from([ + ("id".to_string(), tool_call.id), + ( + "function".to_string(), + json!({ + "name": called.function.name, + "arguments": called.function.arguments, + }) + ), + ("type".to_string(), "function".to_string()), + ])]), ), ])); // Add message from the tool messages.push(IndexMap::from([ ("role".to_string(), Either::Left("tool".to_string())), + ("tool_call_id".to_string(), Either::Left(tool_call.id)), ("content".to_string(), Either::Left(result)), ])); diff --git a/mistralrs/examples/tools/main.rs b/mistralrs/examples/tools/main.rs index 10b102047..76f67166b 100644 --- a/mistralrs/examples/tools/main.rs +++ b/mistralrs/examples/tools/main.rs @@ -66,15 +66,8 @@ async fn main() -> Result<()> { // Add tool call message from assistant so it knows what it called // Then, add message from the tool messages = messages - .add_message( - TextMessageRole::Assistant, - json!({ - "name": called.function.name, - "parameters": called.function.arguments, - }) - .to_string(), - ) - .add_message(TextMessageRole::Tool, result) + .add_message_with_tool_call(TextMessageRole::Assistant, String::new(), vec![called]) + .add_tool_message(result, called.id) .set_tool_choice(ToolChoice::None); let response = model.send_chat_request(messages.clone()).await?; diff --git a/mistralrs/src/messages.rs b/mistralrs/src/messages.rs index c0433d56f..fdeb73543 100644 --- a/mistralrs/src/messages.rs +++ b/mistralrs/src/messages.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, fmt::Display, sync::Arc}; +use std::{collections::HashMap, fmt::Display, ops::Index, sync::Arc}; use super::*; use either::Either; @@ -331,6 +331,10 @@ impl RequestBuilder { } } + /// Add a message to the request. + /// + /// For messages with tool calls, use [`Self::add_message_with_tool_call`]. + /// For messages with tool outputs, use [`Self::add_tool_message`]. pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self { self.messages.push(IndexMap::from([ ("role".to_string(), Either::Left(role.to_string())), @@ -339,6 +343,55 @@ impl RequestBuilder { self } + /// Add a message with the output of a tool call. + pub fn add_tool_message(mut self, tool_content: impl ToString, tool_id: impl ToString) -> Self { + self.messages.push(IndexMap::from([ + ( + "role".to_string(), + Either::Left(TextMessageRole::Tool.to_string()), + ), + ( + "content".to_string(), + Either::Left(tool_content.to_string()), + ), + ( + "tool_call_id".to_string(), + Either::Left(tool_id.to_string()), + ), + ])); + self + } + + pub fn add_message_with_tool_call( + mut self, + role: TextMessageRole, + text: impl ToString, + tool_calls: Vec, + ) -> Self { + let tool_messages = tool_calls + .iter() + .map(|t| { + IndexMap::from([ + ("id".to_string(), Either::Left(t.id)), + ("type".to_string(), Either::Left(t.tp.to_string())), + ( + "function".to_string(), + Either::Right(IndexMap::from([ + ("name".to_string(), Value::String(t.function.name)), + ("arguments".to_string(), Value::String(t.function.arguments)), + ])), + ), + ]) + }) + .collect(); + self.messages.push(IndexMap::from([ + ("role".to_string(), Either::Left(role.to_string())), + ("content".to_string(), Either::Left(text.into())), + ("function".to_string(), Either::Right(tool_messages)), + ])); + self + } + pub fn add_image_message( mut self, role: TextMessageRole, From 23d7df071f0adc78541005521183f61a42520de2 Mon Sep 17 00:00:00 2001 From: jeadie Date: Wed, 16 Oct 2024 11:31:32 +1000 Subject: [PATCH 2/2] fixes --- mistralrs-core/src/tools/response.rs | 8 ++++++++ mistralrs/examples/tools/main.rs | 8 ++++++-- mistralrs/src/messages.rs | 18 +++++++++--------- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/mistralrs-core/src/tools/response.rs b/mistralrs-core/src/tools/response.rs index 4eecd5caf..1e082a900 100644 --- a/mistralrs-core/src/tools/response.rs +++ b/mistralrs-core/src/tools/response.rs @@ -6,6 +6,14 @@ pub enum ToolCallType { Function, } +impl std::fmt::Display for ToolCallType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ToolCallType::Function => write!(f, "function"), + } + } +} + #[cfg_attr(feature = "pyo3_macros", pyo3::pyclass)] #[cfg_attr(feature = "pyo3_macros", pyo3(get_all))] #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] diff --git a/mistralrs/examples/tools/main.rs b/mistralrs/examples/tools/main.rs index 76f67166b..02c002ca0 100644 --- a/mistralrs/examples/tools/main.rs +++ b/mistralrs/examples/tools/main.rs @@ -66,8 +66,12 @@ async fn main() -> Result<()> { // Add tool call message from assistant so it knows what it called // Then, add message from the tool messages = messages - .add_message_with_tool_call(TextMessageRole::Assistant, String::new(), vec![called]) - .add_tool_message(result, called.id) + .add_message_with_tool_call( + TextMessageRole::Assistant, + String::new(), + vec![called.clone()], + ) + .add_tool_message(result, called.id.clone()) .set_tool_choice(ToolChoice::None); let response = model.send_chat_request(messages.clone()).await?; diff --git a/mistralrs/src/messages.rs b/mistralrs/src/messages.rs index fdeb73543..b31b9c4a1 100644 --- a/mistralrs/src/messages.rs +++ b/mistralrs/src/messages.rs @@ -1,10 +1,10 @@ -use std::{collections::HashMap, fmt::Display, ops::Index, sync::Arc}; +use std::{collections::HashMap, fmt::Display, sync::Arc}; use super::*; use either::Either; use image::DynamicImage; use indexmap::IndexMap; -use serde_json::Value; +use serde_json::{json, Value}; /// A type which can be used as a request. pub trait RequestLike { @@ -372,21 +372,21 @@ impl RequestBuilder { .iter() .map(|t| { IndexMap::from([ - ("id".to_string(), Either::Left(t.id)), - ("type".to_string(), Either::Left(t.tp.to_string())), + ("id".to_string(), Value::String(t.id.clone())), + ("type".to_string(), Value::String(t.tp.to_string())), ( "function".to_string(), - Either::Right(IndexMap::from([ - ("name".to_string(), Value::String(t.function.name)), - ("arguments".to_string(), Value::String(t.function.arguments)), - ])), + json!({ + "name": t.function.name, + "arguments": t.function.arguments, + }), ), ]) }) .collect(); self.messages.push(IndexMap::from([ ("role".to_string(), Either::Left(role.to_string())), - ("content".to_string(), Either::Left(text.into())), + ("content".to_string(), Either::Left(text.to_string())), ("function".to_string(), Either::Right(tool_messages)), ])); self